diff options
author | Pherkel | 2023-09-18 18:13:46 +0200 |
---|---|---|
committer | GitHub | 2023-09-18 18:13:46 +0200 |
commit | f94506764bde3e4d41dc593e9d11aa7330c00e30 (patch) | |
tree | 6fc438536a72e195805c1aea97926f4c9bbd4f85 /swr2_asr | |
parent | 8b3a0b47813733ef67befa6959a4d24f8518b5b7 (diff) | |
parent | 21a3b1d7cc8544fa0031b8934283382bdfd1d8f1 (diff) |
Merge pull request #38 from Algo-Boys/decoder
Decoder
Diffstat (limited to 'swr2_asr')
-rw-r--r-- | swr2_asr/inference.py | 35 | ||||
-rw-r--r-- | swr2_asr/train.py | 34 | ||||
-rw-r--r-- | swr2_asr/utils/data.py | 20 | ||||
-rw-r--r-- | swr2_asr/utils/decoder.py | 169 | ||||
-rw-r--r-- | swr2_asr/utils/tokenizer.py | 14 |
5 files changed, 217 insertions, 55 deletions
diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py index 3c58af0..64a6eeb 100644 --- a/swr2_asr/inference.py +++ b/swr2_asr/inference.py @@ -6,25 +6,10 @@ import torchaudio import yaml from swr2_asr.model_deep_speech import SpeechRecognitionModel +from swr2_asr.utils.decoder import decoder_factory from swr2_asr.utils.tokenizer import CharTokenizer -def greedy_decoder(output, tokenizer: CharTokenizer, collapse_repeated=True): - """Greedily decode a sequence.""" - arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member - blank_label = tokenizer.get_blank_token() - decodes = [] - for args in arg_maxes: - decode = [] - for j, index in enumerate(args): - if index != blank_label: - if collapse_repeated and j != 0 and index == args[j - 1]: - continue - decode.append(index.item()) - decodes.append(tokenizer.decode(decode)) - return decodes - - @click.command() @click.option( "--config_path", @@ -46,11 +31,16 @@ def main(config_path: str, file_path: str) -> None: model_config = config_dict.get("model", {}) tokenizer_config = config_dict.get("tokenizer", {}) inference_config = config_dict.get("inference", {}) + decoder_config = config_dict.get("decoder", {}) - if inference_config["device"] == "cpu": + if inference_config.get("device", "") == "cpu": device = "cpu" - elif inference_config["device"] == "cuda": + elif inference_config.get("device", "") == "cuda": device = "cuda" if torch.cuda.is_available() else "cpu" + elif inference_config.get("device", "") == "mps": + device = "mps" + else: + device = "cpu" device = torch.device(device) # pylint: disable=no-member tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"]) @@ -90,11 +80,16 @@ def main(config_path: str, file_path: str) -> None: spec = spec.unsqueeze(0) spec = spec.transpose(1, 2) spec = spec.unsqueeze(0) + spec = spec.to(device) output = model(spec) # pylint: disable=not-callable output = F.log_softmax(output, dim=2) # (batch, time, n_class) - decoded_preds = greedy_decoder(output, tokenizer) - print(decoded_preds) + decoder = decoder_factory(decoder_config["type"])(tokenizer, decoder_config) + + preds = decoder(output) + preds = " ".join(preds[0][0].words).strip() + + print(preds) if __name__ == "__main__": diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 9c7ede9..5277c16 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -12,10 +12,9 @@ from tqdm.autonotebook import tqdm from swr2_asr.model_deep_speech import SpeechRecognitionModel from swr2_asr.utils.data import DataProcessing, MLSDataset, Split -from swr2_asr.utils.decoder import greedy_decoder -from swr2_asr.utils.tokenizer import CharTokenizer - +from swr2_asr.utils.decoder import decoder_factory from swr2_asr.utils.loss_scores import cer, wer +from swr2_asr.utils.tokenizer import CharTokenizer class IterMeter: @@ -123,9 +122,6 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: # get values from test_args: model, device, test_loader, criterion, tokenizer, decoder = test_args.values() - if decoder == "greedy": - decoder = greedy_decoder - model.eval() test_loss = 0 test_cer, test_wer = [], [] @@ -141,12 +137,15 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: loss = criterion(output, labels, input_lengths, label_lengths) test_loss += loss.item() / len(test_loader) - decoded_preds, decoded_targets = greedy_decoder( - output.transpose(0, 1), labels, label_lengths, tokenizer - ) + decoded_targets = tokenizer.decode_batch(labels) + decoded_preds = decoder(output.transpose(0, 1)) for j, _ in enumerate(decoded_preds): - test_cer.append(cer(decoded_targets[j], decoded_preds[j])) - test_wer.append(wer(decoded_targets[j], decoded_preds[j])) + if j >= len(decoded_targets): + break + pred = " ".join(decoded_preds[j][0].words).strip() # batch, top, words + target = decoded_targets[j] + test_cer.append(cer(target, pred)) + test_wer.append(wer(target, pred)) avg_cer = sum(test_cer) / len(test_cer) avg_wer = sum(test_wer) / len(test_wer) @@ -187,6 +186,7 @@ def main(config_path: str): dataset_config = config_dict.get("dataset", {}) tokenizer_config = config_dict.get("tokenizer", {}) checkpoints_config = config_dict.get("checkpoints", {}) + decoder_config = config_dict.get("decoder", {}) if not os.path.isdir(dataset_config["dataset_root_path"]): os.makedirs(dataset_config["dataset_root_path"]) @@ -262,12 +262,19 @@ def main(config_path: str): if checkpoints_config["model_load_path"] is not None: checkpoint = torch.load(checkpoints_config["model_load_path"], map_location=device) - model.load_state_dict(checkpoint["model_state_dict"]) + state_dict = { + k[len("module.") :] if k.startswith("module.") else k: v + for k, v in checkpoint["model_state_dict"].items() + } + + model.load_state_dict(state_dict) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) prev_epoch = checkpoint["epoch"] iter_meter = IterMeter() + decoder = decoder_factory(decoder_config.get("type", "greedy"))(tokenizer, decoder_config) + for epoch in range(prev_epoch + 1, training_config["epochs"] + 1): train_args: TrainArgs = { "model": model, @@ -283,14 +290,13 @@ def main(config_path: str): train_loss = train(train_args) test_loss, test_cer, test_wer = 0, 0, 0 - test_args: TestArgs = { "model": model, "device": device, "test_loader": valid_loader, "criterion": criterion, "tokenizer": tokenizer, - "decoder": "greedy", + "decoder": decoder, } if training_config["eval_every_n"] != 0 and epoch % training_config["eval_every_n"] == 0: diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index d551c98..74cd572 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -6,7 +6,7 @@ import numpy as np import torch import torchaudio from torch import Tensor, nn -from torch.utils.data import DataLoader, Dataset +from torch.utils.data import Dataset from torchaudio.datasets.utils import _extract_tar from swr2_asr.utils.tokenizer import CharTokenizer @@ -343,3 +343,21 @@ class MLSDataset(Dataset): dataset_lookup_entry["chapterid"], idx, ) # type: ignore + + +def create_lexicon(vocab_counts_path, lexicon_path): + """Create a lexicon from the vocab_counts.txt file""" + words_list = [] + with open(vocab_counts_path, "r", encoding="utf-8") as file: + for line in file: + words = line.split() + if len(words) >= 1: + word = words[0] + words_list.append(word) + + with open(lexicon_path, "w", encoding="utf-8") as file: + for word in words_list: + file.write(f"{word} ") + for char in word: + file.write(char + " ") + file.write("<SPACE>") diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py index fcddb79..1fd002a 100644 --- a/swr2_asr/utils/decoder.py +++ b/swr2_asr/utils/decoder.py @@ -1,26 +1,155 @@ """Decoder for CTC-based ASR.""" "" +import os +from dataclasses import dataclass + import torch +from torchaudio.datasets.utils import _extract_tar +from torchaudio.models.decoder import ctc_decoder +from swr2_asr.utils.data import create_lexicon from swr2_asr.utils.tokenizer import CharTokenizer -# TODO: refactor to use torch CTC decoder class -def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True): - """Greedily decode a sequence.""" - blank_label = tokenizer.get_blank_token() - arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member - decodes = [] - targets = [] - for i, args in enumerate(arg_maxes): - decode = [] - targets.append(tokenizer.decode(labels[i][: label_lengths[i]].tolist())) - for j, index in enumerate(args): - if index != blank_label: - if collapse_repeated and j != 0 and index == args[j - 1]: - continue - decode.append(index.item()) - decodes.append(tokenizer.decode(decode)) - return decodes, targets - - -# TODO: add beam search decoder +@dataclass +class DecoderOutput: + """Decoder output.""" + + words: list[str] + + +def decoder_factory(decoder_type: str = "greedy") -> callable: + """Decoder factory.""" + if decoder_type == "greedy": + return get_greedy_decoder + if decoder_type == "lm": + return get_beam_search_decoder + raise NotImplementedError + + +def get_greedy_decoder( + tokenizer: CharTokenizer, # pylint: disable=redefined-outer-name + *_, +): + """Greedy decoder.""" + return GreedyDecoder(tokenizer) + + +def get_beam_search_decoder( + tokenizer: CharTokenizer, # pylint: disable=redefined-outer-name + hparams: dict, # pylint: disable=redefined-outer-name +): + """Beam search decoder.""" + hparams = hparams.get("lm", {}) + language, lang_model_path, n_gram, beam_size, beam_threshold, n_best, lm_weight, word_score = ( + hparams["language"], + hparams["language_model_path"], + hparams["n_gram"], + hparams["beam_size"], + hparams["beam_threshold"], + hparams["n_best"], + hparams["lm_weight"], + hparams["word_score"], + ) + + if not os.path.isdir(os.path.join(lang_model_path, f"mls_lm_{language}")): + url = f"https://dl.fbaipublicfiles.com/mls/mls_lm_{language}.tar.gz" + torch.hub.download_url_to_file(url, f"data/mls_lm_{language}.tar.gz") + _extract_tar("data/mls_lm_{language}.tar.gz", overwrite=True) + + tokens_path = os.path.join(lang_model_path, f"mls_lm_{language}", "tokens.txt") + if not os.path.isfile(tokens_path): + tokenizer.create_tokens_txt(tokens_path) + + lexicon_path = os.path.join(lang_model_path, f"mls_lm_{language}", "lexicon.txt") + if not os.path.isfile(lexicon_path): + occurences_path = os.path.join(lang_model_path, f"mls_lm_{language}", "vocab_counts.txt") + create_lexicon(occurences_path, lexicon_path) + + lm_path = os.path.join(lang_model_path, f"mls_lm_{language}", f"{n_gram}-gram_lm.arpa") + + decoder = ctc_decoder( + lexicon=lexicon_path, + tokens=tokens_path, + lm=lm_path, + blank_token="_", + sil_token="<SPACE>", + unk_word="<UNK>", + nbest=n_best, + beam_size=beam_size, + beam_threshold=beam_threshold, + lm_weight=lm_weight, + word_score=word_score, + ) + return decoder + + +class GreedyDecoder: + """Greedy decoder.""" + + def __init__(self, tokenizer: CharTokenizer): # pylint: disable=redefined-outer-name + self.tokenizer = tokenizer + + def __call__( + self, output, greedy_type: str = "inference", labels=None, label_lengths=None + ): # pylint: disable=redefined-outer-name + """Greedily decode a sequence.""" + if greedy_type == "train": + res = self.train(output, labels, label_lengths) + if greedy_type == "inference": + res = self.inference(output) + + res = [x.split(" ") for x in res] + res = [[DecoderOutput(x)] for x in res] + return res + + def train(self, output, labels, label_lengths): + """Greedily decode a sequence with known labels.""" + blank_label = tokenizer.get_blank_token() + arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member + decodes = [] + targets = [] + for i, args in enumerate(arg_maxes): + decode = [] + targets.append(self.tokenizer.decode(labels[i][: label_lengths[i]].tolist())) + for j, index in enumerate(args): + if index != blank_label: + if j != 0 and index == args[j - 1]: + continue + decode.append(index.item()) + decodes.append(self.tokenizer.decode(decode)) + return decodes, targets + + def inference(self, output): + """Greedily decode a sequence.""" + collapse_repeated = True + arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member + blank_label = self.tokenizer.get_blank_token() + decodes = [] + for args in arg_maxes: + decode = [] + for j, index in enumerate(args): + if index != blank_label: + if collapse_repeated and j != 0 and index == args[j - 1]: + continue + decode.append(index.item()) + decodes.append(self.tokenizer.decode(decode)) + + return decodes + + +if __name__ == "__main__": + tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") + tokenizer.create_tokens_txt("data/tokenizers/tokens_german.txt") + + hparams = { + "language": "german", + "lang_model_path": "data", + "n_gram": 3, + "beam_size": 100, + "beam_threshold": 100, + "n_best": 1, + "lm_weight": 0.5, + "word_score": 1.0, + } + + get_beam_search_decoder(tokenizer, hparams) diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py index 1cc7b84..4e3fddd 100644 --- a/swr2_asr/utils/tokenizer.py +++ b/swr2_asr/utils/tokenizer.py @@ -29,9 +29,17 @@ class CharTokenizer: """Use a character map and convert integer labels to an text sequence""" string = [] for i in labels: + i = int(i) string.append(self.index_map[i]) return "".join(string).replace("<SPACE>", " ") + def decode_batch(self, labels: list[list[int]]) -> list[str]: + """Use a character map and convert integer labels to an text sequence""" + string = [] + for label in labels: + string.append(self.decode(label)) + return string + def get_vocab_size(self) -> int: """Get the number of unique characters in the dataset""" return len(self.char_map) @@ -120,3 +128,9 @@ class CharTokenizer: load_tokenizer.char_map[char] = int(index) load_tokenizer.index_map[int(index)] = char return load_tokenizer + + def create_tokens_txt(self, path: str): + """Create a txt file with all the characters""" + with open(path, "w", encoding="utf-8") as file: + for char, _ in self.char_map.items(): + file.write(f"{char}\n") |