diff options
-rw-r--r-- | swr2_asr/train.py | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 1ef42aa..81312d9 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -9,6 +9,8 @@ from torch.utils.data import DataLoader import torchaudio from .loss_scores import cer, wer +MODEL_SAVE_PATH = "models/model.pt" +LOSS class TextTransform: """Maps characters to integers and vice versa""" @@ -468,6 +470,11 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No epoch, iter_meter, ) + if epoch%3 == 0 or epoch == epochs: + torch.save({ + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + },MODEL_SAVE_PATH) test(model=model, device=device, test_loader=test_loader, criterion=criterion) |