diff options
author | Marvin | 2023-09-05 13:33:13 +0200 |
---|---|---|
committer | GitHub | 2023-09-05 13:33:13 +0200 |
commit | 46b23fd90f5ef9c3126ee66473b012fa715da008 (patch) | |
tree | 158fbd0db3541d9c0c9497fe0b1bf6f7221cc7ff /swr2_asr/train.py | |
parent | 93e49e708fa59613406249069d31c2f6c8f2d2ab (diff) | |
parent | 3dd8dfcf3dd8d1e127e7f8fd68224c164d9f6fb1 (diff) |
Merge pull request #29 from Algo-Boys/plot
Plot
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r-- | swr2_asr/train.py | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 63deb72..6f3bc6c 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -253,7 +253,7 @@ def run( iter_meter, ) - test( + test_loss,avg_cer,avg_wer = test( model=model, device=device, test_loader=valid_loader, @@ -262,7 +262,12 @@ def run( ) print("saving epoch", str(epoch)) torch.save( - {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss}, + {"epoch": epoch, + "model_state_dict": model.state_dict(), + "loss": loss, + "test_loss": test_loss, + "avg_cer": avg_cer, + "avg_wer": avg_wer}, path + str(epoch), ) |