aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/train.py
diff options
context:
space:
mode:
authorMarvin Borner2023-08-19 14:40:47 +0200
committerMarvin Borner2023-08-19 14:40:47 +0200
commitec177107cb3a1a31d2fc49cc4990413af287305e (patch)
tree79d45badaf5f7ee5bf049db214b8b598c335a617 /swr2_asr/train.py
parent897e74f695e291029a08b280c2cea40a2a9639cc (diff)
Fixed some distribution thingies
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r--swr2_asr/train.py98
1 files changed, 70 insertions, 28 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index ba002e0..56d10c0 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -13,8 +13,6 @@ import torchaudio
from .loss_scores import cer, wer
-
-
class TextTransform:
"""Maps characters to integers and vice versa"""
@@ -357,6 +355,7 @@ def train(
)
return loss.item()
+
def test(model, device, test_loader, criterion):
"""Test"""
print("\nevaluating...")
@@ -391,7 +390,15 @@ def test(model, device, test_loader, criterion):
)
-def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3, world_size: int = 1, load: bool=False, path: str="models/model.pt") -> None:
+def run(
+ learning_rate: float,
+ batch_size: int,
+ epochs: int,
+ world_size: int,
+ load: bool,
+ path: str,
+ dataset_path: str,
+) -> None:
"""Runs the training script."""
hparams = {
"n_cnn_layers": 3,
@@ -413,29 +420,38 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3, world
device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member
# device = torch.device("mps")
+ download_dataset = not os.path.isdir(path)
train_dataset = MultilingualLibriSpeech(
- "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="dev", download=False
+ dataset_path, "mls_polish_opus", split="dev", download=download_dataset
)
test_dataset = MultilingualLibriSpeech(
- "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="test", download=False
+ dataset_path, "mls_polish_opus", split="test", download=False
)
# initialize distributed training
ngpus_per_node = torch.cuda.device_count()
if hparams["distributed"]:
- if 'SLURM_PROCID' in os.environ: # for slurm scheduler
- hparams["rank"] = int(os.environ['SLURM_PROCID'])
+ if "SLURM_PROCID" in os.environ: # for slurm scheduler
+ hparams["rank"] = int(os.environ["SLURM_PROCID"])
hparams["gpu"] = hparams["rank"] % ngpus_per_node
- dist.init_process_group(backend="nccl", init_method="env://",
- world_size=hparams["world_size"], rank=hparams["rank"])
+ dist.init_process_group(
+ backend="nccl",
+ init_method="env://",
+ world_size=hparams["world_size"],
+ rank=hparams["rank"],
+ )
- train_sampler = DistributedSampler(train_dataset, shuffle=True)
+ train_sampler = (
+ DistributedSampler(train_dataset, shuffle=True)
+ if hparams["distributed"]
+ else None
+ )
train_loader = DataLoader(
train_dataset,
batch_size=hparams["batch_size"],
shuffle=True,
sampler=train_sampler,
- num_workers=hparams["world_size"], # TODO?
+ num_workers=hparams["world_size"], # TODO?
pin_memory=True,
collate_fn=lambda x: data_processing(x, "train"),
)
@@ -445,7 +461,7 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3, world
batch_size=hparams["batch_size"],
shuffle=True,
sampler=None,
- num_workers=hparams["world_size"], # TODO?
+ num_workers=hparams["world_size"], # TODO?
pin_memory=True,
collate_fn=lambda x: data_processing(x, "train"),
)
@@ -464,7 +480,9 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3, world
if "gpu" in hparams:
torch.cuda.set_device(hparams["gpu"])
model.cuda(hparams["gpu"])
- model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[hparams["gpu"]])
+ model = torch.nn.parallel.DistributedDataParallel(
+ model, device_ids=[hparams["gpu"]]
+ )
model_without_ddp = model.module
else:
model.cuda()
@@ -478,10 +496,10 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3, world
criterion = nn.CTCLoss(blank=28).to(device)
if load:
checkpoint = torch.load(path)
- model.load_state_dict(checkpoint['model_state_dict'])
- optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
- epoch = checkpoint['epoch']
- loss = checkpoint['loss']
+ model.load_state_dict(checkpoint["model_state_dict"])
+ optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
+ epoch = checkpoint["epoch"]
+ loss = checkpoint["loss"]
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=hparams["learning_rate"],
@@ -502,22 +520,46 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3, world
epoch,
iter_meter,
)
- if epoch%3 == 0 or epoch == epochs:
- torch.save({
- 'epoch': epoch,
- 'model_state_dict': model.state_dict(),
- 'loss': loss},path)
+ if epoch % 3 == 0 or epoch == epochs:
+ torch.save(
+ {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss},
+ path,
+ )
test(model=model, device=device, test_loader=test_loader, criterion=criterion)
@click.command()
-@click.option("--learning-rate", default=1e-3, help="Learning rate")
+@click.option("--learning_rate", default=1e-3, help="Learning rate")
@click.option("--batch_size", default=1, help="Batch size")
@click.option("--epochs", default=1, help="Number of epochs")
@click.option("--world_size", default=1, help="Number of nodes for distribution")
-@click.option("--load", default = False, help="Do you want to load a model?")
-@click.option("--path",default="models/model.pt",
- help= "Path where the model will be saved to/loaded from" )
-def run_cli(learning_rate: float, batch_size: int, epochs: int, world_size: int, load: bool, path: str) -> None:
+@click.option("--load", default=False, help="Do you want to load a model?")
+@click.option(
+ "--path",
+ default="models/model.pt",
+ help="Path where the model will be saved to/loaded from",
+)
+@click.option(
+ "--dataset_path",
+ default="data/",
+ help="Path for the dataset directory",
+)
+def run_cli(
+ learning_rate: float,
+ batch_size: int,
+ epochs: int,
+ world_size: int,
+ load: bool,
+ path: str,
+ dataset_path: str,
+) -> None:
"""Runs the training script."""
- run(learning_rate=learning_rate, batch_size=batch_size, epochs=epochs, world_size=world_size, load=load, path=path)
+ run(
+ learning_rate=learning_rate,
+ batch_size=batch_size,
+ epochs=epochs,
+ world_size=world_size,
+ load=load,
+ path=path,
+ dataset_path=dataset_path,
+ )