diff options
Diffstat (limited to 'swr2_asr')
-rw-r--r-- | swr2_asr/train.py | 6 |
1 files changed, 2 insertions, 4 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 6eaf4c1..9a8620f 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -351,10 +351,8 @@ def train( [{batch_idx * len(spectrograms)}/{data_len} \ ({100.0 * batch_idx / len(train_loader)}%)]\t \ Loss: {loss.item()}" - ) - return loss - + return loss.item() def test(model, device, test_loader, criterion): """Test""" @@ -476,7 +474,7 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), - },MODEL_SAVE_PATH) + 'loss': loss},MODEL_SAVE_PATH) test(model=model, device=device, test_loader=test_loader, criterion=criterion) |