diff options
-rw-r--r-- | data/tokenizers/tokens_german.txt | 38 | ||||
-rw-r--r-- | swr2_asr/utils/data.py | 41 | ||||
-rw-r--r-- | swr2_asr/utils/decoder.py | 103 | ||||
-rw-r--r-- | swr2_asr/utils/tokenizer.py | 14 |
4 files changed, 138 insertions, 58 deletions
diff --git a/data/tokenizers/tokens_german.txt b/data/tokenizers/tokens_german.txt new file mode 100644 index 0000000..57f2c3a --- /dev/null +++ b/data/tokenizers/tokens_german.txt @@ -0,0 +1,38 @@ +_ +<BLANK> +<UNK> +<SPACE> +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +é +à +ä +ö +ß +ü +- +' diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index 19605f6..74cd572 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -6,7 +6,7 @@ import numpy as np import torch import torchaudio from torch import Tensor, nn -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import Dataset from torchaudio.datasets.utils import _extract_tar from swr2_asr.utils.tokenizer import CharTokenizer @@ -343,26 +343,21 @@ class MLSDataset(Dataset): dataset_lookup_entry["chapterid"], idx, ) # type: ignore - + + def create_lexicon(vocab_counts_path, lexicon_path): - - words_list = [] - with open(vocab_counts_path, 'r') as file: - for line in file: - - words = line.split() - if len(words) >= 1: - - word = words[0] - words_list.append(word) - - with open(lexicon_path, 'w') as file: - for word in words_list: - file.write(f"{word} ") - for char in word: - file.write(char + ' ') - file.write("<SPACE>") - - - -
\ No newline at end of file + """Create a lexicon from the vocab_counts.txt file""" + words_list = [] + with open(vocab_counts_path, "r", encoding="utf-8") as file: + for line in file: + words = line.split() + if len(words) >= 1: + word = words[0] + words_list.append(word) + + with open(lexicon_path, "w", encoding="utf-8") as file: + for word in words_list: + file.write(f"{word} ") + for char in word: + file.write(char + " ") + file.write("<SPACE>") diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py index e6b5852..098f6a4 100644 --- a/swr2_asr/utils/decoder.py +++ b/swr2_asr/utils/decoder.py @@ -1,14 +1,18 @@ """Decoder for CTC-based ASR.""" "" -import torch - -from swr2_asr.utils.tokenizer import CharTokenizer -from swr2_asr.utils.data import create_lexicon import os + +import torch from torchaudio.datasets.utils import _extract_tar from torchaudio.models.decoder import ctc_decoder -LEXICON = "lexicon.txt" + +from swr2_asr.utils.data import create_lexicon +from swr2_asr.utils.tokenizer import CharTokenizer + + # TODO: refactor to use torch CTC decoder class -def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True): +def greedy_decoder( + output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True +): # pylint: disable=redefined-outer-name """Greedily decode a sequence.""" blank_label = tokenizer.get_blank_token() arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member @@ -26,27 +30,72 @@ def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, coll return decodes, targets -# TODO: add beam search decoder +def beam_search_decoder( + tokenizer: CharTokenizer, # pylint: disable=redefined-outer-name + tokens_path: str, + lang_model_path: str, + language: str, + hparams: dict, # pylint: disable=redefined-outer-name +): + """Beam search decoder.""" + + n_gram, beam_size, beam_threshold, n_best, lm_weight, word_score = ( + hparams["n_gram"], + hparams["beam_size"], + hparams["beam_threshold"], + hparams["n_best"], + hparams["lm_weight"], + hparams["word_score"], + ) + + if not os.path.isdir(os.path.join(lang_model_path, f"mls_lm_{language}")): + url = f"https://dl.fbaipublicfiles.com/mls/mls_lm_{language}.tar.gz" + torch.hub.download_url_to_file(url, f"data/mls_lm_{language}.tar.gz") + _extract_tar("data/mls_lm_{language}.tar.gz", overwrite=True) -def beam_search_decoder(output, tokenizer:CharTokenizer, tokenizer_txt_path,lang_model_path): - if not os.path.isdir(lang_model_path): - url = f"https://dl.fbaipublicfiles.com/mls/mls_lm_german.tar.gz" - torch.hub.download_url_to_file( - url, "data/mls_lm_german.tar.gz" ) - _extract_tar("data/mls_lm_german.tar.gz", overwrite=True) - if not os.path.isfile(tokenizer_txt_path): - tokenizer.create_txt(tokenizer_txt_path) - - lexicon_path= os.join(lang_model_path, LEXICON) + if not os.path.isfile(tokens_path): + tokenizer.create_tokens_txt(tokens_path) + + lexicon_path = os.path.join(lang_model_path, f"mls_lm_{language}", "lexicon.txt") if not os.path.isfile(lexicon_path): - occurences_path = os.join(lang_model_path,"vocab_counts.txt") + occurences_path = os.path.join(lang_model_path, f"mls_lm_{language}", "vocab_counts.txt") create_lexicon(occurences_path, lexicon_path) - lm_path = os.join(lang_model_path,"3-gram_lm.apa") - decoder = ctc_decoder(lexicon = lexicon_path, - tokenizer = tokenizer_txt_path, - lm =lm_path, - blank_token = '_', - nbest =1, - sil_token= '<SPACE>', - unk_word = '<UNK>') - return decoder
\ No newline at end of file + + lm_path = os.path.join(lang_model_path, f"mls_lm_{language}", f"{n_gram}-gram_lm.arpa") + + decoder = ctc_decoder( + lexicon=lexicon_path, + tokens=tokens_path, + lm=lm_path, + blank_token="_", + sil_token="<SPACE>", + unk_word="<UNK>", + nbest=n_best, + beam_size=beam_size, + beam_threshold=beam_threshold, + lm_weight=lm_weight, + word_score=word_score, + ) + return decoder + + +if __name__ == "__main__": + tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") + tokenizer.create_tokens_txt("data/tokenizers/tokens_german.txt") + + hparams = { + "n_gram": 3, + "beam_size": 100, + "beam_threshold": 100, + "n_best": 1, + "lm_weight": 0.5, + "word_score": 1.0, + } + + beam_search_decoder( + tokenizer, + "data/tokenizers/tokens_german.txt", + "data", + "german", + hparams, + ) diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py index 9abf57d..b1de83e 100644 --- a/swr2_asr/utils/tokenizer.py +++ b/swr2_asr/utils/tokenizer.py @@ -120,11 +120,9 @@ class CharTokenizer: load_tokenizer.char_map[char] = int(index) load_tokenizer.index_map[int(index)] = char return load_tokenizer - - - #TO DO check about the weird unknown tokens etc. - def create_txt(self,path:str): - with open(path, 'w',encoding="utf-8") as file: - for key,value in self.char_map(): - file.write(f"{key}\n") -
\ No newline at end of file + + def create_tokens_txt(self, path: str): + """Create a txt file with all the characters""" + with open(path, "w", encoding="utf-8") as file: + for char, _ in self.char_map.items(): + file.write(f"{char}\n") |