diff options
-rwxr-xr-x | distributed.sh | 34 | ||||
-rw-r--r-- | swr2_asr/train.py | 44 | ||||
-rwxr-xr-x | train.sh | 3 |
3 files changed, 67 insertions, 14 deletions
diff --git a/distributed.sh b/distributed.sh new file mode 100755 index 0000000..4949159 --- /dev/null +++ b/distributed.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +#SBATCH --job-name=swr-teamprojekt +#SBATCH --partition=a100 +#SBATCH --time=00:30:00 + +### e.g. request 4 nodes with 1 gpu each, totally 4 gpus (WORLD_SIZE==4) +### Note: --gres=gpu:x should equal to ntasks-per-node +#SBATCH --nodes=2 +#SBATCH --ntasks-per-node=1 +#SBATCH --gres=gpu:a100:1 +#SBATCH --cpus-per-task=8 +#SBATCH --mem=64gb +#SBATCH --chdir=/mnt/lustre/mladm/mfa252/SWR2-cool-projekt-main/ +#SBATCH --output=/mnt/lustre/mladm/mfa252/%x-%j.out + +### change 5-digit MASTER_PORT as you wish, slurm will raise Error if duplicated with others +### change WORLD_SIZE as gpus/node * num_nodes +export MASTER_PORT=18120 +export WORLD_SIZE=2 + +### get the first node name as master address - customized for vgg slurm +### e.g. master(gnodee[2-5],gnoded1) == gnodee2 +echo "NODELIST="${SLURM_NODELIST} +master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) +export MASTER_ADDR=$master_addr +echo "MASTER_ADDR="$MASTER_ADDR + +export NCCL_DEBUG="INFO" + +source venv/bin/activate + +### the command to run +srun ./train.sh diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 56d10c0..643ea68 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -413,6 +413,7 @@ def run( "epochs": epochs, "world_size": world_size, "distributed": world_size > 1, + "rank": 0, } use_cuda = torch.cuda.is_available() @@ -422,10 +423,10 @@ def run( download_dataset = not os.path.isdir(path) train_dataset = MultilingualLibriSpeech( - dataset_path, "mls_polish_opus", split="dev", download=download_dataset + dataset_path, "mls_german_opus", split="dev", download=download_dataset ) test_dataset = MultilingualLibriSpeech( - dataset_path, "mls_polish_opus", split="test", download=False + dataset_path, "mls_german_opus", split="test", download=False ) # initialize distributed training @@ -434,6 +435,9 @@ def run( if "SLURM_PROCID" in os.environ: # for slurm scheduler hparams["rank"] = int(os.environ["SLURM_PROCID"]) hparams["gpu"] = hparams["rank"] % ngpus_per_node + print("slurm: ", str(hparams)) + else: + print("No slurm process!") dist.init_process_group( backend="nccl", init_method="env://", @@ -449,9 +453,9 @@ def run( train_loader = DataLoader( train_dataset, batch_size=hparams["batch_size"], - shuffle=True, + shuffle=(train_sampler is None), sampler=train_sampler, - num_workers=hparams["world_size"], # TODO? + num_workers=0, pin_memory=True, collate_fn=lambda x: data_processing(x, "train"), ) @@ -461,11 +465,15 @@ def run( batch_size=hparams["batch_size"], shuffle=True, sampler=None, - num_workers=hparams["world_size"], # TODO? + num_workers=0, pin_memory=True, collate_fn=lambda x: data_processing(x, "train"), ) + # enable flag to find the most compatible algorithms in advance + if use_cuda: + torch.backends.cudnn.benchmark = True + model = SpeechRecognitionModel( hparams["n_cnn_layers"], hparams["n_rnn_layers"], @@ -483,11 +491,9 @@ def run( model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[hparams["gpu"]] ) - model_without_ddp = model.module else: model.cuda() model = torch.nn.parallel.DistributedDataParallel(model) - model_without_ddp = model.module print( "Num Model Parameters", sum([param.nelement() for param in model.parameters()]) @@ -510,6 +516,10 @@ def run( iter_meter = IterMeter() for epoch in range(1, epochs + 1): + torch.manual_seed(epoch) + if hparams["distributed"]: # each gpu should get different part of dataset + train_loader.sampler.set_epoch(epoch) + loss = train( model, device, @@ -520,12 +530,14 @@ def run( epoch, iter_meter, ) - if epoch % 3 == 0 or epoch == epochs: - torch.save( - {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss}, - path, - ) - test(model=model, device=device, test_loader=test_loader, criterion=criterion) + + if hparams["rank"] == 0: # only validate on master node + test(model=model, device=device, test_loader=test_loader, criterion=criterion) + if epoch % 3 == 0 or epoch == epochs: + torch.save( + {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss}, + path+str(epoch), + ) @click.command() @@ -536,7 +548,7 @@ def run( @click.option("--load", default=False, help="Do you want to load a model?") @click.option( "--path", - default="models/model.pt", + default="model.pt", help="Path where the model will be saved to/loaded from", ) @click.option( @@ -554,6 +566,7 @@ def run_cli( dataset_path: str, ) -> None: """Runs the training script.""" + run( learning_rate=learning_rate, batch_size=batch_size, @@ -563,3 +576,6 @@ def run_cli( path=path, dataset_path=dataset_path, ) + +if __name__ == '__main__': + run_cli() diff --git a/train.sh b/train.sh new file mode 100755 index 0000000..27acc16 --- /dev/null +++ b/train.sh @@ -0,0 +1,3 @@ +#!/bin/sh + +yes no | python -m swr2_asr.train --batch_size=8 --world_size=2 --dataset_path=/mnt/lustre/mladm/mfa252/data |