diff options
Diffstat (limited to 'swr2_asr')
-rw-r--r-- | swr2_asr/__main__.py | 12 | ||||
-rw-r--r-- | swr2_asr/inference.py | 16 | ||||
-rw-r--r-- | swr2_asr/model_deep_speech.py | 17 | ||||
-rw-r--r-- | swr2_asr/train.py | 192 | ||||
-rw-r--r-- | swr2_asr/utils/data.py | 7 | ||||
-rw-r--r-- | swr2_asr/utils/tokenizer.py | 8 | ||||
-rw-r--r-- | swr2_asr/utils/visualization.py | 8 |
7 files changed, 109 insertions, 151 deletions
diff --git a/swr2_asr/__main__.py b/swr2_asr/__main__.py deleted file mode 100644 index be294fb..0000000 --- a/swr2_asr/__main__.py +++ /dev/null @@ -1,12 +0,0 @@ -"""Main entrypoint for swr2-asr.""" -import torch -import torchaudio - -if __name__ == "__main__": - # test if GPU is available - print("GPU available: ", torch.cuda.is_available()) - - # test if torchaudio is installed correctly - print("torchaudio version: ", torchaudio.__version__) - print("torchaudio backend: ", torchaudio.get_audio_backend()) - print("torchaudio info: ", torchaudio.get_audio_backend()) diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py index c3eec42..f8342f7 100644 --- a/swr2_asr/inference.py +++ b/swr2_asr/inference.py @@ -1,11 +1,12 @@ """Training script for the ASR model.""" +from typing import TypedDict + import torch -import torchaudio import torch.nn.functional as F -from typing import TypedDict +import torchaudio -from swr2_asr.tokenizer import CharTokenizer from swr2_asr.model_deep_speech import SpeechRecognitionModel +from swr2_asr.utils.tokenizer import CharTokenizer class HParams(TypedDict): @@ -28,8 +29,7 @@ def greedy_decoder(output, tokenizer, collapse_repeated=True): arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member blank_label = tokenizer.encode(" ").ids[0] decodes = [] - targets = [] - for i, args in enumerate(arg_maxes): + for _i, args in enumerate(arg_maxes): decode = [] for j, index in enumerate(args): if index != blank_label: @@ -44,7 +44,7 @@ def main() -> None: """inference function.""" device = "cuda" if torch.cuda.is_available() else "cpu" - device = torch.device(device) + device = torch.device(device) # pylint: disable=no-member tokenizer = CharTokenizer.from_file("char_tokenizer_german.json") @@ -90,7 +90,7 @@ def main() -> None: model.load_state_dict(state_dict) # waveform, sample_rate = torchaudio.load("test.opus") - waveform, sample_rate = torchaudio.load("marvin_rede.flac") + waveform, sample_rate = torchaudio.load("marvin_rede.flac") # pylint: disable=no-member if sample_rate != spectrogram_hparams["sample_rate"]: resampler = torchaudio.transforms.Resample(sample_rate, spectrogram_hparams["sample_rate"]) waveform = resampler(waveform) @@ -103,7 +103,7 @@ def main() -> None: specs = [spec] specs = torch.nn.utils.rnn.pad_sequence(specs, batch_first=True).unsqueeze(1).transpose(2, 3) - output = model(specs) + output = model(specs) # pylint: disable=not-callable output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) decodes = greedy_decoder(output, tokenizer) diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py index 8ddbd99..77f4c8a 100644 --- a/swr2_asr/model_deep_speech.py +++ b/swr2_asr/model_deep_speech.py @@ -3,27 +3,10 @@ Following definition by Assembly AI (https://www.assemblyai.com/blog/end-to-end-speech-recognition-pytorch/) """ -from typing import TypedDict - import torch.nn.functional as F from torch import nn -class HParams(TypedDict): - """Type for the hyperparameters of the model.""" - - n_cnn_layers: int - n_rnn_layers: int - rnn_dim: int - n_class: int - n_feats: int - stride: int - dropout: float - learning_rate: float - batch_size: int - epochs: int - - class CNNLayerNorm(nn.Module): """Layer normalization built for cnns input""" diff --git a/swr2_asr/train.py b/swr2_asr/train.py index ac7666b..eb79ee2 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -5,11 +5,12 @@ from typing import TypedDict import click import torch import torch.nn.functional as F +import yaml from torch import nn, optim from torch.utils.data import DataLoader from tqdm.autonotebook import tqdm -from swr2_asr.model_deep_speech import HParams, SpeechRecognitionModel +from swr2_asr.model_deep_speech import SpeechRecognitionModel from swr2_asr.utils.data import DataProcessing, MLSDataset, Split from swr2_asr.utils.decoder import greedy_decoder from swr2_asr.utils.tokenizer import CharTokenizer @@ -17,7 +18,7 @@ from swr2_asr.utils.tokenizer import CharTokenizer from .utils.loss_scores import cer, wer -class IterMeter(object): +class IterMeter: """keeps track of total iterations""" def __init__(self): @@ -116,6 +117,7 @@ class TestArgs(TypedDict): def test(test_args: TestArgs) -> tuple[float, float, float]: + """Test""" print("\nevaluating...") # get values from test_args: @@ -128,7 +130,7 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: test_loss = 0 test_cer, test_wer = [], [] with torch.no_grad(): - for i, _data in enumerate(tqdm(test_loader, desc="Validation Batches")): + for _data in tqdm(test_loader, desc="Validation Batches"): spectrograms, labels, input_lengths, label_lengths = _data spectrograms, labels = spectrograms.to(device), labels.to(device) @@ -142,8 +144,6 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: decoded_preds, decoded_targets = greedy_decoder( output.transpose(0, 1), labels, label_lengths, tokenizer ) - if i == 1: - print(f"decoding first sample: {decoded_preds}") for j, _ in enumerate(decoded_preds): test_cer.append(cer(decoded_targets[j], decoded_preds[j])) test_wer.append(wer(decoded_targets[j], decoded_preds[j])) @@ -161,157 +161,149 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: return test_loss, avg_cer, avg_wer -def main( - learning_rate: float, - batch_size: int, - epochs: int, - dataset_path: str, - language: str, - limited_supervision: bool, - model_load_path: str, - model_save_path: str, - dataset_percentage: float, - eval_every: int, - num_workers: int, -): +@click.command() +@click.option( + "--config_path", + default="config.yaml", + help="Path to yaml config file", + type=click.Path(exists=True), +) +def main(config_path: str): """Main function for training the model. - Args: - learning_rate: learning rate for the optimizer - batch_size: batch size - epochs: number of epochs to train - dataset_path: path for the dataset - language: language of the dataset - limited_supervision: whether to use only limited supervision - model_load_path: path to load a model from - model_save_path: path to save the model to - dataset_percentage: percentage of the dataset to use - eval_every: evaluate every n epochs - num_workers: number of workers for the dataloader + Gets all configuration arguments from yaml config file. """ use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member torch.manual_seed(7) - if not os.path.isdir(dataset_path): - os.makedirs(dataset_path) + with open(config_path, "r", encoding="utf-8") as yaml_file: + config_dict = yaml.safe_load(yaml_file) + + # Create separate dictionaries for each top-level key + model_config = config_dict.get("model", {}) + training_config = config_dict.get("training", {}) + dataset_config = config_dict.get("dataset", {}) + tokenizer_config = config_dict.get("tokenizer", {}) + checkpoints_config = config_dict.get("checkpoints", {}) + + print(training_config["learning_rate"]) + + if not os.path.isdir(dataset_config["dataset_root_path"]): + os.makedirs(dataset_config["dataset_root_path"]) train_dataset = MLSDataset( - dataset_path, - language, + dataset_config["dataset_root_path"], + dataset_config["language_name"], Split.TEST, - download=True, - limited=limited_supervision, - size=dataset_percentage, + download=dataset_config["download"], + limited=dataset_config["limited_supervision"], + size=dataset_config["dataset_percentage"], ) valid_dataset = MLSDataset( - dataset_path, - language, + dataset_config["dataset_root_path"], + dataset_config["language_name"], Split.TRAIN, - download=False, - limited=Falimited_supervisionlse, - size=dataset_percentage, + download=dataset_config["download"], + limited=dataset_config["limited_supervision"], + size=dataset_config["dataset_percentage"], ) - # TODO: initialize and possibly train tokenizer if none found - - kwargs = {"num_workers": num_workers, "pin_memory": True} if use_cuda else {} - - hparams = HParams( - n_cnn_layers=3, - n_rnn_layers=5, - rnn_dim=512, - n_class=tokenizer.get_vocab_size(), - n_feats=128, - stride=2, - dropout=0.1, - learning_rate=learning_rate, - batch_size=batch_size, - epochs=epochs, - ) + kwargs = {"num_workers": training_config["num_workers"], "pin_memory": True} if use_cuda else {} + + if tokenizer_config["tokenizer_path"] is None: + print("Tokenizer not found!") + if click.confirm("Do you want to train a new tokenizer?", default=True): + pass + else: + return + tokenizer = CharTokenizer.train( + dataset_config["dataset_root_path"], dataset_config["language_name"] + ) + tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"]) train_data_processing = DataProcessing("train", tokenizer) valid_data_processing = DataProcessing("valid", tokenizer) train_loader = DataLoader( dataset=train_dataset, - batch_size=hparams["batch_size"], - shuffle=True, + batch_size=training_config["batch_size"], + shuffle=dataset_config["shuffle"], collate_fn=train_data_processing, **kwargs, ) valid_loader = DataLoader( dataset=valid_dataset, - batch_size=hparams["batch_size"], - shuffle=False, + batch_size=training_config["batch_size"], + shuffle=dataset_config["shuffle"], collate_fn=valid_data_processing, **kwargs, ) model = SpeechRecognitionModel( - hparams["n_cnn_layers"], - hparams["n_rnn_layers"], - hparams["rnn_dim"], - hparams["n_class"], - hparams["n_feats"], - hparams["stride"], - hparams["dropout"], + model_config["n_cnn_layers"], + model_config["n_rnn_layers"], + model_config["rnn_dim"], + tokenizer.get_vocab_size(), + model_config["n_feats"], + model_config["stride"], + model_config["dropout"], ).to(device) - optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) + optimizer = optim.AdamW(model.parameters(), training_config["learning_rate"]) criterion = nn.CTCLoss(tokenizer.get_blank_token()).to(device) scheduler = optim.lr_scheduler.OneCycleLR( optimizer, - max_lr=hparams["learning_rate"], + max_lr=training_config["learning_rate"], steps_per_epoch=int(len(train_loader)), - epochs=hparams["epochs"], + epochs=training_config["epochs"], anneal_strategy="linear", ) prev_epoch = 0 - if model_load_path is not None: - checkpoint = torch.load(model_load_path) + if checkpoints_config["model_load_path"] is not None: + checkpoint = torch.load(checkpoints_config["model_load_path"]) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) prev_epoch = checkpoint["epoch"] iter_meter = IterMeter() - if not os.path.isdir(os.path.dirname(model_save_path)): - os.makedirs(os.path.dirname(model_save_path)) - for epoch in range(prev_epoch + 1, epochs + 1): - train_args: TrainArgs = dict( - model=model, - device=device, - train_loader=train_loader, - criterion=criterion, - optimizer=optimizer, - scheduler=scheduler, - epoch=epoch, - iter_meter=iter_meter, - ) + + for epoch in range(prev_epoch + 1, training_config["epochs"] + 1): + train_args: TrainArgs = { + "model": model, + "device": device, + "train_loader": train_loader, + "criterion": criterion, + "optimizer": optimizer, + "scheduler": scheduler, + "epoch": epoch, + "iter_meter": iter_meter, + } train_loss = train(train_args) test_loss, test_cer, test_wer = 0, 0, 0 - test_args: TestArgs = dict( - model=model, - device=device, - test_loader=valid_loader, - criterion=criterion, - tokenizer=tokenizer, - decoder="greedy", - ) + test_args: TestArgs = { + "model": model, + "device": device, + "test_loader": valid_loader, + "criterion": criterion, + "tokenizer": tokenizer, + "decoder": "greedy", + } - if epoch % eval_every == 0: + if training_config["eval_every_n"] != 0 and epoch % training_config["eval_every_n"] == 0: test_loss, test_cer, test_wer = test(test_args) - if model_save_path is None: + if checkpoints_config["model_save_path"] is None: continue - if not os.path.isdir(os.path.dirname(model_save_path)): - os.makedirs(os.path.dirname(model_save_path)) + if not os.path.isdir(os.path.dirname(checkpoints_config["model_save_path"])): + os.makedirs(os.path.dirname(checkpoints_config["model_save_path"])) + torch.save( { "epoch": epoch, @@ -322,7 +314,7 @@ def main( "avg_cer": test_cer, "avg_wer": test_wer, }, - model_save_path + str(epoch), + checkpoints_config["model_save_path"] + str(epoch), ) diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index e939e1d..0e06eec 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -1,13 +1,12 @@ """Class containing utils for the ASR system.""" import os from enum import Enum -from typing import TypedDict import numpy as np import torch import torchaudio from torch import Tensor, nn -from torch.utils.data import Dataset +from torch.utils.data import DataLoader, Dataset from torchaudio.datasets.utils import _extract_tar from swr2_asr.utils.tokenizer import CharTokenizer @@ -125,7 +124,7 @@ class MLSDataset(Dataset): self._handle_download_dataset(download) self._validate_local_directory() - if limited and (split == Split.TRAIN or split == Split.VALID): + if limited and split in (Split.TRAIN, Split.VALID): self.initialize_limited() else: self.initialize() @@ -351,8 +350,6 @@ class MLSDataset(Dataset): if __name__ == "__main__": - from torch.utils.data import DataLoader - DATASET_PATH = "/Volumes/pherkel/SWR2-ASR" LANGUAGE = "mls_german_opus" split = Split.DEV diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py index 5482bbe..22569eb 100644 --- a/swr2_asr/utils/tokenizer.py +++ b/swr2_asr/utils/tokenizer.py @@ -1,8 +1,6 @@ """Tokenizer for Multilingual Librispeech datasets""" - - -from datetime import datetime import os +from datetime import datetime from tqdm.autonotebook import tqdm @@ -119,8 +117,8 @@ class CharTokenizer: line = line.strip() if line: char, index = line.split() - tokenizer.char_map[char] = int(index) - tokenizer.index_map[int(index)] = char + load_tokenizer.char_map[char] = int(index) + load_tokenizer.index_map[int(index)] = char return load_tokenizer diff --git a/swr2_asr/utils/visualization.py b/swr2_asr/utils/visualization.py index 80f942a..a55d0d5 100644 --- a/swr2_asr/utils/visualization.py +++ b/swr2_asr/utils/visualization.py @@ -6,10 +6,10 @@ import torch def plot(epochs, path): """Plots the losses over the epochs""" - losses = list() - test_losses = list() - cers = list() - wers = list() + losses = [] + test_losses = [] + cers = [] + wers = [] for epoch in range(1, epochs + 1): current_state = torch.load(path + str(epoch)) losses.append(current_state["loss"]) |