diff options
author | Marvin | 2023-08-19 14:43:52 +0200 |
---|---|---|
committer | GitHub | 2023-08-19 14:43:52 +0200 |
commit | b5aee436d95c6eb54adb7dc3f405249520ff7e9b (patch) | |
tree | 79d45badaf5f7ee5bf049db214b8b598c335a617 /swr2_asr | |
parent | 33a52ef9f681b8665b5b6000e74de447e0983c78 (diff) | |
parent | ec177107cb3a1a31d2fc49cc4990413af287305e (diff) |
Merge pull request #16 from Algo-Boys/distributor
Slurm distributor attempt one
Diffstat (limited to 'swr2_asr')
-rw-r--r-- | swr2_asr/train.py | 120 |
1 files changed, 97 insertions, 23 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 81312d9..56d10c0 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -1,16 +1,17 @@ """Training script for the ASR model.""" from AudioLoader.speech import MultilingualLibriSpeech +import os import click import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +import torch.distributed as dist import torchaudio from .loss_scores import cer, wer -MODEL_SAVE_PATH = "models/model.pt" -LOSS class TextTransform: """Maps characters to integers and vice versa""" @@ -352,6 +353,7 @@ def train( ({100.0 * batch_idx / len(train_loader)}%)]\t \ Loss: {loss.item()}" ) + return loss.item() def test(model, device, test_loader, criterion): @@ -388,7 +390,15 @@ def test(model, device, test_loader, criterion): ) -def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None: +def run( + learning_rate: float, + batch_size: int, + epochs: int, + world_size: int, + load: bool, + path: str, + dataset_path: str, +) -> None: """Runs the training script.""" hparams = { "n_cnn_layers": 3, @@ -401,6 +411,8 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No "learning_rate": learning_rate, "batch_size": batch_size, "epochs": epochs, + "world_size": world_size, + "distributed": world_size > 1, } use_cuda = torch.cuda.is_available() @@ -408,29 +420,50 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member # device = torch.device("mps") + download_dataset = not os.path.isdir(path) train_dataset = MultilingualLibriSpeech( - "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="dev", download=False + dataset_path, "mls_polish_opus", split="dev", download=download_dataset ) test_dataset = MultilingualLibriSpeech( - "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="test", download=False + dataset_path, "mls_polish_opus", split="test", download=False ) - kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} + # initialize distributed training + ngpus_per_node = torch.cuda.device_count() + if hparams["distributed"]: + if "SLURM_PROCID" in os.environ: # for slurm scheduler + hparams["rank"] = int(os.environ["SLURM_PROCID"]) + hparams["gpu"] = hparams["rank"] % ngpus_per_node + dist.init_process_group( + backend="nccl", + init_method="env://", + world_size=hparams["world_size"], + rank=hparams["rank"], + ) + train_sampler = ( + DistributedSampler(train_dataset, shuffle=True) + if hparams["distributed"] + else None + ) train_loader = DataLoader( train_dataset, batch_size=hparams["batch_size"], shuffle=True, + sampler=train_sampler, + num_workers=hparams["world_size"], # TODO? + pin_memory=True, collate_fn=lambda x: data_processing(x, "train"), - **kwargs, ) test_loader = DataLoader( test_dataset, batch_size=hparams["batch_size"], shuffle=True, + sampler=None, + num_workers=hparams["world_size"], # TODO? + pin_memory=True, collate_fn=lambda x: data_processing(x, "train"), - **kwargs, ) model = SpeechRecognitionModel( @@ -443,13 +476,30 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No hparams["dropout"], ).to(device) + if hparams["distributed"]: + if "gpu" in hparams: + torch.cuda.set_device(hparams["gpu"]) + model.cuda(hparams["gpu"]) + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[hparams["gpu"]] + ) + model_without_ddp = model.module + else: + model.cuda() + model = torch.nn.parallel.DistributedDataParallel(model) + model_without_ddp = model.module + print( "Num Model Parameters", sum([param.nelement() for param in model.parameters()]) ) - optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) criterion = nn.CTCLoss(blank=28).to(device) - + if load: + checkpoint = torch.load(path) + model.load_state_dict(checkpoint["model_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + epoch = checkpoint["epoch"] + loss = checkpoint["loss"] scheduler = optim.lr_scheduler.OneCycleLR( optimizer, max_lr=hparams["learning_rate"], @@ -460,7 +510,7 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No iter_meter = IterMeter() for epoch in range(1, epochs + 1): - train( + loss = train( model, device, train_loader, @@ -470,22 +520,46 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No epoch, iter_meter, ) - if epoch%3 == 0 or epoch == epochs: - torch.save({ - 'epoch': epoch, - 'model_state_dict': model.state_dict(), - },MODEL_SAVE_PATH) + if epoch % 3 == 0 or epoch == epochs: + torch.save( + {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss}, + path, + ) test(model=model, device=device, test_loader=test_loader, criterion=criterion) @click.command() -@click.option("--learning-rate", default=1e-3, help="Learning rate") +@click.option("--learning_rate", default=1e-3, help="Learning rate") @click.option("--batch_size", default=1, help="Batch size") @click.option("--epochs", default=1, help="Number of epochs") -def run_cli(learning_rate: float, batch_size: int, epochs: int) -> None: +@click.option("--world_size", default=1, help="Number of nodes for distribution") +@click.option("--load", default=False, help="Do you want to load a model?") +@click.option( + "--path", + default="models/model.pt", + help="Path where the model will be saved to/loaded from", +) +@click.option( + "--dataset_path", + default="data/", + help="Path for the dataset directory", +) +def run_cli( + learning_rate: float, + batch_size: int, + epochs: int, + world_size: int, + load: bool, + path: str, + dataset_path: str, +) -> None: """Runs the training script.""" - run(learning_rate=learning_rate, batch_size=batch_size, epochs=epochs) - - -if __name__ == "__main__": - run(learning_rate=5e-4, batch_size=16, epochs=1) + run( + learning_rate=learning_rate, + batch_size=batch_size, + epochs=epochs, + world_size=world_size, + load=load, + path=path, + dataset_path=dataset_path, + ) |