diff options
-rw-r--r-- | pyproject.toml | 2 | ||||
-rw-r--r-- | swr2_asr/inference_test.py (renamed from swr2_asr/train_3.py) | 0 | ||||
-rw-r--r-- | swr2_asr/train.py (renamed from swr2_asr/train_2.py) | 17 |
3 files changed, 15 insertions, 4 deletions
diff --git a/pyproject.toml b/pyproject.toml index 791a76e..8490aa5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ mypy = "^1.5.1" pylint = "^2.17.5" [tool.poetry.scripts] -train = "swr2_asr.train_2:run" +train = "swr2_asr.train:run_cli" [build-system] requires = ["poetry-core"] diff --git a/swr2_asr/train_3.py b/swr2_asr/inference_test.py index a6b0010..a6b0010 100644 --- a/swr2_asr/train_3.py +++ b/swr2_asr/inference_test.py diff --git a/swr2_asr/train_2.py b/swr2_asr/train.py index bea5bf4..8ee96b9 100644 --- a/swr2_asr/train_2.py +++ b/swr2_asr/train.py @@ -1,5 +1,6 @@ """Training script for the ASR model.""" from AudioLoader.speech.mls import MultilingualLibriSpeech +import click import torch import torch.nn as nn import torch.optim as optim @@ -376,7 +377,8 @@ def test(model, device, test_loader, criterion): ) -def run(lr: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None: +def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None: + """Runs the training script.""" hparams = { "n_cnn_layers": 3, "n_rnn_layers": 5, @@ -385,7 +387,7 @@ def run(lr: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None: "n_feats": 128, "stride": 2, "dropout": 0.1, - "learning_rate": lr, + "learning_rate": learning_rate, "batch_size": batch_size, "epochs": epochs, } @@ -460,5 +462,14 @@ def run(lr: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None: test(model=model, device=device, test_loader=test_loader, criterion=criterion) +@click.command() +@click.option("--learning-rate", default=1e-3, help="Learning rate") +@click.option("--batch_size", default=1, help="Batch size") +@click.option("--epochs", default=1, help="Number of epochs") +def run_cli(learning_rate: float, batch_size: int, epochs: int) -> None: + """Runs the training script.""" + run(learning_rate=learning_rate, batch_size=batch_size, epochs=epochs) + + if __name__ == "__main__": - run(lr=5e-4, batch_size=16, epochs=1) + run(learning_rate=5e-4, batch_size=16, epochs=1) |