diff options
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r-- | swr2_asr/train.py | 73 |
1 files changed, 11 insertions, 62 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 643ea68..4c97482 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -7,8 +7,6 @@ import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler -import torch.distributed as dist import torchaudio from .loss_scores import cer, wer @@ -394,7 +392,6 @@ def run( learning_rate: float, batch_size: int, epochs: int, - world_size: int, load: bool, path: str, dataset_path: str, @@ -411,9 +408,6 @@ def run( "learning_rate": learning_rate, "batch_size": batch_size, "epochs": epochs, - "world_size": world_size, - "distributed": world_size > 1, - "rank": 0, } use_cuda = torch.cuda.is_available() @@ -429,34 +423,10 @@ def run( dataset_path, "mls_german_opus", split="test", download=False ) - # initialize distributed training - ngpus_per_node = torch.cuda.device_count() - if hparams["distributed"]: - if "SLURM_PROCID" in os.environ: # for slurm scheduler - hparams["rank"] = int(os.environ["SLURM_PROCID"]) - hparams["gpu"] = hparams["rank"] % ngpus_per_node - print("slurm: ", str(hparams)) - else: - print("No slurm process!") - dist.init_process_group( - backend="nccl", - init_method="env://", - world_size=hparams["world_size"], - rank=hparams["rank"], - ) - - train_sampler = ( - DistributedSampler(train_dataset, shuffle=True) - if hparams["distributed"] - else None - ) train_loader = DataLoader( train_dataset, batch_size=hparams["batch_size"], - shuffle=(train_sampler is None), - sampler=train_sampler, - num_workers=0, - pin_memory=True, + shuffle=True, collate_fn=lambda x: data_processing(x, "train"), ) @@ -464,9 +434,6 @@ def run( test_dataset, batch_size=hparams["batch_size"], shuffle=True, - sampler=None, - num_workers=0, - pin_memory=True, collate_fn=lambda x: data_processing(x, "train"), ) @@ -484,17 +451,6 @@ def run( hparams["dropout"], ).to(device) - if hparams["distributed"]: - if "gpu" in hparams: - torch.cuda.set_device(hparams["gpu"]) - model.cuda(hparams["gpu"]) - model = torch.nn.parallel.DistributedDataParallel( - model, device_ids=[hparams["gpu"]] - ) - else: - model.cuda() - model = torch.nn.parallel.DistributedDataParallel(model) - print( "Num Model Parameters", sum([param.nelement() for param in model.parameters()]) ) @@ -516,10 +472,6 @@ def run( iter_meter = IterMeter() for epoch in range(1, epochs + 1): - torch.manual_seed(epoch) - if hparams["distributed"]: # each gpu should get different part of dataset - train_loader.sampler.set_epoch(epoch) - loss = train( model, device, @@ -531,24 +483,22 @@ def run( iter_meter, ) - if hparams["rank"] == 0: # only validate on master node - test(model=model, device=device, test_loader=test_loader, criterion=criterion) - if epoch % 3 == 0 or epoch == epochs: - torch.save( - {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss}, - path+str(epoch), - ) + test(model=model, device=device, test_loader=test_loader, criterion=criterion) + print("saving epoch", str(epoch)) + torch.save( + {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss}, + path + str(epoch), + ) @click.command() @click.option("--learning_rate", default=1e-3, help="Learning rate") -@click.option("--batch_size", default=1, help="Batch size") +@click.option("--batch_size", default=10, help="Batch size") @click.option("--epochs", default=1, help="Number of epochs") -@click.option("--world_size", default=1, help="Number of nodes for distribution") @click.option("--load", default=False, help="Do you want to load a model?") @click.option( "--path", - default="model.pt", + default="model", help="Path where the model will be saved to/loaded from", ) @click.option( @@ -560,7 +510,6 @@ def run_cli( learning_rate: float, batch_size: int, epochs: int, - world_size: int, load: bool, path: str, dataset_path: str, @@ -571,11 +520,11 @@ def run_cli( learning_rate=learning_rate, batch_size=batch_size, epochs=epochs, - world_size=world_size, load=load, path=path, dataset_path=dataset_path, ) -if __name__ == '__main__': + +if __name__ == "__main__": run_cli() |