aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr')
-rw-r--r--swr2_asr/train.py45
1 files changed, 35 insertions, 10 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 346be0b..ba002e0 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -1,11 +1,14 @@
"""Training script for the ASR model."""
from AudioLoader.speech import MultilingualLibriSpeech
+import os
import click
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
+import torch.distributed as dist
import torchaudio
from .loss_scores import cer, wer
@@ -388,7 +391,7 @@ def test(model, device, test_loader, criterion):
)
-def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3,load: bool=False, path: str="models/model.pt") -> None:
+def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3, world_size: int = 1, load: bool=False, path: str="models/model.pt") -> None:
"""Runs the training script."""
hparams = {
"n_cnn_layers": 3,
@@ -401,6 +404,8 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3,load:
"learning_rate": learning_rate,
"batch_size": batch_size,
"epochs": epochs,
+ "world_size": world_size,
+ "distributed": world_size > 1,
}
use_cuda = torch.cuda.is_available()
@@ -415,22 +420,34 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3,load:
"/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="test", download=False
)
- kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
+ # initialize distributed training
+ ngpus_per_node = torch.cuda.device_count()
+ if hparams["distributed"]:
+ if 'SLURM_PROCID' in os.environ: # for slurm scheduler
+ hparams["rank"] = int(os.environ['SLURM_PROCID'])
+ hparams["gpu"] = hparams["rank"] % ngpus_per_node
+ dist.init_process_group(backend="nccl", init_method="env://",
+ world_size=hparams["world_size"], rank=hparams["rank"])
+ train_sampler = DistributedSampler(train_dataset, shuffle=True)
train_loader = DataLoader(
train_dataset,
batch_size=hparams["batch_size"],
shuffle=True,
+ sampler=train_sampler,
+ num_workers=hparams["world_size"], # TODO?
+ pin_memory=True,
collate_fn=lambda x: data_processing(x, "train"),
- **kwargs,
)
test_loader = DataLoader(
test_dataset,
batch_size=hparams["batch_size"],
shuffle=True,
+ sampler=None,
+ num_workers=hparams["world_size"], # TODO?
+ pin_memory=True,
collate_fn=lambda x: data_processing(x, "train"),
- **kwargs,
)
model = SpeechRecognitionModel(
@@ -443,6 +460,17 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3,load:
hparams["dropout"],
).to(device)
+ if hparams["distributed"]:
+ if "gpu" in hparams:
+ torch.cuda.set_device(hparams["gpu"])
+ model.cuda(hparams["gpu"])
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[hparams["gpu"]])
+ model_without_ddp = model.module
+ else:
+ model.cuda()
+ model = torch.nn.parallel.DistributedDataParallel(model)
+ model_without_ddp = model.module
+
print(
"Num Model Parameters", sum([param.nelement() for param in model.parameters()])
)
@@ -486,13 +514,10 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3,load:
@click.option("--learning-rate", default=1e-3, help="Learning rate")
@click.option("--batch_size", default=1, help="Batch size")
@click.option("--epochs", default=1, help="Number of epochs")
+@click.option("--world_size", default=1, help="Number of nodes for distribution")
@click.option("--load", default = False, help="Do you want to load a model?")
@click.option("--path",default="models/model.pt",
help= "Path where the model will be saved to/loaded from" )
-def run_cli(learning_rate: float, batch_size: int, epochs: int, load:bool,path:str) -> None:
+def run_cli(learning_rate: float, batch_size: int, epochs: int, world_size: int, load: bool, path: str) -> None:
"""Runs the training script."""
- run(learning_rate=learning_rate, batch_size=batch_size, epochs=epochs,load= load, path = path)
-
-
-if __name__ == "__main__":
- run(learning_rate=5e-4, batch_size=16, epochs=1,load=False, path= "models/model.pt")
+ run(learning_rate=learning_rate, batch_size=batch_size, epochs=epochs, world_size=world_size, load=load, path=path)