aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--pyproject.toml2
-rw-r--r--swr2_asr/train.py120
2 files changed, 98 insertions, 24 deletions
diff --git a/pyproject.toml b/pyproject.toml
index fdd89a5..1c29b7c 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -8,7 +8,7 @@ readme = "readme.md"
packages = [{include = "swr2_asr"}]
[tool.poetry.dependencies]
-python = "~3.10"
+python = "^3.10"
torch = "2.0.0"
torchaudio = "2.0.1"
audioloader = {git = "https://github.com/marvinborner/AudioLoader.git"}
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 81312d9..56d10c0 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -1,16 +1,17 @@
"""Training script for the ASR model."""
from AudioLoader.speech import MultilingualLibriSpeech
+import os
import click
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
+import torch.distributed as dist
import torchaudio
from .loss_scores import cer, wer
-MODEL_SAVE_PATH = "models/model.pt"
-LOSS
class TextTransform:
"""Maps characters to integers and vice versa"""
@@ -352,6 +353,7 @@ def train(
({100.0 * batch_idx / len(train_loader)}%)]\t \
Loss: {loss.item()}"
)
+ return loss.item()
def test(model, device, test_loader, criterion):
@@ -388,7 +390,15 @@ def test(model, device, test_loader, criterion):
)
-def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None:
+def run(
+ learning_rate: float,
+ batch_size: int,
+ epochs: int,
+ world_size: int,
+ load: bool,
+ path: str,
+ dataset_path: str,
+) -> None:
"""Runs the training script."""
hparams = {
"n_cnn_layers": 3,
@@ -401,6 +411,8 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No
"learning_rate": learning_rate,
"batch_size": batch_size,
"epochs": epochs,
+ "world_size": world_size,
+ "distributed": world_size > 1,
}
use_cuda = torch.cuda.is_available()
@@ -408,29 +420,50 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No
device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member
# device = torch.device("mps")
+ download_dataset = not os.path.isdir(path)
train_dataset = MultilingualLibriSpeech(
- "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="dev", download=False
+ dataset_path, "mls_polish_opus", split="dev", download=download_dataset
)
test_dataset = MultilingualLibriSpeech(
- "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="test", download=False
+ dataset_path, "mls_polish_opus", split="test", download=False
)
- kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
+ # initialize distributed training
+ ngpus_per_node = torch.cuda.device_count()
+ if hparams["distributed"]:
+ if "SLURM_PROCID" in os.environ: # for slurm scheduler
+ hparams["rank"] = int(os.environ["SLURM_PROCID"])
+ hparams["gpu"] = hparams["rank"] % ngpus_per_node
+ dist.init_process_group(
+ backend="nccl",
+ init_method="env://",
+ world_size=hparams["world_size"],
+ rank=hparams["rank"],
+ )
+ train_sampler = (
+ DistributedSampler(train_dataset, shuffle=True)
+ if hparams["distributed"]
+ else None
+ )
train_loader = DataLoader(
train_dataset,
batch_size=hparams["batch_size"],
shuffle=True,
+ sampler=train_sampler,
+ num_workers=hparams["world_size"], # TODO?
+ pin_memory=True,
collate_fn=lambda x: data_processing(x, "train"),
- **kwargs,
)
test_loader = DataLoader(
test_dataset,
batch_size=hparams["batch_size"],
shuffle=True,
+ sampler=None,
+ num_workers=hparams["world_size"], # TODO?
+ pin_memory=True,
collate_fn=lambda x: data_processing(x, "train"),
- **kwargs,
)
model = SpeechRecognitionModel(
@@ -443,13 +476,30 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No
hparams["dropout"],
).to(device)
+ if hparams["distributed"]:
+ if "gpu" in hparams:
+ torch.cuda.set_device(hparams["gpu"])
+ model.cuda(hparams["gpu"])
+ model = torch.nn.parallel.DistributedDataParallel(
+ model, device_ids=[hparams["gpu"]]
+ )
+ model_without_ddp = model.module
+ else:
+ model.cuda()
+ model = torch.nn.parallel.DistributedDataParallel(model)
+ model_without_ddp = model.module
+
print(
"Num Model Parameters", sum([param.nelement() for param in model.parameters()])
)
-
optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"])
criterion = nn.CTCLoss(blank=28).to(device)
-
+ if load:
+ checkpoint = torch.load(path)
+ model.load_state_dict(checkpoint["model_state_dict"])
+ optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
+ epoch = checkpoint["epoch"]
+ loss = checkpoint["loss"]
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=hparams["learning_rate"],
@@ -460,7 +510,7 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No
iter_meter = IterMeter()
for epoch in range(1, epochs + 1):
- train(
+ loss = train(
model,
device,
train_loader,
@@ -470,22 +520,46 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No
epoch,
iter_meter,
)
- if epoch%3 == 0 or epoch == epochs:
- torch.save({
- 'epoch': epoch,
- 'model_state_dict': model.state_dict(),
- },MODEL_SAVE_PATH)
+ if epoch % 3 == 0 or epoch == epochs:
+ torch.save(
+ {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss},
+ path,
+ )
test(model=model, device=device, test_loader=test_loader, criterion=criterion)
@click.command()
-@click.option("--learning-rate", default=1e-3, help="Learning rate")
+@click.option("--learning_rate", default=1e-3, help="Learning rate")
@click.option("--batch_size", default=1, help="Batch size")
@click.option("--epochs", default=1, help="Number of epochs")
-def run_cli(learning_rate: float, batch_size: int, epochs: int) -> None:
+@click.option("--world_size", default=1, help="Number of nodes for distribution")
+@click.option("--load", default=False, help="Do you want to load a model?")
+@click.option(
+ "--path",
+ default="models/model.pt",
+ help="Path where the model will be saved to/loaded from",
+)
+@click.option(
+ "--dataset_path",
+ default="data/",
+ help="Path for the dataset directory",
+)
+def run_cli(
+ learning_rate: float,
+ batch_size: int,
+ epochs: int,
+ world_size: int,
+ load: bool,
+ path: str,
+ dataset_path: str,
+) -> None:
"""Runs the training script."""
- run(learning_rate=learning_rate, batch_size=batch_size, epochs=epochs)
-
-
-if __name__ == "__main__":
- run(learning_rate=5e-4, batch_size=16, epochs=1)
+ run(
+ learning_rate=learning_rate,
+ batch_size=batch_size,
+ epochs=epochs,
+ world_size=world_size,
+ load=load,
+ path=path,
+ dataset_path=dataset_path,
+ )