aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--pyproject.toml2
-rw-r--r--swr2_asr/inference_test.py (renamed from swr2_asr/train_3.py)0
-rw-r--r--swr2_asr/train.py (renamed from swr2_asr/train_2.py)17
3 files changed, 15 insertions, 4 deletions
diff --git a/pyproject.toml b/pyproject.toml
index 791a76e..8490aa5 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -22,7 +22,7 @@ mypy = "^1.5.1"
pylint = "^2.17.5"
[tool.poetry.scripts]
-train = "swr2_asr.train_2:run"
+train = "swr2_asr.train:run_cli"
[build-system]
requires = ["poetry-core"]
diff --git a/swr2_asr/train_3.py b/swr2_asr/inference_test.py
index a6b0010..a6b0010 100644
--- a/swr2_asr/train_3.py
+++ b/swr2_asr/inference_test.py
diff --git a/swr2_asr/train_2.py b/swr2_asr/train.py
index bea5bf4..8ee96b9 100644
--- a/swr2_asr/train_2.py
+++ b/swr2_asr/train.py
@@ -1,5 +1,6 @@
"""Training script for the ASR model."""
from AudioLoader.speech.mls import MultilingualLibriSpeech
+import click
import torch
import torch.nn as nn
import torch.optim as optim
@@ -376,7 +377,8 @@ def test(model, device, test_loader, criterion):
)
-def run(lr: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None:
+def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None:
+ """Runs the training script."""
hparams = {
"n_cnn_layers": 3,
"n_rnn_layers": 5,
@@ -385,7 +387,7 @@ def run(lr: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None:
"n_feats": 128,
"stride": 2,
"dropout": 0.1,
- "learning_rate": lr,
+ "learning_rate": learning_rate,
"batch_size": batch_size,
"epochs": epochs,
}
@@ -460,5 +462,14 @@ def run(lr: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None:
test(model=model, device=device, test_loader=test_loader, criterion=criterion)
+@click.command()
+@click.option("--learning-rate", default=1e-3, help="Learning rate")
+@click.option("--batch_size", default=1, help="Batch size")
+@click.option("--epochs", default=1, help="Number of epochs")
+def run_cli(learning_rate: float, batch_size: int, epochs: int) -> None:
+ """Runs the training script."""
+ run(learning_rate=learning_rate, batch_size=batch_size, epochs=epochs)
+
+
if __name__ == "__main__":
- run(lr=5e-4, batch_size=16, epochs=1)
+ run(learning_rate=5e-4, batch_size=16, epochs=1)