aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--swr2_asr/train.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 1ef42aa..81312d9 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -9,6 +9,8 @@ from torch.utils.data import DataLoader
import torchaudio
from .loss_scores import cer, wer
+MODEL_SAVE_PATH = "models/model.pt"
+LOSS
class TextTransform:
"""Maps characters to integers and vice versa"""
@@ -468,6 +470,11 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No
epoch,
iter_meter,
)
+ if epoch%3 == 0 or epoch == epochs:
+ torch.save({
+ 'epoch': epoch,
+ 'model_state_dict': model.state_dict(),
+ },MODEL_SAVE_PATH)
test(model=model, device=device, test_loader=test_loader, criterion=criterion)