diff options
author | Pherkel | 2023-08-24 00:03:56 +0200 |
---|---|---|
committer | Pherkel | 2023-08-24 00:03:56 +0200 |
commit | 403472ca4e65e8ed404e8a73fb9b3fbafe3f2a53 (patch) | |
tree | e5bfaca7d1982f1fadb1abe1da023d2020151363 /swr2_asr/train.py | |
parent | d65e728575e07a54cec52ccb57af3cafedaac1a2 (diff) |
wip: commit before going on vacation :)
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r-- | swr2_asr/train.py | 298 |
1 files changed, 65 insertions, 233 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 6af1e80..53cdac1 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -1,74 +1,44 @@ """Training script for the ASR model.""" import os +from typing import TypedDict + import click import torch import torch.nn.functional as F -import torchaudio -from AudioLoader.speech import MultilingualLibriSpeech +from tokenizers import Tokenizer from torch import nn, optim from torch.utils.data import DataLoader -from tokenizers import Tokenizer -from .tokenizer import CharTokenizer + +from swr2_asr.model_deep_speech import SpeechRecognitionModel +from swr2_asr.tokenizer import train_bpe_tokenizer +from swr2_asr.utils import MLSDataset, Split from .loss_scores import cer, wer -train_audio_transforms = nn.Sequential( - torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), - torchaudio.transforms.FrequencyMasking(freq_mask_param=30), - torchaudio.transforms.TimeMasking(time_mask_param=100), -) -valid_audio_transforms = torchaudio.transforms.MelSpectrogram() - -# text_transform = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") -text_transform = CharTokenizer() -text_transform.from_file("data/tokenizers/char_tokenizer_german.json") - - -def data_processing(data, data_type="train"): - """Return the spectrograms, labels, and their lengths.""" - spectrograms = [] - labels = [] - input_lengths = [] - label_lengths = [] - for sample in data: - if data_type == "train": - spec = train_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1) - elif data_type == "valid": - spec = valid_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1) - else: - raise ValueError("data_type should be train or valid") - spectrograms.append(spec) - label = torch.Tensor(text_transform.encode(sample["utterance"]).ids) - labels.append(label) - input_lengths.append(spec.shape[0] // 2) - label_lengths.append(len(label)) - - spectrograms = ( - nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) - .unsqueeze(1) - .transpose(2, 3) - ) - labels = nn.utils.rnn.pad_sequence(labels, batch_first=True) +class HParams(TypedDict): + """Type for the hyperparameters of the model.""" - return spectrograms, labels, input_lengths, label_lengths + n_cnn_layers: int + n_rnn_layers: int + rnn_dim: int + n_class: int + n_feats: int + stride: int + dropout: float + learning_rate: float + batch_size: int + epochs: int -def greedy_decoder( - output, labels, label_lengths, blank_label=28, collapse_repeated=True -): - # TODO: adopt to support both tokenizers +def greedy_decoder(output, labels, label_lengths, blank_label=28, collapse_repeated=True): """Greedily decode a sequence.""" arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member decodes = [] targets = [] for i, args in enumerate(arg_maxes): decode = [] - targets.append( - text_transform.decode( - [int(x) for x in labels[i][: label_lengths[i]].tolist()] - ) - ) + targets.append(text_transform.decode([int(x) for x in labels[i][: label_lengths[i]].tolist()])) for j, index in enumerate(args): if index != blank_label: if collapse_repeated and j != 0 and index == args[j - 1]: @@ -78,155 +48,6 @@ def greedy_decoder( return decodes, targets -# TODO: restructure into own file / class -class CNNLayerNorm(nn.Module): - """Layer normalization built for cnns input""" - - def __init__(self, n_feats: int): - super().__init__() - self.layer_norm = nn.LayerNorm(n_feats) - - def forward(self, data): - """x (batch, channel, feature, time)""" - data = data.transpose(2, 3).contiguous() # (batch, channel, time, feature) - data = self.layer_norm(data) - return data.transpose(2, 3).contiguous() # (batch, channel, feature, time) - - -class ResidualCNN(nn.Module): - """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf""" - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel: int, - stride: int, - dropout: float, - n_feats: int, - ): - super().__init__() - - self.cnn1 = nn.Conv2d( - in_channels, out_channels, kernel, stride, padding=kernel // 2 - ) - self.cnn2 = nn.Conv2d( - out_channels, - out_channels, - kernel, - stride, - padding=kernel // 2, - ) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - self.layer_norm1 = CNNLayerNorm(n_feats) - self.layer_norm2 = CNNLayerNorm(n_feats) - - def forward(self, data): - """x (batch, channel, feature, time)""" - residual = data # (batch, channel, feature, time) - data = self.layer_norm1(data) - data = F.gelu(data) - data = self.dropout1(data) - data = self.cnn1(data) - data = self.layer_norm2(data) - data = F.gelu(data) - data = self.dropout2(data) - data = self.cnn2(data) - data += residual - return data # (batch, channel, feature, time) - - -class BidirectionalGRU(nn.Module): - """BIdirectional GRU with Layer Normalization and Dropout""" - - def __init__( - self, - rnn_dim: int, - hidden_size: int, - dropout: float, - batch_first: bool, - ): - super().__init__() - - self.bi_gru = nn.GRU( - input_size=rnn_dim, - hidden_size=hidden_size, - num_layers=1, - batch_first=batch_first, - bidirectional=True, - ) - self.layer_norm = nn.LayerNorm(rnn_dim) - self.dropout = nn.Dropout(dropout) - - def forward(self, data): - """data (batch, time, feature)""" - data = self.layer_norm(data) - data = F.gelu(data) - data = self.dropout(data) - data, _ = self.bi_gru(data) - return data - - -class SpeechRecognitionModel(nn.Module): - """Speech Recognition Model Inspired by DeepSpeech 2""" - - def __init__( - self, - n_cnn_layers: int, - n_rnn_layers: int, - rnn_dim: int, - n_class: int, - n_feats: int, - stride: int = 2, - dropout: float = 0.1, - ): - super().__init__() - n_feats //= 2 - self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2) - # n residual cnn layers with filter size of 32 - self.rescnn_layers = nn.Sequential( - *[ - ResidualCNN( - 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats - ) - for _ in range(n_cnn_layers) - ] - ) - self.fully_connected = nn.Linear(n_feats * 32, rnn_dim) - self.birnn_layers = nn.Sequential( - *[ - BidirectionalGRU( - rnn_dim=rnn_dim if i == 0 else rnn_dim * 2, - hidden_size=rnn_dim, - dropout=dropout, - batch_first=i == 0, - ) - for i in range(n_rnn_layers) - ] - ) - self.classifier = nn.Sequential( - nn.Linear(rnn_dim * 2, rnn_dim), # birnn returns rnn_dim*2 - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(rnn_dim, n_class), - ) - - def forward(self, data): - """data (batch, channel, feature, time)""" - data = self.cnn(data) - data = self.rescnn_layers(data) - sizes = data.size() - data = data.view( - sizes[0], sizes[1] * sizes[2], sizes[3] - ) # (batch, feature, time) - data = data.transpose(1, 2) # (batch, time, feature) - data = self.fully_connected(data) - data = self.birnn_layers(data) - data = self.classifier(data) - return data - - class IterMeter: """keeps track of total iterations""" @@ -256,9 +77,8 @@ def train( model.train() data_len = len(train_loader.dataset) for batch_idx, _data in enumerate(train_loader): - spectrograms, labels, input_lengths, label_lengths = _data + _, spectrograms, input_lengths, labels, label_lengths, *_ = _data spectrograms, labels = spectrograms.to(device), labels.to(device) - optimizer.zero_grad() output = model(spectrograms) # (batch, time, n_class) @@ -282,7 +102,6 @@ def train( return loss.item() -# TODO: check how dataloader can be made more efficient def test(model, device, test_loader, criterion): """Test""" print("\nevaluating...") @@ -301,9 +120,7 @@ def test(model, device, test_loader, criterion): loss = criterion(output, labels, input_lengths, label_lengths) test_loss += loss.item() / len(test_loader) - decoded_preds, decoded_targets = greedy_decoder( - output.transpose(0, 1), labels, label_lengths - ) + decoded_preds, decoded_targets = greedy_decoder(output.transpose(0, 1), labels, label_lengths) for j, pred in enumerate(decoded_preds): test_cer.append(cer(decoded_targets[j], pred)) test_wer.append(wer(decoded_targets[j], pred)) @@ -324,46 +141,62 @@ def run( load: bool, path: str, dataset_path: str, + language: str, ) -> None: """Runs the training script.""" - hparams = { - "n_cnn_layers": 3, - "n_rnn_layers": 5, - "rnn_dim": 512, - "n_class": 36, # TODO: dynamically determine this from vocab size - "n_feats": 128, - "stride": 2, - "dropout": 0.1, - "learning_rate": learning_rate, - "batch_size": batch_size, - "epochs": epochs, - } - use_cuda = torch.cuda.is_available() torch.manual_seed(42) device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member # device = torch.device("mps") - download_dataset = not os.path.isdir(path) - train_dataset = MultilingualLibriSpeech( - dataset_path, "mls_german_opus", split="dev", download=download_dataset - ) - test_dataset = MultilingualLibriSpeech( - dataset_path, "mls_german_opus", split="test", download=False + # load dataset + train_dataset = MLSDataset(dataset_path, language, Split.train, download=True) + valid_dataset = MLSDataset(dataset_path, language, Split.valid, download=True) + test_dataset = MLSDataset(dataset_path, language, Split.test, download=True) + + # TODO: add flag to choose tokenizer + # load tokenizer (bpe by default): + if not os.path.isfile("data/tokenizers/bpe_tokenizer_german_3000.json"): + print("There is no tokenizer available. Do you want to train it on the dataset?") + input("Press Enter to continue...") + train_bpe_tokenizer( + dataset_path=dataset_path, + language=language, + split="all", + download=False, + out_path="data/tokenizers/bpe_tokenizer_german_3000.json", + vocab_size=3000, + ) + + tokenizer = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") + + train_dataset.set_tokenizer(tokenizer) + valid_dataset.set_tokenizer(tokenizer) + test_dataset.set_tokenizer(tokenizer) + + hparams = HParams( + n_cnn_layers=3, + n_rnn_layers=5, + rnn_dim=512, + n_class=tokenizer.get_vocab_size(), + n_feats=128, + stride=2, + dropout=0.1, + learning_rate=learning_rate, + batch_size=batch_size, + epochs=epochs, ) train_loader = DataLoader( train_dataset, batch_size=hparams["batch_size"], shuffle=True, - collate_fn=lambda x: data_processing(x, "train"), ) - test_loader = DataLoader( - test_dataset, + valid_loader = DataLoader( + valid_dataset, batch_size=hparams["batch_size"], shuffle=True, - collate_fn=lambda x: data_processing(x, "train"), ) # enable flag to find the most compatible algorithms in advance @@ -380,9 +213,7 @@ def run( hparams["dropout"], ).to(device) - print( - "Num Model Parameters", sum((param.nelement() for param in model.parameters())) - ) + print("Num Model Parameters", sum((param.nelement() for param in model.parameters()))) optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) criterion = nn.CTCLoss(blank=28).to(device) if load: @@ -412,7 +243,7 @@ def run( iter_meter, ) - test(model=model, device=device, test_loader=test_loader, criterion=criterion) + test(model=model, device=device, test_loader=valid_loader, criterion=criterion) print("saving epoch", str(epoch)) torch.save( {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss}, @@ -452,4 +283,5 @@ def run_cli( load=load, path=path, dataset_path=dataset_path, + language="mls_german_opus", ) |