aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r--swr2_asr/train.py25
1 files changed, 12 insertions, 13 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 6f3bc6c..40626e7 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -15,8 +15,6 @@ from swr2_asr.utils import MLSDataset, Split, collate_fn
from .loss_scores import cer, wer
-# TODO: improve naming of functions
-
class HParams(TypedDict):
"""Type for the hyperparameters of the model."""
@@ -157,10 +155,10 @@ def run(
# load dataset
train_dataset = MLSDataset(
- dataset_path, language, Split.TRAIN, download=True, spectrogram_hparams=None
+ dataset_path, language, Split.TRAIN, download=True, spectrogram_hparams=None, limited=True
)
valid_dataset = MLSDataset(
- dataset_path, language, Split.VALID, download=True, spectrogram_hparams=None
+ dataset_path, language, Split.VALID, download=True, spectrogram_hparams=None, limited=True
)
# load tokenizer (bpe by default):
@@ -171,7 +169,6 @@ def run(
dataset_path=dataset_path,
language=language,
split="all",
- download=False,
out_path="data/tokenizers/char_tokenizer_german.json",
)
@@ -211,7 +208,7 @@ def run(
# enable flag to find the most compatible algorithms in advance
if use_cuda:
- torch.backends.cudnn.benchmark = True
+ torch.backends.cudnn.benchmark = True # pylance: disable=no-member
model = SpeechRecognitionModel(
hparams["n_cnn_layers"],
@@ -253,7 +250,7 @@ def run(
iter_meter,
)
- test_loss,avg_cer,avg_wer = test(
+ test_loss, avg_cer, avg_wer = test(
model=model,
device=device,
test_loader=valid_loader,
@@ -262,12 +259,14 @@ def run(
)
print("saving epoch", str(epoch))
torch.save(
- {"epoch": epoch,
- "model_state_dict": model.state_dict(),
- "loss": loss,
- "test_loss": test_loss,
- "avg_cer": avg_cer,
- "avg_wer": avg_wer},
+ {
+ "epoch": epoch,
+ "model_state_dict": model.state_dict(),
+ "loss": loss,
+ "test_loss": test_loss,
+ "avg_cer": avg_cer,
+ "avg_wer": avg_wer,
+ },
path + str(epoch),
)