aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xdistributed.sh34
-rw-r--r--swr2_asr/train.py44
-rwxr-xr-xtrain.sh3
3 files changed, 67 insertions, 14 deletions
diff --git a/distributed.sh b/distributed.sh
new file mode 100755
index 0000000..4949159
--- /dev/null
+++ b/distributed.sh
@@ -0,0 +1,34 @@
+#!/bin/bash
+
+#SBATCH --job-name=swr-teamprojekt
+#SBATCH --partition=a100
+#SBATCH --time=00:30:00
+
+### e.g. request 4 nodes with 1 gpu each, totally 4 gpus (WORLD_SIZE==4)
+### Note: --gres=gpu:x should equal to ntasks-per-node
+#SBATCH --nodes=2
+#SBATCH --ntasks-per-node=1
+#SBATCH --gres=gpu:a100:1
+#SBATCH --cpus-per-task=8
+#SBATCH --mem=64gb
+#SBATCH --chdir=/mnt/lustre/mladm/mfa252/SWR2-cool-projekt-main/
+#SBATCH --output=/mnt/lustre/mladm/mfa252/%x-%j.out
+
+### change 5-digit MASTER_PORT as you wish, slurm will raise Error if duplicated with others
+### change WORLD_SIZE as gpus/node * num_nodes
+export MASTER_PORT=18120
+export WORLD_SIZE=2
+
+### get the first node name as master address - customized for vgg slurm
+### e.g. master(gnodee[2-5],gnoded1) == gnodee2
+echo "NODELIST="${SLURM_NODELIST}
+master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
+export MASTER_ADDR=$master_addr
+echo "MASTER_ADDR="$MASTER_ADDR
+
+export NCCL_DEBUG="INFO"
+
+source venv/bin/activate
+
+### the command to run
+srun ./train.sh
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 56d10c0..643ea68 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -413,6 +413,7 @@ def run(
"epochs": epochs,
"world_size": world_size,
"distributed": world_size > 1,
+ "rank": 0,
}
use_cuda = torch.cuda.is_available()
@@ -422,10 +423,10 @@ def run(
download_dataset = not os.path.isdir(path)
train_dataset = MultilingualLibriSpeech(
- dataset_path, "mls_polish_opus", split="dev", download=download_dataset
+ dataset_path, "mls_german_opus", split="dev", download=download_dataset
)
test_dataset = MultilingualLibriSpeech(
- dataset_path, "mls_polish_opus", split="test", download=False
+ dataset_path, "mls_german_opus", split="test", download=False
)
# initialize distributed training
@@ -434,6 +435,9 @@ def run(
if "SLURM_PROCID" in os.environ: # for slurm scheduler
hparams["rank"] = int(os.environ["SLURM_PROCID"])
hparams["gpu"] = hparams["rank"] % ngpus_per_node
+ print("slurm: ", str(hparams))
+ else:
+ print("No slurm process!")
dist.init_process_group(
backend="nccl",
init_method="env://",
@@ -449,9 +453,9 @@ def run(
train_loader = DataLoader(
train_dataset,
batch_size=hparams["batch_size"],
- shuffle=True,
+ shuffle=(train_sampler is None),
sampler=train_sampler,
- num_workers=hparams["world_size"], # TODO?
+ num_workers=0,
pin_memory=True,
collate_fn=lambda x: data_processing(x, "train"),
)
@@ -461,11 +465,15 @@ def run(
batch_size=hparams["batch_size"],
shuffle=True,
sampler=None,
- num_workers=hparams["world_size"], # TODO?
+ num_workers=0,
pin_memory=True,
collate_fn=lambda x: data_processing(x, "train"),
)
+ # enable flag to find the most compatible algorithms in advance
+ if use_cuda:
+ torch.backends.cudnn.benchmark = True
+
model = SpeechRecognitionModel(
hparams["n_cnn_layers"],
hparams["n_rnn_layers"],
@@ -483,11 +491,9 @@ def run(
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[hparams["gpu"]]
)
- model_without_ddp = model.module
else:
model.cuda()
model = torch.nn.parallel.DistributedDataParallel(model)
- model_without_ddp = model.module
print(
"Num Model Parameters", sum([param.nelement() for param in model.parameters()])
@@ -510,6 +516,10 @@ def run(
iter_meter = IterMeter()
for epoch in range(1, epochs + 1):
+ torch.manual_seed(epoch)
+ if hparams["distributed"]: # each gpu should get different part of dataset
+ train_loader.sampler.set_epoch(epoch)
+
loss = train(
model,
device,
@@ -520,12 +530,14 @@ def run(
epoch,
iter_meter,
)
- if epoch % 3 == 0 or epoch == epochs:
- torch.save(
- {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss},
- path,
- )
- test(model=model, device=device, test_loader=test_loader, criterion=criterion)
+
+ if hparams["rank"] == 0: # only validate on master node
+ test(model=model, device=device, test_loader=test_loader, criterion=criterion)
+ if epoch % 3 == 0 or epoch == epochs:
+ torch.save(
+ {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss},
+ path+str(epoch),
+ )
@click.command()
@@ -536,7 +548,7 @@ def run(
@click.option("--load", default=False, help="Do you want to load a model?")
@click.option(
"--path",
- default="models/model.pt",
+ default="model.pt",
help="Path where the model will be saved to/loaded from",
)
@click.option(
@@ -554,6 +566,7 @@ def run_cli(
dataset_path: str,
) -> None:
"""Runs the training script."""
+
run(
learning_rate=learning_rate,
batch_size=batch_size,
@@ -563,3 +576,6 @@ def run_cli(
path=path,
dataset_path=dataset_path,
)
+
+if __name__ == '__main__':
+ run_cli()
diff --git a/train.sh b/train.sh
new file mode 100755
index 0000000..27acc16
--- /dev/null
+++ b/train.sh
@@ -0,0 +1,3 @@
+#!/bin/sh
+
+yes no | python -m swr2_asr.train --batch_size=8 --world_size=2 --dataset_path=/mnt/lustre/mladm/mfa252/data