diff options
author | Marvin Borner | 2023-08-19 14:22:48 +0200 |
---|---|---|
committer | Marvin Borner | 2023-08-19 14:22:48 +0200 |
commit | 897e74f695e291029a08b280c2cea40a2a9639cc (patch) | |
tree | 975b4568eea56d022c30738e8932c901a59b4a49 /swr2_asr | |
parent | 66c37e72ef2dc7c88e1814627f35e506c7c09648 (diff) | |
parent | aea161ee7f2c96aab529ca22675fb54cdcadbd12 (diff) |
Merge remote-tracking branch 'origin/saving' into distributor
Diffstat (limited to 'swr2_asr')
-rw-r--r-- | swr2_asr/train.py | 31 |
1 files changed, 17 insertions, 14 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 29c2293..ba002e0 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -12,8 +12,8 @@ import torch.distributed as dist import torchaudio from .loss_scores import cer, wer -MODEL_SAVE_PATH = "models/model.pt" -LOSS + + class TextTransform: """Maps characters to integers and vice versa""" @@ -355,7 +355,7 @@ def train( ({100.0 * batch_idx / len(train_loader)}%)]\t \ Loss: {loss.item()}" ) - + return loss.item() def test(model, device, test_loader, criterion): """Test""" @@ -391,7 +391,7 @@ def test(model, device, test_loader, criterion): ) -def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3, world_size: int = 1) -> None: +def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3, world_size: int = 1, load: bool=False, path: str="models/model.pt") -> None: """Runs the training script.""" hparams = { "n_cnn_layers": 3, @@ -474,10 +474,14 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3, world print( "Num Model Parameters", sum([param.nelement() for param in model.parameters()]) ) - optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) criterion = nn.CTCLoss(blank=28).to(device) - + if load: + checkpoint = torch.load(path) + model.load_state_dict(checkpoint['model_state_dict']) + optimizer.load_state_dict(checkpoint['optimizer_state_dict']) + epoch = checkpoint['epoch'] + loss = checkpoint['loss'] scheduler = optim.lr_scheduler.OneCycleLR( optimizer, max_lr=hparams["learning_rate"], @@ -488,7 +492,7 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3, world iter_meter = IterMeter() for epoch in range(1, epochs + 1): - train( + loss = train( model, device, train_loader, @@ -502,7 +506,7 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3, world torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), - },MODEL_SAVE_PATH) + 'loss': loss},path) test(model=model, device=device, test_loader=test_loader, criterion=criterion) @@ -511,10 +515,9 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3, world @click.option("--batch_size", default=1, help="Batch size") @click.option("--epochs", default=1, help="Number of epochs") @click.option("--world_size", default=1, help="Number of nodes for distribution") -def run_cli(learning_rate: float, batch_size: int, epochs: int, world_size: int) -> None: +@click.option("--load", default = False, help="Do you want to load a model?") +@click.option("--path",default="models/model.pt", + help= "Path where the model will be saved to/loaded from" ) +def run_cli(learning_rate: float, batch_size: int, epochs: int, world_size: int, load: bool, path: str) -> None: """Runs the training script.""" - run(learning_rate=learning_rate, batch_size=batch_size, epochs=epochs, world_size=world_size) - - -if __name__ == "__main__": - run(learning_rate=5e-4, batch_size=16, epochs=1) + run(learning_rate=learning_rate, batch_size=batch_size, epochs=epochs, world_size=world_size, load=load, path=path) |