diff options
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r-- | swr2_asr/train.py | 25 |
1 files changed, 12 insertions, 13 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 6f3bc6c..40626e7 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -15,8 +15,6 @@ from swr2_asr.utils import MLSDataset, Split, collate_fn from .loss_scores import cer, wer -# TODO: improve naming of functions - class HParams(TypedDict): """Type for the hyperparameters of the model.""" @@ -157,10 +155,10 @@ def run( # load dataset train_dataset = MLSDataset( - dataset_path, language, Split.TRAIN, download=True, spectrogram_hparams=None + dataset_path, language, Split.TRAIN, download=True, spectrogram_hparams=None, limited=True ) valid_dataset = MLSDataset( - dataset_path, language, Split.VALID, download=True, spectrogram_hparams=None + dataset_path, language, Split.VALID, download=True, spectrogram_hparams=None, limited=True ) # load tokenizer (bpe by default): @@ -171,7 +169,6 @@ def run( dataset_path=dataset_path, language=language, split="all", - download=False, out_path="data/tokenizers/char_tokenizer_german.json", ) @@ -211,7 +208,7 @@ def run( # enable flag to find the most compatible algorithms in advance if use_cuda: - torch.backends.cudnn.benchmark = True + torch.backends.cudnn.benchmark = True # pylance: disable=no-member model = SpeechRecognitionModel( hparams["n_cnn_layers"], @@ -253,7 +250,7 @@ def run( iter_meter, ) - test_loss,avg_cer,avg_wer = test( + test_loss, avg_cer, avg_wer = test( model=model, device=device, test_loader=valid_loader, @@ -262,12 +259,14 @@ def run( ) print("saving epoch", str(epoch)) torch.save( - {"epoch": epoch, - "model_state_dict": model.state_dict(), - "loss": loss, - "test_loss": test_loss, - "avg_cer": avg_cer, - "avg_wer": avg_wer}, + { + "epoch": epoch, + "model_state_dict": model.state_dict(), + "loss": loss, + "test_loss": test_loss, + "avg_cer": avg_cer, + "avg_wer": avg_wer, + }, path + str(epoch), ) |