aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/train.py
diff options
context:
space:
mode:
authorPherkel2023-09-18 18:13:46 +0200
committerGitHub2023-09-18 18:13:46 +0200
commitf94506764bde3e4d41dc593e9d11aa7330c00e30 (patch)
tree6fc438536a72e195805c1aea97926f4c9bbd4f85 /swr2_asr/train.py
parent8b3a0b47813733ef67befa6959a4d24f8518b5b7 (diff)
parent21a3b1d7cc8544fa0031b8934283382bdfd1d8f1 (diff)
Merge pull request #38 from Algo-Boys/decoder
Decoder
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r--swr2_asr/train.py34
1 files changed, 20 insertions, 14 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 9c7ede9..5277c16 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -12,10 +12,9 @@ from tqdm.autonotebook import tqdm
from swr2_asr.model_deep_speech import SpeechRecognitionModel
from swr2_asr.utils.data import DataProcessing, MLSDataset, Split
-from swr2_asr.utils.decoder import greedy_decoder
-from swr2_asr.utils.tokenizer import CharTokenizer
-
+from swr2_asr.utils.decoder import decoder_factory
from swr2_asr.utils.loss_scores import cer, wer
+from swr2_asr.utils.tokenizer import CharTokenizer
class IterMeter:
@@ -123,9 +122,6 @@ def test(test_args: TestArgs) -> tuple[float, float, float]:
# get values from test_args:
model, device, test_loader, criterion, tokenizer, decoder = test_args.values()
- if decoder == "greedy":
- decoder = greedy_decoder
-
model.eval()
test_loss = 0
test_cer, test_wer = [], []
@@ -141,12 +137,15 @@ def test(test_args: TestArgs) -> tuple[float, float, float]:
loss = criterion(output, labels, input_lengths, label_lengths)
test_loss += loss.item() / len(test_loader)
- decoded_preds, decoded_targets = greedy_decoder(
- output.transpose(0, 1), labels, label_lengths, tokenizer
- )
+ decoded_targets = tokenizer.decode_batch(labels)
+ decoded_preds = decoder(output.transpose(0, 1))
for j, _ in enumerate(decoded_preds):
- test_cer.append(cer(decoded_targets[j], decoded_preds[j]))
- test_wer.append(wer(decoded_targets[j], decoded_preds[j]))
+ if j >= len(decoded_targets):
+ break
+ pred = " ".join(decoded_preds[j][0].words).strip() # batch, top, words
+ target = decoded_targets[j]
+ test_cer.append(cer(target, pred))
+ test_wer.append(wer(target, pred))
avg_cer = sum(test_cer) / len(test_cer)
avg_wer = sum(test_wer) / len(test_wer)
@@ -187,6 +186,7 @@ def main(config_path: str):
dataset_config = config_dict.get("dataset", {})
tokenizer_config = config_dict.get("tokenizer", {})
checkpoints_config = config_dict.get("checkpoints", {})
+ decoder_config = config_dict.get("decoder", {})
if not os.path.isdir(dataset_config["dataset_root_path"]):
os.makedirs(dataset_config["dataset_root_path"])
@@ -262,12 +262,19 @@ def main(config_path: str):
if checkpoints_config["model_load_path"] is not None:
checkpoint = torch.load(checkpoints_config["model_load_path"], map_location=device)
- model.load_state_dict(checkpoint["model_state_dict"])
+ state_dict = {
+ k[len("module.") :] if k.startswith("module.") else k: v
+ for k, v in checkpoint["model_state_dict"].items()
+ }
+
+ model.load_state_dict(state_dict)
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
prev_epoch = checkpoint["epoch"]
iter_meter = IterMeter()
+ decoder = decoder_factory(decoder_config.get("type", "greedy"))(tokenizer, decoder_config)
+
for epoch in range(prev_epoch + 1, training_config["epochs"] + 1):
train_args: TrainArgs = {
"model": model,
@@ -283,14 +290,13 @@ def main(config_path: str):
train_loss = train(train_args)
test_loss, test_cer, test_wer = 0, 0, 0
-
test_args: TestArgs = {
"model": model,
"device": device,
"test_loader": valid_loader,
"criterion": criterion,
"tokenizer": tokenizer,
- "decoder": "greedy",
+ "decoder": decoder,
}
if training_config["eval_every_n"] != 0 and epoch % training_config["eval_every_n"] == 0: