diff options
Diffstat (limited to 'swr2_asr')
-rw-r--r-- | swr2_asr/train.py | 45 |
1 files changed, 35 insertions, 10 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 346be0b..ba002e0 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -1,11 +1,14 @@ """Training script for the ASR model.""" from AudioLoader.speech import MultilingualLibriSpeech +import os import click import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler +import torch.distributed as dist import torchaudio from .loss_scores import cer, wer @@ -388,7 +391,7 @@ def test(model, device, test_loader, criterion): ) -def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3,load: bool=False, path: str="models/model.pt") -> None: +def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3, world_size: int = 1, load: bool=False, path: str="models/model.pt") -> None: """Runs the training script.""" hparams = { "n_cnn_layers": 3, @@ -401,6 +404,8 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3,load: "learning_rate": learning_rate, "batch_size": batch_size, "epochs": epochs, + "world_size": world_size, + "distributed": world_size > 1, } use_cuda = torch.cuda.is_available() @@ -415,22 +420,34 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3,load: "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="test", download=False ) - kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} + # initialize distributed training + ngpus_per_node = torch.cuda.device_count() + if hparams["distributed"]: + if 'SLURM_PROCID' in os.environ: # for slurm scheduler + hparams["rank"] = int(os.environ['SLURM_PROCID']) + hparams["gpu"] = hparams["rank"] % ngpus_per_node + dist.init_process_group(backend="nccl", init_method="env://", + world_size=hparams["world_size"], rank=hparams["rank"]) + train_sampler = DistributedSampler(train_dataset, shuffle=True) train_loader = DataLoader( train_dataset, batch_size=hparams["batch_size"], shuffle=True, + sampler=train_sampler, + num_workers=hparams["world_size"], # TODO? + pin_memory=True, collate_fn=lambda x: data_processing(x, "train"), - **kwargs, ) test_loader = DataLoader( test_dataset, batch_size=hparams["batch_size"], shuffle=True, + sampler=None, + num_workers=hparams["world_size"], # TODO? + pin_memory=True, collate_fn=lambda x: data_processing(x, "train"), - **kwargs, ) model = SpeechRecognitionModel( @@ -443,6 +460,17 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3,load: hparams["dropout"], ).to(device) + if hparams["distributed"]: + if "gpu" in hparams: + torch.cuda.set_device(hparams["gpu"]) + model.cuda(hparams["gpu"]) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[hparams["gpu"]]) + model_without_ddp = model.module + else: + model.cuda() + model = torch.nn.parallel.DistributedDataParallel(model) + model_without_ddp = model.module + print( "Num Model Parameters", sum([param.nelement() for param in model.parameters()]) ) @@ -486,13 +514,10 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3,load: @click.option("--learning-rate", default=1e-3, help="Learning rate") @click.option("--batch_size", default=1, help="Batch size") @click.option("--epochs", default=1, help="Number of epochs") +@click.option("--world_size", default=1, help="Number of nodes for distribution") @click.option("--load", default = False, help="Do you want to load a model?") @click.option("--path",default="models/model.pt", help= "Path where the model will be saved to/loaded from" ) -def run_cli(learning_rate: float, batch_size: int, epochs: int, load:bool,path:str) -> None: +def run_cli(learning_rate: float, batch_size: int, epochs: int, world_size: int, load: bool, path: str) -> None: """Runs the training script.""" - run(learning_rate=learning_rate, batch_size=batch_size, epochs=epochs,load= load, path = path) - - -if __name__ == "__main__": - run(learning_rate=5e-4, batch_size=16, epochs=1,load=False, path= "models/model.pt") + run(learning_rate=learning_rate, batch_size=batch_size, epochs=epochs, world_size=world_size, load=load, path=path) |