aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--swr2_asr/loss_scores.py185
-rw-r--r--swr2_asr/train_2.py452
2 files changed, 637 insertions, 0 deletions
diff --git a/swr2_asr/loss_scores.py b/swr2_asr/loss_scores.py
new file mode 100644
index 0000000..977462d
--- /dev/null
+++ b/swr2_asr/loss_scores.py
@@ -0,0 +1,185 @@
+import numpy as np
+
+
+def avg_wer(wer_scores, combined_ref_len):
+ return float(sum(wer_scores)) / float(combined_ref_len)
+
+
+def _levenshtein_distance(ref, hyp):
+ """Levenshtein distance is a string metric for measuring the difference
+ between two sequences. Informally, the levenshtein disctance is defined as
+ the minimum number of single-character edits (substitutions, insertions or
+ deletions) required to change one word into the other. We can naturally
+ extend the edits to word level when calculate levenshtein disctance for
+ two sentences.
+ """
+ m = len(ref)
+ n = len(hyp)
+
+ # special case
+ if ref == hyp:
+ return 0
+ if m == 0:
+ return n
+ if n == 0:
+ return m
+
+ if m < n:
+ ref, hyp = hyp, ref
+ m, n = n, m
+
+ # use O(min(m, n)) space
+ distance = np.zeros((2, n + 1), dtype=np.int32)
+
+ # initialize distance matrix
+ for j in range(0, n + 1):
+ distance[0][j] = j
+
+ # calculate levenshtein distance
+ for i in range(1, m + 1):
+ prev_row_idx = (i - 1) % 2
+ cur_row_idx = i % 2
+ distance[cur_row_idx][0] = i
+ for j in range(1, n + 1):
+ if ref[i - 1] == hyp[j - 1]:
+ distance[cur_row_idx][j] = distance[prev_row_idx][j - 1]
+ else:
+ s_num = distance[prev_row_idx][j - 1] + 1
+ i_num = distance[cur_row_idx][j - 1] + 1
+ d_num = distance[prev_row_idx][j] + 1
+ distance[cur_row_idx][j] = min(s_num, i_num, d_num)
+
+ return distance[m % 2][n]
+
+
+def word_errors(
+ reference: str, hypothesis: str, ignore_case: bool = False, delimiter: str = " "
+):
+ """Compute the levenshtein distance between reference sequence and
+ hypothesis sequence in word-level.
+ :param reference: The reference sentence.
+ :type reference: basestring
+ :param hypothesis: The hypothesis sentence.
+ :type hypothesis: basestring
+ :param ignore_case: Whether case-sensitive or not.
+ :type ignore_case: bool
+ :param delimiter: Delimiter of input sentences.
+ :type delimiter: char
+ :return: Levenshtein distance and word number of reference sentence.
+ :rtype: list
+ """
+ if ignore_case:
+ reference = reference.lower()
+ hypothesis = hypothesis.lower()
+
+ ref_words = reference.split(delimiter)
+ hyp_words = hypothesis.split(delimiter)
+
+ edit_distance = _levenshtein_distance(ref_words, hyp_words)
+ return float(edit_distance), len(ref_words)
+
+
+def char_errors(
+ reference: str,
+ hypothesis: str,
+ ignore_case: bool = False,
+ remove_space: bool = False,
+):
+ """Compute the levenshtein distance between reference sequence and
+ hypothesis sequence in char-level.
+ :param reference: The reference sentence.
+ :type reference: basestring
+ :param hypothesis: The hypothesis sentence.
+ :type hypothesis: basestring
+ :param ignore_case: Whether case-sensitive or not.
+ :type ignore_case: bool
+ :param remove_space: Whether remove internal space characters
+ :type remove_space: bool
+ :return: Levenshtein distance and length of reference sentence.
+ :rtype: list
+ """
+ if ignore_case:
+ reference = reference.lower()
+ hypothesis = hypothesis.lower()
+
+ join_char = " "
+ if remove_space:
+ join_char = ""
+
+ reference = join_char.join(filter(None, reference.split(" ")))
+ hypothesis = join_char.join(filter(None, hypothesis.split(" ")))
+
+ edit_distance = _levenshtein_distance(reference, hypothesis)
+ return float(edit_distance), len(reference)
+
+
+def wer(reference: str, hypothesis: str, ignore_case=False, delimiter=" "):
+ """Calculate word error rate (WER). WER compares reference text and
+ hypothesis text in word-level. WER is defined as:
+ .. math::
+ WER = (Sw + Dw + Iw) / Nw
+ where
+ .. code-block:: text
+ Sw is the number of words subsituted,
+ Dw is the number of words deleted,
+ Iw is the number of words inserted,
+ Nw is the number of words in the reference
+ We can use levenshtein distance to calculate WER. Please draw an attention
+ that empty items will be removed when splitting sentences by delimiter.
+ :param reference: The reference sentence.
+ :type reference: basestring
+ :param hypothesis: The hypothesis sentence.
+ :type hypothesis: basestring
+ :param ignore_case: Whether case-sensitive or not.
+ :type ignore_case: bool
+ :param delimiter: Delimiter of input sentences.
+ :type delimiter: char
+ :return: Word error rate.
+ :rtype: float
+ :raises ValueError: If word number of reference is zero.
+ """
+ edit_distance, ref_len = word_errors(reference, hypothesis, ignore_case, delimiter)
+
+ if ref_len == 0:
+ raise ValueError("Reference's word number should be greater than 0.")
+
+ wer = float(edit_distance) / ref_len
+ return wer
+
+
+def cer(reference, hypothesis, ignore_case=False, remove_space=False):
+ """Calculate charactor error rate (CER). CER compares reference text and
+ hypothesis text in char-level. CER is defined as:
+ .. math::
+ CER = (Sc + Dc + Ic) / Nc
+ where
+ .. code-block:: text
+ Sc is the number of characters substituted,
+ Dc is the number of characters deleted,
+ Ic is the number of characters inserted
+ Nc is the number of characters in the reference
+ We can use levenshtein distance to calculate CER. Chinese input should be
+ encoded to unicode. Please draw an attention that the leading and tailing
+ space characters will be truncated and multiple consecutive space
+ characters in a sentence will be replaced by one space character.
+ :param reference: The reference sentence.
+ :type reference: basestring
+ :param hypothesis: The hypothesis sentence.
+ :type hypothesis: basestring
+ :param ignore_case: Whether case-sensitive or not.
+ :type ignore_case: bool
+ :param remove_space: Whether remove internal space characters
+ :type remove_space: bool
+ :return: Character error rate.
+ :rtype: float
+ :raises ValueError: If the reference length is zero.
+ """
+ edit_distance, ref_len = char_errors(
+ reference, hypothesis, ignore_case, remove_space
+ )
+
+ if ref_len == 0:
+ raise ValueError("Length of reference should be greater than 0.")
+
+ cer = float(edit_distance) / ref_len
+ return cer
diff --git a/swr2_asr/train_2.py b/swr2_asr/train_2.py
new file mode 100644
index 0000000..2e690e2
--- /dev/null
+++ b/swr2_asr/train_2.py
@@ -0,0 +1,452 @@
+"""Training script for the ASR model."""
+from AudioLoader.speech.mls import MultilingualLibriSpeech
+import click
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+import torchaudio
+import torchaudio.functional as AF
+from .loss_scores import cer, wer
+
+
+class TextTransform:
+ """Maps characters to integers and vice versa"""
+
+ def __init__(self):
+ char_map_str = """
+ ' 0
+ <SPACE> 1
+ a 2
+ b 3
+ c 4
+ d 5
+ e 6
+ f 7
+ g 8
+ h 9
+ i 10
+ j 11
+ k 12
+ l 13
+ m 14
+ n 15
+ o 16
+ p 17
+ q 18
+ r 19
+ s 20
+ t 21
+ u 22
+ v 23
+ w 24
+ x 25
+ y 26
+ z 27
+ ä 28
+ ö 29
+ ü 30
+ ß 31
+ """
+ self.char_map = {}
+ self.index_map = {}
+ for line in char_map_str.strip().split("\n"):
+ ch, index = line.split()
+ self.char_map[ch] = int(index)
+ self.index_map[int(index)] = ch
+ self.index_map[1] = " "
+
+ def text_to_int(self, text):
+ """Use a character map and convert text to an integer sequence"""
+ int_sequence = []
+ for c in text:
+ if c == " ":
+ ch = self.char_map["<SPACE>"]
+ else:
+ ch = self.char_map[c]
+ int_sequence.append(ch)
+ return int_sequence
+
+ def int_to_text(self, labels):
+ """Use a character map and convert integer labels to an text sequence"""
+ string = []
+ for i in labels:
+ string.append(self.index_map[i])
+ return "".join(string).replace("<SPACE>", " ")
+
+
+train_audio_transforms = nn.Sequential(
+ torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
+ torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
+ torchaudio.transforms.TimeMasking(time_mask_param=100),
+)
+
+valid_audio_transforms = torchaudio.transforms.MelSpectrogram()
+
+text_transform = TextTransform()
+
+
+def data_processing(data, data_type="train"):
+ """Return the spectrograms, labels, and their lengths."""
+ spectrograms = []
+ labels = []
+ input_lengths = []
+ label_lengths = []
+ for waveform, _, utterance, _, _, _ in data:
+ if data_type == "train":
+ spec = train_audio_transforms(waveform).squeeze(0).transpose(0, 1)
+ elif data_type == "valid":
+ spec = valid_audio_transforms(waveform).squeeze(0).transpose(0, 1)
+ else:
+ raise Exception("data_type should be train or valid")
+ spectrograms.append(spec)
+ label = torch.Tensor(text_transform.text_to_int(utterance.lower()))
+ labels.append(label)
+ input_lengths.append(spec.shape[0] // 2)
+ label_lengths.append(len(label))
+
+ spectrograms = (
+ nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
+ .unsqueeze(1)
+ .transpose(2, 3)
+ )
+ labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)
+
+ return spectrograms, labels, input_lengths, label_lengths
+
+
+def GreedyDecoder(
+ output, labels, label_lengths, blank_label=28, collapse_repeated=True
+):
+ arg_maxes = torch.argmax(output, dim=2)
+ decodes = []
+ targets = []
+ for i, args in enumerate(arg_maxes):
+ decode = []
+ targets.append(
+ text_transform.int_to_text(labels[i][: label_lengths[i]].tolist())
+ )
+ for j, index in enumerate(args):
+ if index != blank_label:
+ if collapse_repeated and j != 0 and index == args[j - 1]:
+ continue
+ decode.append(index.item())
+ decodes.append(text_transform.int_to_text(decode))
+ return decodes, targets
+
+
+class CNNLayerNorm(nn.Module):
+ """Layer normalization built for cnns input"""
+
+ def __init__(self, n_feats: int):
+ super(CNNLayerNorm, self).__init__()
+ self.layer_norm = nn.LayerNorm(n_feats)
+
+ def forward(self, x):
+ """x (batch, channel, feature, time)"""
+ x = x.transpose(2, 3).contiguous() # (batch, channel, time, feature)
+ x = self.layer_norm(x)
+ return x.transpose(2, 3).contiguous() # (batch, channel, feature, time)
+
+
+class ResidualCNN(nn.Module):
+ """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf"""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel: int,
+ stride: int,
+ dropout: float,
+ n_feats: int,
+ ):
+ super(ResidualCNN, self).__init__()
+
+ self.cnn1 = nn.Conv2d(
+ in_channels, out_channels, kernel, stride, padding=kernel // 2
+ )
+ self.cnn2 = nn.Conv2d(
+ out_channels,
+ out_channels,
+ kernel,
+ stride,
+ padding=kernel // 2,
+ )
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.layer_norm1 = CNNLayerNorm(n_feats)
+ self.layer_norm2 = CNNLayerNorm(n_feats)
+
+ def forward(self, x):
+ residual = x # (batch, channel, feature, time)
+ x = self.layer_norm1(x)
+ x = F.gelu(x)
+ x = self.dropout1(x)
+ x = self.cnn1(x)
+ x = self.layer_norm2(x)
+ x = F.gelu(x)
+ x = self.dropout2(x)
+ x = self.cnn2(x)
+ x += residual
+ return x # (batch, channel, feature, time)
+
+
+class BidirectionalGRU(nn.Module):
+ """BIdirectional GRU with Layer Normalization and Dropout"""
+
+ def __init__(
+ self,
+ rnn_dim: int,
+ hidden_size: int,
+ dropout: float,
+ batch_first: bool,
+ ):
+ super(BidirectionalGRU, self).__init__()
+
+ self.BiGRU = nn.GRU(
+ input_size=rnn_dim,
+ hidden_size=hidden_size,
+ num_layers=1,
+ batch_first=batch_first,
+ bidirectional=True,
+ )
+ self.layer_norm = nn.LayerNorm(rnn_dim)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, x):
+ x = self.layer_norm(x)
+ x = F.gelu(x)
+ x = self.dropout(x)
+ x, _ = self.BiGRU(x)
+ return x
+
+
+class SpeechRecognitionModel(nn.Module):
+ def __init__(
+ self,
+ n_cnn_layers: int,
+ n_rnn_layers: int,
+ rnn_dim: int,
+ n_class: int,
+ n_feats: int,
+ stride: int = 2,
+ dropout: float = 0.1,
+ ):
+ super(SpeechRecognitionModel, self).__init__()
+ n_feats //= 2
+ self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2)
+ # n residual cnn layers with filter size of 32
+ self.rescnn_layers = nn.Sequential(
+ *[
+ ResidualCNN(
+ 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats
+ )
+ for _ in range(n_cnn_layers)
+ ]
+ )
+ self.fully_connected = nn.Linear(n_feats * 32, rnn_dim)
+ self.birnn_layers = nn.Sequential(
+ *[
+ BidirectionalGRU(
+ rnn_dim=rnn_dim if i == 0 else rnn_dim * 2,
+ hidden_size=rnn_dim,
+ dropout=dropout,
+ batch_first=i == 0,
+ )
+ for i in range(n_rnn_layers)
+ ]
+ )
+ self.classifier = nn.Sequential(
+ nn.Linear(rnn_dim * 2, rnn_dim), # birnn returns rnn_dim*2
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(rnn_dim, n_class),
+ )
+
+ def forward(self, x):
+ x = self.cnn(x)
+ x = self.rescnn_layers(x)
+ sizes = x.size()
+ x = x.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # (batch, feature, time)
+ x = x.transpose(1, 2) # (batch, time, feature)
+ x = self.fully_connected(x)
+ x = self.birnn_layers(x)
+ x = self.classifier(x)
+ return x
+
+
+class IterMeter(object):
+ """keeps track of total iterations"""
+
+ def __init__(self):
+ self.val = 0
+
+ def step(self):
+ self.val += 1
+
+ def get(self):
+ return self.val
+
+
+def train(
+ model,
+ device,
+ train_loader,
+ criterion,
+ optimizer,
+ scheduler,
+ epoch,
+ iter_meter,
+):
+ model.train()
+ data_len = len(train_loader.dataset)
+ for batch_idx, _data in enumerate(train_loader):
+ spectrograms, labels, input_lengths, label_lengths = _data
+ spectrograms, labels = spectrograms.to(device), labels.to(device)
+
+ optimizer.zero_grad()
+
+ output = model(spectrograms) # (batch, time, n_class)
+ output = F.log_softmax(output, dim=2)
+ output = output.transpose(0, 1) # (time, batch, n_class)
+
+ loss = criterion(output, labels, input_lengths, label_lengths)
+ loss.backward()
+
+ optimizer.step()
+ scheduler.step()
+ iter_meter.step()
+ if batch_idx % 100 == 0 or batch_idx == data_len:
+ print(
+ "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}".format(
+ epoch,
+ batch_idx * len(spectrograms),
+ data_len,
+ 100.0 * batch_idx / len(train_loader),
+ loss.item(),
+ )
+ )
+
+
+def test(model, device, test_loader, criterion, epoch, iter_meter):
+ print("\nevaluating...")
+ model.eval()
+ test_loss = 0
+ test_cer, test_wer = [], []
+ with torch.no_grad():
+ for i, _data in enumerate(test_loader):
+ spectrograms, labels, input_lengths, label_lengths = _data
+ spectrograms, labels = spectrograms.to(device), labels.to(device)
+
+ output = model(spectrograms) # (batch, time, n_class)
+ output = F.log_softmax(output, dim=2)
+ output = output.transpose(0, 1) # (time, batch, n_class)
+
+ loss = criterion(output, labels, input_lengths, label_lengths)
+ test_loss += loss.item() / len(test_loader)
+
+ decoded_preds, decoded_targets = GreedyDecoder(
+ output.transpose(0, 1), labels, label_lengths
+ )
+ for j in range(len(decoded_preds)):
+ test_cer.append(cer(decoded_targets[j], decoded_preds[j]))
+ test_wer.append(wer(decoded_targets[j], decoded_preds[j]))
+
+ avg_cer = sum(test_cer) / len(test_cer)
+ avg_wer = sum(test_wer) / len(test_wer)
+
+ print(
+ "Test set: Average loss: {:.4f}, Average CER: {:4f} Average WER: {:.4f}\n".format(
+ test_loss, avg_cer, avg_wer
+ )
+ )
+
+
+def run(lr: float, batch_size: int, epochs: int) -> None:
+ hparams = {
+ "n_cnn_layers": 3,
+ "n_rnn_layers": 5,
+ "rnn_dim": 512,
+ "n_class": 33,
+ "n_feats": 128,
+ "stride": 2,
+ "dropout": 0.1,
+ "learning_rate": lr,
+ "batch_size": batch_size,
+ "epochs": epochs,
+ }
+
+ use_cuda = torch.cuda.is_available()
+ torch.manual_seed(42)
+ device = torch.device("cuda" if use_cuda else "cpu")
+ device = torch.device("mps")
+
+ train_dataset = MultilingualLibriSpeech(
+ "data", "mls_german_opus", split="train", download=False
+ )
+ test_dataset = MultilingualLibriSpeech(
+ "data", "mls_german_opus", split="test", download=False
+ )
+
+ kwargs = {"num_workers": 1, "pin_memory": True} if use_cuda else {}
+
+ train_loader = DataLoader(
+ train_dataset,
+ batch_size=hparams["batch_size"],
+ shuffle=True,
+ collate_fn=lambda x: data_processing(x, "train"),
+ **kwargs,
+ )
+
+ test_loader = DataLoader(
+ train_dataset,
+ batch_size=hparams["batch_size"],
+ shuffle=True,
+ collate_fn=lambda x: data_processing(x, "train"),
+ **kwargs,
+ )
+
+ model = SpeechRecognitionModel(
+ hparams["n_cnn_layers"],
+ hparams["n_rnn_layers"],
+ hparams["rnn_dim"],
+ hparams["n_class"],
+ hparams["n_feats"],
+ hparams["stride"],
+ hparams["dropout"],
+ ).to(device)
+
+ print(
+ "Num Model Parameters", sum([param.nelement() for param in model.parameters()])
+ )
+
+ optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"])
+ criterion = nn.CTCLoss(blank=28).to(device)
+
+ scheduler = optim.lr_scheduler.OneCycleLR(
+ optimizer,
+ max_lr=hparams["learning_rate"],
+ steps_per_epoch=int(len(train_loader)),
+ epochs=hparams["epochs"],
+ anneal_strategy="linear",
+ )
+
+ iter_meter = IterMeter()
+ for epoch in range(1, epochs + 1):
+ train(
+ model,
+ device,
+ train_loader,
+ criterion,
+ optimizer,
+ scheduler,
+ epoch,
+ iter_meter,
+ )
+ test(model, device, test_loader, criterion, epoch, iter_meter)
+
+
+if __name__ == "__main__":
+ run(lr=5e-4, batch_size=20, epochs=10)