diff options
-rw-r--r-- | config.philipp.yaml | 55 | ||||
-rw-r--r-- | config.yaml | 43 | ||||
-rw-r--r-- | swr2_asr/inference.py | 35 | ||||
-rw-r--r-- | swr2_asr/utils/decoder.py | 121 |
4 files changed, 161 insertions, 93 deletions
diff --git a/config.philipp.yaml b/config.philipp.yaml index f72ce2e..38a68f8 100644 --- a/config.philipp.yaml +++ b/config.philipp.yaml @@ -1,34 +1,45 @@ +dataset: + download: True + dataset_root_path: "/Volumes/pherkel 2/SWR2-ASR" # files will be downloaded into this dir + language_name: "mls_german_opus" + limited_supervision: True # set to True if you want to use limited supervision + dataset_percentage: 0.15 # percentage of dataset to use (1.0 = 100%) + shuffle: True + model: n_cnn_layers: 3 n_rnn_layers: 5 rnn_dim: 512 n_feats: 128 # number of mel features stride: 2 - dropout: 0.2 # recommended to be around 0.4-0.6 for smaller datasets, 0.1 for really large datasets - -training: - learning_rate: 0.0005 - batch_size: 32 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) - epochs: 150 - eval_every_n: 5 # evaluate every n epochs - num_workers: 4 # number of workers for dataloader - device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically - -dataset: - download: true - dataset_root_path: "data" # files will be downloaded into this dir - language_name: "mls_german_opus" - limited_supervision: false # set to True if you want to use limited supervision - dataset_percentage: 1 # percentage of dataset to use (1.0 = 100%) - shuffle: true + dropout: 0.6 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets tokenizer: tokenizer_path: "data/tokenizers/char_tokenizer_german.json" -checkpoints: - model_load_path: "data/runs/epoch31" # path to load model from - model_save_path: "data/runs/epoch" # path to save model to +decoder: + type: "lm" # greedy, or lm (beam search) + + lm: # config for lm decoder + language_model_path: "data" # path where model and supplementary files are stored + language: "german" + n_gram: 3 # n-gram size of the language model, 3 or 5 + beam_size: 50 + beam_threshold: 50 + n_best: 1 + lm_weight: 2 + word_score: 0 + +training: + learning_rate: 0.0005 + batch_size: 8 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) + epochs: 3 + eval_every_n: 3 # evaluate every n epochs + num_workers: 8 # number of workers for dataloader + +checkpoints: # use "~" to disable saving/loading + model_load_path: "YOUR/PATH" # path to load model from + model_save_path: "YOUR/PATH" # path to save model to inference: - model_load_path: "data/runs/epoch30" # path to load model from - device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically
\ No newline at end of file + model_load_path: "data/epoch67" # path to load model from
\ No newline at end of file diff --git a/config.yaml b/config.yaml index 41b473c..d248d43 100644 --- a/config.yaml +++ b/config.yaml @@ -1,3 +1,11 @@ +dataset: + download: True + dataset_root_path: "YOUR/PATH" # files will be downloaded into this dir + language_name: "mls_german_opus" + limited_supervision: False # set to True if you want to use limited supervision + dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%) + shuffle: True + model: n_cnn_layers: 3 n_rnn_layers: 5 @@ -6,32 +14,33 @@ model: stride: 2 dropout: 0.3 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets +tokenizer: + tokenizer_path: "data/tokenizers/char_tokenizer_german.json" + +decoder: + type: "greedy" # greedy, or lm (beam search) + + lm: # config for lm decoder + language_model_path: "data" # path where model and supplementary files are stored + language: "german" + n_gram: 3 # n-gram size of the language model, 3 or 5 + beam_size: 50 + beam_threshold: 50 + n_best: 1 + lm_weight: 2, + word_score: 0, + training: - learning_rate: 5e-4 + learning_rate: 0.0005 batch_size: 8 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) epochs: 3 eval_every_n: 3 # evaluate every n epochs num_workers: 8 # number of workers for dataloader -dataset: - download: True - dataset_root_path: "YOUR/PATH" # files will be downloaded into this dir - language_name: "mls_german_opus" - limited_supervision: False # set to True if you want to use limited supervision - dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%) - shuffle: True - -tokenizer: - tokenizer_path: "data/tokenizers/char_tokenizer_german.yaml" - -checkpoints: +checkpoints: # use "~" to disable saving/loading model_load_path: "YOUR/PATH" # path to load model from model_save_path: "YOUR/PATH" # path to save model to inference: model_load_path: "YOUR/PATH" # path to load model from - beam_width: 10 # beam width for beam search - device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically -lang_model: - path: "data/mls_lm_german" #path where model and supplementary files are stored diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py index 3c58af0..64a6eeb 100644 --- a/swr2_asr/inference.py +++ b/swr2_asr/inference.py @@ -6,25 +6,10 @@ import torchaudio import yaml from swr2_asr.model_deep_speech import SpeechRecognitionModel +from swr2_asr.utils.decoder import decoder_factory from swr2_asr.utils.tokenizer import CharTokenizer -def greedy_decoder(output, tokenizer: CharTokenizer, collapse_repeated=True): - """Greedily decode a sequence.""" - arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member - blank_label = tokenizer.get_blank_token() - decodes = [] - for args in arg_maxes: - decode = [] - for j, index in enumerate(args): - if index != blank_label: - if collapse_repeated and j != 0 and index == args[j - 1]: - continue - decode.append(index.item()) - decodes.append(tokenizer.decode(decode)) - return decodes - - @click.command() @click.option( "--config_path", @@ -46,11 +31,16 @@ def main(config_path: str, file_path: str) -> None: model_config = config_dict.get("model", {}) tokenizer_config = config_dict.get("tokenizer", {}) inference_config = config_dict.get("inference", {}) + decoder_config = config_dict.get("decoder", {}) - if inference_config["device"] == "cpu": + if inference_config.get("device", "") == "cpu": device = "cpu" - elif inference_config["device"] == "cuda": + elif inference_config.get("device", "") == "cuda": device = "cuda" if torch.cuda.is_available() else "cpu" + elif inference_config.get("device", "") == "mps": + device = "mps" + else: + device = "cpu" device = torch.device(device) # pylint: disable=no-member tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"]) @@ -90,11 +80,16 @@ def main(config_path: str, file_path: str) -> None: spec = spec.unsqueeze(0) spec = spec.transpose(1, 2) spec = spec.unsqueeze(0) + spec = spec.to(device) output = model(spec) # pylint: disable=not-callable output = F.log_softmax(output, dim=2) # (batch, time, n_class) - decoded_preds = greedy_decoder(output, tokenizer) - print(decoded_preds) + decoder = decoder_factory(decoder_config["type"])(tokenizer, decoder_config) + + preds = decoder(output) + preds = " ".join(preds[0][0].words).strip() + + print(preds) if __name__ == "__main__": diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py index 098f6a4..2b6d29b 100644 --- a/swr2_asr/utils/decoder.py +++ b/swr2_asr/utils/decoder.py @@ -1,4 +1,5 @@ """Decoder for CTC-based ASR.""" "" +from dataclasses import dataclass import os import torch @@ -9,37 +10,39 @@ from swr2_asr.utils.data import create_lexicon from swr2_asr.utils.tokenizer import CharTokenizer -# TODO: refactor to use torch CTC decoder class -def greedy_decoder( - output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True -): # pylint: disable=redefined-outer-name - """Greedily decode a sequence.""" - blank_label = tokenizer.get_blank_token() - arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member - decodes = [] - targets = [] - for i, args in enumerate(arg_maxes): - decode = [] - targets.append(tokenizer.decode(labels[i][: label_lengths[i]].tolist())) - for j, index in enumerate(args): - if index != blank_label: - if collapse_repeated and j != 0 and index == args[j - 1]: - continue - decode.append(index.item()) - decodes.append(tokenizer.decode(decode)) - return decodes, targets - - -def beam_search_decoder( +@dataclass +class DecoderOutput: + """Decoder output.""" + + words: list[str] + + +def decoder_factory(decoder_type: str = "greedy") -> callable: + """Decoder factory.""" + if decoder_type == "greedy": + return get_greedy_decoder + if decoder_type == "lm": + return get_beam_search_decoder + raise NotImplementedError + + +def get_greedy_decoder( + tokenizer: CharTokenizer, # pylint: disable=redefined-outer-name + *_, +): + """Greedy decoder.""" + return GreedyDecoder(tokenizer) + + +def get_beam_search_decoder( tokenizer: CharTokenizer, # pylint: disable=redefined-outer-name - tokens_path: str, - lang_model_path: str, - language: str, hparams: dict, # pylint: disable=redefined-outer-name ): """Beam search decoder.""" - - n_gram, beam_size, beam_threshold, n_best, lm_weight, word_score = ( + hparams = hparams.get("lm", {}) + language, lang_model_path, n_gram, beam_size, beam_threshold, n_best, lm_weight, word_score = ( + hparams["language"], + hparams["language_model_path"], hparams["n_gram"], hparams["beam_size"], hparams["beam_threshold"], @@ -53,6 +56,7 @@ def beam_search_decoder( torch.hub.download_url_to_file(url, f"data/mls_lm_{language}.tar.gz") _extract_tar("data/mls_lm_{language}.tar.gz", overwrite=True) + tokens_path = os.path.join(lang_model_path, f"mls_lm_{language}", "tokens.txt") if not os.path.isfile(tokens_path): tokenizer.create_tokens_txt(tokens_path) @@ -79,11 +83,66 @@ def beam_search_decoder( return decoder +class GreedyDecoder: + """Greedy decoder.""" + + def __init__(self, tokenizer: CharTokenizer): # pylint: disable=redefined-outer-name + self.tokenizer = tokenizer + + def __call__( + self, output, greedy_type: str = "inference", labels=None, label_lengths=None + ): # pylint: disable=redefined-outer-name + """Greedily decode a sequence.""" + if greedy_type == "train": + res = self.train(output, labels, label_lengths) + if greedy_type == "inference": + res = self.inference(output) + + res = [[DecoderOutput(words=res)]] + return res + + def train(self, output, labels, label_lengths): + """Greedily decode a sequence with known labels.""" + blank_label = tokenizer.get_blank_token() + arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member + decodes = [] + targets = [] + for i, args in enumerate(arg_maxes): + decode = [] + targets.append(self.tokenizer.decode(labels[i][: label_lengths[i]].tolist())) + for j, index in enumerate(args): + if index != blank_label: + if j != 0 and index == args[j - 1]: + continue + decode.append(index.item()) + decodes.append(self.tokenizer.decode(decode)) + return decodes, targets + + def inference(self, output): + """Greedily decode a sequence.""" + collapse_repeated = True + arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member + blank_label = self.tokenizer.get_blank_token() + decodes = [] + for args in arg_maxes: + decode = [] + for j, index in enumerate(args): + if index != blank_label: + if collapse_repeated and j != 0 and index == args[j - 1]: + continue + decode.append(index.item()) + decodes.append(self.tokenizer.decode(decode)) + + return decodes + + if __name__ == "__main__": tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") tokenizer.create_tokens_txt("data/tokenizers/tokens_german.txt") hparams = { + "language": "german", + "lang_model_path": "data", "n_gram": 3, "beam_size": 100, "beam_threshold": 100, @@ -92,10 +151,4 @@ if __name__ == "__main__": "word_score": 1.0, } - beam_search_decoder( - tokenizer, - "data/tokenizers/tokens_german.txt", - "data", - "german", - hparams, - ) + get_beam_search_decoder(tokenizer, hparams) |