diff options
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r-- | swr2_asr/train.py | 89 |
1 files changed, 69 insertions, 20 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 2628028..d13683f 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -1,4 +1,5 @@ """Training script for the ASR model.""" +import os import click import torch import torch.nn.functional as F @@ -7,6 +8,7 @@ from AudioLoader.speech import MultilingualLibriSpeech from torch import nn, optim from torch.utils.data import DataLoader from tokenizers import Tokenizer +from .tokenizer import CharTokenizer from .loss_scores import cer, wer @@ -18,7 +20,9 @@ train_audio_transforms = nn.Sequential( valid_audio_transforms = torchaudio.transforms.MelSpectrogram() -text_transform = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") +# text_transform = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") +text_transform = CharTokenizer() +text_transform.from_file("data/tokenizers/char_tokenizer_german.json") def data_processing(data, data_type="train"): @@ -59,7 +63,11 @@ def greedy_decoder( targets = [] for i, args in enumerate(arg_maxes): decode = [] - targets.append(text_transform.decode(labels[i][: label_lengths[i]].tolist())) + targets.append( + text_transform.decode( + [int(x) for x in labels[i][: label_lengths[i]].tolist()] + ) + ) for j, index in enumerate(args): if index != blank_label: if collapse_repeated and j != 0 and index == args[j - 1]: @@ -269,6 +277,7 @@ def train( ({100.0 * batch_idx / len(train_loader)}%)]\t \ Loss: {loss.item()}" ) + return loss.item() def test(model, device, test_loader, criterion): @@ -305,13 +314,20 @@ def test(model, device, test_loader, criterion): ) -def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None: +def run( + learning_rate: float, + batch_size: int, + epochs: int, + load: bool, + path: str, + dataset_path: str, +) -> None: """Runs the training script.""" hparams = { "n_cnn_layers": 3, "n_rnn_layers": 5, "rnn_dim": 512, - "n_class": 46, + "n_class": 36, "n_feats": 128, "stride": 2, "dropout": 0.1, @@ -325,21 +341,19 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member # device = torch.device("mps") + download_dataset = not os.path.isdir(path) train_dataset = MultilingualLibriSpeech( - "/Volumes/pherkel 2/SWR2-ASR/", "mls_german_opus", split="dev", download=False + dataset_path, "mls_german_opus", split="dev", download=download_dataset ) test_dataset = MultilingualLibriSpeech( - "/Volumes/pherkel 2/SWR2-ASR/", "mls_german_opus", split="test", download=False + dataset_path, "mls_german_opus", split="test", download=False ) - kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} - train_loader = DataLoader( train_dataset, batch_size=hparams["batch_size"], shuffle=True, collate_fn=lambda x: data_processing(x, "train"), - **kwargs, ) test_loader = DataLoader( @@ -347,9 +361,12 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No batch_size=hparams["batch_size"], shuffle=True, collate_fn=lambda x: data_processing(x, "train"), - **kwargs, ) + # enable flag to find the most compatible algorithms in advance + if use_cuda: + torch.backends.cudnn.benchmark = True + model = SpeechRecognitionModel( hparams["n_cnn_layers"], hparams["n_rnn_layers"], @@ -363,10 +380,14 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No print( "Num Model Parameters", sum([param.nelement() for param in model.parameters()]) ) - optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) criterion = nn.CTCLoss(blank=28).to(device) - + if load: + checkpoint = torch.load(path) + model.load_state_dict(checkpoint["model_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + epoch = checkpoint["epoch"] + loss = checkpoint["loss"] scheduler = optim.lr_scheduler.OneCycleLR( optimizer, max_lr=hparams["learning_rate"], @@ -377,7 +398,7 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No iter_meter = IterMeter() for epoch in range(1, epochs + 1): - train( + loss = train( model, device, train_loader, @@ -387,17 +408,45 @@ def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> No epoch, iter_meter, ) + test(model=model, device=device, test_loader=test_loader, criterion=criterion) + print("saving epoch", str(epoch)) + torch.save( + {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss}, + path + str(epoch), + ) @click.command() -@click.option("--learning-rate", default=1e-3, help="Learning rate") -@click.option("--batch_size", default=1, help="Batch size") +@click.option("--learning_rate", default=1e-3, help="Learning rate") +@click.option("--batch_size", default=10, help="Batch size") @click.option("--epochs", default=1, help="Number of epochs") -def run_cli(learning_rate: float, batch_size: int, epochs: int) -> None: +@click.option("--load", default=False, help="Do you want to load a model?") +@click.option( + "--path", + default="model", + help="Path where the model will be saved to/loaded from", +) +@click.option( + "--dataset_path", + default="data/", + help="Path for the dataset directory", +) +def run_cli( + learning_rate: float, + batch_size: int, + epochs: int, + load: bool, + path: str, + dataset_path: str, +) -> None: """Runs the training script.""" - run(learning_rate=learning_rate, batch_size=batch_size, epochs=epochs) - -if __name__ == "__main__": - run(learning_rate=5e-4, batch_size=16, epochs=1) + run( + learning_rate=learning_rate, + batch_size=batch_size, + epochs=epochs, + load=load, + path=path, + dataset_path=dataset_path, + ) |