diff options
author | JoJoBarthold2 | 2023-08-19 13:26:01 +0200 |
---|---|---|
committer | JoJoBarthold2 | 2023-08-19 13:26:01 +0200 |
commit | fd3106c2cce565d378def73b0d77b0123f68523b (patch) | |
tree | f0e923f0875fc1bcec890d622bf45ae86cc370fc | |
parent | 33a52ef9f681b8665b5b6000e74de447e0983c78 (diff) |
train now returns loss so it can be saved ( amen )
-rw-r--r-- | swr2_asr/train.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 81312d9..6eaf4c1 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -351,7 +351,9 @@ def train( [{batch_idx * len(spectrograms)}/{data_len} \ ({100.0 * batch_idx / len(train_loader)}%)]\t \ Loss: {loss.item()}" + ) + return loss def test(model, device, test_loader, criterion): @@ -460,7 +462,7 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No iter_meter = IterMeter() for epoch in range(1, epochs + 1): - train( + loss = train( model, device, train_loader, |