aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r--swr2_asr/train.py44
1 files changed, 30 insertions, 14 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 56d10c0..643ea68 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -413,6 +413,7 @@ def run(
"epochs": epochs,
"world_size": world_size,
"distributed": world_size > 1,
+ "rank": 0,
}
use_cuda = torch.cuda.is_available()
@@ -422,10 +423,10 @@ def run(
download_dataset = not os.path.isdir(path)
train_dataset = MultilingualLibriSpeech(
- dataset_path, "mls_polish_opus", split="dev", download=download_dataset
+ dataset_path, "mls_german_opus", split="dev", download=download_dataset
)
test_dataset = MultilingualLibriSpeech(
- dataset_path, "mls_polish_opus", split="test", download=False
+ dataset_path, "mls_german_opus", split="test", download=False
)
# initialize distributed training
@@ -434,6 +435,9 @@ def run(
if "SLURM_PROCID" in os.environ: # for slurm scheduler
hparams["rank"] = int(os.environ["SLURM_PROCID"])
hparams["gpu"] = hparams["rank"] % ngpus_per_node
+ print("slurm: ", str(hparams))
+ else:
+ print("No slurm process!")
dist.init_process_group(
backend="nccl",
init_method="env://",
@@ -449,9 +453,9 @@ def run(
train_loader = DataLoader(
train_dataset,
batch_size=hparams["batch_size"],
- shuffle=True,
+ shuffle=(train_sampler is None),
sampler=train_sampler,
- num_workers=hparams["world_size"], # TODO?
+ num_workers=0,
pin_memory=True,
collate_fn=lambda x: data_processing(x, "train"),
)
@@ -461,11 +465,15 @@ def run(
batch_size=hparams["batch_size"],
shuffle=True,
sampler=None,
- num_workers=hparams["world_size"], # TODO?
+ num_workers=0,
pin_memory=True,
collate_fn=lambda x: data_processing(x, "train"),
)
+ # enable flag to find the most compatible algorithms in advance
+ if use_cuda:
+ torch.backends.cudnn.benchmark = True
+
model = SpeechRecognitionModel(
hparams["n_cnn_layers"],
hparams["n_rnn_layers"],
@@ -483,11 +491,9 @@ def run(
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[hparams["gpu"]]
)
- model_without_ddp = model.module
else:
model.cuda()
model = torch.nn.parallel.DistributedDataParallel(model)
- model_without_ddp = model.module
print(
"Num Model Parameters", sum([param.nelement() for param in model.parameters()])
@@ -510,6 +516,10 @@ def run(
iter_meter = IterMeter()
for epoch in range(1, epochs + 1):
+ torch.manual_seed(epoch)
+ if hparams["distributed"]: # each gpu should get different part of dataset
+ train_loader.sampler.set_epoch(epoch)
+
loss = train(
model,
device,
@@ -520,12 +530,14 @@ def run(
epoch,
iter_meter,
)
- if epoch % 3 == 0 or epoch == epochs:
- torch.save(
- {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss},
- path,
- )
- test(model=model, device=device, test_loader=test_loader, criterion=criterion)
+
+ if hparams["rank"] == 0: # only validate on master node
+ test(model=model, device=device, test_loader=test_loader, criterion=criterion)
+ if epoch % 3 == 0 or epoch == epochs:
+ torch.save(
+ {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss},
+ path+str(epoch),
+ )
@click.command()
@@ -536,7 +548,7 @@ def run(
@click.option("--load", default=False, help="Do you want to load a model?")
@click.option(
"--path",
- default="models/model.pt",
+ default="model.pt",
help="Path where the model will be saved to/loaded from",
)
@click.option(
@@ -554,6 +566,7 @@ def run_cli(
dataset_path: str,
) -> None:
"""Runs the training script."""
+
run(
learning_rate=learning_rate,
batch_size=batch_size,
@@ -563,3 +576,6 @@ def run_cli(
path=path,
dataset_path=dataset_path,
)
+
+if __name__ == '__main__':
+ run_cli()