From 18abef41ed2f8060985f97248300e5ec8b6cb327 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Fri, 18 Aug 2023 16:37:07 +0200 Subject: added simple wav2vec2 inference --- swr2_asr/train.py | 15 --------------- 1 file changed, 15 deletions(-) delete mode 100644 swr2_asr/train.py (limited to 'swr2_asr/train.py') diff --git a/swr2_asr/train.py b/swr2_asr/train.py deleted file mode 100644 index f61776a..0000000 --- a/swr2_asr/train.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Training script for the ASR model.""" -from AudioLoader.speech.mls import MultilingualLibriSpeech - - -def main() -> None: - """Main function.""" - dataset = MultilingualLibriSpeech( - "data", "mls_polish_opus", split="train", download=True - ) - - print(dataset[1]) - - -if __name__ == "__main__": - main() -- cgit v1.2.3 From 9b4592caac90a41eb6ce18558588ef504c49f58e Mon Sep 17 00:00:00 2001 From: Pherkel Date: Sat, 19 Aug 2023 10:55:02 +0200 Subject: renamed files --- pyproject.toml | 2 +- swr2_asr/inference_test.py | 74 +++++++ swr2_asr/train.py | 475 +++++++++++++++++++++++++++++++++++++++++++++ swr2_asr/train_2.py | 464 ------------------------------------------- swr2_asr/train_3.py | 74 ------- 5 files changed, 550 insertions(+), 539 deletions(-) create mode 100644 swr2_asr/inference_test.py create mode 100644 swr2_asr/train.py delete mode 100644 swr2_asr/train_2.py delete mode 100644 swr2_asr/train_3.py (limited to 'swr2_asr/train.py') diff --git a/pyproject.toml b/pyproject.toml index 791a76e..8490aa5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ mypy = "^1.5.1" pylint = "^2.17.5" [tool.poetry.scripts] -train = "swr2_asr.train_2:run" +train = "swr2_asr.train:run_cli" [build-system] requires = ["poetry-core"] diff --git a/swr2_asr/inference_test.py b/swr2_asr/inference_test.py new file mode 100644 index 0000000..a6b0010 --- /dev/null +++ b/swr2_asr/inference_test.py @@ -0,0 +1,74 @@ +"""Training script for the ASR model.""" +from AudioLoader.speech.mls import MultilingualLibriSpeech +import torch +import torchaudio +import torchaudio.functional as F + + +class GreedyCTCDecoder(torch.nn.Module): + def __init__(self, labels, blank=0) -> None: + super().__init__() + self.labels = labels + self.blank = blank + + def forward(self, emission: torch.Tensor) -> str: + """Given a sequence emission over labels, get the best path string + Args: + emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`. + + Returns: + str: The resulting transcript + """ + indices = torch.argmax(emission, dim=-1) # [num_seq,] + indices = torch.unique_consecutive(indices, dim=-1) + indices = [i for i in indices if i != self.blank] + return "".join([self.labels[i] for i in indices]) + + +def main() -> None: + """Main function.""" + # choose between cuda, cpu and mps devices + device = "cuda" if torch.cuda.is_available() else "cpu" + # device = "mps" + device = torch.device(device) + + torch.random.manual_seed(42) + + bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H + + print(f"Sample rate (model): {bundle.sample_rate}") + print(f"Labels (model): {bundle.get_labels()}") + + model = bundle.get_model().to(device) + + print(model.__class__) + + # only do all things for one single sample + dataset = MultilingualLibriSpeech( + "data", "mls_german_opus", split="train", download=True + ) + + print(dataset[0]) + + # load waveforms and sample rate from dataset + waveform, sample_rate = dataset[0]["waveform"], dataset[0]["sample_rate"] + + if sample_rate != bundle.sample_rate: + waveform = F.resample(waveform, sample_rate, int(bundle.sample_rate)) + + waveform.to(device) + + with torch.inference_mode(): + features, _ = model.extract_features(waveform) + + with torch.inference_mode(): + emission, _ = model(waveform) + + decoder = GreedyCTCDecoder(labels=bundle.get_labels()) + transcript = decoder(emission[0]) + + print(transcript) + + +if __name__ == "__main__": + main() diff --git a/swr2_asr/train.py b/swr2_asr/train.py new file mode 100644 index 0000000..8ee96b9 --- /dev/null +++ b/swr2_asr/train.py @@ -0,0 +1,475 @@ +"""Training script for the ASR model.""" +from AudioLoader.speech.mls import MultilingualLibriSpeech +import click +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torch.utils.data import DataLoader +import torchaudio +from .loss_scores import cer, wer + + +class TextTransform: + """Maps characters to integers and vice versa""" + + def __init__(self): + char_map_str = """ + ' 0 + 1 + a 2 + b 3 + c 4 + d 5 + e 6 + f 7 + g 8 + h 9 + i 10 + j 11 + k 12 + l 13 + m 14 + n 15 + o 16 + p 17 + q 18 + r 19 + s 20 + t 21 + u 22 + v 23 + w 24 + x 25 + y 26 + z 27 + ä 28 + ö 29 + ü 30 + ß 31 + - 32 + é 33 + è 34 + à 35 + ù 36 + ç 37 + â 38 + ê 39 + î 40 + ô 41 + û 42 + ë 43 + ï 44 + ü 45 + """ + self.char_map = {} + self.index_map = {} + for line in char_map_str.strip().split("\n"): + ch, index = line.split() + self.char_map[ch] = int(index) + self.index_map[int(index)] = ch + self.index_map[1] = " " + + def text_to_int(self, text): + """Use a character map and convert text to an integer sequence""" + int_sequence = [] + for c in text: + if c == " ": + ch = self.char_map[""] + else: + ch = self.char_map[c] + int_sequence.append(ch) + return int_sequence + + def int_to_text(self, labels): + """Use a character map and convert integer labels to an text sequence""" + string = [] + for i in labels: + string.append(self.index_map[i]) + return "".join(string).replace("", " ") + + +train_audio_transforms = nn.Sequential( + torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), + torchaudio.transforms.FrequencyMasking(freq_mask_param=30), + torchaudio.transforms.TimeMasking(time_mask_param=100), +) + +valid_audio_transforms = torchaudio.transforms.MelSpectrogram() + +text_transform = TextTransform() + + +def data_processing(data, data_type="train"): + """Return the spectrograms, labels, and their lengths.""" + spectrograms = [] + labels = [] + input_lengths = [] + label_lengths = [] + for x in data: + if data_type == "train": + spec = train_audio_transforms(x["waveform"]).squeeze(0).transpose(0, 1) + elif data_type == "valid": + spec = valid_audio_transforms(x["waveform"]).squeeze(0).transpose(0, 1) + else: + raise ValueError("data_type should be train or valid") + spectrograms.append(spec) + label = torch.Tensor(text_transform.text_to_int(x["utterance"].lower())) + labels.append(label) + input_lengths.append(spec.shape[0] // 2) + label_lengths.append(len(label)) + + spectrograms = ( + nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) + .unsqueeze(1) + .transpose(2, 3) + ) + labels = nn.utils.rnn.pad_sequence(labels, batch_first=True) + + return spectrograms, labels, input_lengths, label_lengths + + +def GreedyDecoder( + output, labels, label_lengths, blank_label=28, collapse_repeated=True +): + """Greedily decode a sequence.""" + arg_maxes = torch.argmax(output, dim=2) + decodes = [] + targets = [] + for i, args in enumerate(arg_maxes): + decode = [] + targets.append( + text_transform.int_to_text(labels[i][: label_lengths[i]].tolist()) + ) + for j, index in enumerate(args): + if index != blank_label: + if collapse_repeated and j != 0 and index == args[j - 1]: + continue + decode.append(index.item()) + decodes.append(text_transform.int_to_text(decode)) + return decodes, targets + + +class CNNLayerNorm(nn.Module): + """Layer normalization built for cnns input""" + + def __init__(self, n_feats: int): + super(CNNLayerNorm, self).__init__() + self.layer_norm = nn.LayerNorm(n_feats) + + def forward(self, x): + """x (batch, channel, feature, time)""" + x = x.transpose(2, 3).contiguous() # (batch, channel, time, feature) + x = self.layer_norm(x) + return x.transpose(2, 3).contiguous() # (batch, channel, feature, time) + + +class ResidualCNN(nn.Module): + """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel: int, + stride: int, + dropout: float, + n_feats: int, + ): + super(ResidualCNN, self).__init__() + + self.cnn1 = nn.Conv2d( + in_channels, out_channels, kernel, stride, padding=kernel // 2 + ) + self.cnn2 = nn.Conv2d( + out_channels, + out_channels, + kernel, + stride, + padding=kernel // 2, + ) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.layer_norm1 = CNNLayerNorm(n_feats) + self.layer_norm2 = CNNLayerNorm(n_feats) + + def forward(self, x): + residual = x # (batch, channel, feature, time) + x = self.layer_norm1(x) + x = F.gelu(x) + x = self.dropout1(x) + x = self.cnn1(x) + x = self.layer_norm2(x) + x = F.gelu(x) + x = self.dropout2(x) + x = self.cnn2(x) + x += residual + return x # (batch, channel, feature, time) + + +class BidirectionalGRU(nn.Module): + """BIdirectional GRU with Layer Normalization and Dropout""" + + def __init__( + self, + rnn_dim: int, + hidden_size: int, + dropout: float, + batch_first: bool, + ): + super(BidirectionalGRU, self).__init__() + + self.BiGRU = nn.GRU( + input_size=rnn_dim, + hidden_size=hidden_size, + num_layers=1, + batch_first=batch_first, + bidirectional=True, + ) + self.layer_norm = nn.LayerNorm(rnn_dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.layer_norm(x) + x = F.gelu(x) + x = self.dropout(x) + x, _ = self.BiGRU(x) + return x + + +class SpeechRecognitionModel(nn.Module): + def __init__( + self, + n_cnn_layers: int, + n_rnn_layers: int, + rnn_dim: int, + n_class: int, + n_feats: int, + stride: int = 2, + dropout: float = 0.1, + ): + super(SpeechRecognitionModel, self).__init__() + n_feats //= 2 + self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2) + # n residual cnn layers with filter size of 32 + self.rescnn_layers = nn.Sequential( + *[ + ResidualCNN( + 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats + ) + for _ in range(n_cnn_layers) + ] + ) + self.fully_connected = nn.Linear(n_feats * 32, rnn_dim) + self.birnn_layers = nn.Sequential( + *[ + BidirectionalGRU( + rnn_dim=rnn_dim if i == 0 else rnn_dim * 2, + hidden_size=rnn_dim, + dropout=dropout, + batch_first=i == 0, + ) + for i in range(n_rnn_layers) + ] + ) + self.classifier = nn.Sequential( + nn.Linear(rnn_dim * 2, rnn_dim), # birnn returns rnn_dim*2 + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(rnn_dim, n_class), + ) + + def forward(self, x): + x = self.cnn(x) + x = self.rescnn_layers(x) + sizes = x.size() + x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # (batch, feature, time) + x = x.transpose(1, 2) # (batch, time, feature) + x = self.fully_connected(x) + x = self.birnn_layers(x) + x = self.classifier(x) + return x + + +class IterMeter(object): + """keeps track of total iterations""" + + def __init__(self): + self.val = 0 + + def step(self): + self.val += 1 + + def get(self): + return self.val + + +def train( + model, + device, + train_loader, + criterion, + optimizer, + scheduler, + epoch, + iter_meter, +): + model.train() + data_len = len(train_loader.dataset) + for batch_idx, _data in enumerate(train_loader): + spectrograms, labels, input_lengths, label_lengths = _data + spectrograms, labels = spectrograms.to(device), labels.to(device) + + optimizer.zero_grad() + + output = model(spectrograms) # (batch, time, n_class) + output = F.log_softmax(output, dim=2) + output = output.transpose(0, 1) # (time, batch, n_class) + + loss = criterion(output, labels, input_lengths, label_lengths) + loss.backward() + + optimizer.step() + scheduler.step() + iter_meter.step() + if batch_idx % 100 == 0 or batch_idx == data_len: + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(spectrograms), + data_len, + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) + + +def test(model, device, test_loader, criterion): + print("\nevaluating...") + model.eval() + test_loss = 0 + test_cer, test_wer = [], [] + with torch.no_grad(): + for _data in test_loader: + spectrograms, labels, input_lengths, label_lengths = _data + spectrograms, labels = spectrograms.to(device), labels.to(device) + + output = model(spectrograms) # (batch, time, n_class) + output = F.log_softmax(output, dim=2) + output = output.transpose(0, 1) # (time, batch, n_class) + + loss = criterion(output, labels, input_lengths, label_lengths) + test_loss += loss.item() / len(test_loader) + + decoded_preds, decoded_targets = GreedyDecoder( + output.transpose(0, 1), labels, label_lengths + ) + for j in range(len(decoded_preds)): + test_cer.append(cer(decoded_targets[j], decoded_preds[j])) + test_wer.append(wer(decoded_targets[j], decoded_preds[j])) + + avg_cer = sum(test_cer) / len(test_cer) + avg_wer = sum(test_wer) / len(test_wer) + + print( + f"Test set: Average loss:\ + {test_loss}, Average CER: {avg_cer} Average WER: {avg_wer}\n" + ) + + +def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None: + """Runs the training script.""" + hparams = { + "n_cnn_layers": 3, + "n_rnn_layers": 5, + "rnn_dim": 512, + "n_class": 46, + "n_feats": 128, + "stride": 2, + "dropout": 0.1, + "learning_rate": learning_rate, + "batch_size": batch_size, + "epochs": epochs, + } + + use_cuda = torch.cuda.is_available() + torch.manual_seed(42) + device = torch.device("cuda" if use_cuda else "cpu") + # device = torch.device("mps") + + train_dataset = MultilingualLibriSpeech( + "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="dev", download=False + ) + test_dataset = MultilingualLibriSpeech( + "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="test", download=False + ) + + kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} + + train_loader = DataLoader( + train_dataset, + batch_size=hparams["batch_size"], + shuffle=True, + collate_fn=lambda x: data_processing(x, "train"), + **kwargs, + ) + + test_loader = DataLoader( + test_dataset, + batch_size=hparams["batch_size"], + shuffle=True, + collate_fn=lambda x: data_processing(x, "train"), + **kwargs, + ) + + model = SpeechRecognitionModel( + hparams["n_cnn_layers"], + hparams["n_rnn_layers"], + hparams["rnn_dim"], + hparams["n_class"], + hparams["n_feats"], + hparams["stride"], + hparams["dropout"], + ).to(device) + + print( + "Num Model Parameters", sum([param.nelement() for param in model.parameters()]) + ) + + optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) + criterion = nn.CTCLoss(blank=28).to(device) + + scheduler = optim.lr_scheduler.OneCycleLR( + optimizer, + max_lr=hparams["learning_rate"], + steps_per_epoch=int(len(train_loader)), + epochs=hparams["epochs"], + anneal_strategy="linear", + ) + + iter_meter = IterMeter() + for epoch in range(1, epochs + 1): + train( + model, + device, + train_loader, + criterion, + optimizer, + scheduler, + epoch, + iter_meter, + ) + test(model=model, device=device, test_loader=test_loader, criterion=criterion) + + +@click.command() +@click.option("--learning-rate", default=1e-3, help="Learning rate") +@click.option("--batch_size", default=1, help="Batch size") +@click.option("--epochs", default=1, help="Number of epochs") +def run_cli(learning_rate: float, batch_size: int, epochs: int) -> None: + """Runs the training script.""" + run(learning_rate=learning_rate, batch_size=batch_size, epochs=epochs) + + +if __name__ == "__main__": + run(learning_rate=5e-4, batch_size=16, epochs=1) diff --git a/swr2_asr/train_2.py b/swr2_asr/train_2.py deleted file mode 100644 index bea5bf4..0000000 --- a/swr2_asr/train_2.py +++ /dev/null @@ -1,464 +0,0 @@ -"""Training script for the ASR model.""" -from AudioLoader.speech.mls import MultilingualLibriSpeech -import torch -import torch.nn as nn -import torch.optim as optim -import torch.nn.functional as F -from torch.utils.data import DataLoader -import torchaudio -from .loss_scores import cer, wer - - -class TextTransform: - """Maps characters to integers and vice versa""" - - def __init__(self): - char_map_str = """ - ' 0 - 1 - a 2 - b 3 - c 4 - d 5 - e 6 - f 7 - g 8 - h 9 - i 10 - j 11 - k 12 - l 13 - m 14 - n 15 - o 16 - p 17 - q 18 - r 19 - s 20 - t 21 - u 22 - v 23 - w 24 - x 25 - y 26 - z 27 - ä 28 - ö 29 - ü 30 - ß 31 - - 32 - é 33 - è 34 - à 35 - ù 36 - ç 37 - â 38 - ê 39 - î 40 - ô 41 - û 42 - ë 43 - ï 44 - ü 45 - """ - self.char_map = {} - self.index_map = {} - for line in char_map_str.strip().split("\n"): - ch, index = line.split() - self.char_map[ch] = int(index) - self.index_map[int(index)] = ch - self.index_map[1] = " " - - def text_to_int(self, text): - """Use a character map and convert text to an integer sequence""" - int_sequence = [] - for c in text: - if c == " ": - ch = self.char_map[""] - else: - ch = self.char_map[c] - int_sequence.append(ch) - return int_sequence - - def int_to_text(self, labels): - """Use a character map and convert integer labels to an text sequence""" - string = [] - for i in labels: - string.append(self.index_map[i]) - return "".join(string).replace("", " ") - - -train_audio_transforms = nn.Sequential( - torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), - torchaudio.transforms.FrequencyMasking(freq_mask_param=30), - torchaudio.transforms.TimeMasking(time_mask_param=100), -) - -valid_audio_transforms = torchaudio.transforms.MelSpectrogram() - -text_transform = TextTransform() - - -def data_processing(data, data_type="train"): - """Return the spectrograms, labels, and their lengths.""" - spectrograms = [] - labels = [] - input_lengths = [] - label_lengths = [] - for x in data: - if data_type == "train": - spec = train_audio_transforms(x["waveform"]).squeeze(0).transpose(0, 1) - elif data_type == "valid": - spec = valid_audio_transforms(x["waveform"]).squeeze(0).transpose(0, 1) - else: - raise ValueError("data_type should be train or valid") - spectrograms.append(spec) - label = torch.Tensor(text_transform.text_to_int(x["utterance"].lower())) - labels.append(label) - input_lengths.append(spec.shape[0] // 2) - label_lengths.append(len(label)) - - spectrograms = ( - nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) - .unsqueeze(1) - .transpose(2, 3) - ) - labels = nn.utils.rnn.pad_sequence(labels, batch_first=True) - - return spectrograms, labels, input_lengths, label_lengths - - -def GreedyDecoder( - output, labels, label_lengths, blank_label=28, collapse_repeated=True -): - """Greedily decode a sequence.""" - arg_maxes = torch.argmax(output, dim=2) - decodes = [] - targets = [] - for i, args in enumerate(arg_maxes): - decode = [] - targets.append( - text_transform.int_to_text(labels[i][: label_lengths[i]].tolist()) - ) - for j, index in enumerate(args): - if index != blank_label: - if collapse_repeated and j != 0 and index == args[j - 1]: - continue - decode.append(index.item()) - decodes.append(text_transform.int_to_text(decode)) - return decodes, targets - - -class CNNLayerNorm(nn.Module): - """Layer normalization built for cnns input""" - - def __init__(self, n_feats: int): - super(CNNLayerNorm, self).__init__() - self.layer_norm = nn.LayerNorm(n_feats) - - def forward(self, x): - """x (batch, channel, feature, time)""" - x = x.transpose(2, 3).contiguous() # (batch, channel, time, feature) - x = self.layer_norm(x) - return x.transpose(2, 3).contiguous() # (batch, channel, feature, time) - - -class ResidualCNN(nn.Module): - """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf""" - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel: int, - stride: int, - dropout: float, - n_feats: int, - ): - super(ResidualCNN, self).__init__() - - self.cnn1 = nn.Conv2d( - in_channels, out_channels, kernel, stride, padding=kernel // 2 - ) - self.cnn2 = nn.Conv2d( - out_channels, - out_channels, - kernel, - stride, - padding=kernel // 2, - ) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - self.layer_norm1 = CNNLayerNorm(n_feats) - self.layer_norm2 = CNNLayerNorm(n_feats) - - def forward(self, x): - residual = x # (batch, channel, feature, time) - x = self.layer_norm1(x) - x = F.gelu(x) - x = self.dropout1(x) - x = self.cnn1(x) - x = self.layer_norm2(x) - x = F.gelu(x) - x = self.dropout2(x) - x = self.cnn2(x) - x += residual - return x # (batch, channel, feature, time) - - -class BidirectionalGRU(nn.Module): - """BIdirectional GRU with Layer Normalization and Dropout""" - - def __init__( - self, - rnn_dim: int, - hidden_size: int, - dropout: float, - batch_first: bool, - ): - super(BidirectionalGRU, self).__init__() - - self.BiGRU = nn.GRU( - input_size=rnn_dim, - hidden_size=hidden_size, - num_layers=1, - batch_first=batch_first, - bidirectional=True, - ) - self.layer_norm = nn.LayerNorm(rnn_dim) - self.dropout = nn.Dropout(dropout) - - def forward(self, x): - x = self.layer_norm(x) - x = F.gelu(x) - x = self.dropout(x) - x, _ = self.BiGRU(x) - return x - - -class SpeechRecognitionModel(nn.Module): - def __init__( - self, - n_cnn_layers: int, - n_rnn_layers: int, - rnn_dim: int, - n_class: int, - n_feats: int, - stride: int = 2, - dropout: float = 0.1, - ): - super(SpeechRecognitionModel, self).__init__() - n_feats //= 2 - self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2) - # n residual cnn layers with filter size of 32 - self.rescnn_layers = nn.Sequential( - *[ - ResidualCNN( - 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats - ) - for _ in range(n_cnn_layers) - ] - ) - self.fully_connected = nn.Linear(n_feats * 32, rnn_dim) - self.birnn_layers = nn.Sequential( - *[ - BidirectionalGRU( - rnn_dim=rnn_dim if i == 0 else rnn_dim * 2, - hidden_size=rnn_dim, - dropout=dropout, - batch_first=i == 0, - ) - for i in range(n_rnn_layers) - ] - ) - self.classifier = nn.Sequential( - nn.Linear(rnn_dim * 2, rnn_dim), # birnn returns rnn_dim*2 - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(rnn_dim, n_class), - ) - - def forward(self, x): - x = self.cnn(x) - x = self.rescnn_layers(x) - sizes = x.size() - x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # (batch, feature, time) - x = x.transpose(1, 2) # (batch, time, feature) - x = self.fully_connected(x) - x = self.birnn_layers(x) - x = self.classifier(x) - return x - - -class IterMeter(object): - """keeps track of total iterations""" - - def __init__(self): - self.val = 0 - - def step(self): - self.val += 1 - - def get(self): - return self.val - - -def train( - model, - device, - train_loader, - criterion, - optimizer, - scheduler, - epoch, - iter_meter, -): - model.train() - data_len = len(train_loader.dataset) - for batch_idx, _data in enumerate(train_loader): - spectrograms, labels, input_lengths, label_lengths = _data - spectrograms, labels = spectrograms.to(device), labels.to(device) - - optimizer.zero_grad() - - output = model(spectrograms) # (batch, time, n_class) - output = F.log_softmax(output, dim=2) - output = output.transpose(0, 1) # (time, batch, n_class) - - loss = criterion(output, labels, input_lengths, label_lengths) - loss.backward() - - optimizer.step() - scheduler.step() - iter_meter.step() - if batch_idx % 100 == 0 or batch_idx == data_len: - print( - "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( - epoch, - batch_idx * len(spectrograms), - data_len, - 100.0 * batch_idx / len(train_loader), - loss.item(), - ) - ) - - -def test(model, device, test_loader, criterion): - print("\nevaluating...") - model.eval() - test_loss = 0 - test_cer, test_wer = [], [] - with torch.no_grad(): - for _data in test_loader: - spectrograms, labels, input_lengths, label_lengths = _data - spectrograms, labels = spectrograms.to(device), labels.to(device) - - output = model(spectrograms) # (batch, time, n_class) - output = F.log_softmax(output, dim=2) - output = output.transpose(0, 1) # (time, batch, n_class) - - loss = criterion(output, labels, input_lengths, label_lengths) - test_loss += loss.item() / len(test_loader) - - decoded_preds, decoded_targets = GreedyDecoder( - output.transpose(0, 1), labels, label_lengths - ) - for j in range(len(decoded_preds)): - test_cer.append(cer(decoded_targets[j], decoded_preds[j])) - test_wer.append(wer(decoded_targets[j], decoded_preds[j])) - - avg_cer = sum(test_cer) / len(test_cer) - avg_wer = sum(test_wer) / len(test_wer) - - print( - f"Test set: Average loss:\ - {test_loss}, Average CER: {avg_cer} Average WER: {avg_wer}\n" - ) - - -def run(lr: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None: - hparams = { - "n_cnn_layers": 3, - "n_rnn_layers": 5, - "rnn_dim": 512, - "n_class": 46, - "n_feats": 128, - "stride": 2, - "dropout": 0.1, - "learning_rate": lr, - "batch_size": batch_size, - "epochs": epochs, - } - - use_cuda = torch.cuda.is_available() - torch.manual_seed(42) - device = torch.device("cuda" if use_cuda else "cpu") - # device = torch.device("mps") - - train_dataset = MultilingualLibriSpeech( - "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="dev", download=False - ) - test_dataset = MultilingualLibriSpeech( - "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="test", download=False - ) - - kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} - - train_loader = DataLoader( - train_dataset, - batch_size=hparams["batch_size"], - shuffle=True, - collate_fn=lambda x: data_processing(x, "train"), - **kwargs, - ) - - test_loader = DataLoader( - test_dataset, - batch_size=hparams["batch_size"], - shuffle=True, - collate_fn=lambda x: data_processing(x, "train"), - **kwargs, - ) - - model = SpeechRecognitionModel( - hparams["n_cnn_layers"], - hparams["n_rnn_layers"], - hparams["rnn_dim"], - hparams["n_class"], - hparams["n_feats"], - hparams["stride"], - hparams["dropout"], - ).to(device) - - print( - "Num Model Parameters", sum([param.nelement() for param in model.parameters()]) - ) - - optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) - criterion = nn.CTCLoss(blank=28).to(device) - - scheduler = optim.lr_scheduler.OneCycleLR( - optimizer, - max_lr=hparams["learning_rate"], - steps_per_epoch=int(len(train_loader)), - epochs=hparams["epochs"], - anneal_strategy="linear", - ) - - iter_meter = IterMeter() - for epoch in range(1, epochs + 1): - train( - model, - device, - train_loader, - criterion, - optimizer, - scheduler, - epoch, - iter_meter, - ) - test(model=model, device=device, test_loader=test_loader, criterion=criterion) - - -if __name__ == "__main__": - run(lr=5e-4, batch_size=16, epochs=1) diff --git a/swr2_asr/train_3.py b/swr2_asr/train_3.py deleted file mode 100644 index a6b0010..0000000 --- a/swr2_asr/train_3.py +++ /dev/null @@ -1,74 +0,0 @@ -"""Training script for the ASR model.""" -from AudioLoader.speech.mls import MultilingualLibriSpeech -import torch -import torchaudio -import torchaudio.functional as F - - -class GreedyCTCDecoder(torch.nn.Module): - def __init__(self, labels, blank=0) -> None: - super().__init__() - self.labels = labels - self.blank = blank - - def forward(self, emission: torch.Tensor) -> str: - """Given a sequence emission over labels, get the best path string - Args: - emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`. - - Returns: - str: The resulting transcript - """ - indices = torch.argmax(emission, dim=-1) # [num_seq,] - indices = torch.unique_consecutive(indices, dim=-1) - indices = [i for i in indices if i != self.blank] - return "".join([self.labels[i] for i in indices]) - - -def main() -> None: - """Main function.""" - # choose between cuda, cpu and mps devices - device = "cuda" if torch.cuda.is_available() else "cpu" - # device = "mps" - device = torch.device(device) - - torch.random.manual_seed(42) - - bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H - - print(f"Sample rate (model): {bundle.sample_rate}") - print(f"Labels (model): {bundle.get_labels()}") - - model = bundle.get_model().to(device) - - print(model.__class__) - - # only do all things for one single sample - dataset = MultilingualLibriSpeech( - "data", "mls_german_opus", split="train", download=True - ) - - print(dataset[0]) - - # load waveforms and sample rate from dataset - waveform, sample_rate = dataset[0]["waveform"], dataset[0]["sample_rate"] - - if sample_rate != bundle.sample_rate: - waveform = F.resample(waveform, sample_rate, int(bundle.sample_rate)) - - waveform.to(device) - - with torch.inference_mode(): - features, _ = model.extract_features(waveform) - - with torch.inference_mode(): - emission, _ = model(waveform) - - decoder = GreedyCTCDecoder(labels=bundle.get_labels()) - transcript = decoder(emission[0]) - - print(transcript) - - -if __name__ == "__main__": - main() -- cgit v1.2.3