aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr
diff options
context:
space:
mode:
authorMarvin Borner2023-08-19 14:22:48 +0200
committerMarvin Borner2023-08-19 14:22:48 +0200
commit897e74f695e291029a08b280c2cea40a2a9639cc (patch)
tree975b4568eea56d022c30738e8932c901a59b4a49 /swr2_asr
parent66c37e72ef2dc7c88e1814627f35e506c7c09648 (diff)
parentaea161ee7f2c96aab529ca22675fb54cdcadbd12 (diff)
Merge remote-tracking branch 'origin/saving' into distributor
Diffstat (limited to 'swr2_asr')
-rw-r--r--swr2_asr/train.py31
1 files changed, 17 insertions, 14 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 29c2293..ba002e0 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -12,8 +12,8 @@ import torch.distributed as dist
import torchaudio
from .loss_scores import cer, wer
-MODEL_SAVE_PATH = "models/model.pt"
-LOSS
+
+
class TextTransform:
"""Maps characters to integers and vice versa"""
@@ -355,7 +355,7 @@ def train(
({100.0 * batch_idx / len(train_loader)}%)]\t \
Loss: {loss.item()}"
)
-
+ return loss.item()
def test(model, device, test_loader, criterion):
"""Test"""
@@ -391,7 +391,7 @@ def test(model, device, test_loader, criterion):
)
-def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3, world_size: int = 1) -> None:
+def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3, world_size: int = 1, load: bool=False, path: str="models/model.pt") -> None:
"""Runs the training script."""
hparams = {
"n_cnn_layers": 3,
@@ -474,10 +474,14 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3, world
print(
"Num Model Parameters", sum([param.nelement() for param in model.parameters()])
)
-
optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"])
criterion = nn.CTCLoss(blank=28).to(device)
-
+ if load:
+ checkpoint = torch.load(path)
+ model.load_state_dict(checkpoint['model_state_dict'])
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+ epoch = checkpoint['epoch']
+ loss = checkpoint['loss']
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=hparams["learning_rate"],
@@ -488,7 +492,7 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3, world
iter_meter = IterMeter()
for epoch in range(1, epochs + 1):
- train(
+ loss = train(
model,
device,
train_loader,
@@ -502,7 +506,7 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3, world
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
- },MODEL_SAVE_PATH)
+ 'loss': loss},path)
test(model=model, device=device, test_loader=test_loader, criterion=criterion)
@@ -511,10 +515,9 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3, world
@click.option("--batch_size", default=1, help="Batch size")
@click.option("--epochs", default=1, help="Number of epochs")
@click.option("--world_size", default=1, help="Number of nodes for distribution")
-def run_cli(learning_rate: float, batch_size: int, epochs: int, world_size: int) -> None:
+@click.option("--load", default = False, help="Do you want to load a model?")
+@click.option("--path",default="models/model.pt",
+ help= "Path where the model will be saved to/loaded from" )
+def run_cli(learning_rate: float, batch_size: int, epochs: int, world_size: int, load: bool, path: str) -> None:
"""Runs the training script."""
- run(learning_rate=learning_rate, batch_size=batch_size, epochs=epochs, world_size=world_size)
-
-
-if __name__ == "__main__":
- run(learning_rate=5e-4, batch_size=16, epochs=1)
+ run(learning_rate=learning_rate, batch_size=batch_size, epochs=epochs, world_size=world_size, load=load, path=path)