diff options
author | Pherkel | 2023-08-19 11:06:33 +0200 |
---|---|---|
committer | Pherkel | 2023-08-19 11:06:33 +0200 |
commit | 3880339761062a588f467c4ea891338838c533e9 (patch) | |
tree | 5638b0a224d9bb1d41a8b392efec30bbbd1b175b /swr2_asr | |
parent | 7c375276b44dc933b3311c87601e0ac6945f5be8 (diff) | |
parent | 9b4592caac90a41eb6ce18558588ef504c49f58e (diff) |
Merge branch 'wave2vec2'
Diffstat (limited to 'swr2_asr')
-rw-r--r-- | swr2_asr/inference_test.py | 74 | ||||
-rw-r--r-- | swr2_asr/loss_scores.py | 185 | ||||
-rw-r--r-- | swr2_asr/train.py | 482 |
3 files changed, 734 insertions, 7 deletions
diff --git a/swr2_asr/inference_test.py b/swr2_asr/inference_test.py new file mode 100644 index 0000000..a6b0010 --- /dev/null +++ b/swr2_asr/inference_test.py @@ -0,0 +1,74 @@ +"""Training script for the ASR model.""" +from AudioLoader.speech.mls import MultilingualLibriSpeech +import torch +import torchaudio +import torchaudio.functional as F + + +class GreedyCTCDecoder(torch.nn.Module): + def __init__(self, labels, blank=0) -> None: + super().__init__() + self.labels = labels + self.blank = blank + + def forward(self, emission: torch.Tensor) -> str: + """Given a sequence emission over labels, get the best path string + Args: + emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`. + + Returns: + str: The resulting transcript + """ + indices = torch.argmax(emission, dim=-1) # [num_seq,] + indices = torch.unique_consecutive(indices, dim=-1) + indices = [i for i in indices if i != self.blank] + return "".join([self.labels[i] for i in indices]) + + +def main() -> None: + """Main function.""" + # choose between cuda, cpu and mps devices + device = "cuda" if torch.cuda.is_available() else "cpu" + # device = "mps" + device = torch.device(device) + + torch.random.manual_seed(42) + + bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H + + print(f"Sample rate (model): {bundle.sample_rate}") + print(f"Labels (model): {bundle.get_labels()}") + + model = bundle.get_model().to(device) + + print(model.__class__) + + # only do all things for one single sample + dataset = MultilingualLibriSpeech( + "data", "mls_german_opus", split="train", download=True + ) + + print(dataset[0]) + + # load waveforms and sample rate from dataset + waveform, sample_rate = dataset[0]["waveform"], dataset[0]["sample_rate"] + + if sample_rate != bundle.sample_rate: + waveform = F.resample(waveform, sample_rate, int(bundle.sample_rate)) + + waveform.to(device) + + with torch.inference_mode(): + features, _ = model.extract_features(waveform) + + with torch.inference_mode(): + emission, _ = model(waveform) + + decoder = GreedyCTCDecoder(labels=bundle.get_labels()) + transcript = decoder(emission[0]) + + print(transcript) + + +if __name__ == "__main__": + main() diff --git a/swr2_asr/loss_scores.py b/swr2_asr/loss_scores.py new file mode 100644 index 0000000..977462d --- /dev/null +++ b/swr2_asr/loss_scores.py @@ -0,0 +1,185 @@ +import numpy as np + + +def avg_wer(wer_scores, combined_ref_len): + return float(sum(wer_scores)) / float(combined_ref_len) + + +def _levenshtein_distance(ref, hyp): + """Levenshtein distance is a string metric for measuring the difference + between two sequences. Informally, the levenshtein disctance is defined as + the minimum number of single-character edits (substitutions, insertions or + deletions) required to change one word into the other. We can naturally + extend the edits to word level when calculate levenshtein disctance for + two sentences. + """ + m = len(ref) + n = len(hyp) + + # special case + if ref == hyp: + return 0 + if m == 0: + return n + if n == 0: + return m + + if m < n: + ref, hyp = hyp, ref + m, n = n, m + + # use O(min(m, n)) space + distance = np.zeros((2, n + 1), dtype=np.int32) + + # initialize distance matrix + for j in range(0, n + 1): + distance[0][j] = j + + # calculate levenshtein distance + for i in range(1, m + 1): + prev_row_idx = (i - 1) % 2 + cur_row_idx = i % 2 + distance[cur_row_idx][0] = i + for j in range(1, n + 1): + if ref[i - 1] == hyp[j - 1]: + distance[cur_row_idx][j] = distance[prev_row_idx][j - 1] + else: + s_num = distance[prev_row_idx][j - 1] + 1 + i_num = distance[cur_row_idx][j - 1] + 1 + d_num = distance[prev_row_idx][j] + 1 + distance[cur_row_idx][j] = min(s_num, i_num, d_num) + + return distance[m % 2][n] + + +def word_errors( + reference: str, hypothesis: str, ignore_case: bool = False, delimiter: str = " " +): + """Compute the levenshtein distance between reference sequence and + hypothesis sequence in word-level. + :param reference: The reference sentence. + :type reference: basestring + :param hypothesis: The hypothesis sentence. + :type hypothesis: basestring + :param ignore_case: Whether case-sensitive or not. + :type ignore_case: bool + :param delimiter: Delimiter of input sentences. + :type delimiter: char + :return: Levenshtein distance and word number of reference sentence. + :rtype: list + """ + if ignore_case: + reference = reference.lower() + hypothesis = hypothesis.lower() + + ref_words = reference.split(delimiter) + hyp_words = hypothesis.split(delimiter) + + edit_distance = _levenshtein_distance(ref_words, hyp_words) + return float(edit_distance), len(ref_words) + + +def char_errors( + reference: str, + hypothesis: str, + ignore_case: bool = False, + remove_space: bool = False, +): + """Compute the levenshtein distance between reference sequence and + hypothesis sequence in char-level. + :param reference: The reference sentence. + :type reference: basestring + :param hypothesis: The hypothesis sentence. + :type hypothesis: basestring + :param ignore_case: Whether case-sensitive or not. + :type ignore_case: bool + :param remove_space: Whether remove internal space characters + :type remove_space: bool + :return: Levenshtein distance and length of reference sentence. + :rtype: list + """ + if ignore_case: + reference = reference.lower() + hypothesis = hypothesis.lower() + + join_char = " " + if remove_space: + join_char = "" + + reference = join_char.join(filter(None, reference.split(" "))) + hypothesis = join_char.join(filter(None, hypothesis.split(" "))) + + edit_distance = _levenshtein_distance(reference, hypothesis) + return float(edit_distance), len(reference) + + +def wer(reference: str, hypothesis: str, ignore_case=False, delimiter=" "): + """Calculate word error rate (WER). WER compares reference text and + hypothesis text in word-level. WER is defined as: + .. math:: + WER = (Sw + Dw + Iw) / Nw + where + .. code-block:: text + Sw is the number of words subsituted, + Dw is the number of words deleted, + Iw is the number of words inserted, + Nw is the number of words in the reference + We can use levenshtein distance to calculate WER. Please draw an attention + that empty items will be removed when splitting sentences by delimiter. + :param reference: The reference sentence. + :type reference: basestring + :param hypothesis: The hypothesis sentence. + :type hypothesis: basestring + :param ignore_case: Whether case-sensitive or not. + :type ignore_case: bool + :param delimiter: Delimiter of input sentences. + :type delimiter: char + :return: Word error rate. + :rtype: float + :raises ValueError: If word number of reference is zero. + """ + edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case, delimiter) + + if ref_len == 0: + raise ValueError("Reference's word number should be greater than 0.") + + wer = float(edit_distance) / ref_len + return wer + + +def cer(reference, hypothesis, ignore_case=False, remove_space=False): + """Calculate charactor error rate (CER). CER compares reference text and + hypothesis text in char-level. CER is defined as: + .. math:: + CER = (Sc + Dc + Ic) / Nc + where + .. code-block:: text + Sc is the number of characters substituted, + Dc is the number of characters deleted, + Ic is the number of characters inserted + Nc is the number of characters in the reference + We can use levenshtein distance to calculate CER. Chinese input should be + encoded to unicode. Please draw an attention that the leading and tailing + space characters will be truncated and multiple consecutive space + characters in a sentence will be replaced by one space character. + :param reference: The reference sentence. + :type reference: basestring + :param hypothesis: The hypothesis sentence. + :type hypothesis: basestring + :param ignore_case: Whether case-sensitive or not. + :type ignore_case: bool + :param remove_space: Whether remove internal space characters + :type remove_space: bool + :return: Character error rate. + :rtype: float + :raises ValueError: If the reference length is zero. + """ + edit_distance, ref_len = char_errors( + reference, hypothesis, ignore_case, remove_space + ) + + if ref_len == 0: + raise ValueError("Length of reference should be greater than 0.") + + cer = float(edit_distance) / ref_len + return cer diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 7a9ffec..29f9372 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -1,16 +1,484 @@ """Training script for the ASR model.""" -import os from AudioLoader.speech.mls import MultilingualLibriSpeech +import click +import torch +import torch.nn as nn +import torch.optim as optim +import torch.nn.functional as F +from torch.utils.data import DataLoader +import torchaudio +from .loss_scores import cer, wer -def main() -> None: - """Main function.""" - dataset = MultilingualLibriSpeech( - "data", "mls_polish_opus", split="train", download=(not os.path.isdir("data")) +class TextTransform: + """Maps characters to integers and vice versa""" + + def __init__(self): + char_map_str = """ + ' 0 + <SPACE> 1 + a 2 + b 3 + c 4 + d 5 + e 6 + f 7 + g 8 + h 9 + i 10 + j 11 + k 12 + l 13 + m 14 + n 15 + o 16 + p 17 + q 18 + r 19 + s 20 + t 21 + u 22 + v 23 + w 24 + x 25 + y 26 + z 27 + ä 28 + ö 29 + ü 30 + ß 31 + - 32 + é 33 + è 34 + à 35 + ù 36 + ç 37 + â 38 + ê 39 + î 40 + ô 41 + û 42 + ë 43 + ï 44 + ü 45 + """ + self.char_map = {} + self.index_map = {} + for line in char_map_str.strip().split("\n"): + char, index = line.split() + self.char_map[char] = int(index) + self.index_map[int(index)] = char + self.index_map[1] = " " + + def text_to_int(self, text): + """Use a character map and convert text to an integer sequence""" + int_sequence = [] + for char in text: + if char == " ": + mapped_char = self.char_map["<SPACE>"] + else: + mapped_char = self.char_map[char] + int_sequence.append(mapped_char) + return int_sequence + + def int_to_text(self, labels): + """Use a character map and convert integer labels to an text sequence""" + string = [] + for i in labels: + string.append(self.index_map[i]) + return "".join(string).replace("<SPACE>", " ") + + +train_audio_transforms = nn.Sequential( + torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), + torchaudio.transforms.FrequencyMasking(freq_mask_param=30), + torchaudio.transforms.TimeMasking(time_mask_param=100), +) + +valid_audio_transforms = torchaudio.transforms.MelSpectrogram() + +text_transform = TextTransform() + + +def data_processing(data, data_type="train"): + """Return the spectrograms, labels, and their lengths.""" + spectrograms = [] + labels = [] + input_lengths = [] + label_lengths = [] + for sample in data: + if data_type == "train": + spec = train_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1) + elif data_type == "valid": + spec = valid_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1) + else: + raise ValueError("data_type should be train or valid") + spectrograms.append(spec) + label = torch.Tensor(text_transform.text_to_int(sample["utterance"].lower())) + labels.append(label) + input_lengths.append(spec.shape[0] // 2) + label_lengths.append(len(label)) + + spectrograms = ( + nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) + .unsqueeze(1) + .transpose(2, 3) ) + labels = nn.utils.rnn.pad_sequence(labels, batch_first=True) + + return spectrograms, labels, input_lengths, label_lengths + + +def greedy_decoder( + output, labels, label_lengths, blank_label=28, collapse_repeated=True +): + """Greedily decode a sequence.""" + arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member + decodes = [] + targets = [] + for i, args in enumerate(arg_maxes): + decode = [] + targets.append( + text_transform.int_to_text(labels[i][: label_lengths[i]].tolist()) + ) + for j, index in enumerate(args): + if index != blank_label: + if collapse_repeated and j != 0 and index == args[j - 1]: + continue + decode.append(index.item()) + decodes.append(text_transform.int_to_text(decode)) + return decodes, targets + + +class CNNLayerNorm(nn.Module): + """Layer normalization built for cnns input""" + + def __init__(self, n_feats: int): + super(CNNLayerNorm, self).__init__() + self.layer_norm = nn.LayerNorm(n_feats) + + def forward(self, data): + """x (batch, channel, feature, time)""" + data = data.transpose(2, 3).contiguous() # (batch, channel, time, feature) + data = self.layer_norm(data) + return data.transpose(2, 3).contiguous() # (batch, channel, feature, time) + + +class ResidualCNN(nn.Module): + """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel: int, + stride: int, + dropout: float, + n_feats: int, + ): + super(ResidualCNN, self).__init__() + + self.cnn1 = nn.Conv2d( + in_channels, out_channels, kernel, stride, padding=kernel // 2 + ) + self.cnn2 = nn.Conv2d( + out_channels, + out_channels, + kernel, + stride, + padding=kernel // 2, + ) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.layer_norm1 = CNNLayerNorm(n_feats) + self.layer_norm2 = CNNLayerNorm(n_feats) + + def forward(self, data): + """x (batch, channel, feature, time)""" + residual = data # (batch, channel, feature, time) + data = self.layer_norm1(data) + data = F.gelu(data) + data = self.dropout1(data) + data = self.cnn1(data) + data = self.layer_norm2(data) + data = F.gelu(data) + data = self.dropout2(data) + data = self.cnn2(data) + data += residual + return data # (batch, channel, feature, time) + + +class BidirectionalGRU(nn.Module): + """BIdirectional GRU with Layer Normalization and Dropout""" + + def __init__( + self, + rnn_dim: int, + hidden_size: int, + dropout: float, + batch_first: bool, + ): + super(BidirectionalGRU, self).__init__() + + self.bi_gru = nn.GRU( + input_size=rnn_dim, + hidden_size=hidden_size, + num_layers=1, + batch_first=batch_first, + bidirectional=True, + ) + self.layer_norm = nn.LayerNorm(rnn_dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, data): + """data (batch, time, feature)""" + data = self.layer_norm(data) + data = F.gelu(data) + data = self.dropout(data) + data, _ = self.bi_gru(data) + return data + + +class SpeechRecognitionModel(nn.Module): + """Speech Recognition Model Inspired by DeepSpeech 2""" + + def __init__( + self, + n_cnn_layers: int, + n_rnn_layers: int, + rnn_dim: int, + n_class: int, + n_feats: int, + stride: int = 2, + dropout: float = 0.1, + ): + super(SpeechRecognitionModel, self).__init__() + n_feats //= 2 + self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2) + # n residual cnn layers with filter size of 32 + self.rescnn_layers = nn.Sequential( + *[ + ResidualCNN( + 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats + ) + for _ in range(n_cnn_layers) + ] + ) + self.fully_connected = nn.Linear(n_feats * 32, rnn_dim) + self.birnn_layers = nn.Sequential( + *[ + BidirectionalGRU( + rnn_dim=rnn_dim if i == 0 else rnn_dim * 2, + hidden_size=rnn_dim, + dropout=dropout, + batch_first=i == 0, + ) + for i in range(n_rnn_layers) + ] + ) + self.classifier = nn.Sequential( + nn.Linear(rnn_dim * 2, rnn_dim), # birnn returns rnn_dim*2 + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(rnn_dim, n_class), + ) + + def forward(self, data): + """data (batch, channel, feature, time)""" + data = self.cnn(data) + data = self.rescnn_layers(data) + sizes = data.size() + data = data.view( + sizes[0], sizes[1] * sizes[2], sizes[3] + ) # (batch, feature, time) + data = data.transpose(1, 2) # (batch, time, feature) + data = self.fully_connected(data) + data = self.birnn_layers(data) + data = self.classifier(data) + return data + + +class IterMeter(object): + """keeps track of total iterations""" + + def __init__(self): + self.val = 0 + + def step(self): + """step""" + self.val += 1 + + def get(self): + """get""" + return self.val + + +def train( + model, + device, + train_loader, + criterion, + optimizer, + scheduler, + epoch, + iter_meter, +): + """Train""" + model.train() + data_len = len(train_loader.dataset) + for batch_idx, _data in enumerate(train_loader): + spectrograms, labels, input_lengths, label_lengths = _data + spectrograms, labels = spectrograms.to(device), labels.to(device) + + optimizer.zero_grad() + + output = model(spectrograms) # (batch, time, n_class) + output = F.log_softmax(output, dim=2) + output = output.transpose(0, 1) # (time, batch, n_class) + + loss = criterion(output, labels, input_lengths, label_lengths) + loss.backward() + + optimizer.step() + scheduler.step() + iter_meter.step() + if batch_idx % 100 == 0 or batch_idx == data_len: + print( + f"Train Epoch: \ + {epoch} \ + [{batch_idx * len(spectrograms)}/{data_len} \ + ({100.0 * batch_idx / len(train_loader)}%)]\t \ + Loss: {loss.item()}" + ) + + +def test(model, device, test_loader, criterion): + """Test""" + print("\nevaluating...") + model.eval() + test_loss = 0 + test_cer, test_wer = [], [] + with torch.no_grad(): + for _data in test_loader: + spectrograms, labels, input_lengths, label_lengths = _data + spectrograms, labels = spectrograms.to(device), labels.to(device) + + output = model(spectrograms) # (batch, time, n_class) + output = F.log_softmax(output, dim=2) + output = output.transpose(0, 1) # (time, batch, n_class) + + loss = criterion(output, labels, input_lengths, label_lengths) + test_loss += loss.item() / len(test_loader) + + decoded_preds, decoded_targets = greedy_decoder( + output.transpose(0, 1), labels, label_lengths + ) + for j, pred in enumerate(decoded_preds): + test_cer.append(cer(decoded_targets[j], pred)) + test_wer.append(wer(decoded_targets[j], pred)) + + avg_cer = sum(test_cer) / len(test_cer) + avg_wer = sum(test_wer) / len(test_wer) + + print( + f"Test set: Average loss:\ + {test_loss}, Average CER: {avg_cer} Average WER: {avg_wer}\n" + ) + + +def run(learning_rate: float = 5e-4, batch_size: int = 8, epochs: int = 3) -> None: + """Runs the training script.""" + hparams = { + "n_cnn_layers": 3, + "n_rnn_layers": 5, + "rnn_dim": 512, + "n_class": 46, + "n_feats": 128, + "stride": 2, + "dropout": 0.1, + "learning_rate": learning_rate, + "batch_size": batch_size, + "epochs": epochs, + } + + use_cuda = torch.cuda.is_available() + torch.manual_seed(42) + device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member + # device = torch.device("mps") + + train_dataset = MultilingualLibriSpeech( + "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="dev", download=False + ) + test_dataset = MultilingualLibriSpeech( + "/Volumes/pherkel/SWR2-ASR/", "mls_german_opus", split="test", download=False + ) + + kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {} + + train_loader = DataLoader( + train_dataset, + batch_size=hparams["batch_size"], + shuffle=True, + collate_fn=lambda x: data_processing(x, "train"), + **kwargs, + ) + + test_loader = DataLoader( + test_dataset, + batch_size=hparams["batch_size"], + shuffle=True, + collate_fn=lambda x: data_processing(x, "train"), + **kwargs, + ) + + model = SpeechRecognitionModel( + hparams["n_cnn_layers"], + hparams["n_rnn_layers"], + hparams["rnn_dim"], + hparams["n_class"], + hparams["n_feats"], + hparams["stride"], + hparams["dropout"], + ).to(device) + + print( + "Num Model Parameters", sum([param.nelement() for param in model.parameters()]) + ) + + optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) + criterion = nn.CTCLoss(blank=28).to(device) + + scheduler = optim.lr_scheduler.OneCycleLR( + optimizer, + max_lr=hparams["learning_rate"], + steps_per_epoch=int(len(train_loader)), + epochs=hparams["epochs"], + anneal_strategy="linear", + ) + + iter_meter = IterMeter() + for epoch in range(1, epochs + 1): + train( + model, + device, + train_loader, + criterion, + optimizer, + scheduler, + epoch, + iter_meter, + ) + test(model=model, device=device, test_loader=test_loader, criterion=criterion) + - print(dataset[1]) +@click.command() +@click.option("--learning-rate", default=1e-3, help="Learning rate") +@click.option("--batch_size", default=1, help="Batch size") +@click.option("--epochs", default=1, help="Number of epochs") +def run_cli(learning_rate: float, batch_size: int, epochs: int) -> None: + """Runs the training script.""" + run(learning_rate=learning_rate, batch_size=batch_size, epochs=epochs) if __name__ == "__main__": - main() + run(learning_rate=5e-4, batch_size=16, epochs=1) |