diff options
-rw-r--r-- | config.philipp.yaml | 10 | ||||
-rw-r--r-- | pyproject.toml | 2 | ||||
-rw-r--r-- | swr2_asr/train.py | 32 | ||||
-rw-r--r-- | swr2_asr/utils/decoder.py | 2 | ||||
-rw-r--r-- | swr2_asr/utils/tokenizer.py | 8 |
5 files changed, 33 insertions, 21 deletions
diff --git a/config.philipp.yaml b/config.philipp.yaml index 38a68f8..608720f 100644 --- a/config.philipp.yaml +++ b/config.philipp.yaml @@ -3,7 +3,7 @@ dataset: dataset_root_path: "/Volumes/pherkel 2/SWR2-ASR" # files will be downloaded into this dir language_name: "mls_german_opus" limited_supervision: True # set to True if you want to use limited supervision - dataset_percentage: 0.15 # percentage of dataset to use (1.0 = 100%) + dataset_percentage: 0.01 # percentage of dataset to use (1.0 = 100%) shuffle: True model: @@ -33,13 +33,13 @@ decoder: training: learning_rate: 0.0005 batch_size: 8 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU) - epochs: 3 - eval_every_n: 3 # evaluate every n epochs + epochs: 100 + eval_every_n: 1 # evaluate every n epochs num_workers: 8 # number of workers for dataloader checkpoints: # use "~" to disable saving/loading - model_load_path: "YOUR/PATH" # path to load model from - model_save_path: "YOUR/PATH" # path to save model to + model_load_path: "data/epoch67" # path to load model from + model_save_path: ~ # path to save model to inference: model_load_path: "data/epoch67" # path to load model from
\ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index f6d19dd..2f26e5e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,7 +35,7 @@ target-version = "py310" line-length = 100 [tool.poetry.scripts] -train = "swr2_asr.train:run_cli" +train = "swr2_asr.train:main" train-bpe-tokenizer = "swr2_asr.tokenizer:train_bpe_tokenizer" train-char-tokenizer = "swr2_asr.tokenizer:train_char_tokenizer" diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 9c7ede9..1e57ba0 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -12,10 +12,9 @@ from tqdm.autonotebook import tqdm from swr2_asr.model_deep_speech import SpeechRecognitionModel from swr2_asr.utils.data import DataProcessing, MLSDataset, Split -from swr2_asr.utils.decoder import greedy_decoder -from swr2_asr.utils.tokenizer import CharTokenizer - +from swr2_asr.utils.decoder import decoder_factory from swr2_asr.utils.loss_scores import cer, wer +from swr2_asr.utils.tokenizer import CharTokenizer class IterMeter: @@ -123,9 +122,6 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: # get values from test_args: model, device, test_loader, criterion, tokenizer, decoder = test_args.values() - if decoder == "greedy": - decoder = greedy_decoder - model.eval() test_loss = 0 test_cer, test_wer = [], [] @@ -141,12 +137,13 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: loss = criterion(output, labels, input_lengths, label_lengths) test_loss += loss.item() / len(test_loader) - decoded_preds, decoded_targets = greedy_decoder( - output.transpose(0, 1), labels, label_lengths, tokenizer - ) + decoded_targets = tokenizer.decode_batch(labels) + decoded_preds = decoder(output.transpose(0, 1)) for j, _ in enumerate(decoded_preds): - test_cer.append(cer(decoded_targets[j], decoded_preds[j])) - test_wer.append(wer(decoded_targets[j], decoded_preds[j])) + if j >= len(decoded_targets): + break + test_cer.append(cer(decoded_targets[j], decoded_preds[j][0].words[0])) + test_wer.append(wer(decoded_targets[j], decoded_preds[j][0].words[0])) avg_cer = sum(test_cer) / len(test_cer) avg_wer = sum(test_wer) / len(test_wer) @@ -187,6 +184,7 @@ def main(config_path: str): dataset_config = config_dict.get("dataset", {}) tokenizer_config = config_dict.get("tokenizer", {}) checkpoints_config = config_dict.get("checkpoints", {}) + decoder_config = config_dict.get("decoder", {}) if not os.path.isdir(dataset_config["dataset_root_path"]): os.makedirs(dataset_config["dataset_root_path"]) @@ -262,12 +260,19 @@ def main(config_path: str): if checkpoints_config["model_load_path"] is not None: checkpoint = torch.load(checkpoints_config["model_load_path"], map_location=device) - model.load_state_dict(checkpoint["model_state_dict"]) + state_dict = { + k[len("module.") :] if k.startswith("module.") else k: v + for k, v in checkpoint["model_state_dict"].items() + } + + model.load_state_dict(state_dict) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) prev_epoch = checkpoint["epoch"] iter_meter = IterMeter() + decoder = decoder_factory(decoder_config.get("type", "greedy"))(tokenizer, decoder_config) + for epoch in range(prev_epoch + 1, training_config["epochs"] + 1): train_args: TrainArgs = { "model": model, @@ -283,14 +288,13 @@ def main(config_path: str): train_loss = train(train_args) test_loss, test_cer, test_wer = 0, 0, 0 - test_args: TestArgs = { "model": model, "device": device, "test_loader": valid_loader, "criterion": criterion, "tokenizer": tokenizer, - "decoder": "greedy", + "decoder": decoder, } if training_config["eval_every_n"] != 0 and epoch % training_config["eval_every_n"] == 0: diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py index 2b6d29b..6ffdef2 100644 --- a/swr2_asr/utils/decoder.py +++ b/swr2_asr/utils/decoder.py @@ -1,6 +1,6 @@ """Decoder for CTC-based ASR.""" "" -from dataclasses import dataclass import os +from dataclasses import dataclass import torch from torchaudio.datasets.utils import _extract_tar diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py index b1de83e..4e3fddd 100644 --- a/swr2_asr/utils/tokenizer.py +++ b/swr2_asr/utils/tokenizer.py @@ -29,9 +29,17 @@ class CharTokenizer: """Use a character map and convert integer labels to an text sequence""" string = [] for i in labels: + i = int(i) string.append(self.index_map[i]) return "".join(string).replace("<SPACE>", " ") + def decode_batch(self, labels: list[list[int]]) -> list[str]: + """Use a character map and convert integer labels to an text sequence""" + string = [] + for label in labels: + string.append(self.decode(label)) + return string + def get_vocab_size(self) -> int: """Get the number of unique characters in the dataset""" return len(self.char_map) |