diff options
Diffstat (limited to 'swr2_asr/utils/decoder.py')
-rw-r--r-- | swr2_asr/utils/decoder.py | 103 |
1 files changed, 76 insertions, 27 deletions
diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py index e6b5852..098f6a4 100644 --- a/swr2_asr/utils/decoder.py +++ b/swr2_asr/utils/decoder.py @@ -1,14 +1,18 @@ """Decoder for CTC-based ASR.""" "" -import torch - -from swr2_asr.utils.tokenizer import CharTokenizer -from swr2_asr.utils.data import create_lexicon import os + +import torch from torchaudio.datasets.utils import _extract_tar from torchaudio.models.decoder import ctc_decoder -LEXICON = "lexicon.txt" + +from swr2_asr.utils.data import create_lexicon +from swr2_asr.utils.tokenizer import CharTokenizer + + # TODO: refactor to use torch CTC decoder class -def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True): +def greedy_decoder( + output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True +): # pylint: disable=redefined-outer-name """Greedily decode a sequence.""" blank_label = tokenizer.get_blank_token() arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member @@ -26,27 +30,72 @@ def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, coll return decodes, targets -# TODO: add beam search decoder +def beam_search_decoder( + tokenizer: CharTokenizer, # pylint: disable=redefined-outer-name + tokens_path: str, + lang_model_path: str, + language: str, + hparams: dict, # pylint: disable=redefined-outer-name +): + """Beam search decoder.""" + + n_gram, beam_size, beam_threshold, n_best, lm_weight, word_score = ( + hparams["n_gram"], + hparams["beam_size"], + hparams["beam_threshold"], + hparams["n_best"], + hparams["lm_weight"], + hparams["word_score"], + ) + + if not os.path.isdir(os.path.join(lang_model_path, f"mls_lm_{language}")): + url = f"https://dl.fbaipublicfiles.com/mls/mls_lm_{language}.tar.gz" + torch.hub.download_url_to_file(url, f"data/mls_lm_{language}.tar.gz") + _extract_tar("data/mls_lm_{language}.tar.gz", overwrite=True) -def beam_search_decoder(output, tokenizer:CharTokenizer, tokenizer_txt_path,lang_model_path): - if not os.path.isdir(lang_model_path): - url = f"https://dl.fbaipublicfiles.com/mls/mls_lm_german.tar.gz" - torch.hub.download_url_to_file( - url, "data/mls_lm_german.tar.gz" ) - _extract_tar("data/mls_lm_german.tar.gz", overwrite=True) - if not os.path.isfile(tokenizer_txt_path): - tokenizer.create_txt(tokenizer_txt_path) - - lexicon_path= os.join(lang_model_path, LEXICON) + if not os.path.isfile(tokens_path): + tokenizer.create_tokens_txt(tokens_path) + + lexicon_path = os.path.join(lang_model_path, f"mls_lm_{language}", "lexicon.txt") if not os.path.isfile(lexicon_path): - occurences_path = os.join(lang_model_path,"vocab_counts.txt") + occurences_path = os.path.join(lang_model_path, f"mls_lm_{language}", "vocab_counts.txt") create_lexicon(occurences_path, lexicon_path) - lm_path = os.join(lang_model_path,"3-gram_lm.apa") - decoder = ctc_decoder(lexicon = lexicon_path, - tokenizer = tokenizer_txt_path, - lm =lm_path, - blank_token = '_', - nbest =1, - sil_token= '<SPACE>', - unk_word = '<UNK>') - return decoder
\ No newline at end of file + + lm_path = os.path.join(lang_model_path, f"mls_lm_{language}", f"{n_gram}-gram_lm.arpa") + + decoder = ctc_decoder( + lexicon=lexicon_path, + tokens=tokens_path, + lm=lm_path, + blank_token="_", + sil_token="<SPACE>", + unk_word="<UNK>", + nbest=n_best, + beam_size=beam_size, + beam_threshold=beam_threshold, + lm_weight=lm_weight, + word_score=word_score, + ) + return decoder + + +if __name__ == "__main__": + tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") + tokenizer.create_tokens_txt("data/tokenizers/tokens_german.txt") + + hparams = { + "n_gram": 3, + "beam_size": 100, + "beam_threshold": 100, + "n_best": 1, + "lm_weight": 0.5, + "word_score": 1.0, + } + + beam_search_decoder( + tokenizer, + "data/tokenizers/tokens_german.txt", + "data", + "german", + hparams, + ) |