aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoJoBarthold22023-08-19 13:26:01 +0200
committerJoJoBarthold22023-08-19 13:26:01 +0200
commitfd3106c2cce565d378def73b0d77b0123f68523b (patch)
treef0e923f0875fc1bcec890d622bf45ae86cc370fc
parent33a52ef9f681b8665b5b6000e74de447e0983c78 (diff)
train now returns loss so it can be saved ( amen )
-rw-r--r--swr2_asr/train.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 81312d9..6eaf4c1 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -351,7 +351,9 @@ def train(
[{batch_idx * len(spectrograms)}/{data_len} \
({100.0 * batch_idx / len(train_loader)}%)]\t \
Loss: {loss.item()}"
+
)
+ return loss
def test(model, device, test_loader, criterion):
@@ -460,7 +462,7 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No
iter_meter = IterMeter()
for epoch in range(1, epochs + 1):
- train(
+ loss = train(
model,
device,
train_loader,