From 80544737bc4338bd0cde305b8bccb7c5209e1bdc Mon Sep 17 00:00:00 2001 From: JoJoBarthold2 Date: Sat, 16 Sep 2023 15:23:56 +0200 Subject: added a method to create a txt for the ctc decoder --- swr2_asr/utils/decoder.py | 1 + swr2_asr/utils/tokenizer.py | 6 ++++++ 2 files changed, 7 insertions(+) (limited to 'swr2_asr/utils') diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py index fcddb79..ef8de49 100644 --- a/swr2_asr/utils/decoder.py +++ b/swr2_asr/utils/decoder.py @@ -24,3 +24,4 @@ def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, coll # TODO: add beam search decoder + diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py index 1cc7b84..ee89cdb 100644 --- a/swr2_asr/utils/tokenizer.py +++ b/swr2_asr/utils/tokenizer.py @@ -120,3 +120,9 @@ class CharTokenizer: load_tokenizer.char_map[char] = int(index) load_tokenizer.index_map[int(index)] = char return load_tokenizer + + def create_txt(self,path:str): + with open(path, 'w',encoding="utf-8") as file: + for key,value in self.char_map(): + file.write(f"{key}\n") + \ No newline at end of file -- cgit v1.2.3 From ea42dd50f167307d52fb128823904fe46f1118ec Mon Sep 17 00:00:00 2001 From: JoJoBarthold2 Date: Sat, 16 Sep 2023 15:25:19 +0200 Subject: added a todo --- swr2_asr/utils/tokenizer.py | 2 ++ 1 file changed, 2 insertions(+) (limited to 'swr2_asr/utils') diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py index ee89cdb..9abf57d 100644 --- a/swr2_asr/utils/tokenizer.py +++ b/swr2_asr/utils/tokenizer.py @@ -121,6 +121,8 @@ class CharTokenizer: load_tokenizer.index_map[int(index)] = char return load_tokenizer + + #TO DO check about the weird unknown tokens etc. def create_txt(self,path:str): with open(path, 'w',encoding="utf-8") as file: for key,value in self.char_map(): -- cgit v1.2.3 From ec8bfe9df205608282e5297635363fc8fc8fe55b Mon Sep 17 00:00:00 2001 From: JoJoBarthold2 Date: Sat, 16 Sep 2023 15:48:16 +0200 Subject: created a method that prints the lexicon --- swr2_asr/utils/data.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) (limited to 'swr2_asr/utils') diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index d551c98..f484bdd 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -343,3 +343,24 @@ class MLSDataset(Dataset): dataset_lookup_entry["chapterid"], idx, ) # type: ignore + + def create_lexicon(vocab_counts_path, lexicon_path): + + words_list = [] + with open(vocab_counts_path, 'r') as file: + for line in file: + + words = line.split() + if len(words) >= 1: + + word = words[0] + words_list.append(word) + + with open(lexicon_path, 'w') as file: + for word in words_list: + file.write(f"{word} ") + for char in word: + file.write(char + ' ') + file.write("|") + + \ No newline at end of file -- cgit v1.2.3 From 0f14789f1c33d55dc270bcd154201cce2c4d516e Mon Sep 17 00:00:00 2001 From: JoJoBarthold2 Date: Mon, 18 Sep 2023 12:19:31 +0200 Subject: reset commit history --- config.yaml | 5 ++++- swr2_asr/utils/data.py | 6 ++++-- swr2_asr/utils/decoder.py | 29 +++++++++++++++++++++++++++-- 3 files changed, 35 insertions(+), 5 deletions(-) (limited to 'swr2_asr/utils') diff --git a/config.yaml b/config.yaml index e5ff43a..41b473c 100644 --- a/config.yaml +++ b/config.yaml @@ -31,4 +31,7 @@ checkpoints: inference: model_load_path: "YOUR/PATH" # path to load model from beam_width: 10 # beam width for beam search - device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically \ No newline at end of file + device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically + +lang_model: + path: "data/mls_lm_german" #path where model and supplementary files are stored diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index f484bdd..19605f6 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -344,7 +344,7 @@ class MLSDataset(Dataset): idx, ) # type: ignore - def create_lexicon(vocab_counts_path, lexicon_path): +def create_lexicon(vocab_counts_path, lexicon_path): words_list = [] with open(vocab_counts_path, 'r') as file: @@ -361,6 +361,8 @@ class MLSDataset(Dataset): file.write(f"{word} ") for char in word: file.write(char + ' ') - file.write("|") + file.write("") + + \ No newline at end of file diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py index ef8de49..e6b5852 100644 --- a/swr2_asr/utils/decoder.py +++ b/swr2_asr/utils/decoder.py @@ -2,8 +2,11 @@ import torch from swr2_asr.utils.tokenizer import CharTokenizer - - +from swr2_asr.utils.data import create_lexicon +import os +from torchaudio.datasets.utils import _extract_tar +from torchaudio.models.decoder import ctc_decoder +LEXICON = "lexicon.txt" # TODO: refactor to use torch CTC decoder class def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True): """Greedily decode a sequence.""" @@ -25,3 +28,25 @@ def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, coll # TODO: add beam search decoder +def beam_search_decoder(output, tokenizer:CharTokenizer, tokenizer_txt_path,lang_model_path): + if not os.path.isdir(lang_model_path): + url = f"https://dl.fbaipublicfiles.com/mls/mls_lm_german.tar.gz" + torch.hub.download_url_to_file( + url, "data/mls_lm_german.tar.gz" ) + _extract_tar("data/mls_lm_german.tar.gz", overwrite=True) + if not os.path.isfile(tokenizer_txt_path): + tokenizer.create_txt(tokenizer_txt_path) + + lexicon_path= os.join(lang_model_path, LEXICON) + if not os.path.isfile(lexicon_path): + occurences_path = os.join(lang_model_path,"vocab_counts.txt") + create_lexicon(occurences_path, lexicon_path) + lm_path = os.join(lang_model_path,"3-gram_lm.apa") + decoder = ctc_decoder(lexicon = lexicon_path, + tokenizer = tokenizer_txt_path, + lm =lm_path, + blank_token = '_', + nbest =1, + sil_token= '', + unk_word = '') + return decoder \ No newline at end of file -- cgit v1.2.3 From 0d38af95f0058875d42dd261a287856ba84d3ce6 Mon Sep 17 00:00:00 2001 From: Marvin Borner Date: Sat, 16 Sep 2023 15:55:41 +0200 Subject: Added better visualization --- swr2_asr/utils/visualization.py | 36 ++++++++++++++++++++++++++---------- 1 file changed, 26 insertions(+), 10 deletions(-) (limited to 'swr2_asr/utils') diff --git a/swr2_asr/utils/visualization.py b/swr2_asr/utils/visualization.py index a55d0d5..b288c5a 100644 --- a/swr2_asr/utils/visualization.py +++ b/swr2_asr/utils/visualization.py @@ -4,19 +4,35 @@ import matplotlib.pyplot as plt import torch -def plot(epochs, path): +def plot(path): """Plots the losses over the epochs""" - losses = [] + train_losses = [] test_losses = [] cers = [] wers = [] - for epoch in range(1, epochs + 1): - current_state = torch.load(path + str(epoch)) - losses.append(current_state["loss"]) - test_losses.append(current_state["test_loss"]) - cers.append(current_state["avg_cer"]) - wers.append(current_state["avg_wer"]) - plt.plot(losses) - plt.plot(test_losses) + epoch = 5 + while True: + try: + current_state = torch.load(path + str(epoch), map_location=torch.device("cpu")) + except FileNotFoundError: + break + train_losses.append((epoch, current_state["train_loss"].item())) + test_losses.append((epoch, current_state["test_loss"])) + cers.append((epoch, current_state["avg_cer"])) + wers.append((epoch, current_state["avg_wer"])) + epoch += 5 + + plt.plot(*zip(*train_losses), label="train_loss") + plt.plot(*zip(*test_losses), label="test_loss") + plt.plot(*zip(*cers), label="cer") + plt.plot(*zip(*wers), label="wer") + plt.xlabel("epoch") + plt.ylabel("score") + plt.title("Model performance for 5n epochs") + plt.legend() plt.savefig("losses.svg") + + +if __name__ == "__main__": + plot("data/runs/epoch") -- cgit v1.2.3 From 9475900a1085b8277808b0a0b1555c59f7eb6d36 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 18 Sep 2023 12:44:34 +0200 Subject: small fixes --- data/tokenizers/tokens_german.txt | 38 ++++++++++++++ swr2_asr/utils/data.py | 41 +++++++-------- swr2_asr/utils/decoder.py | 103 ++++++++++++++++++++++++++++---------- swr2_asr/utils/tokenizer.py | 14 +++--- 4 files changed, 138 insertions(+), 58 deletions(-) create mode 100644 data/tokenizers/tokens_german.txt (limited to 'swr2_asr/utils') diff --git a/data/tokenizers/tokens_german.txt b/data/tokenizers/tokens_german.txt new file mode 100644 index 0000000..57f2c3a --- /dev/null +++ b/data/tokenizers/tokens_german.txt @@ -0,0 +1,38 @@ +_ + + + +a +b +c +d +e +f +g +h +i +j +k +l +m +n +o +p +q +r +s +t +u +v +w +x +y +z +é +à +ä +ö +ß +ü +- +' diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index 19605f6..74cd572 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -6,7 +6,7 @@ import numpy as np import torch import torchaudio from torch import Tensor, nn -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import Dataset from torchaudio.datasets.utils import _extract_tar from swr2_asr.utils.tokenizer import CharTokenizer @@ -343,26 +343,21 @@ class MLSDataset(Dataset): dataset_lookup_entry["chapterid"], idx, ) # type: ignore - + + def create_lexicon(vocab_counts_path, lexicon_path): - - words_list = [] - with open(vocab_counts_path, 'r') as file: - for line in file: - - words = line.split() - if len(words) >= 1: - - word = words[0] - words_list.append(word) - - with open(lexicon_path, 'w') as file: - for word in words_list: - file.write(f"{word} ") - for char in word: - file.write(char + ' ') - file.write("") - - - - \ No newline at end of file + """Create a lexicon from the vocab_counts.txt file""" + words_list = [] + with open(vocab_counts_path, "r", encoding="utf-8") as file: + for line in file: + words = line.split() + if len(words) >= 1: + word = words[0] + words_list.append(word) + + with open(lexicon_path, "w", encoding="utf-8") as file: + for word in words_list: + file.write(f"{word} ") + for char in word: + file.write(char + " ") + file.write("") diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py index e6b5852..098f6a4 100644 --- a/swr2_asr/utils/decoder.py +++ b/swr2_asr/utils/decoder.py @@ -1,14 +1,18 @@ """Decoder for CTC-based ASR.""" "" -import torch - -from swr2_asr.utils.tokenizer import CharTokenizer -from swr2_asr.utils.data import create_lexicon import os + +import torch from torchaudio.datasets.utils import _extract_tar from torchaudio.models.decoder import ctc_decoder -LEXICON = "lexicon.txt" + +from swr2_asr.utils.data import create_lexicon +from swr2_asr.utils.tokenizer import CharTokenizer + + # TODO: refactor to use torch CTC decoder class -def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True): +def greedy_decoder( + output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True +): # pylint: disable=redefined-outer-name """Greedily decode a sequence.""" blank_label = tokenizer.get_blank_token() arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member @@ -26,27 +30,72 @@ def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, coll return decodes, targets -# TODO: add beam search decoder +def beam_search_decoder( + tokenizer: CharTokenizer, # pylint: disable=redefined-outer-name + tokens_path: str, + lang_model_path: str, + language: str, + hparams: dict, # pylint: disable=redefined-outer-name +): + """Beam search decoder.""" + + n_gram, beam_size, beam_threshold, n_best, lm_weight, word_score = ( + hparams["n_gram"], + hparams["beam_size"], + hparams["beam_threshold"], + hparams["n_best"], + hparams["lm_weight"], + hparams["word_score"], + ) + + if not os.path.isdir(os.path.join(lang_model_path, f"mls_lm_{language}")): + url = f"https://dl.fbaipublicfiles.com/mls/mls_lm_{language}.tar.gz" + torch.hub.download_url_to_file(url, f"data/mls_lm_{language}.tar.gz") + _extract_tar("data/mls_lm_{language}.tar.gz", overwrite=True) -def beam_search_decoder(output, tokenizer:CharTokenizer, tokenizer_txt_path,lang_model_path): - if not os.path.isdir(lang_model_path): - url = f"https://dl.fbaipublicfiles.com/mls/mls_lm_german.tar.gz" - torch.hub.download_url_to_file( - url, "data/mls_lm_german.tar.gz" ) - _extract_tar("data/mls_lm_german.tar.gz", overwrite=True) - if not os.path.isfile(tokenizer_txt_path): - tokenizer.create_txt(tokenizer_txt_path) - - lexicon_path= os.join(lang_model_path, LEXICON) + if not os.path.isfile(tokens_path): + tokenizer.create_tokens_txt(tokens_path) + + lexicon_path = os.path.join(lang_model_path, f"mls_lm_{language}", "lexicon.txt") if not os.path.isfile(lexicon_path): - occurences_path = os.join(lang_model_path,"vocab_counts.txt") + occurences_path = os.path.join(lang_model_path, f"mls_lm_{language}", "vocab_counts.txt") create_lexicon(occurences_path, lexicon_path) - lm_path = os.join(lang_model_path,"3-gram_lm.apa") - decoder = ctc_decoder(lexicon = lexicon_path, - tokenizer = tokenizer_txt_path, - lm =lm_path, - blank_token = '_', - nbest =1, - sil_token= '', - unk_word = '') - return decoder \ No newline at end of file + + lm_path = os.path.join(lang_model_path, f"mls_lm_{language}", f"{n_gram}-gram_lm.arpa") + + decoder = ctc_decoder( + lexicon=lexicon_path, + tokens=tokens_path, + lm=lm_path, + blank_token="_", + sil_token="", + unk_word="", + nbest=n_best, + beam_size=beam_size, + beam_threshold=beam_threshold, + lm_weight=lm_weight, + word_score=word_score, + ) + return decoder + + +if __name__ == "__main__": + tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") + tokenizer.create_tokens_txt("data/tokenizers/tokens_german.txt") + + hparams = { + "n_gram": 3, + "beam_size": 100, + "beam_threshold": 100, + "n_best": 1, + "lm_weight": 0.5, + "word_score": 1.0, + } + + beam_search_decoder( + tokenizer, + "data/tokenizers/tokens_german.txt", + "data", + "german", + hparams, + ) diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py index 9abf57d..b1de83e 100644 --- a/swr2_asr/utils/tokenizer.py +++ b/swr2_asr/utils/tokenizer.py @@ -120,11 +120,9 @@ class CharTokenizer: load_tokenizer.char_map[char] = int(index) load_tokenizer.index_map[int(index)] = char return load_tokenizer - - - #TO DO check about the weird unknown tokens etc. - def create_txt(self,path:str): - with open(path, 'w',encoding="utf-8") as file: - for key,value in self.char_map(): - file.write(f"{key}\n") - \ No newline at end of file + + def create_tokens_txt(self, path: str): + """Create a txt file with all the characters""" + with open(path, "w", encoding="utf-8") as file: + for char, _ in self.char_map.items(): + file.write(f"{char}\n") -- cgit v1.2.3 From d5e482b7dc3d8b6acc48a883ae9b53b354fa1715 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 18 Sep 2023 14:25:36 +0200 Subject: decoder changes --- config.philipp.yaml | 55 ++++++++++++--------- config.yaml | 43 +++++++++------- swr2_asr/inference.py | 35 ++++++-------- swr2_asr/utils/decoder.py | 121 +++++++++++++++++++++++++++++++++------------- 4 files changed, 161 insertions(+), 93 deletions(-) (limited to 'swr2_asr/utils') diff --git a/config.philipp.yaml b/config.philipp.yaml index f72ce2e..38a68f8 100644 --- a/config.philipp.yaml +++ b/config.philipp.yaml @@ -1,34 +1,45 @@ +dataset: + download: True + dataset_root_path: "/Volumes/pherkel 2/SWR2-ASR" # files will be downloaded into this dir + language_name: "mls_german_opus" + limited_supervision: True # set to True if you want to use limited supervision + dataset_percentage: 0.15 # percentage of dataset to use (1.0 = 100%) + shuffle: True + model: n_cnn_layers: 3 n_rnn_layers: 5 rnn_dim: 512 n_feats: 128 # number of mel features stride: 2 - dropout: 0.2 # recommended to be around 0.4-0.6 for smaller datasets, 0.1 for really large datasets - -training: - learning_rate: 0.0005 - batch_size: 32 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) - epochs: 150 - eval_every_n: 5 # evaluate every n epochs - num_workers: 4 # number of workers for dataloader - device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically - -dataset: - download: true - dataset_root_path: "data" # files will be downloaded into this dir - language_name: "mls_german_opus" - limited_supervision: false # set to True if you want to use limited supervision - dataset_percentage: 1 # percentage of dataset to use (1.0 = 100%) - shuffle: true + dropout: 0.6 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets tokenizer: tokenizer_path: "data/tokenizers/char_tokenizer_german.json" -checkpoints: - model_load_path: "data/runs/epoch31" # path to load model from - model_save_path: "data/runs/epoch" # path to save model to +decoder: + type: "lm" # greedy, or lm (beam search) + + lm: # config for lm decoder + language_model_path: "data" # path where model and supplementary files are stored + language: "german" + n_gram: 3 # n-gram size of the language model, 3 or 5 + beam_size: 50 + beam_threshold: 50 + n_best: 1 + lm_weight: 2 + word_score: 0 + +training: + learning_rate: 0.0005 + batch_size: 8 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) + epochs: 3 + eval_every_n: 3 # evaluate every n epochs + num_workers: 8 # number of workers for dataloader + +checkpoints: # use "~" to disable saving/loading + model_load_path: "YOUR/PATH" # path to load model from + model_save_path: "YOUR/PATH" # path to save model to inference: - model_load_path: "data/runs/epoch30" # path to load model from - device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically \ No newline at end of file + model_load_path: "data/epoch67" # path to load model from \ No newline at end of file diff --git a/config.yaml b/config.yaml index 41b473c..d248d43 100644 --- a/config.yaml +++ b/config.yaml @@ -1,3 +1,11 @@ +dataset: + download: True + dataset_root_path: "YOUR/PATH" # files will be downloaded into this dir + language_name: "mls_german_opus" + limited_supervision: False # set to True if you want to use limited supervision + dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%) + shuffle: True + model: n_cnn_layers: 3 n_rnn_layers: 5 @@ -6,32 +14,33 @@ model: stride: 2 dropout: 0.3 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets +tokenizer: + tokenizer_path: "data/tokenizers/char_tokenizer_german.json" + +decoder: + type: "greedy" # greedy, or lm (beam search) + + lm: # config for lm decoder + language_model_path: "data" # path where model and supplementary files are stored + language: "german" + n_gram: 3 # n-gram size of the language model, 3 or 5 + beam_size: 50 + beam_threshold: 50 + n_best: 1 + lm_weight: 2, + word_score: 0, + training: - learning_rate: 5e-4 + learning_rate: 0.0005 batch_size: 8 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) epochs: 3 eval_every_n: 3 # evaluate every n epochs num_workers: 8 # number of workers for dataloader -dataset: - download: True - dataset_root_path: "YOUR/PATH" # files will be downloaded into this dir - language_name: "mls_german_opus" - limited_supervision: False # set to True if you want to use limited supervision - dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%) - shuffle: True - -tokenizer: - tokenizer_path: "data/tokenizers/char_tokenizer_german.yaml" - -checkpoints: +checkpoints: # use "~" to disable saving/loading model_load_path: "YOUR/PATH" # path to load model from model_save_path: "YOUR/PATH" # path to save model to inference: model_load_path: "YOUR/PATH" # path to load model from - beam_width: 10 # beam width for beam search - device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically -lang_model: - path: "data/mls_lm_german" #path where model and supplementary files are stored diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py index 3c58af0..64a6eeb 100644 --- a/swr2_asr/inference.py +++ b/swr2_asr/inference.py @@ -6,25 +6,10 @@ import torchaudio import yaml from swr2_asr.model_deep_speech import SpeechRecognitionModel +from swr2_asr.utils.decoder import decoder_factory from swr2_asr.utils.tokenizer import CharTokenizer -def greedy_decoder(output, tokenizer: CharTokenizer, collapse_repeated=True): - """Greedily decode a sequence.""" - arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member - blank_label = tokenizer.get_blank_token() - decodes = [] - for args in arg_maxes: - decode = [] - for j, index in enumerate(args): - if index != blank_label: - if collapse_repeated and j != 0 and index == args[j - 1]: - continue - decode.append(index.item()) - decodes.append(tokenizer.decode(decode)) - return decodes - - @click.command() @click.option( "--config_path", @@ -46,11 +31,16 @@ def main(config_path: str, file_path: str) -> None: model_config = config_dict.get("model", {}) tokenizer_config = config_dict.get("tokenizer", {}) inference_config = config_dict.get("inference", {}) + decoder_config = config_dict.get("decoder", {}) - if inference_config["device"] == "cpu": + if inference_config.get("device", "") == "cpu": device = "cpu" - elif inference_config["device"] == "cuda": + elif inference_config.get("device", "") == "cuda": device = "cuda" if torch.cuda.is_available() else "cpu" + elif inference_config.get("device", "") == "mps": + device = "mps" + else: + device = "cpu" device = torch.device(device) # pylint: disable=no-member tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"]) @@ -90,11 +80,16 @@ def main(config_path: str, file_path: str) -> None: spec = spec.unsqueeze(0) spec = spec.transpose(1, 2) spec = spec.unsqueeze(0) + spec = spec.to(device) output = model(spec) # pylint: disable=not-callable output = F.log_softmax(output, dim=2) # (batch, time, n_class) - decoded_preds = greedy_decoder(output, tokenizer) - print(decoded_preds) + decoder = decoder_factory(decoder_config["type"])(tokenizer, decoder_config) + + preds = decoder(output) + preds = " ".join(preds[0][0].words).strip() + + print(preds) if __name__ == "__main__": diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py index 098f6a4..2b6d29b 100644 --- a/swr2_asr/utils/decoder.py +++ b/swr2_asr/utils/decoder.py @@ -1,4 +1,5 @@ """Decoder for CTC-based ASR.""" "" +from dataclasses import dataclass import os import torch @@ -9,37 +10,39 @@ from swr2_asr.utils.data import create_lexicon from swr2_asr.utils.tokenizer import CharTokenizer -# TODO: refactor to use torch CTC decoder class -def greedy_decoder( - output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True -): # pylint: disable=redefined-outer-name - """Greedily decode a sequence.""" - blank_label = tokenizer.get_blank_token() - arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member - decodes = [] - targets = [] - for i, args in enumerate(arg_maxes): - decode = [] - targets.append(tokenizer.decode(labels[i][: label_lengths[i]].tolist())) - for j, index in enumerate(args): - if index != blank_label: - if collapse_repeated and j != 0 and index == args[j - 1]: - continue - decode.append(index.item()) - decodes.append(tokenizer.decode(decode)) - return decodes, targets - - -def beam_search_decoder( +@dataclass +class DecoderOutput: + """Decoder output.""" + + words: list[str] + + +def decoder_factory(decoder_type: str = "greedy") -> callable: + """Decoder factory.""" + if decoder_type == "greedy": + return get_greedy_decoder + if decoder_type == "lm": + return get_beam_search_decoder + raise NotImplementedError + + +def get_greedy_decoder( + tokenizer: CharTokenizer, # pylint: disable=redefined-outer-name + *_, +): + """Greedy decoder.""" + return GreedyDecoder(tokenizer) + + +def get_beam_search_decoder( tokenizer: CharTokenizer, # pylint: disable=redefined-outer-name - tokens_path: str, - lang_model_path: str, - language: str, hparams: dict, # pylint: disable=redefined-outer-name ): """Beam search decoder.""" - - n_gram, beam_size, beam_threshold, n_best, lm_weight, word_score = ( + hparams = hparams.get("lm", {}) + language, lang_model_path, n_gram, beam_size, beam_threshold, n_best, lm_weight, word_score = ( + hparams["language"], + hparams["language_model_path"], hparams["n_gram"], hparams["beam_size"], hparams["beam_threshold"], @@ -53,6 +56,7 @@ def beam_search_decoder( torch.hub.download_url_to_file(url, f"data/mls_lm_{language}.tar.gz") _extract_tar("data/mls_lm_{language}.tar.gz", overwrite=True) + tokens_path = os.path.join(lang_model_path, f"mls_lm_{language}", "tokens.txt") if not os.path.isfile(tokens_path): tokenizer.create_tokens_txt(tokens_path) @@ -79,11 +83,66 @@ def beam_search_decoder( return decoder +class GreedyDecoder: + """Greedy decoder.""" + + def __init__(self, tokenizer: CharTokenizer): # pylint: disable=redefined-outer-name + self.tokenizer = tokenizer + + def __call__( + self, output, greedy_type: str = "inference", labels=None, label_lengths=None + ): # pylint: disable=redefined-outer-name + """Greedily decode a sequence.""" + if greedy_type == "train": + res = self.train(output, labels, label_lengths) + if greedy_type == "inference": + res = self.inference(output) + + res = [[DecoderOutput(words=res)]] + return res + + def train(self, output, labels, label_lengths): + """Greedily decode a sequence with known labels.""" + blank_label = tokenizer.get_blank_token() + arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member + decodes = [] + targets = [] + for i, args in enumerate(arg_maxes): + decode = [] + targets.append(self.tokenizer.decode(labels[i][: label_lengths[i]].tolist())) + for j, index in enumerate(args): + if index != blank_label: + if j != 0 and index == args[j - 1]: + continue + decode.append(index.item()) + decodes.append(self.tokenizer.decode(decode)) + return decodes, targets + + def inference(self, output): + """Greedily decode a sequence.""" + collapse_repeated = True + arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member + blank_label = self.tokenizer.get_blank_token() + decodes = [] + for args in arg_maxes: + decode = [] + for j, index in enumerate(args): + if index != blank_label: + if collapse_repeated and j != 0 and index == args[j - 1]: + continue + decode.append(index.item()) + decodes.append(self.tokenizer.decode(decode)) + + return decodes + + if __name__ == "__main__": tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") tokenizer.create_tokens_txt("data/tokenizers/tokens_german.txt") hparams = { + "language": "german", + "lang_model_path": "data", "n_gram": 3, "beam_size": 100, "beam_threshold": 100, @@ -92,10 +151,4 @@ if __name__ == "__main__": "word_score": 1.0, } - beam_search_decoder( - tokenizer, - "data/tokenizers/tokens_german.txt", - "data", - "german", - hparams, - ) + get_beam_search_decoder(tokenizer, hparams) -- cgit v1.2.3 From c09ff76ba6f4c5dd5de64a401efcd27449150aec Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 18 Sep 2023 15:05:21 +0200 Subject: added support for lm decoder during training --- config.philipp.yaml | 10 +++++----- pyproject.toml | 2 +- swr2_asr/train.py | 32 ++++++++++++++++++-------------- swr2_asr/utils/decoder.py | 2 +- swr2_asr/utils/tokenizer.py | 8 ++++++++ 5 files changed, 33 insertions(+), 21 deletions(-) (limited to 'swr2_asr/utils') diff --git a/config.philipp.yaml b/config.philipp.yaml index 38a68f8..608720f 100644 --- a/config.philipp.yaml +++ b/config.philipp.yaml @@ -3,7 +3,7 @@ dataset: dataset_root_path: "/Volumes/pherkel 2/SWR2-ASR" # files will be downloaded into this dir language_name: "mls_german_opus" limited_supervision: True # set to True if you want to use limited supervision - dataset_percentage: 0.15 # percentage of dataset to use (1.0 = 100%) + dataset_percentage: 0.01 # percentage of dataset to use (1.0 = 100%) shuffle: True model: @@ -33,13 +33,13 @@ decoder: training: learning_rate: 0.0005 batch_size: 8 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) - epochs: 3 - eval_every_n: 3 # evaluate every n epochs + epochs: 100 + eval_every_n: 1 # evaluate every n epochs num_workers: 8 # number of workers for dataloader checkpoints: # use "~" to disable saving/loading - model_load_path: "YOUR/PATH" # path to load model from - model_save_path: "YOUR/PATH" # path to save model to + model_load_path: "data/epoch67" # path to load model from + model_save_path: ~ # path to save model to inference: model_load_path: "data/epoch67" # path to load model from \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index f6d19dd..2f26e5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ target-version = "py310" line-length = 100 [tool.poetry.scripts] -train = "swr2_asr.train:run_cli" +train = "swr2_asr.train:main" train-bpe-tokenizer = "swr2_asr.tokenizer:train_bpe_tokenizer" train-char-tokenizer = "swr2_asr.tokenizer:train_char_tokenizer" diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 9c7ede9..1e57ba0 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -12,10 +12,9 @@ from tqdm.autonotebook import tqdm from swr2_asr.model_deep_speech import SpeechRecognitionModel from swr2_asr.utils.data import DataProcessing, MLSDataset, Split -from swr2_asr.utils.decoder import greedy_decoder -from swr2_asr.utils.tokenizer import CharTokenizer - +from swr2_asr.utils.decoder import decoder_factory from swr2_asr.utils.loss_scores import cer, wer +from swr2_asr.utils.tokenizer import CharTokenizer class IterMeter: @@ -123,9 +122,6 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: # get values from test_args: model, device, test_loader, criterion, tokenizer, decoder = test_args.values() - if decoder == "greedy": - decoder = greedy_decoder - model.eval() test_loss = 0 test_cer, test_wer = [], [] @@ -141,12 +137,13 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: loss = criterion(output, labels, input_lengths, label_lengths) test_loss += loss.item() / len(test_loader) - decoded_preds, decoded_targets = greedy_decoder( - output.transpose(0, 1), labels, label_lengths, tokenizer - ) + decoded_targets = tokenizer.decode_batch(labels) + decoded_preds = decoder(output.transpose(0, 1)) for j, _ in enumerate(decoded_preds): - test_cer.append(cer(decoded_targets[j], decoded_preds[j])) - test_wer.append(wer(decoded_targets[j], decoded_preds[j])) + if j >= len(decoded_targets): + break + test_cer.append(cer(decoded_targets[j], decoded_preds[j][0].words[0])) + test_wer.append(wer(decoded_targets[j], decoded_preds[j][0].words[0])) avg_cer = sum(test_cer) / len(test_cer) avg_wer = sum(test_wer) / len(test_wer) @@ -187,6 +184,7 @@ def main(config_path: str): dataset_config = config_dict.get("dataset", {}) tokenizer_config = config_dict.get("tokenizer", {}) checkpoints_config = config_dict.get("checkpoints", {}) + decoder_config = config_dict.get("decoder", {}) if not os.path.isdir(dataset_config["dataset_root_path"]): os.makedirs(dataset_config["dataset_root_path"]) @@ -262,12 +260,19 @@ def main(config_path: str): if checkpoints_config["model_load_path"] is not None: checkpoint = torch.load(checkpoints_config["model_load_path"], map_location=device) - model.load_state_dict(checkpoint["model_state_dict"]) + state_dict = { + k[len("module.") :] if k.startswith("module.") else k: v + for k, v in checkpoint["model_state_dict"].items() + } + + model.load_state_dict(state_dict) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) prev_epoch = checkpoint["epoch"] iter_meter = IterMeter() + decoder = decoder_factory(decoder_config.get("type", "greedy"))(tokenizer, decoder_config) + for epoch in range(prev_epoch + 1, training_config["epochs"] + 1): train_args: TrainArgs = { "model": model, @@ -283,14 +288,13 @@ def main(config_path: str): train_loss = train(train_args) test_loss, test_cer, test_wer = 0, 0, 0 - test_args: TestArgs = { "model": model, "device": device, "test_loader": valid_loader, "criterion": criterion, "tokenizer": tokenizer, - "decoder": "greedy", + "decoder": decoder, } if training_config["eval_every_n"] != 0 and epoch % training_config["eval_every_n"] == 0: diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py index 2b6d29b..6ffdef2 100644 --- a/swr2_asr/utils/decoder.py +++ b/swr2_asr/utils/decoder.py @@ -1,6 +1,6 @@ """Decoder for CTC-based ASR.""" "" -from dataclasses import dataclass import os +from dataclasses import dataclass import torch from torchaudio.datasets.utils import _extract_tar diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py index b1de83e..4e3fddd 100644 --- a/swr2_asr/utils/tokenizer.py +++ b/swr2_asr/utils/tokenizer.py @@ -29,9 +29,17 @@ class CharTokenizer: """Use a character map and convert integer labels to an text sequence""" string = [] for i in labels: + i = int(i) string.append(self.index_map[i]) return "".join(string).replace("", " ") + def decode_batch(self, labels: list[list[int]]) -> list[str]: + """Use a character map and convert integer labels to an text sequence""" + string = [] + for label in labels: + string.append(self.decode(label)) + return string + def get_vocab_size(self) -> int: """Get the number of unique characters in the dataset""" return len(self.char_map) -- cgit v1.2.3 From d44cf7b1cab683a8aa3876619c82226f4e6d6f3b Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 18 Sep 2023 18:11:33 +0200 Subject: fix --- .gitignore | 4 + Dockerfile | 13 --- Makefile | 9 -- config.cluster.yaml | 34 ------ config.philipp.yaml | 2 +- data/own/Philipp_HerrK.flac | Bin 0 -> 595064 bytes lm_decoder_hparams.ipynb | 245 ++++++++++++++++++++++++++++++++++++++++++++ metrics.csv | 69 +++++++++++++ plots.ipynb | 131 +++++++++++++++++++++++ swr2_asr/train.py | 6 +- swr2_asr/utils/decoder.py | 3 +- 11 files changed, 456 insertions(+), 60 deletions(-) delete mode 100644 Dockerfile delete mode 100644 Makefile delete mode 100644 config.cluster.yaml create mode 100644 data/own/Philipp_HerrK.flac create mode 100644 lm_decoder_hparams.ipynb create mode 100644 metrics.csv create mode 100644 plots.ipynb (limited to 'swr2_asr/utils') diff --git a/.gitignore b/.gitignore index 485df5b..061bfca 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,11 @@ +# pictures +**/*.png + # Training files data/* !data/tokenizers !data/own +!data/metrics.csv # Mac **/.DS_Store diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index ca7463f..0000000 --- a/Dockerfile +++ /dev/null @@ -1,13 +0,0 @@ -FROM python:3.10 - -# install python poetry -RUN curl -sSL https://install.python-poetry.org | python3 - - -WORKDIR /app - -COPY readme.md mypy.ini poetry.lock pyproject.toml ./ -COPY swr2_asr ./swr2_asr -ENV POETRY_VIRTUALENVS_IN_PROJECT=true -RUN /root/.local/bin/poetry --no-interaction install --without dev - -ENTRYPOINT [ "/root/.local/bin/poetry", "run", "python", "-m", "swr2_asr" ] diff --git a/Makefile b/Makefile deleted file mode 100644 index a37644c..0000000 --- a/Makefile +++ /dev/null @@ -1,9 +0,0 @@ -format: - @poetry run black . - -format-check: - @poetry run black --check . - -lint: - @poetry run mypy --strict swr2_asr - @poetry run pylint swr2_asr \ No newline at end of file diff --git a/config.cluster.yaml b/config.cluster.yaml deleted file mode 100644 index 7af0aca..0000000 --- a/config.cluster.yaml +++ /dev/null @@ -1,34 +0,0 @@ -model: - n_cnn_layers: 3 - n_rnn_layers: 5 - rnn_dim: 512 - n_feats: 128 # number of mel features - stride: 2 - dropout: 0.2 # recommended to be around 0.4-0.6 for smaller datasets, 0.1 for really large datasets - -training: - learning_rate: 0.0005 - batch_size: 400 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) - epochs: 150 - eval_every_n: 5 # evaluate every n epochs - num_workers: 12 # number of workers for dataloader - device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically - -dataset: - download: False - dataset_root_path: "/mnt/lustre/mladm/mfa252/data" # files will be downloaded into this dir - language_name: "mls_german_opus" - limited_supervision: False # set to True if you want to use limited supervision - dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%) - shuffle: True - -tokenizer: - tokenizer_path: "data/tokenizers/char_tokenizer_german.json" - -checkpoints: - model_load_path: "data/runs/epoch50" # path to load model from - model_save_path: "data/runs/epoch" # path to save model to - -inference: - model_load_path: ~ # path to load model from - device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically diff --git a/config.philipp.yaml b/config.philipp.yaml index 608720f..7a93d05 100644 --- a/config.philipp.yaml +++ b/config.philipp.yaml @@ -18,7 +18,7 @@ tokenizer: tokenizer_path: "data/tokenizers/char_tokenizer_german.json" decoder: - type: "lm" # greedy, or lm (beam search) + type: "greedy" # greedy, or lm (beam search) lm: # config for lm decoder language_model_path: "data" # path where model and supplementary files are stored diff --git a/data/own/Philipp_HerrK.flac b/data/own/Philipp_HerrK.flac new file mode 100644 index 0000000..dec59e3 Binary files /dev/null and b/data/own/Philipp_HerrK.flac differ diff --git a/lm_decoder_hparams.ipynb b/lm_decoder_hparams.ipynb new file mode 100644 index 0000000..5e56312 --- /dev/null +++ b/lm_decoder_hparams.ipynb @@ -0,0 +1,245 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "lm_weights = [0, 1.0, 2.5,]\n", + "word_score = [-1.5, 0.0, 1.5]\n", + "beam_sizes = [50, 500]\n", + "beam_thresholds = [50]\n", + "beam_size_token = [10, 38]" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/lm/1zmdkgm91k912l2vgq978z800000gn/T/ipykernel_80481/3805229751.py:1: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n", + " from tqdm.autonotebook import tqdm\n", + "/Users/philippmerkel/DEV/SWR2-cool-projekt/.venv/lib/python3.10/site-packages/torchaudio/models/decoder/_ctc_decoder.py:62: UserWarning: The built-in flashlight integration is deprecated, and will be removed in future release. Please install flashlight-text. https://pypi.org/project/flashlight-text/ For the detail of CTC decoder migration, please see https://github.com/pytorch/audio/issues/3088.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "from tqdm.autonotebook import tqdm\n", + "\n", + "import torch\n", + "from torch.utils.data import DataLoader\n", + "import torch.nn.functional as F\n", + "\n", + "from swr2_asr.utils.decoder import decoder_factory\n", + "from swr2_asr.utils.tokenizer import CharTokenizer\n", + "from swr2_asr.model_deep_speech import SpeechRecognitionModel\n", + "from swr2_asr.utils.data import MLSDataset, Split, DataProcessing\n", + "from swr2_asr.utils.loss_scores import cer, wer" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "34aafd9aca2541748dc41d8550334536", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/144 [00:00= len(decoded_targets):\n", + " break\n", + " pred = \" \".join(decoded_preds[j][0].words).strip()\n", + " target = decoded_targets[j]\n", + " \n", + " test_cer.append(cer(pred, target))\n", + " test_wer.append(wer(pred, target))\n", + "\n", + " avg_cer = sum(test_cer) / len(test_cer)\n", + " avg_wer = sum(test_wer) / len(test_wer)\n", + " \n", + " if avg_wer < best_wer:\n", + " best_wer = avg_wer\n", + " best_cer = avg_cer\n", + " best_config = config\n", + " print(\"New best WER: \", best_wer, \" CER: \", best_cer)\n", + " print(\"Config: \", best_config)\n", + " print(\"LM Weight: \", lm_weight, \n", + " \" Word Score: \", ws, \n", + " \" Beam Size: \", beam_size, \n", + " \" Beam Threshold: \", beam_threshold, \n", + " \" Beam Size Token: \", beam_size_t)\n", + " print(\"--------------------------------------------------------------\")\n", + " \n", + " pbar.update(1)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/metrics.csv b/metrics.csv new file mode 100644 index 0000000..22b8cec --- /dev/null +++ b/metrics.csv @@ -0,0 +1,69 @@ +epoch,train_loss,test_loss,cer,wer +0.0,3.25246262550354,3.0130836963653564,1.0,0.9999533337969454 +1.0,2.791025161743164,0.0,0.0,0.0 +2.0,1.5954065322875977,0.0,0.0,0.0 +3.0,1.3106564283370972,0.0,0.0,0.0 +4.0,1.206541895866394,0.0,0.0,0.0 +5.0,1.1116338968276978,0.9584052684355759,0.26248163774768096,0.8057431713202183 +6.0,1.0295032262802124,0.0,0.0,0.0 +7.0,0.957234263420105,0.0,0.0,0.0 +8.0,0.8958202004432678,0.0,0.0,0.0 +9.0,0.8403098583221436,0.0,0.0,0.0 +10.0,0.7934719324111938,0.577774976386505,0.1647645650587519,0.5597785267513198 +11.0,0.7537956833839417,0.0,0.0,0.0 +12.0,0.7180628776550293,0.0,0.0,0.0 +13.0,0.6870554089546204,0.0,0.0,0.0 +14.0,0.6595032811164856,0.0,0.0,0.0 +15.0,0.6374552845954895,0.42232042328030084,0.12030436712014228,0.43601402176865556 +16.0,0.6134707927703857,0.0,0.0,0.0 +17.0,0.5946973562240601,0.0,0.0,0.0 +18.0,0.577201783657074,0.0,0.0,0.0 +19.0,0.5612062811851501,0.0,0.0,0.0 +20.0,0.5256602764129639,0.33855139215787244,0.09390776269838304,0.35605188295180307 +21.0,0.5190389752388,0.0,0.0,0.0 +22.0,0.5163558721542358,0.0,0.0,0.0 +23.0,0.5132778286933899,0.0,0.0,0.0 +24.0,0.5090991854667664,0.0,0.0,0.0 +25.0,0.5072354078292847,0.32589933276176464,0.08999255619329079,0.341225825396658 +26.0,0.5023046731948853,0.0,0.0,0.0 +27.0,0.4994561970233917,0.0,0.0,0.0 +28.0,0.4942632019519806,0.0,0.0,0.0 +29.0,0.4906529486179352,0.0,0.0,0.0 +30.0,0.4855062663555145,0.29864962175995297,0.08296308087950884,0.3177622785738594 +31.0,0.4822919964790344,0.0,0.0,0.0 +32.0,0.4456436336040497,0.0,0.0,0.0 +33.0,0.4389857053756714,0.0,0.0,0.0 +34.0,0.43762147426605225,0.0,0.0,0.0 +35.0,0.4351556599140167,0.5776603897412618,0.16294622142152407,0.5232870602289124 +36.0,0.43377435207366943,0.0,0.0,0.0 +37.0,0.4318349063396454,0.0,0.0,0.0 +38.0,0.43010208010673523,0.0,0.0,0.0 +39.0,0.4276123046875,0.0,0.0,0.0 +40.0,0.4253982901573181,0.5735072294871012,0.1586969400218906,0.5131595862326734 +41.0,0.4236880838871002,0.0,0.0,0.0 +42.0,0.42077934741973877,0.0,0.0,0.0 +43.0,0.4181424081325531,0.0,0.0,0.0 +44.0,0.4154696464538574,0.0,0.0,0.0 +45.0,0.419731080532074,0.5696070055166881,0.15437095897735878,0.5002024974353078 +46.0,0.4099026024341583,0.0,0.0,0.0 +47.0,0.4078012704849243,0.0,0.0,0.0 +48.0,0.40490180253982544,0.0,0.0,0.0 +49.0,0.4024839699268341,0.0,0.0,0.0 +50.0,0.3694721758365631,0.5247387786706288,0.1450933666590186,0.4700957797096995 +51.0,0.36624056100845337,0.0,0.0,0.0 +52.0,0.36418089270591736,0.0,0.0,0.0 +53.0,0.36366793513298035,0.0,0.0,0.0 +54.0,0.36317530274391174,0.0,0.0,0.0 +55.0,0.3624136447906494,0.510421613852183,0.14174752623520492,0.4632967062415951 +56.0,0.36174166202545166,0.0,0.0,0.0 +57.0,0.36113062500953674,0.0,0.0,0.0 +58.0,0.36098596453666687,0.0,0.0,0.0 +59.0,0.35909315943717957,0.0,0.0,0.0 +60.0,0.36021551489830017,0.5095615088939668,0.14084592211118552,0.45461000263956114 +61.0,0.35837724804878235,0.0,0.0,0.0 +62.0,0.3567410409450531,0.0,0.0,0.0 +63.0,0.3565385341644287,0.0,0.0,0.0 +64.0,0.35535314679145813,0.0,0.0,0.0 +65.0,0.35792484879493713,0.5086047914293077,0.13893481611889835,0.45137245514066726 +66.0,0.35215333104133606,0.0,0.0,0.0 +67.0,0.35401859879493713,0.0,0.0,0.0 \ No newline at end of file diff --git a/plots.ipynb b/plots.ipynb new file mode 100644 index 0000000..716834a --- /dev/null +++ b/plots.ipynb @@ -0,0 +1,131 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# load csv with colmns epoch, train_loss, test_loss, cer, wer\n", + "# test_loss, cer, wer should not be plotted if they are 0.0\n", + "# plot train_loss and test_loss in one plot\n", + "# plot cer and wer in one plot\n", + " \n", + "# save plots as png\n", + "\n", + "csv_path = \"metrics.csv\"\n", + "\n", + "# load csv\n", + "df = pd.read_csv(csv_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# plot train_loss and test_loss\n", + "# do not use colors, distinguis by line style. use solid for train_loss and dashed for test_loss\n", + "plt.plot(df['epoch'], df['train_loss'], label='train_loss', linestyle='solid', color='black')\n", + "\n", + "# create zip with epoch and test_loss for all epochs\n", + "# filter out all test_loss with value 0.0\n", + "# plot test_loss\n", + "epoch_loss = zip(df['epoch'], df['test_loss'])\n", + "epoch_loss = list(filter(lambda x: x[1] != 0.0, epoch_loss))\n", + "plt.plot([x[0] for x in epoch_loss], [x[1] for x in epoch_loss], label='test_loss', linestyle='dashed', color='black')\n", + "\n", + "# add markers for test_loss\n", + "for x, y in epoch_loss:\n", + " plt.plot(x, y, marker='o', markersize=3, color='black')\n", + "\n", + "plt.xlabel('epoch')\n", + "plt.ylabel('loss')\n", + "plt.legend()\n", + "\n", + "# add ticks every 5 epochs\n", + "plt.xticks(range(0, 70, 5))\n", + "\n", + "# set y limits to 0\n", + "plt.ylim(bottom=0)\n", + "# reduce margins\n", + "plt.tight_layout()\n", + "# increase resolution\n", + "plt.savefig('train_test_loss.png', dpi=300)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "epoch_cer = zip(df['epoch'], df['cer'])\n", + "epoch_cer = list(filter(lambda x: x[1] != 0.0, epoch_cer))\n", + "plt.plot([x[0] for x in epoch_cer], [x[1] for x in epoch_cer], label='cer', linestyle='solid', color='black')\n", + "\n", + "# add markers for cer\n", + "for x, y in epoch_cer:\n", + " plt.plot(x, y, marker='o', markersize=3, color='black')\n", + " \n", + "epoch_wer = zip(df['epoch'], df['wer'])\n", + "epoch_wer = list(filter(lambda x: x[1] != 0.0, epoch_wer))\n", + "plt.plot([x[0] for x in epoch_wer], [x[1] for x in epoch_wer], label='wer', linestyle='dashed', color='black')\n", + "\n", + "# add markers for wer\n", + "for x, y in epoch_wer:\n", + " plt.plot(x, y, marker='o', markersize=3, color='black')\n", + " \n", + "# set y limits to 0 and 1\n", + "plt.ylim(bottom=0, top=1)\n", + "plt.xlabel('epoch')\n", + "plt.ylabel('error rate')\n", + "plt.legend()\n", + "# reduce margins\n", + "plt.tight_layout()\n", + "\n", + "# add ticks every 5 epochs\n", + "plt.xticks(range(0, 70, 5))\n", + "\n", + "# add ticks every 0.1 \n", + "plt.yticks([x/10 for x in range(0, 11, 1)])\n", + "\n", + "# increase resolution\n", + "plt.savefig('cer_wer.png', dpi=300)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 1e57ba0..5277c16 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -142,8 +142,10 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: for j, _ in enumerate(decoded_preds): if j >= len(decoded_targets): break - test_cer.append(cer(decoded_targets[j], decoded_preds[j][0].words[0])) - test_wer.append(wer(decoded_targets[j], decoded_preds[j][0].words[0])) + pred = " ".join(decoded_preds[j][0].words).strip() # batch, top, words + target = decoded_targets[j] + test_cer.append(cer(target, pred)) + test_wer.append(wer(target, pred)) avg_cer = sum(test_cer) / len(test_cer) avg_wer = sum(test_wer) / len(test_wer) diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py index 6ffdef2..1fd002a 100644 --- a/swr2_asr/utils/decoder.py +++ b/swr2_asr/utils/decoder.py @@ -98,7 +98,8 @@ class GreedyDecoder: if greedy_type == "inference": res = self.inference(output) - res = [[DecoderOutput(words=res)]] + res = [x.split(" ") for x in res] + res = [[DecoderOutput(x)] for x in res] return res def train(self, output, labels, label_lengths): -- cgit v1.2.3