From 18abef41ed2f8060985f97248300e5ec8b6cb327 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Fri, 18 Aug 2023 16:37:07 +0200 Subject: added simple wav2vec2 inference --- swr2_asr/train.py | 15 ----------- swr2_asr/train_3.py | 74 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 15 deletions(-) delete mode 100644 swr2_asr/train.py create mode 100644 swr2_asr/train_3.py (limited to 'swr2_asr') diff --git a/swr2_asr/train.py b/swr2_asr/train.py deleted file mode 100644 index f61776a..0000000 --- a/swr2_asr/train.py +++ /dev/null @@ -1,15 +0,0 @@ -"""Training script for the ASR model.""" -from AudioLoader.speech.mls import MultilingualLibriSpeech - - -def main() -> None: - """Main function.""" - dataset = MultilingualLibriSpeech( - "data", "mls_polish_opus", split="train", download=True - ) - - print(dataset[1]) - - -if __name__ == "__main__": - main() diff --git a/swr2_asr/train_3.py b/swr2_asr/train_3.py new file mode 100644 index 0000000..a6b0010 --- /dev/null +++ b/swr2_asr/train_3.py @@ -0,0 +1,74 @@ +"""Training script for the ASR model.""" +from AudioLoader.speech.mls import MultilingualLibriSpeech +import torch +import torchaudio +import torchaudio.functional as F + + +class GreedyCTCDecoder(torch.nn.Module): + def __init__(self, labels, blank=0) -> None: + super().__init__() + self.labels = labels + self.blank = blank + + def forward(self, emission: torch.Tensor) -> str: + """Given a sequence emission over labels, get the best path string + Args: + emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`. + + Returns: + str: The resulting transcript + """ + indices = torch.argmax(emission, dim=-1) # [num_seq,] + indices = torch.unique_consecutive(indices, dim=-1) + indices = [i for i in indices if i != self.blank] + return "".join([self.labels[i] for i in indices]) + + +def main() -> None: + """Main function.""" + # choose between cuda, cpu and mps devices + device = "cuda" if torch.cuda.is_available() else "cpu" + # device = "mps" + device = torch.device(device) + + torch.random.manual_seed(42) + + bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H + + print(f"Sample rate (model): {bundle.sample_rate}") + print(f"Labels (model): {bundle.get_labels()}") + + model = bundle.get_model().to(device) + + print(model.__class__) + + # only do all things for one single sample + dataset = MultilingualLibriSpeech( + "data", "mls_german_opus", split="train", download=True + ) + + print(dataset[0]) + + # load waveforms and sample rate from dataset + waveform, sample_rate = dataset[0]["waveform"], dataset[0]["sample_rate"] + + if sample_rate != bundle.sample_rate: + waveform = F.resample(waveform, sample_rate, int(bundle.sample_rate)) + + waveform.to(device) + + with torch.inference_mode(): + features, _ = model.extract_features(waveform) + + with torch.inference_mode(): + emission, _ = model(waveform) + + decoder = GreedyCTCDecoder(labels=bundle.get_labels()) + transcript = decoder(emission[0]) + + print(transcript) + + +if __name__ == "__main__": + main() -- cgit v1.2.3 From dc597eff978e841707c6b58fb6039bd12dcfea79 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Fri, 18 Aug 2023 16:37:53 +0200 Subject: added assemblyai model --- swr2_asr/loss_scores.py | 185 ++++++++++++++++++++ swr2_asr/train_2.py | 452 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 637 insertions(+) create mode 100644 swr2_asr/loss_scores.py create mode 100644 swr2_asr/train_2.py (limited to 'swr2_asr') diff --git a/swr2_asr/loss_scores.py b/swr2_asr/loss_scores.py new file mode 100644 index 0000000..977462d --- /dev/null +++ b/swr2_asr/loss_scores.py @@ -0,0 +1,185 @@ +import numpy as np + + +def avg_wer(wer_scores, combined_ref_len): + return float(sum(wer_scores)) / float(combined_ref_len) + + +def _levenshtein_distance(ref, hyp): + """Levenshtein distance is a string metric for measuring the difference + between two sequences. Informally, the levenshtein disctance is defined as + the minimum number of single-character edits (substitutions, insertions or + deletions) required to change one word into the other. We can naturally + extend the edits to word level when calculate levenshtein disctance for + two sentences. + """ + m = len(ref) + n = len(hyp) + + # special case + if ref == hyp: + return 0 + if m == 0: + return n + if n == 0: + return m + + if m < n: + ref, hyp = hyp, ref + m, n = n, m + + # use O(min(m, n)) space + distance = np.zeros((2, n + 1), dtype=np.int32) + + # initialize distance matrix + for j in range(0, n + 1): + distance[0][j] = j + + # calculate levenshtein distance + for i in range(1, m + 1): + prev_row_idx = (i - 1) % 2 + cur_row_idx = i % 2 + distance[cur_row_idx][0] = i + for j in range(1, n + 1): + if ref[i - 1] == hyp[j - 1]: + distance[cur_row_idx][j] = distance[prev_row_idx][j - 1] + else: + s_num = distance[prev_row_idx][j - 1] + 1 + i_num = distance[cur_row_idx][j - 1] + 1 + d_num = distance[prev_row_idx][j] + 1 + distance[cur_row_idx][j] = min(s_num, i_num, d_num) + + return distance[m % 2][n] + + +def word_errors( + reference: str, hypothesis: str, ignore_case: bool = False, delimiter: str = " " +): + """Compute the levenshtein distance between reference sequence and + hypothesis sequence in word-level. + :param reference: The reference sentence. + :type reference: basestring + :param hypothesis: The hypothesis sentence. + :type hypothesis: basestring + :param ignore_case: Whether case-sensitive or not. + :type ignore_case: bool + :param delimiter: Delimiter of input sentences. + :type delimiter: char + :return: Levenshtein distance and word number of reference sentence. + :rtype: list + """ + if ignore_case: + reference = reference.lower() + hypothesis = hypothesis.lower() + + ref_words = reference.split(delimiter) + hyp_words = hypothesis.split(delimiter) + + edit_distance = _levenshtein_distance(ref_words, hyp_words) + return float(edit_distance), len(ref_words) + + +def char_errors( + reference: str, + hypothesis: str, + ignore_case: bool = False, + remove_space: bool = False, +): + """Compute the levenshtein distance between reference sequence and + hypothesis sequence in char-level. + :param reference: The reference sentence. + :type reference: basestring + :param hypothesis: The hypothesis sentence. + :type hypothesis: basestring + :param ignore_case: Whether case-sensitive or not. + :type ignore_case: bool + :param remove_space: Whether remove internal space characters + :type remove_space: bool + :return: Levenshtein distance and length of reference sentence. + :rtype: list + """ + if ignore_case: + reference = reference.lower() + hypothesis = hypothesis.lower() + + join_char = " " + if remove_space: + join_char = "" + + reference = join_char.join(filter(None, reference.split(" "))) + hypothesis = join_char.join(filter(None, hypothesis.split(" "))) + + edit_distance = _levenshtein_distance(reference, hypothesis) + return float(edit_distance), len(reference) + + +def wer(reference: str, hypothesis: str, ignore_case=False, delimiter=" "): + """Calculate word error rate (WER). WER compares reference text and + hypothesis text in word-level. WER is defined as: + .. math:: + WER = (Sw + Dw + Iw) / Nw + where + .. code-block:: text + Sw is the number of words subsituted, + Dw is the number of words deleted, + Iw is the number of words inserted, + Nw is the number of words in the reference + We can use levenshtein distance to calculate WER. Please draw an attention + that empty items will be removed when splitting sentences by delimiter. + :param reference: The reference sentence. + :type reference: basestring + :param hypothesis: The hypothesis sentence. + :type hypothesis: basestring + :param ignore_case: Whether case-sensitive or not. + :type ignore_case: bool + :param delimiter: Delimiter of input sentences. + :type delimiter: char + :return: Word error rate. + :rtype: float + :raises ValueError: If word number of reference is zero. + """ + edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case, delimiter) + + if ref_len == 0: + raise ValueError("Reference's word number should be greater than 0.") + + wer = float(edit_distance) / ref_len + return wer + + +def cer(reference, hypothesis, ignore_case=False, remove_space=False): + """Calculate charactor error rate (CER). CER compares reference text and + hypothesis text in char-level. CER is defined as: + .. math:: + CER = (Sc + Dc + Ic) / Nc + where + .. code-block:: text + Sc is the number of characters substituted, + Dc is the number of characters deleted, + Ic is the number of characters inserted + Nc is the number of characters in the reference + We can use levenshtein distance to calculate CER. Chinese input should be + encoded to unicode. Please draw an attention that the leading and tailing + space characters will be truncated and multiple consecutive space + characters in a sentence will be replaced by one space character. + :param reference: The reference sentence. + :type reference: basestring + :param hypothesis: The hypothesis sentence. + :type hypothesis: basestring + :param ignore_case: Whether case-sensitive or not. + :type ignore_case: bool + :param remove_space: Whether remove internal space characters + :type remove_space: bool + :return: Character error rate. + :rtype: float + :raises ValueError: If the reference length is zero. + """ + edit_distance, ref_len = char_errors( + reference, hypothesis, ignore_case, remove_space + ) + + if ref_len == 0: + raise ValueError("Length of reference should be greater than 0.") + + cer = float(edit_distance) / ref_len + return cer diff --git a/swr2_asr/train_2.py b/swr2_asr/train_2.py new file mode 100644 index 0000000..2e690e2 --- /dev/null +++ b/swr2_asr/train_2.py @@ -0,0 +1,452 @@ +"""Training script for the ASR model.""" +from AudioLoader.speech.mls import MultilingualLibriSpeech +import click +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torch.utils.data import DataLoader +import torchaudio +import torchaudio.functional as AF +from .loss_scores import cer, wer + + +class TextTransform: + """Maps characters to integers and vice versa""" + + def __init__(self): + char_map_str = """ + ' 0 + 1 + a 2 + b 3 + c 4 + d 5 + e 6 + f 7 + g 8 + h 9 + i 10 + j 11 + k 12 + l 13 + m 14 + n 15 + o 16 + p 17 + q 18 + r 19 + s 20 + t 21 + u 22 + v 23 + w 24 + x 25 + y 26 + z 27 + ä 28 + ö 29 + ü 30 + ß 31 + """ + self.char_map = {} + self.index_map = {} + for line in char_map_str.strip().split("\n"): + ch, index = line.split() + self.char_map[ch] = int(index) + self.index_map[int(index)] = ch + self.index_map[1] = " " + + def text_to_int(self, text): + """Use a character map and convert text to an integer sequence""" + int_sequence = [] + for c in text: + if c == " ": + ch = self.char_map[""] + else: + ch = self.char_map[c] + int_sequence.append(ch) + return int_sequence + + def int_to_text(self, labels): + """Use a character map and convert integer labels to an text sequence""" + string = [] + for i in labels: + string.append(self.index_map[i]) + return "".join(string).replace("", " ") + + +train_audio_transforms = nn.Sequential( + torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), + torchaudio.transforms.FrequencyMasking(freq_mask_param=30), + torchaudio.transforms.TimeMasking(time_mask_param=100), +) + +valid_audio_transforms = torchaudio.transforms.MelSpectrogram() + +text_transform = TextTransform() + + +def data_processing(data, data_type="train"): + """Return the spectrograms, labels, and their lengths.""" + spectrograms = [] + labels = [] + input_lengths = [] + label_lengths = [] + for waveform, _, utterance, _, _, _ in data: + if data_type == "train": + spec = train_audio_transforms(waveform).squeeze(0).transpose(0, 1) + elif data_type == "valid": + spec = valid_audio_transforms(waveform).squeeze(0).transpose(0, 1) + else: + raise Exception("data_type should be train or valid") + spectrograms.append(spec) + label = torch.Tensor(text_transform.text_to_int(utterance.lower())) + labels.append(label) + input_lengths.append(spec.shape[0] // 2) + label_lengths.append(len(label)) + + spectrograms = ( + nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) + .unsqueeze(1) + .transpose(2, 3) + ) + labels = nn.utils.rnn.pad_sequence(labels, batch_first=True) + + return spectrograms, labels, input_lengths, label_lengths + + +def GreedyDecoder( + output, labels, label_lengths, blank_label=28, collapse_repeated=True +): + arg_maxes = torch.argmax(output, dim=2) + decodes = [] + targets = [] + for i, args in enumerate(arg_maxes): + decode = [] + targets.append( + text_transform.int_to_text(labels[i][: label_lengths[i]].tolist()) + ) + for j, index in enumerate(args): + if index != blank_label: + if collapse_repeated and j != 0 and index == args[j - 1]: + continue + decode.append(index.item()) + decodes.append(text_transform.int_to_text(decode)) + return decodes, targets + + +class CNNLayerNorm(nn.Module): + """Layer normalization built for cnns input""" + + def __init__(self, n_feats: int): + super(CNNLayerNorm, self).__init__() + self.layer_norm = nn.LayerNorm(n_feats) + + def forward(self, x): + """x (batch, channel, feature, time)""" + x = x.transpose(2, 3).contiguous() # (batch, channel, time, feature) + x = self.layer_norm(x) + return x.transpose(2, 3).contiguous() # (batch, channel, feature, time) + + +class ResidualCNN(nn.Module): + """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel: int, + stride: int, + dropout: float, + n_feats: int, + ): + super(ResidualCNN, self).__init__() + + self.cnn1 = nn.Conv2d( + in_channels, out_channels, kernel, stride, padding=kernel // 2 + ) + self.cnn2 = nn.Conv2d( + out_channels, + out_channels, + kernel, + stride, + padding=kernel // 2, + ) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.layer_norm1 = CNNLayerNorm(n_feats) + self.layer_norm2 = CNNLayerNorm(n_feats) + + def forward(self, x): + residual = x # (batch, channel, feature, time) + x = self.layer_norm1(x) + x = F.gelu(x) + x = self.dropout1(x) + x = self.cnn1(x) + x = self.layer_norm2(x) + x = F.gelu(x) + x = self.dropout2(x) + x = self.cnn2(x) + x += residual + return x # (batch, channel, feature, time) + + +class BidirectionalGRU(nn.Module): + """BIdirectional GRU with Layer Normalization and Dropout""" + + def __init__( + self, + rnn_dim: int, + hidden_size: int, + dropout: float, + batch_first: bool, + ): + super(BidirectionalGRU, self).__init__() + + self.BiGRU = nn.GRU( + input_size=rnn_dim, + hidden_size=hidden_size, + num_layers=1, + batch_first=batch_first, + bidirectional=True, + ) + self.layer_norm = nn.LayerNorm(rnn_dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.layer_norm(x) + x = F.gelu(x) + x = self.dropout(x) + x, _ = self.BiGRU(x) + return x + + +class SpeechRecognitionModel(nn.Module): + def __init__( + self, + n_cnn_layers: int, + n_rnn_layers: int, + rnn_dim: int, + n_class: int, + n_feats: int, + stride: int = 2, + dropout: float = 0.1, + ): + super(SpeechRecognitionModel, self).__init__() + n_feats //= 2 + self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2) + # n residual cnn layers with filter size of 32 + self.rescnn_layers = nn.Sequential( + *[ + ResidualCNN( + 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats + ) + for _ in range(n_cnn_layers) + ] + ) + self.fully_connected = nn.Linear(n_feats * 32, rnn_dim) + self.birnn_layers = nn.Sequential( + *[ + BidirectionalGRU( + rnn_dim=rnn_dim if i == 0 else rnn_dim * 2, + hidden_size=rnn_dim, + dropout=dropout, + batch_first=i == 0, + ) + for i in range(n_rnn_layers) + ] + ) + self.classifier = nn.Sequential( + nn.Linear(rnn_dim * 2, rnn_dim), # birnn returns rnn_dim*2 + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(rnn_dim, n_class), + ) + + def forward(self, x): + x = self.cnn(x) + x = self.rescnn_layers(x) + sizes = x.size() + x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # (batch, feature, time) + x = x.transpose(1, 2) # (batch, time, feature) + x = self.fully_connected(x) + x = self.birnn_layers(x) + x = self.classifier(x) + return x + + +class IterMeter(object): + """keeps track of total iterations""" + + def __init__(self): + self.val = 0 + + def step(self): + self.val += 1 + + def get(self): + return self.val + + +def train( + model, + device, + train_loader, + criterion, + optimizer, + scheduler, + epoch, + iter_meter, +): + model.train() + data_len = len(train_loader.dataset) + for batch_idx, _data in enumerate(train_loader): + spectrograms, labels, input_lengths, label_lengths = _data + spectrograms, labels = spectrograms.to(device), labels.to(device) + + optimizer.zero_grad() + + output = model(spectrograms) # (batch, time, n_class) + output = F.log_softmax(output, dim=2) + output = output.transpose(0, 1) # (time, batch, n_class) + + loss = criterion(output, labels, input_lengths, label_lengths) + loss.backward() + + optimizer.step() + scheduler.step() + iter_meter.step() + if batch_idx % 100 == 0 or batch_idx == data_len: + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(spectrograms), + data_len, + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) + + +def test(model, device, test_loader, criterion, epoch, iter_meter): + print("\nevaluating...") + model.eval() + test_loss = 0 + test_cer, test_wer = [], [] + with torch.no_grad(): + for i, _data in enumerate(test_loader): + spectrograms, labels, input_lengths, label_lengths = _data + spectrograms, labels = spectrograms.to(device), labels.to(device) + + output = model(spectrograms) # (batch, time, n_class) + output = F.log_softmax(output, dim=2) + output = output.transpose(0, 1) # (time, batch, n_class) + + loss = criterion(output, labels, input_lengths, label_lengths) + test_loss += loss.item() / len(test_loader) + + decoded_preds, decoded_targets = GreedyDecoder( + output.transpose(0, 1), labels, label_lengths + ) + for j in range(len(decoded_preds)): + test_cer.append(cer(decoded_targets[j], decoded_preds[j])) + test_wer.append(wer(decoded_targets[j], decoded_preds[j])) + + avg_cer = sum(test_cer) / len(test_cer) + avg_wer = sum(test_wer) / len(test_wer) + + print( + "Test set: Average loss: {:.4f}, Average CER: {:4f} Average WER: {:.4f}\n".format( + test_loss, avg_cer, avg_wer + ) + ) + + +def run(lr: float, batch_size: int, epochs: int) -> None: + hparams = { + "n_cnn_layers": 3, + "n_rnn_layers": 5, + "rnn_dim": 512, + "n_class": 33, + "n_feats": 128, + "stride": 2, + "dropout": 0.1, + "learning_rate": lr, + "batch_size": batch_size, + "epochs": epochs, + } + + use_cuda = torch.cuda.is_available() + torch.manual_seed(42) + device = torch.device("cuda" if use_cuda else "cpu") + device = torch.device("mps") + + train_dataset = MultilingualLibriSpeech( + "data", "mls_german_opus", split="train", download=False + ) + test_dataset = MultilingualLibriSpeech( + "data", "mls_german_opus", split="test", download=False + ) + + kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} + + train_loader = DataLoader( + train_dataset, + batch_size=hparams["batch_size"], + shuffle=True, + collate_fn=lambda x: data_processing(x, "train"), + **kwargs, + ) + + test_loader = DataLoader( + train_dataset, + batch_size=hparams["batch_size"], + shuffle=True, + collate_fn=lambda x: data_processing(x, "train"), + **kwargs, + ) + + model = SpeechRecognitionModel( + hparams["n_cnn_layers"], + hparams["n_rnn_layers"], + hparams["rnn_dim"], + hparams["n_class"], + hparams["n_feats"], + hparams["stride"], + hparams["dropout"], + ).to(device) + + print( + "Num Model Parameters", sum([param.nelement() for param in model.parameters()]) + ) + + optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) + criterion = nn.CTCLoss(blank=28).to(device) + + scheduler = optim.lr_scheduler.OneCycleLR( + optimizer, + max_lr=hparams["learning_rate"], + steps_per_epoch=int(len(train_loader)), + epochs=hparams["epochs"], + anneal_strategy="linear", + ) + + iter_meter = IterMeter() + for epoch in range(1, epochs + 1): + train( + model, + device, + train_loader, + criterion, + optimizer, + scheduler, + epoch, + iter_meter, + ) + test(model, device, test_loader, criterion, epoch, iter_meter) + + +if __name__ == "__main__": + run(lr=5e-4, batch_size=20, epochs=10) -- cgit v1.2.3 From 13a608530eba90cea4c003566e331938fbf34bda Mon Sep 17 00:00:00 2001 From: Pherkel Date: Fri, 18 Aug 2023 23:19:47 +0200 Subject: fixed training procedure --- .gitignore | 7 ++----- readme.md | 6 +++--- swr2_asr/train_2.py | 36 +++++++++++++++++++++++++----------- tests/__init__.py | 0 4 files changed, 30 insertions(+), 19 deletions(-) create mode 100644 tests/__init__.py (limited to 'swr2_asr') diff --git a/.gitignore b/.gitignore index 33600ee..8e64e4b 100644 --- a/.gitignore +++ b/.gitignore @@ -118,6 +118,8 @@ ipython_config.py # in version control. # https://pdm.fming.dev/#use-with-ide .pdm.toml +.pdm-python +.pdm-build/ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm __pypackages__/ @@ -162,14 +164,9 @@ dmypy.json # Cython debug symbols cython_debug/ -# linter -**/.ruff - # PyCharm # JetBrains specific template is maintained in a separate JetBrains.gitignore that can # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore # and can be added to the global gitignore or merged into this file. For a more nuclear # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ - -data/ diff --git a/readme.md b/readme.md index 99d741c..47d9a31 100644 --- a/readme.md +++ b/readme.md @@ -5,7 +5,7 @@ recogniton 2 (SWR2) in the summer term 2023. # Installation ``` -pip install -r requirements.txt +poetry install ``` # Usage @@ -14,13 +14,13 @@ pip install -r requirements.txt Train using the provided train script: - poetry run train --data PATH/TO/DATA --lr 0.01 + poetry run train ## Evaluation ## Inference - poetry run recognize --data PATH/TO/FILE + poetry run recognize ## CI diff --git a/swr2_asr/train_2.py b/swr2_asr/train_2.py index 2e690e2..b1b597a 100644 --- a/swr2_asr/train_2.py +++ b/swr2_asr/train_2.py @@ -48,6 +48,20 @@ class TextTransform: ö 29 ü 30 ß 31 + - 32 + é 33 + è 34 + à 35 + ù 36 + ç 37 + â 38 + ê 39 + î 40 + ô 41 + û 42 + ë 43 + ï 44 + ü 45 """ self.char_map = {} self.index_map = {} @@ -93,15 +107,15 @@ def data_processing(data, data_type="train"): labels = [] input_lengths = [] label_lengths = [] - for waveform, _, utterance, _, _, _ in data: + for x in data: if data_type == "train": - spec = train_audio_transforms(waveform).squeeze(0).transpose(0, 1) + spec = train_audio_transforms(x["waveform"]).squeeze(0).transpose(0, 1) elif data_type == "valid": - spec = valid_audio_transforms(waveform).squeeze(0).transpose(0, 1) + spec = valid_audio_transforms(x["waveform"]).squeeze(0).transpose(0, 1) else: raise Exception("data_type should be train or valid") spectrograms.append(spec) - label = torch.Tensor(text_transform.text_to_int(utterance.lower())) + label = torch.Tensor(text_transform.text_to_int(x["utterance"].lower())) labels.append(label) input_lengths.append(spec.shape[0] // 2) label_lengths.append(len(label)) @@ -364,12 +378,12 @@ def test(model, device, test_loader, criterion, epoch, iter_meter): ) -def run(lr: float, batch_size: int, epochs: int) -> None: +def run(lr: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None: hparams = { "n_cnn_layers": 3, "n_rnn_layers": 5, "rnn_dim": 512, - "n_class": 33, + "n_class": 46, "n_feats": 128, "stride": 2, "dropout": 0.1, @@ -381,13 +395,13 @@ def run(lr: float, batch_size: int, epochs: int) -> None: use_cuda = torch.cuda.is_available() torch.manual_seed(42) device = torch.device("cuda" if use_cuda else "cpu") - device = torch.device("mps") + # device = torch.device("mps") train_dataset = MultilingualLibriSpeech( - "data", "mls_german_opus", split="train", download=False + "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="dev", download=False ) test_dataset = MultilingualLibriSpeech( - "data", "mls_german_opus", split="test", download=False + "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="test", download=False ) kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} @@ -401,7 +415,7 @@ def run(lr: float, batch_size: int, epochs: int) -> None: ) test_loader = DataLoader( - train_dataset, + test_dataset, batch_size=hparams["batch_size"], shuffle=True, collate_fn=lambda x: data_processing(x, "train"), @@ -449,4 +463,4 @@ def run(lr: float, batch_size: int, epochs: int) -> None: if __name__ == "__main__": - run(lr=5e-4, batch_size=20, epochs=10) + run(lr=5e-4, batch_size=16, epochs=1) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 -- cgit v1.2.3 From 8159a0b8a4519dced2490d77b7e1ae7fd1bbadef Mon Sep 17 00:00:00 2001 From: Pherkel Date: Fri, 18 Aug 2023 23:41:09 +0200 Subject: made linter changes (will still fail) --- .vscode/settings.json | 2 ++ mypy.ini | 6 ++++++ swr2_asr/train_2.py | 16 +++++++--------- 3 files changed, 15 insertions(+), 9 deletions(-) (limited to 'swr2_asr') diff --git a/.vscode/settings.json b/.vscode/settings.json index 6d5637c..bd8762b 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -7,4 +7,6 @@ }, "black-formatter.importStrategy": "fromEnvironment", "python.analysis.typeCheckingMode": "basic", + "python.linting.pylintEnabled": true, + "python.linting.enabled": true, } \ No newline at end of file diff --git a/mypy.ini b/mypy.ini index f5a713c..9f3a098 100644 --- a/mypy.ini +++ b/mypy.ini @@ -2,4 +2,10 @@ ignore_missing_imports = True [mypy-torchaudio.*] +ignore_missing_imports = true + +[mypy-torch.*] +ignore_missing_imports = true + +[mypy-click.*] ignore_missing_imports = true \ No newline at end of file diff --git a/swr2_asr/train_2.py b/swr2_asr/train_2.py index b1b597a..bea5bf4 100644 --- a/swr2_asr/train_2.py +++ b/swr2_asr/train_2.py @@ -1,13 +1,11 @@ """Training script for the ASR model.""" from AudioLoader.speech.mls import MultilingualLibriSpeech -import click import torch import torch.nn as nn import torch.optim as optim import torch.nn.functional as F from torch.utils.data import DataLoader import torchaudio -import torchaudio.functional as AF from .loss_scores import cer, wer @@ -113,7 +111,7 @@ def data_processing(data, data_type="train"): elif data_type == "valid": spec = valid_audio_transforms(x["waveform"]).squeeze(0).transpose(0, 1) else: - raise Exception("data_type should be train or valid") + raise ValueError("data_type should be train or valid") spectrograms.append(spec) label = torch.Tensor(text_transform.text_to_int(x["utterance"].lower())) labels.append(label) @@ -133,6 +131,7 @@ def data_processing(data, data_type="train"): def GreedyDecoder( output, labels, label_lengths, blank_label=28, collapse_repeated=True ): + """Greedily decode a sequence.""" arg_maxes = torch.argmax(output, dim=2) decodes = [] targets = [] @@ -344,13 +343,13 @@ def train( ) -def test(model, device, test_loader, criterion, epoch, iter_meter): +def test(model, device, test_loader, criterion): print("\nevaluating...") model.eval() test_loss = 0 test_cer, test_wer = [], [] with torch.no_grad(): - for i, _data in enumerate(test_loader): + for _data in test_loader: spectrograms, labels, input_lengths, label_lengths = _data spectrograms, labels = spectrograms.to(device), labels.to(device) @@ -372,9 +371,8 @@ def test(model, device, test_loader, criterion, epoch, iter_meter): avg_wer = sum(test_wer) / len(test_wer) print( - "Test set: Average loss: {:.4f}, Average CER: {:4f} Average WER: {:.4f}\n".format( - test_loss, avg_cer, avg_wer - ) + f"Test set: Average loss:\ + {test_loss}, Average CER: {avg_cer} Average WER: {avg_wer}\n" ) @@ -459,7 +457,7 @@ def run(lr: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None: epoch, iter_meter, ) - test(model, device, test_loader, criterion, epoch, iter_meter) + test(model=model, device=device, test_loader=test_loader, criterion=criterion) if __name__ == "__main__": -- cgit v1.2.3 From 9b4592caac90a41eb6ce18558588ef504c49f58e Mon Sep 17 00:00:00 2001 From: Pherkel Date: Sat, 19 Aug 2023 10:55:02 +0200 Subject: renamed files --- pyproject.toml | 2 +- swr2_asr/inference_test.py | 74 +++++++ swr2_asr/train.py | 475 +++++++++++++++++++++++++++++++++++++++++++++ swr2_asr/train_2.py | 464 ------------------------------------------- swr2_asr/train_3.py | 74 ------- 5 files changed, 550 insertions(+), 539 deletions(-) create mode 100644 swr2_asr/inference_test.py create mode 100644 swr2_asr/train.py delete mode 100644 swr2_asr/train_2.py delete mode 100644 swr2_asr/train_3.py (limited to 'swr2_asr') diff --git a/pyproject.toml b/pyproject.toml index 791a76e..8490aa5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,7 @@ mypy = "^1.5.1" pylint = "^2.17.5" [tool.poetry.scripts] -train = "swr2_asr.train_2:run" +train = "swr2_asr.train:run_cli" [build-system] requires = ["poetry-core"] diff --git a/swr2_asr/inference_test.py b/swr2_asr/inference_test.py new file mode 100644 index 0000000..a6b0010 --- /dev/null +++ b/swr2_asr/inference_test.py @@ -0,0 +1,74 @@ +"""Training script for the ASR model.""" +from AudioLoader.speech.mls import MultilingualLibriSpeech +import torch +import torchaudio +import torchaudio.functional as F + + +class GreedyCTCDecoder(torch.nn.Module): + def __init__(self, labels, blank=0) -> None: + super().__init__() + self.labels = labels + self.blank = blank + + def forward(self, emission: torch.Tensor) -> str: + """Given a sequence emission over labels, get the best path string + Args: + emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`. + + Returns: + str: The resulting transcript + """ + indices = torch.argmax(emission, dim=-1) # [num_seq,] + indices = torch.unique_consecutive(indices, dim=-1) + indices = [i for i in indices if i != self.blank] + return "".join([self.labels[i] for i in indices]) + + +def main() -> None: + """Main function.""" + # choose between cuda, cpu and mps devices + device = "cuda" if torch.cuda.is_available() else "cpu" + # device = "mps" + device = torch.device(device) + + torch.random.manual_seed(42) + + bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H + + print(f"Sample rate (model): {bundle.sample_rate}") + print(f"Labels (model): {bundle.get_labels()}") + + model = bundle.get_model().to(device) + + print(model.__class__) + + # only do all things for one single sample + dataset = MultilingualLibriSpeech( + "data", "mls_german_opus", split="train", download=True + ) + + print(dataset[0]) + + # load waveforms and sample rate from dataset + waveform, sample_rate = dataset[0]["waveform"], dataset[0]["sample_rate"] + + if sample_rate != bundle.sample_rate: + waveform = F.resample(waveform, sample_rate, int(bundle.sample_rate)) + + waveform.to(device) + + with torch.inference_mode(): + features, _ = model.extract_features(waveform) + + with torch.inference_mode(): + emission, _ = model(waveform) + + decoder = GreedyCTCDecoder(labels=bundle.get_labels()) + transcript = decoder(emission[0]) + + print(transcript) + + +if __name__ == "__main__": + main() diff --git a/swr2_asr/train.py b/swr2_asr/train.py new file mode 100644 index 0000000..8ee96b9 --- /dev/null +++ b/swr2_asr/train.py @@ -0,0 +1,475 @@ +"""Training script for the ASR model.""" +from AudioLoader.speech.mls import MultilingualLibriSpeech +import click +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torch.utils.data import DataLoader +import torchaudio +from .loss_scores import cer, wer + + +class TextTransform: + """Maps characters to integers and vice versa""" + + def __init__(self): + char_map_str = """ + ' 0 + 1 + a 2 + b 3 + c 4 + d 5 + e 6 + f 7 + g 8 + h 9 + i 10 + j 11 + k 12 + l 13 + m 14 + n 15 + o 16 + p 17 + q 18 + r 19 + s 20 + t 21 + u 22 + v 23 + w 24 + x 25 + y 26 + z 27 + ä 28 + ö 29 + ü 30 + ß 31 + - 32 + é 33 + è 34 + à 35 + ù 36 + ç 37 + â 38 + ê 39 + î 40 + ô 41 + û 42 + ë 43 + ï 44 + ü 45 + """ + self.char_map = {} + self.index_map = {} + for line in char_map_str.strip().split("\n"): + ch, index = line.split() + self.char_map[ch] = int(index) + self.index_map[int(index)] = ch + self.index_map[1] = " " + + def text_to_int(self, text): + """Use a character map and convert text to an integer sequence""" + int_sequence = [] + for c in text: + if c == " ": + ch = self.char_map[""] + else: + ch = self.char_map[c] + int_sequence.append(ch) + return int_sequence + + def int_to_text(self, labels): + """Use a character map and convert integer labels to an text sequence""" + string = [] + for i in labels: + string.append(self.index_map[i]) + return "".join(string).replace("", " ") + + +train_audio_transforms = nn.Sequential( + torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), + torchaudio.transforms.FrequencyMasking(freq_mask_param=30), + torchaudio.transforms.TimeMasking(time_mask_param=100), +) + +valid_audio_transforms = torchaudio.transforms.MelSpectrogram() + +text_transform = TextTransform() + + +def data_processing(data, data_type="train"): + """Return the spectrograms, labels, and their lengths.""" + spectrograms = [] + labels = [] + input_lengths = [] + label_lengths = [] + for x in data: + if data_type == "train": + spec = train_audio_transforms(x["waveform"]).squeeze(0).transpose(0, 1) + elif data_type == "valid": + spec = valid_audio_transforms(x["waveform"]).squeeze(0).transpose(0, 1) + else: + raise ValueError("data_type should be train or valid") + spectrograms.append(spec) + label = torch.Tensor(text_transform.text_to_int(x["utterance"].lower())) + labels.append(label) + input_lengths.append(spec.shape[0] // 2) + label_lengths.append(len(label)) + + spectrograms = ( + nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) + .unsqueeze(1) + .transpose(2, 3) + ) + labels = nn.utils.rnn.pad_sequence(labels, batch_first=True) + + return spectrograms, labels, input_lengths, label_lengths + + +def GreedyDecoder( + output, labels, label_lengths, blank_label=28, collapse_repeated=True +): + """Greedily decode a sequence.""" + arg_maxes = torch.argmax(output, dim=2) + decodes = [] + targets = [] + for i, args in enumerate(arg_maxes): + decode = [] + targets.append( + text_transform.int_to_text(labels[i][: label_lengths[i]].tolist()) + ) + for j, index in enumerate(args): + if index != blank_label: + if collapse_repeated and j != 0 and index == args[j - 1]: + continue + decode.append(index.item()) + decodes.append(text_transform.int_to_text(decode)) + return decodes, targets + + +class CNNLayerNorm(nn.Module): + """Layer normalization built for cnns input""" + + def __init__(self, n_feats: int): + super(CNNLayerNorm, self).__init__() + self.layer_norm = nn.LayerNorm(n_feats) + + def forward(self, x): + """x (batch, channel, feature, time)""" + x = x.transpose(2, 3).contiguous() # (batch, channel, time, feature) + x = self.layer_norm(x) + return x.transpose(2, 3).contiguous() # (batch, channel, feature, time) + + +class ResidualCNN(nn.Module): + """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel: int, + stride: int, + dropout: float, + n_feats: int, + ): + super(ResidualCNN, self).__init__() + + self.cnn1 = nn.Conv2d( + in_channels, out_channels, kernel, stride, padding=kernel // 2 + ) + self.cnn2 = nn.Conv2d( + out_channels, + out_channels, + kernel, + stride, + padding=kernel // 2, + ) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.layer_norm1 = CNNLayerNorm(n_feats) + self.layer_norm2 = CNNLayerNorm(n_feats) + + def forward(self, x): + residual = x # (batch, channel, feature, time) + x = self.layer_norm1(x) + x = F.gelu(x) + x = self.dropout1(x) + x = self.cnn1(x) + x = self.layer_norm2(x) + x = F.gelu(x) + x = self.dropout2(x) + x = self.cnn2(x) + x += residual + return x # (batch, channel, feature, time) + + +class BidirectionalGRU(nn.Module): + """BIdirectional GRU with Layer Normalization and Dropout""" + + def __init__( + self, + rnn_dim: int, + hidden_size: int, + dropout: float, + batch_first: bool, + ): + super(BidirectionalGRU, self).__init__() + + self.BiGRU = nn.GRU( + input_size=rnn_dim, + hidden_size=hidden_size, + num_layers=1, + batch_first=batch_first, + bidirectional=True, + ) + self.layer_norm = nn.LayerNorm(rnn_dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + x = self.layer_norm(x) + x = F.gelu(x) + x = self.dropout(x) + x, _ = self.BiGRU(x) + return x + + +class SpeechRecognitionModel(nn.Module): + def __init__( + self, + n_cnn_layers: int, + n_rnn_layers: int, + rnn_dim: int, + n_class: int, + n_feats: int, + stride: int = 2, + dropout: float = 0.1, + ): + super(SpeechRecognitionModel, self).__init__() + n_feats //= 2 + self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2) + # n residual cnn layers with filter size of 32 + self.rescnn_layers = nn.Sequential( + *[ + ResidualCNN( + 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats + ) + for _ in range(n_cnn_layers) + ] + ) + self.fully_connected = nn.Linear(n_feats * 32, rnn_dim) + self.birnn_layers = nn.Sequential( + *[ + BidirectionalGRU( + rnn_dim=rnn_dim if i == 0 else rnn_dim * 2, + hidden_size=rnn_dim, + dropout=dropout, + batch_first=i == 0, + ) + for i in range(n_rnn_layers) + ] + ) + self.classifier = nn.Sequential( + nn.Linear(rnn_dim * 2, rnn_dim), # birnn returns rnn_dim*2 + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(rnn_dim, n_class), + ) + + def forward(self, x): + x = self.cnn(x) + x = self.rescnn_layers(x) + sizes = x.size() + x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # (batch, feature, time) + x = x.transpose(1, 2) # (batch, time, feature) + x = self.fully_connected(x) + x = self.birnn_layers(x) + x = self.classifier(x) + return x + + +class IterMeter(object): + """keeps track of total iterations""" + + def __init__(self): + self.val = 0 + + def step(self): + self.val += 1 + + def get(self): + return self.val + + +def train( + model, + device, + train_loader, + criterion, + optimizer, + scheduler, + epoch, + iter_meter, +): + model.train() + data_len = len(train_loader.dataset) + for batch_idx, _data in enumerate(train_loader): + spectrograms, labels, input_lengths, label_lengths = _data + spectrograms, labels = spectrograms.to(device), labels.to(device) + + optimizer.zero_grad() + + output = model(spectrograms) # (batch, time, n_class) + output = F.log_softmax(output, dim=2) + output = output.transpose(0, 1) # (time, batch, n_class) + + loss = criterion(output, labels, input_lengths, label_lengths) + loss.backward() + + optimizer.step() + scheduler.step() + iter_meter.step() + if batch_idx % 100 == 0 or batch_idx == data_len: + print( + "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( + epoch, + batch_idx * len(spectrograms), + data_len, + 100.0 * batch_idx / len(train_loader), + loss.item(), + ) + ) + + +def test(model, device, test_loader, criterion): + print("\nevaluating...") + model.eval() + test_loss = 0 + test_cer, test_wer = [], [] + with torch.no_grad(): + for _data in test_loader: + spectrograms, labels, input_lengths, label_lengths = _data + spectrograms, labels = spectrograms.to(device), labels.to(device) + + output = model(spectrograms) # (batch, time, n_class) + output = F.log_softmax(output, dim=2) + output = output.transpose(0, 1) # (time, batch, n_class) + + loss = criterion(output, labels, input_lengths, label_lengths) + test_loss += loss.item() / len(test_loader) + + decoded_preds, decoded_targets = GreedyDecoder( + output.transpose(0, 1), labels, label_lengths + ) + for j in range(len(decoded_preds)): + test_cer.append(cer(decoded_targets[j], decoded_preds[j])) + test_wer.append(wer(decoded_targets[j], decoded_preds[j])) + + avg_cer = sum(test_cer) / len(test_cer) + avg_wer = sum(test_wer) / len(test_wer) + + print( + f"Test set: Average loss:\ + {test_loss}, Average CER: {avg_cer} Average WER: {avg_wer}\n" + ) + + +def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None: + """Runs the training script.""" + hparams = { + "n_cnn_layers": 3, + "n_rnn_layers": 5, + "rnn_dim": 512, + "n_class": 46, + "n_feats": 128, + "stride": 2, + "dropout": 0.1, + "learning_rate": learning_rate, + "batch_size": batch_size, + "epochs": epochs, + } + + use_cuda = torch.cuda.is_available() + torch.manual_seed(42) + device = torch.device("cuda" if use_cuda else "cpu") + # device = torch.device("mps") + + train_dataset = MultilingualLibriSpeech( + "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="dev", download=False + ) + test_dataset = MultilingualLibriSpeech( + "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="test", download=False + ) + + kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} + + train_loader = DataLoader( + train_dataset, + batch_size=hparams["batch_size"], + shuffle=True, + collate_fn=lambda x: data_processing(x, "train"), + **kwargs, + ) + + test_loader = DataLoader( + test_dataset, + batch_size=hparams["batch_size"], + shuffle=True, + collate_fn=lambda x: data_processing(x, "train"), + **kwargs, + ) + + model = SpeechRecognitionModel( + hparams["n_cnn_layers"], + hparams["n_rnn_layers"], + hparams["rnn_dim"], + hparams["n_class"], + hparams["n_feats"], + hparams["stride"], + hparams["dropout"], + ).to(device) + + print( + "Num Model Parameters", sum([param.nelement() for param in model.parameters()]) + ) + + optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) + criterion = nn.CTCLoss(blank=28).to(device) + + scheduler = optim.lr_scheduler.OneCycleLR( + optimizer, + max_lr=hparams["learning_rate"], + steps_per_epoch=int(len(train_loader)), + epochs=hparams["epochs"], + anneal_strategy="linear", + ) + + iter_meter = IterMeter() + for epoch in range(1, epochs + 1): + train( + model, + device, + train_loader, + criterion, + optimizer, + scheduler, + epoch, + iter_meter, + ) + test(model=model, device=device, test_loader=test_loader, criterion=criterion) + + +@click.command() +@click.option("--learning-rate", default=1e-3, help="Learning rate") +@click.option("--batch_size", default=1, help="Batch size") +@click.option("--epochs", default=1, help="Number of epochs") +def run_cli(learning_rate: float, batch_size: int, epochs: int) -> None: + """Runs the training script.""" + run(learning_rate=learning_rate, batch_size=batch_size, epochs=epochs) + + +if __name__ == "__main__": + run(learning_rate=5e-4, batch_size=16, epochs=1) diff --git a/swr2_asr/train_2.py b/swr2_asr/train_2.py deleted file mode 100644 index bea5bf4..0000000 --- a/swr2_asr/train_2.py +++ /dev/null @@ -1,464 +0,0 @@ -"""Training script for the ASR model.""" -from AudioLoader.speech.mls import MultilingualLibriSpeech -import torch -import torch.nn as nn -import torch.optim as optim -import torch.nn.functional as F -from torch.utils.data import DataLoader -import torchaudio -from .loss_scores import cer, wer - - -class TextTransform: - """Maps characters to integers and vice versa""" - - def __init__(self): - char_map_str = """ - ' 0 - 1 - a 2 - b 3 - c 4 - d 5 - e 6 - f 7 - g 8 - h 9 - i 10 - j 11 - k 12 - l 13 - m 14 - n 15 - o 16 - p 17 - q 18 - r 19 - s 20 - t 21 - u 22 - v 23 - w 24 - x 25 - y 26 - z 27 - ä 28 - ö 29 - ü 30 - ß 31 - - 32 - é 33 - è 34 - à 35 - ù 36 - ç 37 - â 38 - ê 39 - î 40 - ô 41 - û 42 - ë 43 - ï 44 - ü 45 - """ - self.char_map = {} - self.index_map = {} - for line in char_map_str.strip().split("\n"): - ch, index = line.split() - self.char_map[ch] = int(index) - self.index_map[int(index)] = ch - self.index_map[1] = " " - - def text_to_int(self, text): - """Use a character map and convert text to an integer sequence""" - int_sequence = [] - for c in text: - if c == " ": - ch = self.char_map[""] - else: - ch = self.char_map[c] - int_sequence.append(ch) - return int_sequence - - def int_to_text(self, labels): - """Use a character map and convert integer labels to an text sequence""" - string = [] - for i in labels: - string.append(self.index_map[i]) - return "".join(string).replace("", " ") - - -train_audio_transforms = nn.Sequential( - torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), - torchaudio.transforms.FrequencyMasking(freq_mask_param=30), - torchaudio.transforms.TimeMasking(time_mask_param=100), -) - -valid_audio_transforms = torchaudio.transforms.MelSpectrogram() - -text_transform = TextTransform() - - -def data_processing(data, data_type="train"): - """Return the spectrograms, labels, and their lengths.""" - spectrograms = [] - labels = [] - input_lengths = [] - label_lengths = [] - for x in data: - if data_type == "train": - spec = train_audio_transforms(x["waveform"]).squeeze(0).transpose(0, 1) - elif data_type == "valid": - spec = valid_audio_transforms(x["waveform"]).squeeze(0).transpose(0, 1) - else: - raise ValueError("data_type should be train or valid") - spectrograms.append(spec) - label = torch.Tensor(text_transform.text_to_int(x["utterance"].lower())) - labels.append(label) - input_lengths.append(spec.shape[0] // 2) - label_lengths.append(len(label)) - - spectrograms = ( - nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) - .unsqueeze(1) - .transpose(2, 3) - ) - labels = nn.utils.rnn.pad_sequence(labels, batch_first=True) - - return spectrograms, labels, input_lengths, label_lengths - - -def GreedyDecoder( - output, labels, label_lengths, blank_label=28, collapse_repeated=True -): - """Greedily decode a sequence.""" - arg_maxes = torch.argmax(output, dim=2) - decodes = [] - targets = [] - for i, args in enumerate(arg_maxes): - decode = [] - targets.append( - text_transform.int_to_text(labels[i][: label_lengths[i]].tolist()) - ) - for j, index in enumerate(args): - if index != blank_label: - if collapse_repeated and j != 0 and index == args[j - 1]: - continue - decode.append(index.item()) - decodes.append(text_transform.int_to_text(decode)) - return decodes, targets - - -class CNNLayerNorm(nn.Module): - """Layer normalization built for cnns input""" - - def __init__(self, n_feats: int): - super(CNNLayerNorm, self).__init__() - self.layer_norm = nn.LayerNorm(n_feats) - - def forward(self, x): - """x (batch, channel, feature, time)""" - x = x.transpose(2, 3).contiguous() # (batch, channel, time, feature) - x = self.layer_norm(x) - return x.transpose(2, 3).contiguous() # (batch, channel, feature, time) - - -class ResidualCNN(nn.Module): - """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf""" - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel: int, - stride: int, - dropout: float, - n_feats: int, - ): - super(ResidualCNN, self).__init__() - - self.cnn1 = nn.Conv2d( - in_channels, out_channels, kernel, stride, padding=kernel // 2 - ) - self.cnn2 = nn.Conv2d( - out_channels, - out_channels, - kernel, - stride, - padding=kernel // 2, - ) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - self.layer_norm1 = CNNLayerNorm(n_feats) - self.layer_norm2 = CNNLayerNorm(n_feats) - - def forward(self, x): - residual = x # (batch, channel, feature, time) - x = self.layer_norm1(x) - x = F.gelu(x) - x = self.dropout1(x) - x = self.cnn1(x) - x = self.layer_norm2(x) - x = F.gelu(x) - x = self.dropout2(x) - x = self.cnn2(x) - x += residual - return x # (batch, channel, feature, time) - - -class BidirectionalGRU(nn.Module): - """BIdirectional GRU with Layer Normalization and Dropout""" - - def __init__( - self, - rnn_dim: int, - hidden_size: int, - dropout: float, - batch_first: bool, - ): - super(BidirectionalGRU, self).__init__() - - self.BiGRU = nn.GRU( - input_size=rnn_dim, - hidden_size=hidden_size, - num_layers=1, - batch_first=batch_first, - bidirectional=True, - ) - self.layer_norm = nn.LayerNorm(rnn_dim) - self.dropout = nn.Dropout(dropout) - - def forward(self, x): - x = self.layer_norm(x) - x = F.gelu(x) - x = self.dropout(x) - x, _ = self.BiGRU(x) - return x - - -class SpeechRecognitionModel(nn.Module): - def __init__( - self, - n_cnn_layers: int, - n_rnn_layers: int, - rnn_dim: int, - n_class: int, - n_feats: int, - stride: int = 2, - dropout: float = 0.1, - ): - super(SpeechRecognitionModel, self).__init__() - n_feats //= 2 - self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2) - # n residual cnn layers with filter size of 32 - self.rescnn_layers = nn.Sequential( - *[ - ResidualCNN( - 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats - ) - for _ in range(n_cnn_layers) - ] - ) - self.fully_connected = nn.Linear(n_feats * 32, rnn_dim) - self.birnn_layers = nn.Sequential( - *[ - BidirectionalGRU( - rnn_dim=rnn_dim if i == 0 else rnn_dim * 2, - hidden_size=rnn_dim, - dropout=dropout, - batch_first=i == 0, - ) - for i in range(n_rnn_layers) - ] - ) - self.classifier = nn.Sequential( - nn.Linear(rnn_dim * 2, rnn_dim), # birnn returns rnn_dim*2 - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(rnn_dim, n_class), - ) - - def forward(self, x): - x = self.cnn(x) - x = self.rescnn_layers(x) - sizes = x.size() - x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # (batch, feature, time) - x = x.transpose(1, 2) # (batch, time, feature) - x = self.fully_connected(x) - x = self.birnn_layers(x) - x = self.classifier(x) - return x - - -class IterMeter(object): - """keeps track of total iterations""" - - def __init__(self): - self.val = 0 - - def step(self): - self.val += 1 - - def get(self): - return self.val - - -def train( - model, - device, - train_loader, - criterion, - optimizer, - scheduler, - epoch, - iter_meter, -): - model.train() - data_len = len(train_loader.dataset) - for batch_idx, _data in enumerate(train_loader): - spectrograms, labels, input_lengths, label_lengths = _data - spectrograms, labels = spectrograms.to(device), labels.to(device) - - optimizer.zero_grad() - - output = model(spectrograms) # (batch, time, n_class) - output = F.log_softmax(output, dim=2) - output = output.transpose(0, 1) # (time, batch, n_class) - - loss = criterion(output, labels, input_lengths, label_lengths) - loss.backward() - - optimizer.step() - scheduler.step() - iter_meter.step() - if batch_idx % 100 == 0 or batch_idx == data_len: - print( - "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format( - epoch, - batch_idx * len(spectrograms), - data_len, - 100.0 * batch_idx / len(train_loader), - loss.item(), - ) - ) - - -def test(model, device, test_loader, criterion): - print("\nevaluating...") - model.eval() - test_loss = 0 - test_cer, test_wer = [], [] - with torch.no_grad(): - for _data in test_loader: - spectrograms, labels, input_lengths, label_lengths = _data - spectrograms, labels = spectrograms.to(device), labels.to(device) - - output = model(spectrograms) # (batch, time, n_class) - output = F.log_softmax(output, dim=2) - output = output.transpose(0, 1) # (time, batch, n_class) - - loss = criterion(output, labels, input_lengths, label_lengths) - test_loss += loss.item() / len(test_loader) - - decoded_preds, decoded_targets = GreedyDecoder( - output.transpose(0, 1), labels, label_lengths - ) - for j in range(len(decoded_preds)): - test_cer.append(cer(decoded_targets[j], decoded_preds[j])) - test_wer.append(wer(decoded_targets[j], decoded_preds[j])) - - avg_cer = sum(test_cer) / len(test_cer) - avg_wer = sum(test_wer) / len(test_wer) - - print( - f"Test set: Average loss:\ - {test_loss}, Average CER: {avg_cer} Average WER: {avg_wer}\n" - ) - - -def run(lr: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None: - hparams = { - "n_cnn_layers": 3, - "n_rnn_layers": 5, - "rnn_dim": 512, - "n_class": 46, - "n_feats": 128, - "stride": 2, - "dropout": 0.1, - "learning_rate": lr, - "batch_size": batch_size, - "epochs": epochs, - } - - use_cuda = torch.cuda.is_available() - torch.manual_seed(42) - device = torch.device("cuda" if use_cuda else "cpu") - # device = torch.device("mps") - - train_dataset = MultilingualLibriSpeech( - "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="dev", download=False - ) - test_dataset = MultilingualLibriSpeech( - "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="test", download=False - ) - - kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} - - train_loader = DataLoader( - train_dataset, - batch_size=hparams["batch_size"], - shuffle=True, - collate_fn=lambda x: data_processing(x, "train"), - **kwargs, - ) - - test_loader = DataLoader( - test_dataset, - batch_size=hparams["batch_size"], - shuffle=True, - collate_fn=lambda x: data_processing(x, "train"), - **kwargs, - ) - - model = SpeechRecognitionModel( - hparams["n_cnn_layers"], - hparams["n_rnn_layers"], - hparams["rnn_dim"], - hparams["n_class"], - hparams["n_feats"], - hparams["stride"], - hparams["dropout"], - ).to(device) - - print( - "Num Model Parameters", sum([param.nelement() for param in model.parameters()]) - ) - - optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) - criterion = nn.CTCLoss(blank=28).to(device) - - scheduler = optim.lr_scheduler.OneCycleLR( - optimizer, - max_lr=hparams["learning_rate"], - steps_per_epoch=int(len(train_loader)), - epochs=hparams["epochs"], - anneal_strategy="linear", - ) - - iter_meter = IterMeter() - for epoch in range(1, epochs + 1): - train( - model, - device, - train_loader, - criterion, - optimizer, - scheduler, - epoch, - iter_meter, - ) - test(model=model, device=device, test_loader=test_loader, criterion=criterion) - - -if __name__ == "__main__": - run(lr=5e-4, batch_size=16, epochs=1) diff --git a/swr2_asr/train_3.py b/swr2_asr/train_3.py deleted file mode 100644 index a6b0010..0000000 --- a/swr2_asr/train_3.py +++ /dev/null @@ -1,74 +0,0 @@ -"""Training script for the ASR model.""" -from AudioLoader.speech.mls import MultilingualLibriSpeech -import torch -import torchaudio -import torchaudio.functional as F - - -class GreedyCTCDecoder(torch.nn.Module): - def __init__(self, labels, blank=0) -> None: - super().__init__() - self.labels = labels - self.blank = blank - - def forward(self, emission: torch.Tensor) -> str: - """Given a sequence emission over labels, get the best path string - Args: - emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`. - - Returns: - str: The resulting transcript - """ - indices = torch.argmax(emission, dim=-1) # [num_seq,] - indices = torch.unique_consecutive(indices, dim=-1) - indices = [i for i in indices if i != self.blank] - return "".join([self.labels[i] for i in indices]) - - -def main() -> None: - """Main function.""" - # choose between cuda, cpu and mps devices - device = "cuda" if torch.cuda.is_available() else "cpu" - # device = "mps" - device = torch.device(device) - - torch.random.manual_seed(42) - - bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H - - print(f"Sample rate (model): {bundle.sample_rate}") - print(f"Labels (model): {bundle.get_labels()}") - - model = bundle.get_model().to(device) - - print(model.__class__) - - # only do all things for one single sample - dataset = MultilingualLibriSpeech( - "data", "mls_german_opus", split="train", download=True - ) - - print(dataset[0]) - - # load waveforms and sample rate from dataset - waveform, sample_rate = dataset[0]["waveform"], dataset[0]["sample_rate"] - - if sample_rate != bundle.sample_rate: - waveform = F.resample(waveform, sample_rate, int(bundle.sample_rate)) - - waveform.to(device) - - with torch.inference_mode(): - features, _ = model.extract_features(waveform) - - with torch.inference_mode(): - emission, _ = model(waveform) - - decoder = GreedyCTCDecoder(labels=bundle.get_labels()) - transcript = decoder(emission[0]) - - print(transcript) - - -if __name__ == "__main__": - main() -- cgit v1.2.3