aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/train.py
diff options
context:
space:
mode:
authorJoJoBarthold22023-09-05 10:25:28 +0200
committerJoJoBarthold22023-09-05 10:25:28 +0200
commite9721154ed94dbedb3443bac83a8e1e3f9d4b734 (patch)
tree72f76724bc4bedc5bae6d5cdda11403309f14a1a /swr2_asr/train.py
parent93e49e708fa59613406249069d31c2f6c8f2d2ab (diff)
model now saves validation_loss and cer and wer
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r--swr2_asr/train.py9
1 files changed, 7 insertions, 2 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 63deb72..6f3bc6c 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -253,7 +253,7 @@ def run(
iter_meter,
)
- test(
+ test_loss,avg_cer,avg_wer = test(
model=model,
device=device,
test_loader=valid_loader,
@@ -262,7 +262,12 @@ def run(
)
print("saving epoch", str(epoch))
torch.save(
- {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss},
+ {"epoch": epoch,
+ "model_state_dict": model.state_dict(),
+ "loss": loss,
+ "test_loss": test_loss,
+ "avg_cer": avg_cer,
+ "avg_wer": avg_wer},
path + str(epoch),
)