From 80544737bc4338bd0cde305b8bccb7c5209e1bdc Mon Sep 17 00:00:00 2001 From: JoJoBarthold2 Date: Sat, 16 Sep 2023 15:23:56 +0200 Subject: added a method to create a txt for the ctc decoder --- swr2_asr/utils/tokenizer.py | 6 ++++++ 1 file changed, 6 insertions(+) (limited to 'swr2_asr/utils/tokenizer.py') diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py index 1cc7b84..ee89cdb 100644 --- a/swr2_asr/utils/tokenizer.py +++ b/swr2_asr/utils/tokenizer.py @@ -120,3 +120,9 @@ class CharTokenizer: load_tokenizer.char_map[char] = int(index) load_tokenizer.index_map[int(index)] = char return load_tokenizer + + def create_txt(self,path:str): + with open(path, 'w',encoding="utf-8") as file: + for key,value in self.char_map(): + file.write(f"{key}\n") + \ No newline at end of file -- cgit v1.2.3 From ea42dd50f167307d52fb128823904fe46f1118ec Mon Sep 17 00:00:00 2001 From: JoJoBarthold2 Date: Sat, 16 Sep 2023 15:25:19 +0200 Subject: added a todo --- swr2_asr/utils/tokenizer.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'swr2_asr/utils/tokenizer.py') diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py index ee89cdb..9abf57d 100644 --- a/swr2_asr/utils/tokenizer.py +++ b/swr2_asr/utils/tokenizer.py @@ -121,6 +121,8 @@ class CharTokenizer: load_tokenizer.index_map[int(index)] = char return load_tokenizer + + #TO DO check about the weird unknown tokens etc. def create_txt(self,path:str): with open(path, 'w',encoding="utf-8") as file: for key,value in self.char_map(): -- cgit v1.2.3 From 9475900a1085b8277808b0a0b1555c59f7eb6d36 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 18 Sep 2023 12:44:34 +0200 Subject: small fixes --- data/tokenizers/tokens_german.txt | 38 ++++++++++++++ swr2_asr/utils/data.py | 41 +++++++-------- swr2_asr/utils/decoder.py | 103 ++++++++++++++++++++++++++++---------- swr2_asr/utils/tokenizer.py | 14 +++--- 4 files changed, 138 insertions(+), 58 deletions(-) create mode 100644 data/tokenizers/tokens_german.txt (limited to 'swr2_asr/utils/tokenizer.py') diff --git a/data/tokenizers/tokens_german.txt b/data/tokenizers/tokens_german.txt new file mode 100644 index 0000000..57f2c3a --- /dev/null +++ b/data/tokenizers/tokens_german.txt @@ -0,0 +1,38 @@ +_ + + + +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +é +à +ä +ö +ß +ü +- +' diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index 19605f6..74cd572 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -6,7 +6,7 @@ import numpy as np import torch import torchaudio from torch import Tensor, nn -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import Dataset from torchaudio.datasets.utils import _extract_tar from swr2_asr.utils.tokenizer import CharTokenizer @@ -343,26 +343,21 @@ class MLSDataset(Dataset): dataset_lookup_entry["chapterid"], idx, ) # type: ignore - + + def create_lexicon(vocab_counts_path, lexicon_path): - - words_list = [] - with open(vocab_counts_path, 'r') as file: - for line in file: - - words = line.split() - if len(words) >= 1: - - word = words[0] - words_list.append(word) - - with open(lexicon_path, 'w') as file: - for word in words_list: - file.write(f"{word} ") - for char in word: - file.write(char + ' ') - file.write("") - - - - \ No newline at end of file + """Create a lexicon from the vocab_counts.txt file""" + words_list = [] + with open(vocab_counts_path, "r", encoding="utf-8") as file: + for line in file: + words = line.split() + if len(words) >= 1: + word = words[0] + words_list.append(word) + + with open(lexicon_path, "w", encoding="utf-8") as file: + for word in words_list: + file.write(f"{word} ") + for char in word: + file.write(char + " ") + file.write("") diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py index e6b5852..098f6a4 100644 --- a/swr2_asr/utils/decoder.py +++ b/swr2_asr/utils/decoder.py @@ -1,14 +1,18 @@ """Decoder for CTC-based ASR.""" "" -import torch - -from swr2_asr.utils.tokenizer import CharTokenizer -from swr2_asr.utils.data import create_lexicon import os + +import torch from torchaudio.datasets.utils import _extract_tar from torchaudio.models.decoder import ctc_decoder -LEXICON = "lexicon.txt" + +from swr2_asr.utils.data import create_lexicon +from swr2_asr.utils.tokenizer import CharTokenizer + + # TODO: refactor to use torch CTC decoder class -def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True): +def greedy_decoder( + output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True +): # pylint: disable=redefined-outer-name """Greedily decode a sequence.""" blank_label = tokenizer.get_blank_token() arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member @@ -26,27 +30,72 @@ def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, coll return decodes, targets -# TODO: add beam search decoder +def beam_search_decoder( + tokenizer: CharTokenizer, # pylint: disable=redefined-outer-name + tokens_path: str, + lang_model_path: str, + language: str, + hparams: dict, # pylint: disable=redefined-outer-name +): + """Beam search decoder.""" + + n_gram, beam_size, beam_threshold, n_best, lm_weight, word_score = ( + hparams["n_gram"], + hparams["beam_size"], + hparams["beam_threshold"], + hparams["n_best"], + hparams["lm_weight"], + hparams["word_score"], + ) + + if not os.path.isdir(os.path.join(lang_model_path, f"mls_lm_{language}")): + url = f"https://dl.fbaipublicfiles.com/mls/mls_lm_{language}.tar.gz" + torch.hub.download_url_to_file(url, f"data/mls_lm_{language}.tar.gz") + _extract_tar("data/mls_lm_{language}.tar.gz", overwrite=True) -def beam_search_decoder(output, tokenizer:CharTokenizer, tokenizer_txt_path,lang_model_path): - if not os.path.isdir(lang_model_path): - url = f"https://dl.fbaipublicfiles.com/mls/mls_lm_german.tar.gz" - torch.hub.download_url_to_file( - url, "data/mls_lm_german.tar.gz" ) - _extract_tar("data/mls_lm_german.tar.gz", overwrite=True) - if not os.path.isfile(tokenizer_txt_path): - tokenizer.create_txt(tokenizer_txt_path) - - lexicon_path= os.join(lang_model_path, LEXICON) + if not os.path.isfile(tokens_path): + tokenizer.create_tokens_txt(tokens_path) + + lexicon_path = os.path.join(lang_model_path, f"mls_lm_{language}", "lexicon.txt") if not os.path.isfile(lexicon_path): - occurences_path = os.join(lang_model_path,"vocab_counts.txt") + occurences_path = os.path.join(lang_model_path, f"mls_lm_{language}", "vocab_counts.txt") create_lexicon(occurences_path, lexicon_path) - lm_path = os.join(lang_model_path,"3-gram_lm.apa") - decoder = ctc_decoder(lexicon = lexicon_path, - tokenizer = tokenizer_txt_path, - lm =lm_path, - blank_token = '_', - nbest =1, - sil_token= '', - unk_word = '') - return decoder \ No newline at end of file + + lm_path = os.path.join(lang_model_path, f"mls_lm_{language}", f"{n_gram}-gram_lm.arpa") + + decoder = ctc_decoder( + lexicon=lexicon_path, + tokens=tokens_path, + lm=lm_path, + blank_token="_", + sil_token="", + unk_word="", + nbest=n_best, + beam_size=beam_size, + beam_threshold=beam_threshold, + lm_weight=lm_weight, + word_score=word_score, + ) + return decoder + + +if __name__ == "__main__": + tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") + tokenizer.create_tokens_txt("data/tokenizers/tokens_german.txt") + + hparams = { + "n_gram": 3, + "beam_size": 100, + "beam_threshold": 100, + "n_best": 1, + "lm_weight": 0.5, + "word_score": 1.0, + } + + beam_search_decoder( + tokenizer, + "data/tokenizers/tokens_german.txt", + "data", + "german", + hparams, + ) diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py index 9abf57d..b1de83e 100644 --- a/swr2_asr/utils/tokenizer.py +++ b/swr2_asr/utils/tokenizer.py @@ -120,11 +120,9 @@ class CharTokenizer: load_tokenizer.char_map[char] = int(index) load_tokenizer.index_map[int(index)] = char return load_tokenizer - - - #TO DO check about the weird unknown tokens etc. - def create_txt(self,path:str): - with open(path, 'w',encoding="utf-8") as file: - for key,value in self.char_map(): - file.write(f"{key}\n") - \ No newline at end of file + + def create_tokens_txt(self, path: str): + """Create a txt file with all the characters""" + with open(path, "w", encoding="utf-8") as file: + for char, _ in self.char_map.items(): + file.write(f"{char}\n") -- cgit v1.2.3 From c09ff76ba6f4c5dd5de64a401efcd27449150aec Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 18 Sep 2023 15:05:21 +0200 Subject: added support for lm decoder during training --- config.philipp.yaml | 10 +++++----- pyproject.toml | 2 +- swr2_asr/train.py | 32 ++++++++++++++++++-------------- swr2_asr/utils/decoder.py | 2 +- swr2_asr/utils/tokenizer.py | 8 ++++++++ 5 files changed, 33 insertions(+), 21 deletions(-) (limited to 'swr2_asr/utils/tokenizer.py') diff --git a/config.philipp.yaml b/config.philipp.yaml index 38a68f8..608720f 100644 --- a/config.philipp.yaml +++ b/config.philipp.yaml @@ -3,7 +3,7 @@ dataset: dataset_root_path: "/Volumes/pherkel 2/SWR2-ASR" # files will be downloaded into this dir language_name: "mls_german_opus" limited_supervision: True # set to True if you want to use limited supervision - dataset_percentage: 0.15 # percentage of dataset to use (1.0 = 100%) + dataset_percentage: 0.01 # percentage of dataset to use (1.0 = 100%) shuffle: True model: @@ -33,13 +33,13 @@ decoder: training: learning_rate: 0.0005 batch_size: 8 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) - epochs: 3 - eval_every_n: 3 # evaluate every n epochs + epochs: 100 + eval_every_n: 1 # evaluate every n epochs num_workers: 8 # number of workers for dataloader checkpoints: # use "~" to disable saving/loading - model_load_path: "YOUR/PATH" # path to load model from - model_save_path: "YOUR/PATH" # path to save model to + model_load_path: "data/epoch67" # path to load model from + model_save_path: ~ # path to save model to inference: model_load_path: "data/epoch67" # path to load model from \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index f6d19dd..2f26e5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ target-version = "py310" line-length = 100 [tool.poetry.scripts] -train = "swr2_asr.train:run_cli" +train = "swr2_asr.train:main" train-bpe-tokenizer = "swr2_asr.tokenizer:train_bpe_tokenizer" train-char-tokenizer = "swr2_asr.tokenizer:train_char_tokenizer" diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 9c7ede9..1e57ba0 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -12,10 +12,9 @@ from tqdm.autonotebook import tqdm from swr2_asr.model_deep_speech import SpeechRecognitionModel from swr2_asr.utils.data import DataProcessing, MLSDataset, Split -from swr2_asr.utils.decoder import greedy_decoder -from swr2_asr.utils.tokenizer import CharTokenizer - +from swr2_asr.utils.decoder import decoder_factory from swr2_asr.utils.loss_scores import cer, wer +from swr2_asr.utils.tokenizer import CharTokenizer class IterMeter: @@ -123,9 +122,6 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: # get values from test_args: model, device, test_loader, criterion, tokenizer, decoder = test_args.values() - if decoder == "greedy": - decoder = greedy_decoder - model.eval() test_loss = 0 test_cer, test_wer = [], [] @@ -141,12 +137,13 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: loss = criterion(output, labels, input_lengths, label_lengths) test_loss += loss.item() / len(test_loader) - decoded_preds, decoded_targets = greedy_decoder( - output.transpose(0, 1), labels, label_lengths, tokenizer - ) + decoded_targets = tokenizer.decode_batch(labels) + decoded_preds = decoder(output.transpose(0, 1)) for j, _ in enumerate(decoded_preds): - test_cer.append(cer(decoded_targets[j], decoded_preds[j])) - test_wer.append(wer(decoded_targets[j], decoded_preds[j])) + if j >= len(decoded_targets): + break + test_cer.append(cer(decoded_targets[j], decoded_preds[j][0].words[0])) + test_wer.append(wer(decoded_targets[j], decoded_preds[j][0].words[0])) avg_cer = sum(test_cer) / len(test_cer) avg_wer = sum(test_wer) / len(test_wer) @@ -187,6 +184,7 @@ def main(config_path: str): dataset_config = config_dict.get("dataset", {}) tokenizer_config = config_dict.get("tokenizer", {}) checkpoints_config = config_dict.get("checkpoints", {}) + decoder_config = config_dict.get("decoder", {}) if not os.path.isdir(dataset_config["dataset_root_path"]): os.makedirs(dataset_config["dataset_root_path"]) @@ -262,12 +260,19 @@ def main(config_path: str): if checkpoints_config["model_load_path"] is not None: checkpoint = torch.load(checkpoints_config["model_load_path"], map_location=device) - model.load_state_dict(checkpoint["model_state_dict"]) + state_dict = { + k[len("module.") :] if k.startswith("module.") else k: v + for k, v in checkpoint["model_state_dict"].items() + } + + model.load_state_dict(state_dict) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) prev_epoch = checkpoint["epoch"] iter_meter = IterMeter() + decoder = decoder_factory(decoder_config.get("type", "greedy"))(tokenizer, decoder_config) + for epoch in range(prev_epoch + 1, training_config["epochs"] + 1): train_args: TrainArgs = { "model": model, @@ -283,14 +288,13 @@ def main(config_path: str): train_loss = train(train_args) test_loss, test_cer, test_wer = 0, 0, 0 - test_args: TestArgs = { "model": model, "device": device, "test_loader": valid_loader, "criterion": criterion, "tokenizer": tokenizer, - "decoder": "greedy", + "decoder": decoder, } if training_config["eval_every_n"] != 0 and epoch % training_config["eval_every_n"] == 0: diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py index 2b6d29b..6ffdef2 100644 --- a/swr2_asr/utils/decoder.py +++ b/swr2_asr/utils/decoder.py @@ -1,6 +1,6 @@ """Decoder for CTC-based ASR.""" "" -from dataclasses import dataclass import os +from dataclasses import dataclass import torch from torchaudio.datasets.utils import _extract_tar diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py index b1de83e..4e3fddd 100644 --- a/swr2_asr/utils/tokenizer.py +++ b/swr2_asr/utils/tokenizer.py @@ -29,9 +29,17 @@ class CharTokenizer: """Use a character map and convert integer labels to an text sequence""" string = [] for i in labels: + i = int(i) string.append(self.index_map[i]) return "".join(string).replace("", " ") + def decode_batch(self, labels: list[list[int]]) -> list[str]: + """Use a character map and convert integer labels to an text sequence""" + string = [] + for label in labels: + string.append(self.decode(label)) + return string + def get_vocab_size(self) -> int: """Get the number of unique characters in the dataset""" return len(self.char_map) -- cgit v1.2.3