diff options
author | JoJoBarthold2 | 2023-09-18 12:19:31 +0200 |
---|---|---|
committer | JoJoBarthold2 | 2023-09-18 12:19:31 +0200 |
commit | 0f14789f1c33d55dc270bcd154201cce2c4d516e (patch) | |
tree | 7d3e80fc80233eba2005399dc85afaa52452a854 /swr2_asr/utils | |
parent | ec8bfe9df205608282e5297635363fc8fc8fe55b (diff) |
reset commit history
Diffstat (limited to 'swr2_asr/utils')
-rw-r--r-- | swr2_asr/utils/data.py | 6 | ||||
-rw-r--r-- | swr2_asr/utils/decoder.py | 29 |
2 files changed, 31 insertions, 4 deletions
diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index f484bdd..19605f6 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -344,7 +344,7 @@ class MLSDataset(Dataset): idx, ) # type: ignore - def create_lexicon(vocab_counts_path, lexicon_path): +def create_lexicon(vocab_counts_path, lexicon_path): words_list = [] with open(vocab_counts_path, 'r') as file: @@ -361,6 +361,8 @@ class MLSDataset(Dataset): file.write(f"{word} ") for char in word: file.write(char + ' ') - file.write("|") + file.write("<SPACE>") + +
\ No newline at end of file diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py index ef8de49..e6b5852 100644 --- a/swr2_asr/utils/decoder.py +++ b/swr2_asr/utils/decoder.py @@ -2,8 +2,11 @@ import torch from swr2_asr.utils.tokenizer import CharTokenizer - - +from swr2_asr.utils.data import create_lexicon +import os +from torchaudio.datasets.utils import _extract_tar +from torchaudio.models.decoder import ctc_decoder +LEXICON = "lexicon.txt" # TODO: refactor to use torch CTC decoder class def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True): """Greedily decode a sequence.""" @@ -25,3 +28,25 @@ def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, coll # TODO: add beam search decoder +def beam_search_decoder(output, tokenizer:CharTokenizer, tokenizer_txt_path,lang_model_path): + if not os.path.isdir(lang_model_path): + url = f"https://dl.fbaipublicfiles.com/mls/mls_lm_german.tar.gz" + torch.hub.download_url_to_file( + url, "data/mls_lm_german.tar.gz" ) + _extract_tar("data/mls_lm_german.tar.gz", overwrite=True) + if not os.path.isfile(tokenizer_txt_path): + tokenizer.create_txt(tokenizer_txt_path) + + lexicon_path= os.join(lang_model_path, LEXICON) + if not os.path.isfile(lexicon_path): + occurences_path = os.join(lang_model_path,"vocab_counts.txt") + create_lexicon(occurences_path, lexicon_path) + lm_path = os.join(lang_model_path,"3-gram_lm.apa") + decoder = ctc_decoder(lexicon = lexicon_path, + tokenizer = tokenizer_txt_path, + lm =lm_path, + blank_token = '_', + nbest =1, + sil_token= '<SPACE>', + unk_word = '<UNK>') + return decoder
\ No newline at end of file |