From fd3106c2cce565d378def73b0d77b0123f68523b Mon Sep 17 00:00:00 2001 From: JoJoBarthold2 Date: Sat, 19 Aug 2023 13:26:01 +0200 Subject: train now returns loss so it can be saved ( amen ) --- swr2_asr/train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 81312d9..6eaf4c1 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -351,7 +351,9 @@ def train( [{batch_idx * len(spectrograms)}/{data_len} \ ({100.0 * batch_idx / len(train_loader)}%)]\t \ Loss: {loss.item()}" + ) + return loss def test(model, device, test_loader, criterion): @@ -460,7 +462,7 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No iter_meter = IterMeter() for epoch in range(1, epochs + 1): - train( + loss = train( model, device, train_loader, -- cgit v1.2.3