diff options
author | JoJoBarthold2 | 2023-08-19 13:29:20 +0200 |
---|---|---|
committer | JoJoBarthold2 | 2023-08-19 13:29:20 +0200 |
commit | d5568bb9f51c4b586c7bd8537140cb1e201f5840 (patch) | |
tree | 4786e239a299fd7db231db799ccbcb93a8a830fa /swr2_asr | |
parent | fd3106c2cce565d378def73b0d77b0123f68523b (diff) |
also now saves loss ( hahah funny meme)
| ||
|| |_
Diffstat (limited to 'swr2_asr')
-rw-r--r-- | swr2_asr/train.py | 6 |
1 files changed, 2 insertions, 4 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 6eaf4c1..9a8620f 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -351,10 +351,8 @@ def train( [{batch_idx * len(spectrograms)}/{data_len} \ ({100.0 * batch_idx / len(train_loader)}%)]\t \ Loss: {loss.item()}" - ) - return loss - + return loss.item() def test(model, device, test_loader, criterion): """Test""" @@ -476,7 +474,7 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), - },MODEL_SAVE_PATH) + 'loss': loss},MODEL_SAVE_PATH) test(model=model, device=device, test_loader=test_loader, criterion=criterion) |