diff options
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r-- | swr2_asr/train.py | 34 |
1 files changed, 20 insertions, 14 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 9c7ede9..5277c16 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -12,10 +12,9 @@ from tqdm.autonotebook import tqdm from swr2_asr.model_deep_speech import SpeechRecognitionModel from swr2_asr.utils.data import DataProcessing, MLSDataset, Split -from swr2_asr.utils.decoder import greedy_decoder -from swr2_asr.utils.tokenizer import CharTokenizer - +from swr2_asr.utils.decoder import decoder_factory from swr2_asr.utils.loss_scores import cer, wer +from swr2_asr.utils.tokenizer import CharTokenizer class IterMeter: @@ -123,9 +122,6 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: # get values from test_args: model, device, test_loader, criterion, tokenizer, decoder = test_args.values() - if decoder == "greedy": - decoder = greedy_decoder - model.eval() test_loss = 0 test_cer, test_wer = [], [] @@ -141,12 +137,15 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: loss = criterion(output, labels, input_lengths, label_lengths) test_loss += loss.item() / len(test_loader) - decoded_preds, decoded_targets = greedy_decoder( - output.transpose(0, 1), labels, label_lengths, tokenizer - ) + decoded_targets = tokenizer.decode_batch(labels) + decoded_preds = decoder(output.transpose(0, 1)) for j, _ in enumerate(decoded_preds): - test_cer.append(cer(decoded_targets[j], decoded_preds[j])) - test_wer.append(wer(decoded_targets[j], decoded_preds[j])) + if j >= len(decoded_targets): + break + pred = " ".join(decoded_preds[j][0].words).strip() # batch, top, words + target = decoded_targets[j] + test_cer.append(cer(target, pred)) + test_wer.append(wer(target, pred)) avg_cer = sum(test_cer) / len(test_cer) avg_wer = sum(test_wer) / len(test_wer) @@ -187,6 +186,7 @@ def main(config_path: str): dataset_config = config_dict.get("dataset", {}) tokenizer_config = config_dict.get("tokenizer", {}) checkpoints_config = config_dict.get("checkpoints", {}) + decoder_config = config_dict.get("decoder", {}) if not os.path.isdir(dataset_config["dataset_root_path"]): os.makedirs(dataset_config["dataset_root_path"]) @@ -262,12 +262,19 @@ def main(config_path: str): if checkpoints_config["model_load_path"] is not None: checkpoint = torch.load(checkpoints_config["model_load_path"], map_location=device) - model.load_state_dict(checkpoint["model_state_dict"]) + state_dict = { + k[len("module.") :] if k.startswith("module.") else k: v + for k, v in checkpoint["model_state_dict"].items() + } + + model.load_state_dict(state_dict) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) prev_epoch = checkpoint["epoch"] iter_meter = IterMeter() + decoder = decoder_factory(decoder_config.get("type", "greedy"))(tokenizer, decoder_config) + for epoch in range(prev_epoch + 1, training_config["epochs"] + 1): train_args: TrainArgs = { "model": model, @@ -283,14 +290,13 @@ def main(config_path: str): train_loss = train(train_args) test_loss, test_cer, test_wer = 0, 0, 0 - test_args: TestArgs = { "model": model, "device": device, "test_loader": valid_loader, "criterion": criterion, "tokenizer": tokenizer, - "decoder": "greedy", + "decoder": decoder, } if training_config["eval_every_n"] != 0 and epoch % training_config["eval_every_n"] == 0: |