aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rwxr-xr-xdistributed.sh34
-rwxr-xr-xhpc.sh19
-rwxr-xr-xhpc_train.sh3
-rw-r--r--swr2_asr/train.py73
-rwxr-xr-xtrain.sh3
5 files changed, 33 insertions, 99 deletions
diff --git a/distributed.sh b/distributed.sh
deleted file mode 100755
index 4949159..0000000
--- a/distributed.sh
+++ /dev/null
@@ -1,34 +0,0 @@
-#!/bin/bash
-
-#SBATCH --job-name=swr-teamprojekt
-#SBATCH --partition=a100
-#SBATCH --time=00:30:00
-
-### e.g. request 4 nodes with 1 gpu each, totally 4 gpus (WORLD_SIZE==4)
-### Note: --gres=gpu:x should equal to ntasks-per-node
-#SBATCH --nodes=2
-#SBATCH --ntasks-per-node=1
-#SBATCH --gres=gpu:a100:1
-#SBATCH --cpus-per-task=8
-#SBATCH --mem=64gb
-#SBATCH --chdir=/mnt/lustre/mladm/mfa252/SWR2-cool-projekt-main/
-#SBATCH --output=/mnt/lustre/mladm/mfa252/%x-%j.out
-
-### change 5-digit MASTER_PORT as you wish, slurm will raise Error if duplicated with others
-### change WORLD_SIZE as gpus/node * num_nodes
-export MASTER_PORT=18120
-export WORLD_SIZE=2
-
-### get the first node name as master address - customized for vgg slurm
-### e.g. master(gnodee[2-5],gnoded1) == gnodee2
-echo "NODELIST="${SLURM_NODELIST}
-master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
-export MASTER_ADDR=$master_addr
-echo "MASTER_ADDR="$MASTER_ADDR
-
-export NCCL_DEBUG="INFO"
-
-source venv/bin/activate
-
-### the command to run
-srun ./train.sh
diff --git a/hpc.sh b/hpc.sh
new file mode 100755
index 0000000..ba0c5eb
--- /dev/null
+++ b/hpc.sh
@@ -0,0 +1,19 @@
+#!/bin/bash
+
+#SBATCH --job-name=swr-teamprojekt
+#SBATCH --partition=a100
+#SBATCH --time=00:30:00
+
+### Note: --gres=gpu:x should equal to ntasks-per-node
+#SBATCH --nodes=1
+#SBATCH --ntasks-per-node=1
+#SBATCH --gres=gpu:a100:1
+#SBATCH --cpus-per-task=8
+#SBATCH --mem=64gb
+#SBATCH --chdir=/mnt/lustre/mladm/mfa252/SWR2-cool-projekt-main/
+#SBATCH --output=/mnt/lustre/mladm/mfa252/%x-%j.out
+
+source venv/bin/activate
+
+### the command to run
+srun ./hpc_train.sh
diff --git a/hpc_train.sh b/hpc_train.sh
new file mode 100755
index 0000000..c7d1636
--- /dev/null
+++ b/hpc_train.sh
@@ -0,0 +1,3 @@
+#!/bin/sh
+
+yes no | python -m swr2_asr.train --epochs=100 --batch_size=30 --dataset_path=/mnt/lustre/mladm/mfa252/data
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 643ea68..4c97482 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -7,8 +7,6 @@ import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
-from torch.utils.data.distributed import DistributedSampler
-import torch.distributed as dist
import torchaudio
from .loss_scores import cer, wer
@@ -394,7 +392,6 @@ def run(
learning_rate: float,
batch_size: int,
epochs: int,
- world_size: int,
load: bool,
path: str,
dataset_path: str,
@@ -411,9 +408,6 @@ def run(
"learning_rate": learning_rate,
"batch_size": batch_size,
"epochs": epochs,
- "world_size": world_size,
- "distributed": world_size > 1,
- "rank": 0,
}
use_cuda = torch.cuda.is_available()
@@ -429,34 +423,10 @@ def run(
dataset_path, "mls_german_opus", split="test", download=False
)
- # initialize distributed training
- ngpus_per_node = torch.cuda.device_count()
- if hparams["distributed"]:
- if "SLURM_PROCID" in os.environ: # for slurm scheduler
- hparams["rank"] = int(os.environ["SLURM_PROCID"])
- hparams["gpu"] = hparams["rank"] % ngpus_per_node
- print("slurm: ", str(hparams))
- else:
- print("No slurm process!")
- dist.init_process_group(
- backend="nccl",
- init_method="env://",
- world_size=hparams["world_size"],
- rank=hparams["rank"],
- )
-
- train_sampler = (
- DistributedSampler(train_dataset, shuffle=True)
- if hparams["distributed"]
- else None
- )
train_loader = DataLoader(
train_dataset,
batch_size=hparams["batch_size"],
- shuffle=(train_sampler is None),
- sampler=train_sampler,
- num_workers=0,
- pin_memory=True,
+ shuffle=True,
collate_fn=lambda x: data_processing(x, "train"),
)
@@ -464,9 +434,6 @@ def run(
test_dataset,
batch_size=hparams["batch_size"],
shuffle=True,
- sampler=None,
- num_workers=0,
- pin_memory=True,
collate_fn=lambda x: data_processing(x, "train"),
)
@@ -484,17 +451,6 @@ def run(
hparams["dropout"],
).to(device)
- if hparams["distributed"]:
- if "gpu" in hparams:
- torch.cuda.set_device(hparams["gpu"])
- model.cuda(hparams["gpu"])
- model = torch.nn.parallel.DistributedDataParallel(
- model, device_ids=[hparams["gpu"]]
- )
- else:
- model.cuda()
- model = torch.nn.parallel.DistributedDataParallel(model)
-
print(
"Num Model Parameters", sum([param.nelement() for param in model.parameters()])
)
@@ -516,10 +472,6 @@ def run(
iter_meter = IterMeter()
for epoch in range(1, epochs + 1):
- torch.manual_seed(epoch)
- if hparams["distributed"]: # each gpu should get different part of dataset
- train_loader.sampler.set_epoch(epoch)
-
loss = train(
model,
device,
@@ -531,24 +483,22 @@ def run(
iter_meter,
)
- if hparams["rank"] == 0: # only validate on master node
- test(model=model, device=device, test_loader=test_loader, criterion=criterion)
- if epoch % 3 == 0 or epoch == epochs:
- torch.save(
- {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss},
- path+str(epoch),
- )
+ test(model=model, device=device, test_loader=test_loader, criterion=criterion)
+ print("saving epoch", str(epoch))
+ torch.save(
+ {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss},
+ path + str(epoch),
+ )
@click.command()
@click.option("--learning_rate", default=1e-3, help="Learning rate")
-@click.option("--batch_size", default=1, help="Batch size")
+@click.option("--batch_size", default=10, help="Batch size")
@click.option("--epochs", default=1, help="Number of epochs")
-@click.option("--world_size", default=1, help="Number of nodes for distribution")
@click.option("--load", default=False, help="Do you want to load a model?")
@click.option(
"--path",
- default="model.pt",
+ default="model",
help="Path where the model will be saved to/loaded from",
)
@click.option(
@@ -560,7 +510,6 @@ def run_cli(
learning_rate: float,
batch_size: int,
epochs: int,
- world_size: int,
load: bool,
path: str,
dataset_path: str,
@@ -571,11 +520,11 @@ def run_cli(
learning_rate=learning_rate,
batch_size=batch_size,
epochs=epochs,
- world_size=world_size,
load=load,
path=path,
dataset_path=dataset_path,
)
-if __name__ == '__main__':
+
+if __name__ == "__main__":
run_cli()
diff --git a/train.sh b/train.sh
deleted file mode 100755
index 27acc16..0000000
--- a/train.sh
+++ /dev/null
@@ -1,3 +0,0 @@
-#!/bin/sh
-
-yes no | python -m swr2_asr.train --batch_size=8 --world_size=2 --dataset_path=/mnt/lustre/mladm/mfa252/data