aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr
diff options
context:
space:
mode:
authorPherkel2023-08-20 15:50:36 +0200
committerGitHub2023-08-20 15:50:36 +0200
commit14ceeb5ad36beea2f05214aa26260cdd1d86590b (patch)
tree891cedeb665913af1a078a3778afffbccd37bae7 /swr2_asr
parentf88c9afc6e9efcb6f79a959779114095c23e0cef (diff)
parent899a5e1cd7ca9b0601ed64ca3157e2052dd3e669 (diff)
Merge pull request #22 from Algo-Boys/tokenizer
Tokenizer
Diffstat (limited to 'swr2_asr')
-rw-r--r--swr2_asr/loss_scores.py36
-rw-r--r--swr2_asr/tokenizer.py329
-rw-r--r--swr2_asr/train.py119
3 files changed, 372 insertions, 112 deletions
diff --git a/swr2_asr/loss_scores.py b/swr2_asr/loss_scores.py
index 977462d..c49cc15 100644
--- a/swr2_asr/loss_scores.py
+++ b/swr2_asr/loss_scores.py
@@ -1,7 +1,9 @@
+"""Methods for determining the loss and scores of the model."""
import numpy as np
def avg_wer(wer_scores, combined_ref_len):
+ """Calculate the average word error rate (WER) of the model."""
return float(sum(wer_scores)) / float(combined_ref_len)
@@ -13,34 +15,34 @@ def _levenshtein_distance(ref, hyp):
extend the edits to word level when calculate levenshtein disctance for
two sentences.
"""
- m = len(ref)
- n = len(hyp)
+ len_ref = len(ref)
+ len_hyp = len(hyp)
# special case
if ref == hyp:
return 0
- if m == 0:
- return n
- if n == 0:
- return m
+ if len_ref == 0:
+ return len_hyp
+ if len_hyp == 0:
+ return len_ref
- if m < n:
+ if len_ref < len_hyp:
ref, hyp = hyp, ref
- m, n = n, m
+ len_ref, len_hyp = len_hyp, len_ref
# use O(min(m, n)) space
- distance = np.zeros((2, n + 1), dtype=np.int32)
+ distance = np.zeros((2, len_hyp + 1), dtype=np.int32)
# initialize distance matrix
- for j in range(0, n + 1):
+ for j in range(0, len_hyp + 1):
distance[0][j] = j
# calculate levenshtein distance
- for i in range(1, m + 1):
+ for i in range(1, len_ref + 1):
prev_row_idx = (i - 1) % 2
cur_row_idx = i % 2
distance[cur_row_idx][0] = i
- for j in range(1, n + 1):
+ for j in range(1, len_hyp + 1):
if ref[i - 1] == hyp[j - 1]:
distance[cur_row_idx][j] = distance[prev_row_idx][j - 1]
else:
@@ -49,7 +51,7 @@ def _levenshtein_distance(ref, hyp):
d_num = distance[prev_row_idx][j] + 1
distance[cur_row_idx][j] = min(s_num, i_num, d_num)
- return distance[m % 2][n]
+ return distance[len_ref % 2][len_hyp]
def word_errors(
@@ -143,8 +145,8 @@ def wer(reference: str, hypothesis: str, ignore_case=False, delimiter=" "):
if ref_len == 0:
raise ValueError("Reference's word number should be greater than 0.")
- wer = float(edit_distance) / ref_len
- return wer
+ word_error_rate = float(edit_distance) / ref_len
+ return word_error_rate
def cer(reference, hypothesis, ignore_case=False, remove_space=False):
@@ -181,5 +183,5 @@ def cer(reference, hypothesis, ignore_case=False, remove_space=False):
if ref_len == 0:
raise ValueError("Length of reference should be greater than 0.")
- cer = float(edit_distance) / ref_len
- return cer
+ char_error_rate = float(edit_distance) / ref_len
+ return char_error_rate
diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py
new file mode 100644
index 0000000..a665159
--- /dev/null
+++ b/swr2_asr/tokenizer.py
@@ -0,0 +1,329 @@
+"""Tokenizer for use with Multilingual Librispeech"""
+from dataclasses import dataclass
+import json
+import os
+import click
+from tqdm import tqdm
+
+from AudioLoader.speech import MultilingualLibriSpeech
+
+from tokenizers import Tokenizer, normalizers
+from tokenizers.models import BPE
+from tokenizers.trainers import BpeTrainer
+from tokenizers.pre_tokenizers import Whitespace
+
+
+@dataclass
+class Encoding:
+ """Simple dataclass to represent an encoding"""
+
+ ids: list[int]
+
+
+class CharTokenizer:
+ """Very simple tokenizer for use with Multilingual Librispeech
+
+ Simply checks what characters are in the dataset and uses them as tokens.
+
+ Exposes the same interface as tokenizers from the huggingface library, i.e.
+ encode, decode, decode_batch, get_vocab_size, save, from_file and train.
+ """
+
+ def __init__(self):
+ self.char_map = {}
+ self.index_map = {}
+ self.add_tokens(["<UNK>", "<SPACE>"])
+
+ def add_tokens(self, tokens: list[str]):
+ """Manually add tokens to the tokenizer
+
+ Args:
+ tokens (list[str]): List of tokens to add
+ """
+ for token in tokens:
+ if token not in self.char_map:
+ self.char_map[token] = len(self.char_map)
+ self.index_map[len(self.index_map)] = token
+
+ def train(
+ self, dataset_path: str, language: str, split: str, download: bool = True
+ ):
+ """Train the tokenizer on the given dataset
+
+ Args:
+ dataset_path (str): Path to the MLS dataset
+ language (str): Language to use
+ split (str): Split to use
+ """
+ if split not in ["train", "dev", "test", "all"]:
+ raise ValueError("Split must be one of train, dev, test, all")
+
+ if split == "all":
+ splits = ["train", "dev", "test"]
+ else:
+ splits = [split]
+
+ chars: set = set()
+ for s_plit in splits:
+ transcript_path = os.path.join(
+ dataset_path, language, s_plit, "transcripts.txt"
+ )
+
+ # check if dataset is downloaded, download if not
+ if download and not os.path.exists(transcript_path):
+ MultilingualLibriSpeech(dataset_path, language, s_plit, download=True)
+
+ with open(
+ transcript_path,
+ "r",
+ encoding="utf-8",
+ ) as file:
+ lines = file.readlines()
+ lines = [line.split(" ", 1)[1] for line in lines]
+ lines = [line.strip() for line in lines]
+
+ for line in tqdm(lines, desc=f"Training tokenizer on {s_plit} split"):
+ chars.update(line)
+ offset = len(self.char_map)
+ for i, char in enumerate(chars):
+ i += offset
+ self.char_map[char] = i
+ self.index_map[i] = char
+
+ def encode(self, text: str):
+ """Use a character map and convert text to an integer sequence
+
+ automatically maps spaces to <SPACE> and makes everything lowercase
+ unknown characters are mapped to the <UNK> token
+
+ """
+ int_sequence = []
+ text = text.lower()
+ for char in text:
+ if char == " ":
+ mapped_char = self.char_map["<SPACE>"]
+ elif char not in self.char_map:
+ mapped_char = self.char_map["<UNK>"]
+ else:
+ mapped_char = self.char_map[char]
+ int_sequence.append(mapped_char)
+ return Encoding(ids=int_sequence)
+
+ def decode(self, labels: list[int], remove_special_tokens: bool = True):
+ """Use a character map and convert integer labels to an text sequence
+
+ Args:
+ labels (list[int]): List of integer labels
+ remove_special_tokens (bool): Whether to remove special tokens.
+ Defaults to True.
+ """
+ string = []
+ for i in labels:
+ if remove_special_tokens and self.index_map[f"{i}"] == "<UNK>":
+ continue
+ if remove_special_tokens and self.index_map[f"{i}"] == "<SPACE>":
+ string.append(" ")
+ string.append(self.index_map[f"{i}"])
+ return "".join(string).replace("<SPACE>", " ")
+
+ def decode_batch(self, labels: list[list[int]]):
+ """Use a character map and convert integer labels to an text sequence"""
+ strings = []
+ for label in labels:
+ string = []
+ for i in label:
+ if self.index_map[i] == "<UNK>":
+ continue
+ if self.index_map[i] == "<SPACE>":
+ string.append(" ")
+ string.append(self.index_map[i])
+ strings.append("".join(string).replace("<SPACE>", " "))
+ return strings
+
+ def get_vocab_size(self):
+ """Get the size of the vocabulary"""
+ return len(self.char_map)
+
+ def save(self, path: str):
+ """Save the tokenizer to a file"""
+ with open(path, "w", encoding="utf-8") as file:
+ # save it in the following format:
+ # {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}}
+ json.dump(
+ {"char_map": self.char_map, "index_map": self.index_map},
+ file,
+ ensure_ascii=False,
+ )
+
+ def from_file(self, path: str):
+ """Load the tokenizer from a file"""
+ with open(path, "r", encoding="utf-8") as file:
+ # load it in the following format:
+ # {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}}
+ saved_file = json.load(file)
+ self.char_map = saved_file["char_map"]
+ self.index_map = saved_file["index_map"]
+
+
+@click.command()
+@click.option("--dataset_path", default="data", help="Path to the MLS dataset")
+@click.option("--language", default="mls_german_opus", help="Language to use")
+@click.option("--split", default="train", help="Split to use (including all)")
+@click.option("--download", default=True, help="Whether to download the dataset")
+@click.option(
+ "--out_path", default="tokenizer.json", help="Path to save the tokenizer to"
+)
+@click.option("--vocab_size", default=2000, help="Size of the vocabulary")
+def train_bpe_tokenizer(
+ dataset_path: str,
+ language: str,
+ split: str,
+ out_path: str,
+ download: bool,
+ vocab_size: int,
+):
+ """Train a Byte-Pair Encoder tokenizer on the MLS dataset
+
+ Assumes that the MLS dataset is located in the dataset_path and there is a
+ transcripts.txt file in the split folder.
+
+ Args:
+ dataset_path (str): Path to the MLS dataset
+ language (str): Language to use
+ split (str): Split to use
+ download (bool): Whether to download the dataset if it is not present
+ out_path (str): Path to save the tokenizer to
+ vocab_size (int): Size of the vocabulary
+ """
+ if split not in ["train", "dev", "test", "all"]:
+ raise ValueError("Split must be one of train, dev, test, all")
+
+ if split == "all":
+ splits = ["train", "dev", "test"]
+ else:
+ splits = [split]
+
+ lines = []
+
+ for s_plit in splits:
+ transcripts_path = os.path.join(
+ dataset_path, language, s_plit, "transcripts.txt"
+ )
+ if download and not os.path.exists(transcripts_path):
+ MultilingualLibriSpeech(dataset_path, language, s_plit, download=True)
+
+ with open(
+ transcripts_path,
+ "r",
+ encoding="utf-8",
+ ) as file:
+ sp_lines = file.readlines()
+ sp_lines = [line.split(" ", 1)[1] for line in sp_lines]
+ sp_lines = [line.strip() for line in sp_lines]
+
+ lines.append(sp_lines)
+
+ bpe_tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
+
+ initial_alphabet = [
+ "a",
+ "b",
+ "c",
+ "d",
+ "e",
+ "f",
+ "g",
+ "h",
+ "i",
+ "j",
+ "k",
+ "l",
+ "m",
+ "n",
+ "o",
+ "p",
+ "q",
+ "r",
+ "s",
+ "t",
+ "u",
+ "v",
+ "w",
+ "x",
+ "y",
+ "z",
+ "ä",
+ "ö",
+ "ü",
+ "ß",
+ "-",
+ "é",
+ "è",
+ "à",
+ "ù",
+ "ç",
+ "â",
+ "ê",
+ "î",
+ "ô",
+ "û",
+ "ë",
+ "ï",
+ "ü",
+ ]
+
+ trainer = BpeTrainer(
+ special_tokens=["[UNK]"],
+ vocab_size=vocab_size,
+ initial_alphabet=initial_alphabet,
+ show_progress=True,
+ ) # type: ignore
+
+ bpe_tokenizer.pre_tokenizer = Whitespace() # type: ignore
+
+ bpe_tokenizer.normalizer = normalizers.Lowercase() # type: ignore
+
+ bpe_tokenizer.train_from_iterator(lines, trainer=trainer)
+
+ bpe_tokenizer.save(out_path)
+
+
+@click.command()
+@click.option("--dataset_path", default="data", help="Path to the MLS dataset")
+@click.option("--language", default="mls_german_opus", help="Language to use")
+@click.option("--split", default="train", help="Split to use")
+@click.option(
+ "--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to"
+)
+@click.option("--download", default=True, help="Whether to download the dataset")
+def train_char_tokenizer(
+ dataset_path: str,
+ language: str,
+ split: str,
+ out_path: str,
+ download: bool,
+):
+ """Train a Byte-Pair Encoder tokenizer on the MLS dataset
+
+ Assumes that the MLS dataset is located in the dataset_path and there is a
+ transcripts.txt file in the split folder.
+
+ Args:
+ dataset_path (str): Path to the MLS dataset
+ language (str): Language to use
+ split (str): Split to use
+ download (bool): Whether to download the dataset if it is not present
+ out_path (str): Path to save the tokenizer to
+ """
+ char_tokenizer = CharTokenizer()
+
+ char_tokenizer.train(dataset_path, language, split, download)
+
+ char_tokenizer.save(out_path)
+
+
+if __name__ == "__main__":
+ tokenizer = CharTokenizer()
+ tokenizer.from_file("data/tokenizers/char_tokenizer_german.json")
+
+ print(tokenizer.decode(tokenizer.encode("Fichier non trouvé").ids))
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index ad8c9e9..6af1e80 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -1,94 +1,16 @@
"""Training script for the ASR model."""
-from AudioLoader.speech import MultilingualLibriSpeech
import os
import click
import torch
-import torch.nn as nn
-import torch.optim as optim
import torch.nn.functional as F
-from torch.utils.data import DataLoader
import torchaudio
-from .loss_scores import cer, wer
-
-
-class TextTransform:
- """Maps characters to integers and vice versa"""
-
- def __init__(self):
- char_map_str = """
- ' 0
- <SPACE> 1
- a 2
- b 3
- c 4
- d 5
- e 6
- f 7
- g 8
- h 9
- i 10
- j 11
- k 12
- l 13
- m 14
- n 15
- o 16
- p 17
- q 18
- r 19
- s 20
- t 21
- u 22
- v 23
- w 24
- x 25
- y 26
- z 27
- ä 28
- ö 29
- ü 30
- ß 31
- - 32
- é 33
- è 34
- à 35
- ù 36
- ç 37
- â 38
- ê 39
- î 40
- ô 41
- û 42
- ë 43
- ï 44
- ü 45
- """
- self.char_map = {}
- self.index_map = {}
- for line in char_map_str.strip().split("\n"):
- char, index = line.split()
- self.char_map[char] = int(index)
- self.index_map[int(index)] = char
- self.index_map[1] = " "
-
- def text_to_int(self, text):
- """Use a character map and convert text to an integer sequence"""
- int_sequence = []
- for char in text:
- if char == " ":
- mapped_char = self.char_map["<SPACE>"]
- else:
- mapped_char = self.char_map[char]
- int_sequence.append(mapped_char)
- return int_sequence
-
- def int_to_text(self, labels):
- """Use a character map and convert integer labels to an text sequence"""
- string = []
- for i in labels:
- string.append(self.index_map[i])
- return "".join(string).replace("<SPACE>", " ")
+from AudioLoader.speech import MultilingualLibriSpeech
+from torch import nn, optim
+from torch.utils.data import DataLoader
+from tokenizers import Tokenizer
+from .tokenizer import CharTokenizer
+from .loss_scores import cer, wer
train_audio_transforms = nn.Sequential(
torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
@@ -98,7 +20,9 @@ train_audio_transforms = nn.Sequential(
valid_audio_transforms = torchaudio.transforms.MelSpectrogram()
-text_transform = TextTransform()
+# text_transform = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
+text_transform = CharTokenizer()
+text_transform.from_file("data/tokenizers/char_tokenizer_german.json")
def data_processing(data, data_type="train"):
@@ -115,7 +39,7 @@ def data_processing(data, data_type="train"):
else:
raise ValueError("data_type should be train or valid")
spectrograms.append(spec)
- label = torch.Tensor(text_transform.text_to_int(sample["utterance"].lower()))
+ label = torch.Tensor(text_transform.encode(sample["utterance"]).ids)
labels.append(label)
input_lengths.append(spec.shape[0] // 2)
label_lengths.append(len(label))
@@ -133,6 +57,7 @@ def data_processing(data, data_type="train"):
def greedy_decoder(
output, labels, label_lengths, blank_label=28, collapse_repeated=True
):
+ # TODO: adopt to support both tokenizers
"""Greedily decode a sequence."""
arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
decodes = []
@@ -140,22 +65,25 @@ def greedy_decoder(
for i, args in enumerate(arg_maxes):
decode = []
targets.append(
- text_transform.int_to_text(labels[i][: label_lengths[i]].tolist())
+ text_transform.decode(
+ [int(x) for x in labels[i][: label_lengths[i]].tolist()]
+ )
)
for j, index in enumerate(args):
if index != blank_label:
if collapse_repeated and j != 0 and index == args[j - 1]:
continue
decode.append(index.item())
- decodes.append(text_transform.int_to_text(decode))
+ decodes.append(text_transform.decode(decode))
return decodes, targets
+# TODO: restructure into own file / class
class CNNLayerNorm(nn.Module):
"""Layer normalization built for cnns input"""
def __init__(self, n_feats: int):
- super(CNNLayerNorm, self).__init__()
+ super().__init__()
self.layer_norm = nn.LayerNorm(n_feats)
def forward(self, data):
@@ -177,7 +105,7 @@ class ResidualCNN(nn.Module):
dropout: float,
n_feats: int,
):
- super(ResidualCNN, self).__init__()
+ super().__init__()
self.cnn1 = nn.Conv2d(
in_channels, out_channels, kernel, stride, padding=kernel // 2
@@ -219,7 +147,7 @@ class BidirectionalGRU(nn.Module):
dropout: float,
batch_first: bool,
):
- super(BidirectionalGRU, self).__init__()
+ super().__init__()
self.bi_gru = nn.GRU(
input_size=rnn_dim,
@@ -253,7 +181,7 @@ class SpeechRecognitionModel(nn.Module):
stride: int = 2,
dropout: float = 0.1,
):
- super(SpeechRecognitionModel, self).__init__()
+ super().__init__()
n_feats //= 2
self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2)
# n residual cnn layers with filter size of 32
@@ -299,7 +227,7 @@ class SpeechRecognitionModel(nn.Module):
return data
-class IterMeter(object):
+class IterMeter:
"""keeps track of total iterations"""
def __init__(self):
@@ -354,6 +282,7 @@ def train(
return loss.item()
+# TODO: check how dataloader can be made more efficient
def test(model, device, test_loader, criterion):
"""Test"""
print("\nevaluating...")
@@ -401,7 +330,7 @@ def run(
"n_cnn_layers": 3,
"n_rnn_layers": 5,
"rnn_dim": 512,
- "n_class": 46,
+ "n_class": 36, # TODO: dynamically determine this from vocab size
"n_feats": 128,
"stride": 2,
"dropout": 0.1,
@@ -452,7 +381,7 @@ def run(
).to(device)
print(
- "Num Model Parameters", sum([param.nelement() for param in model.parameters()])
+ "Num Model Parameters", sum((param.nelement() for param in model.parameters()))
)
optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"])
criterion = nn.CTCLoss(blank=28).to(device)