diff options
-rw-r--r-- | .vscode/settings.json | 26 | ||||
-rw-r--r-- | pyproject.toml | 3 | ||||
-rw-r--r-- | swr2_asr/model_deep_speech.py | 25 | ||||
-rw-r--r-- | swr2_asr/train.py | 372 | ||||
-rw-r--r-- | swr2_asr/utils/data.py | 14 | ||||
-rw-r--r-- | swr2_asr/utils/loss_scores.py (renamed from swr2_asr/loss_scores.py) | 0 |
6 files changed, 232 insertions, 208 deletions
diff --git a/.vscode/settings.json b/.vscode/settings.json index 0054bca..1adbc18 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,14 +1,14 @@ { - "[python]": { - "editor.formatOnType": true, - "editor.defaultFormatter": "ms-python.black-formatter", - "editor.formatOnSave": true, - "editor.rulers": [88, 120], - }, - "black-formatter.importStrategy": "fromEnvironment", - "python.analysis.typeCheckingMode": "basic", - "ruff.organizeImports": true, - "ruff.importStrategy": "fromEnvironment", - "ruff.fixAll": true, - "ruff.run": "onType" -}
\ No newline at end of file + "[python]": { + "editor.formatOnType": true, + "editor.defaultFormatter": "ms-python.black-formatter", + "editor.formatOnSave": true, + "editor.rulers": [88, 120] + }, + "black-formatter.importStrategy": "fromEnvironment", + "python.analysis.typeCheckingMode": "off", + "ruff.organizeImports": true, + "ruff.importStrategy": "fromEnvironment", + "ruff.fixAll": true, + "ruff.run": "onType" +} diff --git a/pyproject.toml b/pyproject.toml index 6f74b49..38cc51a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,12 +23,11 @@ black = "^23.7.0" mypy = "^1.5.1" pylint = "^2.17.5" ruff = "^0.0.285" -types-tqdm = "^4.66.0.1" [tool.ruff] select = ["E", "F", "B", "I"] fixable = ["ALL"] -line-length = 120 +line-length = 100 target-version = "py310" [tool.black] diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py index dd07ff9..8ddbd99 100644 --- a/swr2_asr/model_deep_speech.py +++ b/swr2_asr/model_deep_speech.py @@ -1,8 +1,29 @@ -"""Main definition of model""" +"""Main definition of the Deep speech 2 model by Baidu Research. + +Following definition by Assembly AI +(https://www.assemblyai.com/blog/end-to-end-speech-recognition-pytorch/) +""" +from typing import TypedDict + import torch.nn.functional as F from torch import nn +class HParams(TypedDict): + """Type for the hyperparameters of the model.""" + + n_cnn_layers: int + n_rnn_layers: int + rnn_dim: int + n_class: int + n_feats: int + stride: int + dropout: float + learning_rate: float + batch_size: int + epochs: int + + class CNNLayerNorm(nn.Module): """Layer normalization built for cnns input""" @@ -60,7 +81,7 @@ class ResidualCNN(nn.Module): class BidirectionalGRU(nn.Module): - """BIdirectional GRU with Layer Normalization and Dropout""" + """Bidirectional GRU with Layer Normalization and Dropout""" def __init__( self, diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 9f12bcb..ac7666b 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -7,50 +7,17 @@ import torch import torch.nn.functional as F from torch import nn, optim from torch.utils.data import DataLoader -from tqdm import tqdm - -from swr2_asr.model_deep_speech import SpeechRecognitionModel -from swr2_asr.tokenizer import CharTokenizer, train_char_tokenizer -from swr2_asr.utils import MLSDataset, Split, collate_fn,plot - -from .loss_scores import cer, wer - - -class HParams(TypedDict): - """Type for the hyperparameters of the model.""" - - n_cnn_layers: int - n_rnn_layers: int - rnn_dim: int - n_class: int - n_feats: int - stride: int - dropout: float - learning_rate: float - batch_size: int - epochs: int - - -def greedy_decoder(output, tokenizer, labels, label_lengths, collapse_repeated=True): - """Greedily decode a sequence.""" - print("output shape", output.shape) - arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member - blank_label = tokenizer.encode(" ").ids[0] - decodes = [] - targets = [] - for i, args in enumerate(arg_maxes): - decode = [] - targets.append(tokenizer.decode([int(x) for x in labels[i][: label_lengths[i]].tolist()])) - for j, index in enumerate(args): - if index != blank_label: - if collapse_repeated and j != 0 and index == args[j - 1]: - continue - decode.append(index.item()) - decodes.append(tokenizer.decode(decode)) - return decodes, targets - - -class IterMeter: +from tqdm.autonotebook import tqdm + +from swr2_asr.model_deep_speech import HParams, SpeechRecognitionModel +from swr2_asr.utils.data import DataProcessing, MLSDataset, Split +from swr2_asr.utils.decoder import greedy_decoder +from swr2_asr.utils.tokenizer import CharTokenizer + +from .utils.loss_scores import cer, wer + + +class IterMeter(object): """keeps track of total iterations""" def __init__(self): @@ -61,123 +28,195 @@ class IterMeter: self.val += 1 def get(self): - """get""" + """get steps""" return self.val -def train( - model, - device, - train_loader, - criterion, - optimizer, - scheduler, - epoch, - iter_meter, -): - """Train""" +class TrainArgs(TypedDict): + """Type for the arguments of the training function.""" + + model: SpeechRecognitionModel + device: torch.device # pylint: disable=no-member + train_loader: DataLoader + criterion: nn.CTCLoss + optimizer: optim.AdamW + scheduler: optim.lr_scheduler.OneCycleLR + epoch: int + iter_meter: IterMeter + + +def train(train_args) -> float: + """Train + Args: + model: model + device: device type + train_loader: train dataloader + criterion: loss function + optimizer: optimizer + scheduler: learning rate scheduler + epoch: epoch number + iter_meter: iteration meter + + Returns: + avg_train_loss: avg_train_loss for the epoch + + Information: + spectrograms: (batch, time, feature) + labels: (batch, label_length) + + model output: (batch,time, n_class) + + """ + # get values from train_args: + ( + model, + device, + train_loader, + criterion, + optimizer, + scheduler, + epoch, + iter_meter, + ) = train_args.values() + model.train() - print(f"Epoch: {epoch}") - losses = [] - for _data in tqdm(train_loader, desc="batches"): - spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device) + print(f"training batch {epoch}") + train_losses = [] + for _data in tqdm(train_loader, desc="Training batches"): + spectrograms, labels, input_lengths, label_lengths = _data + spectrograms, labels = spectrograms.to(device), labels.to(device) + optimizer.zero_grad() output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) - loss = criterion(output, labels, _data["input_length"], _data["utterance_length"]) + loss = criterion(output, labels, input_lengths, label_lengths) + train_losses.append(loss) loss.backward() optimizer.step() scheduler.step() iter_meter.step() + avg_train_loss = sum(train_losses) / len(train_losses) + print(f"Train set: Average loss: {avg_train_loss:.2f}") + return avg_train_loss + - losses.append(loss.item()) +class TestArgs(TypedDict): + """Type for the arguments of the test function.""" - print(f"loss in epoch {epoch}: {sum(losses) / len(losses)}") - return sum(losses) / len(losses) + model: SpeechRecognitionModel + device: torch.device # pylint: disable=no-member + test_loader: DataLoader + criterion: nn.CTCLoss + tokenizer: CharTokenizer + decoder: str -def test(model, device, test_loader, criterion, tokenizer): - """Test""" +def test(test_args: TestArgs) -> tuple[float, float, float]: print("\nevaluating...") + + # get values from test_args: + model, device, test_loader, criterion, tokenizer, decoder = test_args.values() + + if decoder == "greedy": + decoder = greedy_decoder + model.eval() test_loss = 0 test_cer, test_wer = [], [] with torch.no_grad(): - for _data in test_loader: - spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device) + for i, _data in enumerate(tqdm(test_loader, desc="Validation Batches")): + spectrograms, labels, input_lengths, label_lengths = _data + spectrograms, labels = spectrograms.to(device), labels.to(device) output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) - loss = criterion(output, labels, _data["input_length"], _data["utterance_length"]) + loss = criterion(output, labels, input_lengths, label_lengths) test_loss += loss.item() / len(test_loader) decoded_preds, decoded_targets = greedy_decoder( - output=output.transpose(0, 1), - labels=labels, - label_lengths=_data["utterance_length"], - tokenizer=tokenizer, + output.transpose(0, 1), labels, label_lengths, tokenizer ) - for j, pred in enumerate(decoded_preds): - test_cer.append(cer(decoded_targets[j], pred)) - test_wer.append(wer(decoded_targets[j], pred)) + if i == 1: + print(f"decoding first sample: {decoded_preds}") + for j, _ in enumerate(decoded_preds): + test_cer.append(cer(decoded_targets[j], decoded_preds[j])) + test_wer.append(wer(decoded_targets[j], decoded_preds[j])) avg_cer = sum(test_cer) / len(test_cer) avg_wer = sum(test_wer) / len(test_wer) print( - f"Test set: Average loss:\ - {test_loss}, Average CER: {None} Average WER: {None}\n" + f"Test set: \ + Average loss: {test_loss:.4f}, \ + Average CER: {avg_cer:4f} \ + Average WER: {avg_wer:.4f}\n" ) return test_loss, avg_cer, avg_wer -def run( +def main( learning_rate: float, batch_size: int, epochs: int, - load: bool, - path: str, dataset_path: str, language: str, -) -> None: - """Runs the training script.""" + limited_supervision: bool, + model_load_path: str, + model_save_path: str, + dataset_percentage: float, + eval_every: int, + num_workers: int, +): + """Main function for training the model. + + Args: + learning_rate: learning rate for the optimizer + batch_size: batch size + epochs: number of epochs to train + dataset_path: path for the dataset + language: language of the dataset + limited_supervision: whether to use only limited supervision + model_load_path: path to load a model from + model_save_path: path to save the model to + dataset_percentage: percentage of the dataset to use + eval_every: evaluate every n epochs + num_workers: number of workers for the dataloader + """ use_cuda = torch.cuda.is_available() - torch.manual_seed(42) device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member - # device = torch.device("mps") - # load dataset + torch.manual_seed(7) + + if not os.path.isdir(dataset_path): + os.makedirs(dataset_path) + train_dataset = MLSDataset( - dataset_path, language, Split.TRAIN, download=True, spectrogram_hparams=None, limited=True + dataset_path, + language, + Split.TEST, + download=True, + limited=limited_supervision, + size=dataset_percentage, ) valid_dataset = MLSDataset( - dataset_path, language, Split.VALID, download=True, spectrogram_hparams=None, limited=True + dataset_path, + language, + Split.TRAIN, + download=False, + limited=Falimited_supervisionlse, + size=dataset_percentage, ) - # load tokenizer (bpe by default): - if not os.path.isfile("data/tokenizers/char_tokenizer_german.json"): - print("There is no tokenizer available. Do you want to train it on the dataset?") - input("Press Enter to continue...") - train_char_tokenizer( - dataset_path=dataset_path, - language=language, - split="all", - out_path="data/tokenizers/char_tokenizer_german.json", - ) - - tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") - - train_dataset.set_tokenizer(tokenizer) # type: ignore - valid_dataset.set_tokenizer(tokenizer) # type: ignore + # TODO: initialize and possibly train tokenizer if none found - print(f"Waveform shape: {train_dataset[0]['waveform'].shape}") + kwargs = {"num_workers": num_workers, "pin_memory": True} if use_cuda else {} hparams = HParams( n_cnn_layers=3, @@ -192,24 +231,24 @@ def run( epochs=epochs, ) + train_data_processing = DataProcessing("train", tokenizer) + valid_data_processing = DataProcessing("valid", tokenizer) + train_loader = DataLoader( - train_dataset, + dataset=train_dataset, batch_size=hparams["batch_size"], shuffle=True, - collate_fn=lambda x: collate_fn(x), + collate_fn=train_data_processing, + **kwargs, ) - valid_loader = DataLoader( - valid_dataset, + dataset=valid_dataset, batch_size=hparams["batch_size"], - shuffle=True, - collate_fn=lambda x: collate_fn(x), + shuffle=False, + collate_fn=valid_data_processing, + **kwargs, ) - # enable flag to find the most compatible algorithms in advance - if use_cuda: - torch.backends.cudnn.benchmark = True # pylance: disable=no-member - model = SpeechRecognitionModel( hparams["n_cnn_layers"], hparams["n_rnn_layers"], @@ -219,16 +258,9 @@ def run( hparams["stride"], hparams["dropout"], ).to(device) - print(tokenizer.encode(" ")) - print("Num Model Parameters", sum((param.nelement() for param in model.parameters()))) + optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) - criterion = nn.CTCLoss(tokenizer.encode(" ").ids[0]).to(device) - if load: - checkpoint = torch.load(path) - model.load_state_dict(checkpoint["model_state_dict"]) - optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) - epoch = checkpoint["epoch"] - loss = checkpoint["loss"] + criterion = nn.CTCLoss(tokenizer.get_blank_token()).to(device) scheduler = optim.lr_scheduler.OneCycleLR( optimizer, max_lr=hparams["learning_rate"], @@ -236,77 +268,63 @@ def run( epochs=hparams["epochs"], anneal_strategy="linear", ) + prev_epoch = 0 + + if model_load_path is not None: + checkpoint = torch.load(model_load_path) + model.load_state_dict(checkpoint["model_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + prev_epoch = checkpoint["epoch"] iter_meter = IterMeter() - for epoch in range(1, epochs + 1): - loss = train( - model, - device, - train_loader, - criterion, - optimizer, - scheduler, - epoch, - iter_meter, + if not os.path.isdir(os.path.dirname(model_save_path)): + os.makedirs(os.path.dirname(model_save_path)) + for epoch in range(prev_epoch + 1, epochs + 1): + train_args: TrainArgs = dict( + model=model, + device=device, + train_loader=train_loader, + criterion=criterion, + optimizer=optimizer, + scheduler=scheduler, + epoch=epoch, + iter_meter=iter_meter, ) - test_loss, avg_cer, avg_wer = test( + train_loss = train(train_args) + + test_loss, test_cer, test_wer = 0, 0, 0 + + test_args: TestArgs = dict( model=model, device=device, test_loader=valid_loader, criterion=criterion, tokenizer=tokenizer, + decoder="greedy", ) - print("saving epoch", str(epoch)) + + if epoch % eval_every == 0: + test_loss, test_cer, test_wer = test(test_args) + + if model_save_path is None: + continue + + if not os.path.isdir(os.path.dirname(model_save_path)): + os.makedirs(os.path.dirname(model_save_path)) torch.save( { "epoch": epoch, "model_state_dict": model.state_dict(), - "loss": loss, + "optimizer_state_dict": optimizer.state_dict(), + "train_loss": train_loss, "test_loss": test_loss, - "avg_cer": avg_cer, - "avg_wer": avg_wer, + "avg_cer": test_cer, + "avg_wer": test_wer, }, - path + str(epoch), - plot(epochs,path) + model_save_path + str(epoch), ) -@click.command() -@click.option("--learning_rate", default=1e-3, help="Learning rate") -@click.option("--batch_size", default=10, help="Batch size") -@click.option("--epochs", default=1, help="Number of epochs") -@click.option("--load", default=False, help="Do you want to load a model?") -@click.option( - "--path", - default="model", - help="Path where the model will be saved to/loaded from", -) -@click.option( - "--dataset_path", - default="data/", - help="Path for the dataset directory", -) -def run_cli( - learning_rate: float, - batch_size: int, - epochs: int, - load: bool, - path: str, - dataset_path: str, -) -> None: - """Runs the training script.""" - - run( - learning_rate=learning_rate, - batch_size=batch_size, - epochs=epochs, - load=load, - path=path, - dataset_path=dataset_path, - language="mls_german_opus", - ) - - if __name__ == "__main__": - run_cli() # pylint: disable=no-value-for-parameter + main() # pylint: disable=no-value-for-parameter diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index 74d10c9..e939e1d 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -76,20 +76,6 @@ def split_to_mls_split(split_name: Split) -> MLSSplit: return split_name # type: ignore -class Sample(TypedDict): - """Type for a sample in the dataset""" - - waveform: torch.Tensor - spectrogram: torch.Tensor - input_length: int - utterance: torch.Tensor - utterance_length: int - sample_rate: int - speaker_id: str - book_id: str - chapter_id: str - - class MLSDataset(Dataset): """Custom Dataset for reading Multilingual LibriSpeech diff --git a/swr2_asr/loss_scores.py b/swr2_asr/utils/loss_scores.py index 80285f6..80285f6 100644 --- a/swr2_asr/loss_scores.py +++ b/swr2_asr/utils/loss_scores.py |