"""Training script for the ASR model.""" import os import click import torch import torch.nn.functional as F import torchaudio from AudioLoader.speech import MultilingualLibriSpeech from torch import nn, optim from torch.utils.data import DataLoader from tokenizers import Tokenizer from .tokenizer import CharTokenizer from .loss_scores import cer, wer train_audio_transforms = nn.Sequential( torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), torchaudio.transforms.FrequencyMasking(freq_mask_param=30), torchaudio.transforms.TimeMasking(time_mask_param=100), ) valid_audio_transforms = torchaudio.transforms.MelSpectrogram() # text_transform = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") text_transform = CharTokenizer() text_transform.from_file("data/tokenizers/char_tokenizer_german.json") def data_processing(data, data_type="train"): """Return the spectrograms, labels, and their lengths.""" spectrograms = [] labels = [] input_lengths = [] label_lengths = [] for sample in data: if data_type == "train": spec = train_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1) elif data_type == "valid": spec = valid_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1) else: raise ValueError("data_type should be train or valid") spectrograms.append(spec) label = torch.Tensor(text_transform.encode(sample["utterance"]).ids) labels.append(label) input_lengths.append(spec.shape[0] // 2) label_lengths.append(len(label)) spectrograms = ( nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) .unsqueeze(1) .transpose(2, 3) ) labels = nn.utils.rnn.pad_sequence(labels, batch_first=True) return spectrograms, labels, input_lengths, label_lengths def greedy_decoder( output, labels, label_lengths, blank_label=28, collapse_repeated=True ): # TODO: adopt to support both tokenizers """Greedily decode a sequence.""" arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member decodes = [] targets = [] for i, args in enumerate(arg_maxes): decode = [] targets.append( text_transform.decode( [int(x) for x in labels[i][: label_lengths[i]].tolist()] ) ) for j, index in enumerate(args): if index != blank_label: if collapse_repeated and j != 0 and index == args[j - 1]: continue decode.append(index.item()) decodes.append(text_transform.decode(decode)) return decodes, targets # TODO: restructure into own file / class class CNNLayerNorm(nn.Module): """Layer normalization built for cnns input""" def __init__(self, n_feats: int): super(CNNLayerNorm, self).__init__() self.layer_norm = nn.LayerNorm(n_feats) def forward(self, data): """x (batch, channel, feature, time)""" data = data.transpose(2, 3).contiguous() # (batch, channel, time, feature) data = self.layer_norm(data) return data.transpose(2, 3).contiguous() # (batch, channel, feature, time) class ResidualCNN(nn.Module): """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf""" def __init__( self, in_channels: int, out_channels: int, kernel: int, stride: int, dropout: float, n_feats: int, ): super(ResidualCNN, self).__init__() self.cnn1 = nn.Conv2d( in_channels, out_channels, kernel, stride, padding=kernel // 2 ) self.cnn2 = nn.Conv2d( out_channels, out_channels, kernel, stride, padding=kernel // 2, ) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.layer_norm1 = CNNLayerNorm(n_feats) self.layer_norm2 = CNNLayerNorm(n_feats) def forward(self, data): """x (batch, channel, feature, time)""" residual = data # (batch, channel, feature, time) data = self.layer_norm1(data) data = F.gelu(data) data = self.dropout1(data) data = self.cnn1(data) data = self.layer_norm2(data) data = F.gelu(data) data = self.dropout2(data) data = self.cnn2(data) data += residual return data # (batch, channel, feature, time) class BidirectionalGRU(nn.Module): """BIdirectional GRU with Layer Normalization and Dropout""" def __init__( self, rnn_dim: int, hidden_size: int, dropout: float, batch_first: bool, ): super(BidirectionalGRU, self).__init__() self.bi_gru = nn.GRU( input_size=rnn_dim, hidden_size=hidden_size, num_layers=1, batch_first=batch_first, bidirectional=True, ) self.layer_norm = nn.LayerNorm(rnn_dim) self.dropout = nn.Dropout(dropout) def forward(self, data): """data (batch, time, feature)""" data = self.layer_norm(data) data = F.gelu(data) data = self.dropout(data) data, _ = self.bi_gru(data) return data class SpeechRecognitionModel(nn.Module): """Speech Recognition Model Inspired by DeepSpeech 2""" def __init__( self, n_cnn_layers: int, n_rnn_layers: int, rnn_dim: int, n_class: int, n_feats: int, stride: int = 2, dropout: float = 0.1, ): super(SpeechRecognitionModel, self).__init__() n_feats //= 2 self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2) # n residual cnn layers with filter size of 32 self.rescnn_layers = nn.Sequential( *[ ResidualCNN( 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats ) for _ in range(n_cnn_layers) ] ) self.fully_connected = nn.Linear(n_feats * 32, rnn_dim) self.birnn_layers = nn.Sequential( *[ BidirectionalGRU( rnn_dim=rnn_dim if i == 0 else rnn_dim * 2, hidden_size=rnn_dim, dropout=dropout, batch_first=i == 0, ) for i in range(n_rnn_layers) ] ) self.classifier = nn.Sequential( nn.Linear(rnn_dim * 2, rnn_dim), # birnn returns rnn_dim*2 nn.GELU(), nn.Dropout(dropout), nn.Linear(rnn_dim, n_class), ) def forward(self, data): """data (batch, channel, feature, time)""" data = self.cnn(data) data = self.rescnn_layers(data) sizes = data.size() data = data.view( sizes[0], sizes[1] * sizes[2], sizes[3] ) # (batch, feature, time) data = data.transpose(1, 2) # (batch, time, feature) data = self.fully_connected(data) data = self.birnn_layers(data) data = self.classifier(data) return data class IterMeter(object): """keeps track of total iterations""" def __init__(self): self.val = 0 def step(self): """step""" self.val += 1 def get(self): """get""" return self.val def train( model, device, train_loader, criterion, optimizer, scheduler, epoch, iter_meter, ): """Train""" model.train() data_len = len(train_loader.dataset) for batch_idx, _data in enumerate(train_loader): spectrograms, labels, input_lengths, label_lengths = _data spectrograms, labels = spectrograms.to(device), labels.to(device) optimizer.zero_grad() output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) loss = criterion(output, labels, input_lengths, label_lengths) loss.backward() optimizer.step() scheduler.step() iter_meter.step() if batch_idx % 100 == 0 or batch_idx == data_len: print( f"Train Epoch: \ {epoch} \ [{batch_idx * len(spectrograms)}/{data_len} \ ({100.0 * batch_idx / len(train_loader)}%)]\t \ Loss: {loss.item()}" ) return loss.item() # TODO: check how dataloader can be made more efficient def test(model, device, test_loader, criterion): """Test""" print("\nevaluating...") model.eval() test_loss = 0 test_cer, test_wer = [], [] with torch.no_grad(): for _data in test_loader: spectrograms, labels, input_lengths, label_lengths = _data spectrograms, labels = spectrograms.to(device), labels.to(device) output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) loss = criterion(output, labels, input_lengths, label_lengths) test_loss += loss.item() / len(test_loader) decoded_preds, decoded_targets = greedy_decoder( output.transpose(0, 1), labels, label_lengths ) for j, pred in enumerate(decoded_preds): test_cer.append(cer(decoded_targets[j], pred)) test_wer.append(wer(decoded_targets[j], pred)) avg_cer = sum(test_cer) / len(test_cer) avg_wer = sum(test_wer) / len(test_wer) print( f"Test set: Average loss:\ {test_loss}, Average CER: {avg_cer} Average WER: {avg_wer}\n" ) def run( learning_rate: float, batch_size: int, epochs: int, load: bool, path: str, dataset_path: str, ) -> None: """Runs the training script.""" hparams = { "n_cnn_layers": 3, "n_rnn_layers": 5, "rnn_dim": 512, "n_class": 36, # TODO: dynamically determine this from vocab size "n_feats": 128, "stride": 2, "dropout": 0.1, "learning_rate": learning_rate, "batch_size": batch_size, "epochs": epochs, } use_cuda = torch.cuda.is_available() torch.manual_seed(42) device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member # device = torch.device("mps") download_dataset = not os.path.isdir(path) train_dataset = MultilingualLibriSpeech( dataset_path, "mls_german_opus", split="dev", download=download_dataset ) test_dataset = MultilingualLibriSpeech( dataset_path, "mls_german_opus", split="test", download=False ) train_loader = DataLoader( train_dataset, batch_size=hparams["batch_size"], shuffle=True, collate_fn=lambda x: data_processing(x, "train"), ) test_loader = DataLoader( test_dataset, batch_size=hparams["batch_size"], shuffle=True, collate_fn=lambda x: data_processing(x, "train"), ) # enable flag to find the most compatible algorithms in advance if use_cuda: torch.backends.cudnn.benchmark = True model = SpeechRecognitionModel( hparams["n_cnn_layers"], hparams["n_rnn_layers"], hparams["rnn_dim"], hparams["n_class"], hparams["n_feats"], hparams["stride"], hparams["dropout"], ).to(device) print( "Num Model Parameters", sum([param.nelement() for param in model.parameters()]) ) optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) criterion = nn.CTCLoss(blank=28).to(device) if load: checkpoint = torch.load(path) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) epoch = checkpoint["epoch"] loss = checkpoint["loss"] scheduler = optim.lr_scheduler.OneCycleLR( optimizer, max_lr=hparams["learning_rate"], steps_per_epoch=int(len(train_loader)), epochs=hparams["epochs"], anneal_strategy="linear", ) iter_meter = IterMeter() for epoch in range(1, epochs + 1): loss = train( model, device, train_loader, criterion, optimizer, scheduler, epoch, iter_meter, ) test(model=model, device=device, test_loader=test_loader, criterion=criterion) print("saving epoch", str(epoch)) torch.save( {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss}, path + str(epoch), ) @click.command() @click.option("--learning_rate", default=1e-3, help="Learning rate") @click.option("--batch_size", default=10, help="Batch size") @click.option("--epochs", default=1, help="Number of epochs") @click.option("--load", default=False, help="Do you want to load a model?") @click.option( "--path", default="model", help="Path where the model will be saved to/loaded from", ) @click.option( "--dataset_path", default="data/", help="Path for the dataset directory", ) def run_cli( learning_rate: float, batch_size: int, epochs: int, load: bool, path: str, dataset_path: str, ) -> None: """Runs the training script.""" run( learning_rate=learning_rate, batch_size=batch_size, epochs=epochs, load=load, path=path, dataset_path=dataset_path, )