From ec8bfe9df205608282e5297635363fc8fc8fe55b Mon Sep 17 00:00:00 2001 From: JoJoBarthold2 Date: Sat, 16 Sep 2023 15:48:16 +0200 Subject: created a method that prints the lexicon --- swr2_asr/utils/data.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) (limited to 'swr2_asr/utils/data.py') diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index d551c98..f484bdd 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -343,3 +343,24 @@ class MLSDataset(Dataset): dataset_lookup_entry["chapterid"], idx, ) # type: ignore + + def create_lexicon(vocab_counts_path, lexicon_path): + + words_list = [] + with open(vocab_counts_path, 'r') as file: + for line in file: + + words = line.split() + if len(words) >= 1: + + word = words[0] + words_list.append(word) + + with open(lexicon_path, 'w') as file: + for word in words_list: + file.write(f"{word} ") + for char in word: + file.write(char + ' ') + file.write("|") + + \ No newline at end of file -- cgit v1.2.3 From 0f14789f1c33d55dc270bcd154201cce2c4d516e Mon Sep 17 00:00:00 2001 From: JoJoBarthold2 Date: Mon, 18 Sep 2023 12:19:31 +0200 Subject: reset commit history --- config.yaml | 5 ++++- swr2_asr/utils/data.py | 6 ++++-- swr2_asr/utils/decoder.py | 29 +++++++++++++++++++++++++++-- 3 files changed, 35 insertions(+), 5 deletions(-) (limited to 'swr2_asr/utils/data.py') diff --git a/config.yaml b/config.yaml index e5ff43a..41b473c 100644 --- a/config.yaml +++ b/config.yaml @@ -31,4 +31,7 @@ checkpoints: inference: model_load_path: "YOUR/PATH" # path to load model from beam_width: 10 # beam width for beam search - device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically \ No newline at end of file + device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically + +lang_model: + path: "data/mls_lm_german" #path where model and supplementary files are stored diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index f484bdd..19605f6 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -344,7 +344,7 @@ class MLSDataset(Dataset): idx, ) # type: ignore - def create_lexicon(vocab_counts_path, lexicon_path): +def create_lexicon(vocab_counts_path, lexicon_path): words_list = [] with open(vocab_counts_path, 'r') as file: @@ -361,6 +361,8 @@ class MLSDataset(Dataset): file.write(f"{word} ") for char in word: file.write(char + ' ') - file.write("|") + file.write("") + + \ No newline at end of file diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py index ef8de49..e6b5852 100644 --- a/swr2_asr/utils/decoder.py +++ b/swr2_asr/utils/decoder.py @@ -2,8 +2,11 @@ import torch from swr2_asr.utils.tokenizer import CharTokenizer - - +from swr2_asr.utils.data import create_lexicon +import os +from torchaudio.datasets.utils import _extract_tar +from torchaudio.models.decoder import ctc_decoder +LEXICON = "lexicon.txt" # TODO: refactor to use torch CTC decoder class def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True): """Greedily decode a sequence.""" @@ -25,3 +28,25 @@ def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, coll # TODO: add beam search decoder +def beam_search_decoder(output, tokenizer:CharTokenizer, tokenizer_txt_path,lang_model_path): + if not os.path.isdir(lang_model_path): + url = f"https://dl.fbaipublicfiles.com/mls/mls_lm_german.tar.gz" + torch.hub.download_url_to_file( + url, "data/mls_lm_german.tar.gz" ) + _extract_tar("data/mls_lm_german.tar.gz", overwrite=True) + if not os.path.isfile(tokenizer_txt_path): + tokenizer.create_txt(tokenizer_txt_path) + + lexicon_path= os.join(lang_model_path, LEXICON) + if not os.path.isfile(lexicon_path): + occurences_path = os.join(lang_model_path,"vocab_counts.txt") + create_lexicon(occurences_path, lexicon_path) + lm_path = os.join(lang_model_path,"3-gram_lm.apa") + decoder = ctc_decoder(lexicon = lexicon_path, + tokenizer = tokenizer_txt_path, + lm =lm_path, + blank_token = '_', + nbest =1, + sil_token= '', + unk_word = '') + return decoder \ No newline at end of file -- cgit v1.2.3 From 9475900a1085b8277808b0a0b1555c59f7eb6d36 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 18 Sep 2023 12:44:34 +0200 Subject: small fixes --- data/tokenizers/tokens_german.txt | 38 ++++++++++++++ swr2_asr/utils/data.py | 41 +++++++-------- swr2_asr/utils/decoder.py | 103 ++++++++++++++++++++++++++++---------- swr2_asr/utils/tokenizer.py | 14 +++--- 4 files changed, 138 insertions(+), 58 deletions(-) create mode 100644 data/tokenizers/tokens_german.txt (limited to 'swr2_asr/utils/data.py') diff --git a/data/tokenizers/tokens_german.txt b/data/tokenizers/tokens_german.txt new file mode 100644 index 0000000..57f2c3a --- /dev/null +++ b/data/tokenizers/tokens_german.txt @@ -0,0 +1,38 @@ +_ + + + +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +é +à +ä +ö +ß +ü +- +' diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index 19605f6..74cd572 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -6,7 +6,7 @@ import numpy as np import torch import torchaudio from torch import Tensor, nn -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import Dataset from torchaudio.datasets.utils import _extract_tar from swr2_asr.utils.tokenizer import CharTokenizer @@ -343,26 +343,21 @@ class MLSDataset(Dataset): dataset_lookup_entry["chapterid"], idx, ) # type: ignore - + + def create_lexicon(vocab_counts_path, lexicon_path): - - words_list = [] - with open(vocab_counts_path, 'r') as file: - for line in file: - - words = line.split() - if len(words) >= 1: - - word = words[0] - words_list.append(word) - - with open(lexicon_path, 'w') as file: - for word in words_list: - file.write(f"{word} ") - for char in word: - file.write(char + ' ') - file.write("") - - - - \ No newline at end of file + """Create a lexicon from the vocab_counts.txt file""" + words_list = [] + with open(vocab_counts_path, "r", encoding="utf-8") as file: + for line in file: + words = line.split() + if len(words) >= 1: + word = words[0] + words_list.append(word) + + with open(lexicon_path, "w", encoding="utf-8") as file: + for word in words_list: + file.write(f"{word} ") + for char in word: + file.write(char + " ") + file.write("") diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py index e6b5852..098f6a4 100644 --- a/swr2_asr/utils/decoder.py +++ b/swr2_asr/utils/decoder.py @@ -1,14 +1,18 @@ """Decoder for CTC-based ASR.""" "" -import torch - -from swr2_asr.utils.tokenizer import CharTokenizer -from swr2_asr.utils.data import create_lexicon import os + +import torch from torchaudio.datasets.utils import _extract_tar from torchaudio.models.decoder import ctc_decoder -LEXICON = "lexicon.txt" + +from swr2_asr.utils.data import create_lexicon +from swr2_asr.utils.tokenizer import CharTokenizer + + # TODO: refactor to use torch CTC decoder class -def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True): +def greedy_decoder( + output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True +): # pylint: disable=redefined-outer-name """Greedily decode a sequence.""" blank_label = tokenizer.get_blank_token() arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member @@ -26,27 +30,72 @@ def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, coll return decodes, targets -# TODO: add beam search decoder +def beam_search_decoder( + tokenizer: CharTokenizer, # pylint: disable=redefined-outer-name + tokens_path: str, + lang_model_path: str, + language: str, + hparams: dict, # pylint: disable=redefined-outer-name +): + """Beam search decoder.""" + + n_gram, beam_size, beam_threshold, n_best, lm_weight, word_score = ( + hparams["n_gram"], + hparams["beam_size"], + hparams["beam_threshold"], + hparams["n_best"], + hparams["lm_weight"], + hparams["word_score"], + ) + + if not os.path.isdir(os.path.join(lang_model_path, f"mls_lm_{language}")): + url = f"https://dl.fbaipublicfiles.com/mls/mls_lm_{language}.tar.gz" + torch.hub.download_url_to_file(url, f"data/mls_lm_{language}.tar.gz") + _extract_tar("data/mls_lm_{language}.tar.gz", overwrite=True) -def beam_search_decoder(output, tokenizer:CharTokenizer, tokenizer_txt_path,lang_model_path): - if not os.path.isdir(lang_model_path): - url = f"https://dl.fbaipublicfiles.com/mls/mls_lm_german.tar.gz" - torch.hub.download_url_to_file( - url, "data/mls_lm_german.tar.gz" ) - _extract_tar("data/mls_lm_german.tar.gz", overwrite=True) - if not os.path.isfile(tokenizer_txt_path): - tokenizer.create_txt(tokenizer_txt_path) - - lexicon_path= os.join(lang_model_path, LEXICON) + if not os.path.isfile(tokens_path): + tokenizer.create_tokens_txt(tokens_path) + + lexicon_path = os.path.join(lang_model_path, f"mls_lm_{language}", "lexicon.txt") if not os.path.isfile(lexicon_path): - occurences_path = os.join(lang_model_path,"vocab_counts.txt") + occurences_path = os.path.join(lang_model_path, f"mls_lm_{language}", "vocab_counts.txt") create_lexicon(occurences_path, lexicon_path) - lm_path = os.join(lang_model_path,"3-gram_lm.apa") - decoder = ctc_decoder(lexicon = lexicon_path, - tokenizer = tokenizer_txt_path, - lm =lm_path, - blank_token = '_', - nbest =1, - sil_token= '', - unk_word = '') - return decoder \ No newline at end of file + + lm_path = os.path.join(lang_model_path, f"mls_lm_{language}", f"{n_gram}-gram_lm.arpa") + + decoder = ctc_decoder( + lexicon=lexicon_path, + tokens=tokens_path, + lm=lm_path, + blank_token="_", + sil_token="", + unk_word="", + nbest=n_best, + beam_size=beam_size, + beam_threshold=beam_threshold, + lm_weight=lm_weight, + word_score=word_score, + ) + return decoder + + +if __name__ == "__main__": + tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") + tokenizer.create_tokens_txt("data/tokenizers/tokens_german.txt") + + hparams = { + "n_gram": 3, + "beam_size": 100, + "beam_threshold": 100, + "n_best": 1, + "lm_weight": 0.5, + "word_score": 1.0, + } + + beam_search_decoder( + tokenizer, + "data/tokenizers/tokens_german.txt", + "data", + "german", + hparams, + ) diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py index 9abf57d..b1de83e 100644 --- a/swr2_asr/utils/tokenizer.py +++ b/swr2_asr/utils/tokenizer.py @@ -120,11 +120,9 @@ class CharTokenizer: load_tokenizer.char_map[char] = int(index) load_tokenizer.index_map[int(index)] = char return load_tokenizer - - - #TO DO check about the weird unknown tokens etc. - def create_txt(self,path:str): - with open(path, 'w',encoding="utf-8") as file: - for key,value in self.char_map(): - file.write(f"{key}\n") - \ No newline at end of file + + def create_tokens_txt(self, path: str): + """Create a txt file with all the characters""" + with open(path, "w", encoding="utf-8") as file: + for char, _ in self.char_map.items(): + file.write(f"{char}\n") -- cgit v1.2.3