aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr')
-rw-r--r--swr2_asr/inference.py35
-rw-r--r--swr2_asr/train.py34
-rw-r--r--swr2_asr/utils/data.py20
-rw-r--r--swr2_asr/utils/decoder.py169
-rw-r--r--swr2_asr/utils/tokenizer.py14
5 files changed, 217 insertions, 55 deletions
diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py
index 3c58af0..64a6eeb 100644
--- a/swr2_asr/inference.py
+++ b/swr2_asr/inference.py
@@ -6,25 +6,10 @@ import torchaudio
import yaml
from swr2_asr.model_deep_speech import SpeechRecognitionModel
+from swr2_asr.utils.decoder import decoder_factory
from swr2_asr.utils.tokenizer import CharTokenizer
-def greedy_decoder(output, tokenizer: CharTokenizer, collapse_repeated=True):
- """Greedily decode a sequence."""
- arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
- blank_label = tokenizer.get_blank_token()
- decodes = []
- for args in arg_maxes:
- decode = []
- for j, index in enumerate(args):
- if index != blank_label:
- if collapse_repeated and j != 0 and index == args[j - 1]:
- continue
- decode.append(index.item())
- decodes.append(tokenizer.decode(decode))
- return decodes
-
-
@click.command()
@click.option(
"--config_path",
@@ -46,11 +31,16 @@ def main(config_path: str, file_path: str) -> None:
model_config = config_dict.get("model", {})
tokenizer_config = config_dict.get("tokenizer", {})
inference_config = config_dict.get("inference", {})
+ decoder_config = config_dict.get("decoder", {})
- if inference_config["device"] == "cpu":
+ if inference_config.get("device", "") == "cpu":
device = "cpu"
- elif inference_config["device"] == "cuda":
+ elif inference_config.get("device", "") == "cuda":
device = "cuda" if torch.cuda.is_available() else "cpu"
+ elif inference_config.get("device", "") == "mps":
+ device = "mps"
+ else:
+ device = "cpu"
device = torch.device(device) # pylint: disable=no-member
tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"])
@@ -90,11 +80,16 @@ def main(config_path: str, file_path: str) -> None:
spec = spec.unsqueeze(0)
spec = spec.transpose(1, 2)
spec = spec.unsqueeze(0)
+ spec = spec.to(device)
output = model(spec) # pylint: disable=not-callable
output = F.log_softmax(output, dim=2) # (batch, time, n_class)
- decoded_preds = greedy_decoder(output, tokenizer)
- print(decoded_preds)
+ decoder = decoder_factory(decoder_config["type"])(tokenizer, decoder_config)
+
+ preds = decoder(output)
+ preds = " ".join(preds[0][0].words).strip()
+
+ print(preds)
if __name__ == "__main__":
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 9c7ede9..5277c16 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -12,10 +12,9 @@ from tqdm.autonotebook import tqdm
from swr2_asr.model_deep_speech import SpeechRecognitionModel
from swr2_asr.utils.data import DataProcessing, MLSDataset, Split
-from swr2_asr.utils.decoder import greedy_decoder
-from swr2_asr.utils.tokenizer import CharTokenizer
-
+from swr2_asr.utils.decoder import decoder_factory
from swr2_asr.utils.loss_scores import cer, wer
+from swr2_asr.utils.tokenizer import CharTokenizer
class IterMeter:
@@ -123,9 +122,6 @@ def test(test_args: TestArgs) -> tuple[float, float, float]:
# get values from test_args:
model, device, test_loader, criterion, tokenizer, decoder = test_args.values()
- if decoder == "greedy":
- decoder = greedy_decoder
-
model.eval()
test_loss = 0
test_cer, test_wer = [], []
@@ -141,12 +137,15 @@ def test(test_args: TestArgs) -> tuple[float, float, float]:
loss = criterion(output, labels, input_lengths, label_lengths)
test_loss += loss.item() / len(test_loader)
- decoded_preds, decoded_targets = greedy_decoder(
- output.transpose(0, 1), labels, label_lengths, tokenizer
- )
+ decoded_targets = tokenizer.decode_batch(labels)
+ decoded_preds = decoder(output.transpose(0, 1))
for j, _ in enumerate(decoded_preds):
- test_cer.append(cer(decoded_targets[j], decoded_preds[j]))
- test_wer.append(wer(decoded_targets[j], decoded_preds[j]))
+ if j >= len(decoded_targets):
+ break
+ pred = " ".join(decoded_preds[j][0].words).strip() # batch, top, words
+ target = decoded_targets[j]
+ test_cer.append(cer(target, pred))
+ test_wer.append(wer(target, pred))
avg_cer = sum(test_cer) / len(test_cer)
avg_wer = sum(test_wer) / len(test_wer)
@@ -187,6 +186,7 @@ def main(config_path: str):
dataset_config = config_dict.get("dataset", {})
tokenizer_config = config_dict.get("tokenizer", {})
checkpoints_config = config_dict.get("checkpoints", {})
+ decoder_config = config_dict.get("decoder", {})
if not os.path.isdir(dataset_config["dataset_root_path"]):
os.makedirs(dataset_config["dataset_root_path"])
@@ -262,12 +262,19 @@ def main(config_path: str):
if checkpoints_config["model_load_path"] is not None:
checkpoint = torch.load(checkpoints_config["model_load_path"], map_location=device)
- model.load_state_dict(checkpoint["model_state_dict"])
+ state_dict = {
+ k[len("module.") :] if k.startswith("module.") else k: v
+ for k, v in checkpoint["model_state_dict"].items()
+ }
+
+ model.load_state_dict(state_dict)
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
prev_epoch = checkpoint["epoch"]
iter_meter = IterMeter()
+ decoder = decoder_factory(decoder_config.get("type", "greedy"))(tokenizer, decoder_config)
+
for epoch in range(prev_epoch + 1, training_config["epochs"] + 1):
train_args: TrainArgs = {
"model": model,
@@ -283,14 +290,13 @@ def main(config_path: str):
train_loss = train(train_args)
test_loss, test_cer, test_wer = 0, 0, 0
-
test_args: TestArgs = {
"model": model,
"device": device,
"test_loader": valid_loader,
"criterion": criterion,
"tokenizer": tokenizer,
- "decoder": "greedy",
+ "decoder": decoder,
}
if training_config["eval_every_n"] != 0 and epoch % training_config["eval_every_n"] == 0:
diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py
index d551c98..74cd572 100644
--- a/swr2_asr/utils/data.py
+++ b/swr2_asr/utils/data.py
@@ -6,7 +6,7 @@ import numpy as np
import torch
import torchaudio
from torch import Tensor, nn
-from torch.utils.data import DataLoader, Dataset
+from torch.utils.data import Dataset
from torchaudio.datasets.utils import _extract_tar
from swr2_asr.utils.tokenizer import CharTokenizer
@@ -343,3 +343,21 @@ class MLSDataset(Dataset):
dataset_lookup_entry["chapterid"],
idx,
) # type: ignore
+
+
+def create_lexicon(vocab_counts_path, lexicon_path):
+ """Create a lexicon from the vocab_counts.txt file"""
+ words_list = []
+ with open(vocab_counts_path, "r", encoding="utf-8") as file:
+ for line in file:
+ words = line.split()
+ if len(words) >= 1:
+ word = words[0]
+ words_list.append(word)
+
+ with open(lexicon_path, "w", encoding="utf-8") as file:
+ for word in words_list:
+ file.write(f"{word} ")
+ for char in word:
+ file.write(char + " ")
+ file.write("<SPACE>")
diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py
index fcddb79..1fd002a 100644
--- a/swr2_asr/utils/decoder.py
+++ b/swr2_asr/utils/decoder.py
@@ -1,26 +1,155 @@
"""Decoder for CTC-based ASR.""" ""
+import os
+from dataclasses import dataclass
+
import torch
+from torchaudio.datasets.utils import _extract_tar
+from torchaudio.models.decoder import ctc_decoder
+from swr2_asr.utils.data import create_lexicon
from swr2_asr.utils.tokenizer import CharTokenizer
-# TODO: refactor to use torch CTC decoder class
-def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True):
- """Greedily decode a sequence."""
- blank_label = tokenizer.get_blank_token()
- arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
- decodes = []
- targets = []
- for i, args in enumerate(arg_maxes):
- decode = []
- targets.append(tokenizer.decode(labels[i][: label_lengths[i]].tolist()))
- for j, index in enumerate(args):
- if index != blank_label:
- if collapse_repeated and j != 0 and index == args[j - 1]:
- continue
- decode.append(index.item())
- decodes.append(tokenizer.decode(decode))
- return decodes, targets
-
-
-# TODO: add beam search decoder
+@dataclass
+class DecoderOutput:
+ """Decoder output."""
+
+ words: list[str]
+
+
+def decoder_factory(decoder_type: str = "greedy") -> callable:
+ """Decoder factory."""
+ if decoder_type == "greedy":
+ return get_greedy_decoder
+ if decoder_type == "lm":
+ return get_beam_search_decoder
+ raise NotImplementedError
+
+
+def get_greedy_decoder(
+ tokenizer: CharTokenizer, # pylint: disable=redefined-outer-name
+ *_,
+):
+ """Greedy decoder."""
+ return GreedyDecoder(tokenizer)
+
+
+def get_beam_search_decoder(
+ tokenizer: CharTokenizer, # pylint: disable=redefined-outer-name
+ hparams: dict, # pylint: disable=redefined-outer-name
+):
+ """Beam search decoder."""
+ hparams = hparams.get("lm", {})
+ language, lang_model_path, n_gram, beam_size, beam_threshold, n_best, lm_weight, word_score = (
+ hparams["language"],
+ hparams["language_model_path"],
+ hparams["n_gram"],
+ hparams["beam_size"],
+ hparams["beam_threshold"],
+ hparams["n_best"],
+ hparams["lm_weight"],
+ hparams["word_score"],
+ )
+
+ if not os.path.isdir(os.path.join(lang_model_path, f"mls_lm_{language}")):
+ url = f"https://dl.fbaipublicfiles.com/mls/mls_lm_{language}.tar.gz"
+ torch.hub.download_url_to_file(url, f"data/mls_lm_{language}.tar.gz")
+ _extract_tar("data/mls_lm_{language}.tar.gz", overwrite=True)
+
+ tokens_path = os.path.join(lang_model_path, f"mls_lm_{language}", "tokens.txt")
+ if not os.path.isfile(tokens_path):
+ tokenizer.create_tokens_txt(tokens_path)
+
+ lexicon_path = os.path.join(lang_model_path, f"mls_lm_{language}", "lexicon.txt")
+ if not os.path.isfile(lexicon_path):
+ occurences_path = os.path.join(lang_model_path, f"mls_lm_{language}", "vocab_counts.txt")
+ create_lexicon(occurences_path, lexicon_path)
+
+ lm_path = os.path.join(lang_model_path, f"mls_lm_{language}", f"{n_gram}-gram_lm.arpa")
+
+ decoder = ctc_decoder(
+ lexicon=lexicon_path,
+ tokens=tokens_path,
+ lm=lm_path,
+ blank_token="_",
+ sil_token="<SPACE>",
+ unk_word="<UNK>",
+ nbest=n_best,
+ beam_size=beam_size,
+ beam_threshold=beam_threshold,
+ lm_weight=lm_weight,
+ word_score=word_score,
+ )
+ return decoder
+
+
+class GreedyDecoder:
+ """Greedy decoder."""
+
+ def __init__(self, tokenizer: CharTokenizer): # pylint: disable=redefined-outer-name
+ self.tokenizer = tokenizer
+
+ def __call__(
+ self, output, greedy_type: str = "inference", labels=None, label_lengths=None
+ ): # pylint: disable=redefined-outer-name
+ """Greedily decode a sequence."""
+ if greedy_type == "train":
+ res = self.train(output, labels, label_lengths)
+ if greedy_type == "inference":
+ res = self.inference(output)
+
+ res = [x.split(" ") for x in res]
+ res = [[DecoderOutput(x)] for x in res]
+ return res
+
+ def train(self, output, labels, label_lengths):
+ """Greedily decode a sequence with known labels."""
+ blank_label = tokenizer.get_blank_token()
+ arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
+ decodes = []
+ targets = []
+ for i, args in enumerate(arg_maxes):
+ decode = []
+ targets.append(self.tokenizer.decode(labels[i][: label_lengths[i]].tolist()))
+ for j, index in enumerate(args):
+ if index != blank_label:
+ if j != 0 and index == args[j - 1]:
+ continue
+ decode.append(index.item())
+ decodes.append(self.tokenizer.decode(decode))
+ return decodes, targets
+
+ def inference(self, output):
+ """Greedily decode a sequence."""
+ collapse_repeated = True
+ arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
+ blank_label = self.tokenizer.get_blank_token()
+ decodes = []
+ for args in arg_maxes:
+ decode = []
+ for j, index in enumerate(args):
+ if index != blank_label:
+ if collapse_repeated and j != 0 and index == args[j - 1]:
+ continue
+ decode.append(index.item())
+ decodes.append(self.tokenizer.decode(decode))
+
+ return decodes
+
+
+if __name__ == "__main__":
+ tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json")
+ tokenizer.create_tokens_txt("data/tokenizers/tokens_german.txt")
+
+ hparams = {
+ "language": "german",
+ "lang_model_path": "data",
+ "n_gram": 3,
+ "beam_size": 100,
+ "beam_threshold": 100,
+ "n_best": 1,
+ "lm_weight": 0.5,
+ "word_score": 1.0,
+ }
+
+ get_beam_search_decoder(tokenizer, hparams)
diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py
index 1cc7b84..4e3fddd 100644
--- a/swr2_asr/utils/tokenizer.py
+++ b/swr2_asr/utils/tokenizer.py
@@ -29,9 +29,17 @@ class CharTokenizer:
"""Use a character map and convert integer labels to an text sequence"""
string = []
for i in labels:
+ i = int(i)
string.append(self.index_map[i])
return "".join(string).replace("<SPACE>", " ")
+ def decode_batch(self, labels: list[list[int]]) -> list[str]:
+ """Use a character map and convert integer labels to an text sequence"""
+ string = []
+ for label in labels:
+ string.append(self.decode(label))
+ return string
+
def get_vocab_size(self) -> int:
"""Get the number of unique characters in the dataset"""
return len(self.char_map)
@@ -120,3 +128,9 @@ class CharTokenizer:
load_tokenizer.char_map[char] = int(index)
load_tokenizer.index_map[int(index)] = char
return load_tokenizer
+
+ def create_tokens_txt(self, path: str):
+ """Create a txt file with all the characters"""
+ with open(path, "w", encoding="utf-8") as file:
+ for char, _ in self.char_map.items():
+ file.write(f"{char}\n")