diff options
author | Pherkel | 2023-09-04 22:55:27 +0200 |
---|---|---|
committer | GitHub | 2023-09-04 22:55:27 +0200 |
commit | 93e49e708fa59613406249069d31c2f6c8f2d2ab (patch) | |
tree | 9f413c226283db990116c9559ffffb9124b911d8 /swr2_asr/train.py | |
parent | 14ceeb5ad36beea2f05214aa26260cdd1d86590b (diff) | |
parent | 0d70a19e1fea6eda3f7b16ad0084591613f2de72 (diff) |
Merge pull request #27 from Algo-Boys/refactor_modularize
Refactor modularize
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r-- | swr2_asr/train.py | 354 |
1 files changed, 103 insertions, 251 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 6af1e80..63deb72 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -1,232 +1,57 @@ """Training script for the ASR model.""" import os +from typing import TypedDict + import click import torch import torch.nn.functional as F -import torchaudio -from AudioLoader.speech import MultilingualLibriSpeech from torch import nn, optim from torch.utils.data import DataLoader -from tokenizers import Tokenizer -from .tokenizer import CharTokenizer +from tqdm import tqdm + +from swr2_asr.model_deep_speech import SpeechRecognitionModel +from swr2_asr.tokenizer import CharTokenizer, train_char_tokenizer +from swr2_asr.utils import MLSDataset, Split, collate_fn from .loss_scores import cer, wer -train_audio_transforms = nn.Sequential( - torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), - torchaudio.transforms.FrequencyMasking(freq_mask_param=30), - torchaudio.transforms.TimeMasking(time_mask_param=100), -) +# TODO: improve naming of functions -valid_audio_transforms = torchaudio.transforms.MelSpectrogram() - -# text_transform = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") -text_transform = CharTokenizer() -text_transform.from_file("data/tokenizers/char_tokenizer_german.json") - - -def data_processing(data, data_type="train"): - """Return the spectrograms, labels, and their lengths.""" - spectrograms = [] - labels = [] - input_lengths = [] - label_lengths = [] - for sample in data: - if data_type == "train": - spec = train_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1) - elif data_type == "valid": - spec = valid_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1) - else: - raise ValueError("data_type should be train or valid") - spectrograms.append(spec) - label = torch.Tensor(text_transform.encode(sample["utterance"]).ids) - labels.append(label) - input_lengths.append(spec.shape[0] // 2) - label_lengths.append(len(label)) - - spectrograms = ( - nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) - .unsqueeze(1) - .transpose(2, 3) - ) - labels = nn.utils.rnn.pad_sequence(labels, batch_first=True) - return spectrograms, labels, input_lengths, label_lengths +class HParams(TypedDict): + """Type for the hyperparameters of the model.""" + n_cnn_layers: int + n_rnn_layers: int + rnn_dim: int + n_class: int + n_feats: int + stride: int + dropout: float + learning_rate: float + batch_size: int + epochs: int -def greedy_decoder( - output, labels, label_lengths, blank_label=28, collapse_repeated=True -): - # TODO: adopt to support both tokenizers + +def greedy_decoder(output, tokenizer, labels, label_lengths, collapse_repeated=True): """Greedily decode a sequence.""" + print("output shape", output.shape) arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member + blank_label = tokenizer.encode(" ").ids[0] decodes = [] targets = [] for i, args in enumerate(arg_maxes): decode = [] - targets.append( - text_transform.decode( - [int(x) for x in labels[i][: label_lengths[i]].tolist()] - ) - ) + targets.append(tokenizer.decode([int(x) for x in labels[i][: label_lengths[i]].tolist()])) for j, index in enumerate(args): if index != blank_label: if collapse_repeated and j != 0 and index == args[j - 1]: continue decode.append(index.item()) - decodes.append(text_transform.decode(decode)) + decodes.append(tokenizer.decode(decode)) return decodes, targets -# TODO: restructure into own file / class -class CNNLayerNorm(nn.Module): - """Layer normalization built for cnns input""" - - def __init__(self, n_feats: int): - super().__init__() - self.layer_norm = nn.LayerNorm(n_feats) - - def forward(self, data): - """x (batch, channel, feature, time)""" - data = data.transpose(2, 3).contiguous() # (batch, channel, time, feature) - data = self.layer_norm(data) - return data.transpose(2, 3).contiguous() # (batch, channel, feature, time) - - -class ResidualCNN(nn.Module): - """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf""" - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel: int, - stride: int, - dropout: float, - n_feats: int, - ): - super().__init__() - - self.cnn1 = nn.Conv2d( - in_channels, out_channels, kernel, stride, padding=kernel // 2 - ) - self.cnn2 = nn.Conv2d( - out_channels, - out_channels, - kernel, - stride, - padding=kernel // 2, - ) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - self.layer_norm1 = CNNLayerNorm(n_feats) - self.layer_norm2 = CNNLayerNorm(n_feats) - - def forward(self, data): - """x (batch, channel, feature, time)""" - residual = data # (batch, channel, feature, time) - data = self.layer_norm1(data) - data = F.gelu(data) - data = self.dropout1(data) - data = self.cnn1(data) - data = self.layer_norm2(data) - data = F.gelu(data) - data = self.dropout2(data) - data = self.cnn2(data) - data += residual - return data # (batch, channel, feature, time) - - -class BidirectionalGRU(nn.Module): - """BIdirectional GRU with Layer Normalization and Dropout""" - - def __init__( - self, - rnn_dim: int, - hidden_size: int, - dropout: float, - batch_first: bool, - ): - super().__init__() - - self.bi_gru = nn.GRU( - input_size=rnn_dim, - hidden_size=hidden_size, - num_layers=1, - batch_first=batch_first, - bidirectional=True, - ) - self.layer_norm = nn.LayerNorm(rnn_dim) - self.dropout = nn.Dropout(dropout) - - def forward(self, data): - """data (batch, time, feature)""" - data = self.layer_norm(data) - data = F.gelu(data) - data = self.dropout(data) - data, _ = self.bi_gru(data) - return data - - -class SpeechRecognitionModel(nn.Module): - """Speech Recognition Model Inspired by DeepSpeech 2""" - - def __init__( - self, - n_cnn_layers: int, - n_rnn_layers: int, - rnn_dim: int, - n_class: int, - n_feats: int, - stride: int = 2, - dropout: float = 0.1, - ): - super().__init__() - n_feats //= 2 - self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2) - # n residual cnn layers with filter size of 32 - self.rescnn_layers = nn.Sequential( - *[ - ResidualCNN( - 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats - ) - for _ in range(n_cnn_layers) - ] - ) - self.fully_connected = nn.Linear(n_feats * 32, rnn_dim) - self.birnn_layers = nn.Sequential( - *[ - BidirectionalGRU( - rnn_dim=rnn_dim if i == 0 else rnn_dim * 2, - hidden_size=rnn_dim, - dropout=dropout, - batch_first=i == 0, - ) - for i in range(n_rnn_layers) - ] - ) - self.classifier = nn.Sequential( - nn.Linear(rnn_dim * 2, rnn_dim), # birnn returns rnn_dim*2 - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(rnn_dim, n_class), - ) - - def forward(self, data): - """data (batch, channel, feature, time)""" - data = self.cnn(data) - data = self.rescnn_layers(data) - sizes = data.size() - data = data.view( - sizes[0], sizes[1] * sizes[2], sizes[3] - ) # (batch, feature, time) - data = data.transpose(1, 2) # (batch, time, feature) - data = self.fully_connected(data) - data = self.birnn_layers(data) - data = self.classifier(data) - return data - - class IterMeter: """keeps track of total iterations""" @@ -254,36 +79,30 @@ def train( ): """Train""" model.train() - data_len = len(train_loader.dataset) - for batch_idx, _data in enumerate(train_loader): - spectrograms, labels, input_lengths, label_lengths = _data - spectrograms, labels = spectrograms.to(device), labels.to(device) - + print(f"Epoch: {epoch}") + losses = [] + for _data in tqdm(train_loader, desc="batches"): + spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device) optimizer.zero_grad() output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) - loss = criterion(output, labels, input_lengths, label_lengths) + loss = criterion(output, labels, _data["input_length"], _data["utterance_length"]) loss.backward() optimizer.step() scheduler.step() iter_meter.step() - if batch_idx % 100 == 0 or batch_idx == data_len: - print( - f"Train Epoch: \ - {epoch} \ - [{batch_idx * len(spectrograms)}/{data_len} \ - ({100.0 * batch_idx / len(train_loader)}%)]\t \ - Loss: {loss.item()}" - ) - return loss.item() + losses.append(loss.item()) + + print(f"loss in epoch {epoch}: {sum(losses) / len(losses)}") + return sum(losses) / len(losses) -# TODO: check how dataloader can be made more efficient -def test(model, device, test_loader, criterion): + +def test(model, device, test_loader, criterion, tokenizer): """Test""" print("\nevaluating...") model.eval() @@ -291,18 +110,20 @@ def test(model, device, test_loader, criterion): test_cer, test_wer = [], [] with torch.no_grad(): for _data in test_loader: - spectrograms, labels, input_lengths, label_lengths = _data - spectrograms, labels = spectrograms.to(device), labels.to(device) + spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device) output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) - loss = criterion(output, labels, input_lengths, label_lengths) + loss = criterion(output, labels, _data["input_length"], _data["utterance_length"]) test_loss += loss.item() / len(test_loader) decoded_preds, decoded_targets = greedy_decoder( - output.transpose(0, 1), labels, label_lengths + output=output.transpose(0, 1), + labels=labels, + label_lengths=_data["utterance_length"], + tokenizer=tokenizer, ) for j, pred in enumerate(decoded_preds): test_cer.append(cer(decoded_targets[j], pred)) @@ -313,9 +134,11 @@ def test(model, device, test_loader, criterion): print( f"Test set: Average loss:\ - {test_loss}, Average CER: {avg_cer} Average WER: {avg_wer}\n" + {test_loss}, Average CER: {None} Average WER: {None}\n" ) + return test_loss, avg_cer, avg_wer + def run( learning_rate: float, @@ -324,46 +147,66 @@ def run( load: bool, path: str, dataset_path: str, + language: str, ) -> None: """Runs the training script.""" - hparams = { - "n_cnn_layers": 3, - "n_rnn_layers": 5, - "rnn_dim": 512, - "n_class": 36, # TODO: dynamically determine this from vocab size - "n_feats": 128, - "stride": 2, - "dropout": 0.1, - "learning_rate": learning_rate, - "batch_size": batch_size, - "epochs": epochs, - } - use_cuda = torch.cuda.is_available() torch.manual_seed(42) device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member # device = torch.device("mps") - download_dataset = not os.path.isdir(path) - train_dataset = MultilingualLibriSpeech( - dataset_path, "mls_german_opus", split="dev", download=download_dataset + # load dataset + train_dataset = MLSDataset( + dataset_path, language, Split.TRAIN, download=True, spectrogram_hparams=None ) - test_dataset = MultilingualLibriSpeech( - dataset_path, "mls_german_opus", split="test", download=False + valid_dataset = MLSDataset( + dataset_path, language, Split.VALID, download=True, spectrogram_hparams=None + ) + + # load tokenizer (bpe by default): + if not os.path.isfile("data/tokenizers/char_tokenizer_german.json"): + print("There is no tokenizer available. Do you want to train it on the dataset?") + input("Press Enter to continue...") + train_char_tokenizer( + dataset_path=dataset_path, + language=language, + split="all", + download=False, + out_path="data/tokenizers/char_tokenizer_german.json", + ) + + tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") + + train_dataset.set_tokenizer(tokenizer) # type: ignore + valid_dataset.set_tokenizer(tokenizer) # type: ignore + + print(f"Waveform shape: {train_dataset[0]['waveform'].shape}") + + hparams = HParams( + n_cnn_layers=3, + n_rnn_layers=5, + rnn_dim=512, + n_class=tokenizer.get_vocab_size(), + n_feats=128, + stride=2, + dropout=0.1, + learning_rate=learning_rate, + batch_size=batch_size, + epochs=epochs, ) train_loader = DataLoader( train_dataset, batch_size=hparams["batch_size"], shuffle=True, - collate_fn=lambda x: data_processing(x, "train"), + collate_fn=lambda x: collate_fn(x), ) - test_loader = DataLoader( - test_dataset, + valid_loader = DataLoader( + valid_dataset, batch_size=hparams["batch_size"], shuffle=True, - collate_fn=lambda x: data_processing(x, "train"), + collate_fn=lambda x: collate_fn(x), ) # enable flag to find the most compatible algorithms in advance @@ -379,12 +222,10 @@ def run( hparams["stride"], hparams["dropout"], ).to(device) - - print( - "Num Model Parameters", sum((param.nelement() for param in model.parameters())) - ) + print(tokenizer.encode(" ")) + print("Num Model Parameters", sum((param.nelement() for param in model.parameters()))) optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) - criterion = nn.CTCLoss(blank=28).to(device) + criterion = nn.CTCLoss(tokenizer.encode(" ").ids[0]).to(device) if load: checkpoint = torch.load(path) model.load_state_dict(checkpoint["model_state_dict"]) @@ -412,7 +253,13 @@ def run( iter_meter, ) - test(model=model, device=device, test_loader=test_loader, criterion=criterion) + test( + model=model, + device=device, + test_loader=valid_loader, + criterion=criterion, + tokenizer=tokenizer, + ) print("saving epoch", str(epoch)) torch.save( {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss}, @@ -452,4 +299,9 @@ def run_cli( load=load, path=path, dataset_path=dataset_path, + language="mls_german_opus", ) + + +if __name__ == "__main__": + run_cli() # pylint: disable=no-value-for-parameter |