diff options
author | Marvin Borner | 2023-08-19 14:40:47 +0200 |
---|---|---|
committer | Marvin Borner | 2023-08-19 14:40:47 +0200 |
commit | ec177107cb3a1a31d2fc49cc4990413af287305e (patch) | |
tree | 79d45badaf5f7ee5bf049db214b8b598c335a617 /swr2_asr/train.py | |
parent | 897e74f695e291029a08b280c2cea40a2a9639cc (diff) |
Fixed some distribution thingies
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r-- | swr2_asr/train.py | 98 |
1 files changed, 70 insertions, 28 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py index ba002e0..56d10c0 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -13,8 +13,6 @@ import torchaudio from .loss_scores import cer, wer - - class TextTransform: """Maps characters to integers and vice versa""" @@ -357,6 +355,7 @@ def train( ) return loss.item() + def test(model, device, test_loader, criterion): """Test""" print("\nevaluating...") @@ -391,7 +390,15 @@ def test(model, device, test_loader, criterion): ) -def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3, world_size: int = 1, load: bool=False, path: str="models/model.pt") -> None: +def run( + learning_rate: float, + batch_size: int, + epochs: int, + world_size: int, + load: bool, + path: str, + dataset_path: str, +) -> None: """Runs the training script.""" hparams = { "n_cnn_layers": 3, @@ -413,29 +420,38 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3, world device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member # device = torch.device("mps") + download_dataset = not os.path.isdir(path) train_dataset = MultilingualLibriSpeech( - "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="dev", download=False + dataset_path, "mls_polish_opus", split="dev", download=download_dataset ) test_dataset = MultilingualLibriSpeech( - "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="test", download=False + dataset_path, "mls_polish_opus", split="test", download=False ) # initialize distributed training ngpus_per_node = torch.cuda.device_count() if hparams["distributed"]: - if 'SLURM_PROCID' in os.environ: # for slurm scheduler - hparams["rank"] = int(os.environ['SLURM_PROCID']) + if "SLURM_PROCID" in os.environ: # for slurm scheduler + hparams["rank"] = int(os.environ["SLURM_PROCID"]) hparams["gpu"] = hparams["rank"] % ngpus_per_node - dist.init_process_group(backend="nccl", init_method="env://", - world_size=hparams["world_size"], rank=hparams["rank"]) + dist.init_process_group( + backend="nccl", + init_method="env://", + world_size=hparams["world_size"], + rank=hparams["rank"], + ) - train_sampler = DistributedSampler(train_dataset, shuffle=True) + train_sampler = ( + DistributedSampler(train_dataset, shuffle=True) + if hparams["distributed"] + else None + ) train_loader = DataLoader( train_dataset, batch_size=hparams["batch_size"], shuffle=True, sampler=train_sampler, - num_workers=hparams["world_size"], # TODO? + num_workers=hparams["world_size"], # TODO? pin_memory=True, collate_fn=lambda x: data_processing(x, "train"), ) @@ -445,7 +461,7 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3, world batch_size=hparams["batch_size"], shuffle=True, sampler=None, - num_workers=hparams["world_size"], # TODO? + num_workers=hparams["world_size"], # TODO? pin_memory=True, collate_fn=lambda x: data_processing(x, "train"), ) @@ -464,7 +480,9 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3, world if "gpu" in hparams: torch.cuda.set_device(hparams["gpu"]) model.cuda(hparams["gpu"]) - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[hparams["gpu"]]) + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[hparams["gpu"]] + ) model_without_ddp = model.module else: model.cuda() @@ -478,10 +496,10 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3, world criterion = nn.CTCLoss(blank=28).to(device) if load: checkpoint = torch.load(path) - model.load_state_dict(checkpoint['model_state_dict']) - optimizer.load_state_dict(checkpoint['optimizer_state_dict']) - epoch = checkpoint['epoch'] - loss = checkpoint['loss'] + model.load_state_dict(checkpoint["model_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + epoch = checkpoint["epoch"] + loss = checkpoint["loss"] scheduler = optim.lr_scheduler.OneCycleLR( optimizer, max_lr=hparams["learning_rate"], @@ -502,22 +520,46 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3, world epoch, iter_meter, ) - if epoch%3 == 0 or epoch == epochs: - torch.save({ - 'epoch': epoch, - 'model_state_dict': model.state_dict(), - 'loss': loss},path) + if epoch % 3 == 0 or epoch == epochs: + torch.save( + {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss}, + path, + ) test(model=model, device=device, test_loader=test_loader, criterion=criterion) @click.command() -@click.option("--learning-rate", default=1e-3, help="Learning rate") +@click.option("--learning_rate", default=1e-3, help="Learning rate") @click.option("--batch_size", default=1, help="Batch size") @click.option("--epochs", default=1, help="Number of epochs") @click.option("--world_size", default=1, help="Number of nodes for distribution") -@click.option("--load", default = False, help="Do you want to load a model?") -@click.option("--path",default="models/model.pt", - help= "Path where the model will be saved to/loaded from" ) -def run_cli(learning_rate: float, batch_size: int, epochs: int, world_size: int, load: bool, path: str) -> None: +@click.option("--load", default=False, help="Do you want to load a model?") +@click.option( + "--path", + default="models/model.pt", + help="Path where the model will be saved to/loaded from", +) +@click.option( + "--dataset_path", + default="data/", + help="Path for the dataset directory", +) +def run_cli( + learning_rate: float, + batch_size: int, + epochs: int, + world_size: int, + load: bool, + path: str, + dataset_path: str, +) -> None: """Runs the training script.""" - run(learning_rate=learning_rate, batch_size=batch_size, epochs=epochs, world_size=world_size, load=load, path=path) + run( + learning_rate=learning_rate, + batch_size=batch_size, + epochs=epochs, + world_size=world_size, + load=load, + path=path, + dataset_path=dataset_path, + ) |