diff options
-rw-r--r-- | poetry.lock | 46 | ||||
-rw-r--r-- | pyproject.toml | 10 | ||||
-rw-r--r-- | swr2_asr/inference_test.py | 11 | ||||
-rw-r--r-- | swr2_asr/loss_scores.py | 8 | ||||
-rw-r--r-- | swr2_asr/model_deep_speech.py | 145 | ||||
-rw-r--r-- | swr2_asr/tokenizer.py | 144 | ||||
-rw-r--r-- | swr2_asr/train.py | 354 | ||||
-rw-r--r-- | swr2_asr/utils.py | 316 |
8 files changed, 702 insertions, 332 deletions
diff --git a/poetry.lock b/poetry.lock index 1f3609a..c322398 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "astroid" @@ -20,21 +20,6 @@ wrapt = [ ] [[package]] -name = "AudioLoader" -version = "0.1.4" -description = "A collection of PyTorch audio datasets for speech and music applications" -optional = false -python-versions = ">=3.6" -files = [] -develop = false - -[package.source] -type = "git" -url = "https://github.com/marvinborner/AudioLoader.git" -reference = "HEAD" -resolved_reference = "8fb829bf7fb98f26f8456dc22ef0fe2c7bb38ac2" - -[[package]] name = "black" version = "23.7.0" description = "The uncompromising code formatter." @@ -149,18 +134,21 @@ graph = ["objgraph (>=1.7.2)"] [[package]] name = "filelock" -version = "3.12.2" +version = "3.12.3" description = "A platform independent file lock." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "filelock-3.12.2-py3-none-any.whl", hash = "sha256:cbb791cdea2a72f23da6ac5b5269ab0a0d161e9ef0100e653b69049a7706d1ec"}, - {file = "filelock-3.12.2.tar.gz", hash = "sha256:002740518d8aa59a26b0c76e10fb8c6e15eae825d34b6fdf670333fd7b938d81"}, + {file = "filelock-3.12.3-py3-none-any.whl", hash = "sha256:f067e40ccc40f2b48395a80fcbd4728262fab54e232e090a4063ab804179efeb"}, + {file = "filelock-3.12.3.tar.gz", hash = "sha256:0ecc1dd2ec4672a10c8550a8182f1bd0c0a5088470ecd5a125e45f49472fac3d"}, ] +[package.dependencies] +typing-extensions = {version = ">=4.7.1", markers = "python_version < \"3.11\""} + [package.extras] -docs = ["furo (>=2023.5.20)", "sphinx (>=7.0.1)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "diff-cover (>=7.5)", "pytest (>=7.3.1)", "pytest-cov (>=4.1)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"] +docs = ["furo (>=2023.7.26)", "sphinx (>=7.1.2)", "sphinx-autodoc-typehints (>=1.24)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3)", "diff-cover (>=7.7)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)", "pytest-timeout (>=2.1)"] [[package]] name = "isort" @@ -975,13 +963,13 @@ tutorials = ["matplotlib", "pandas", "tabulate"] [[package]] name = "types-tqdm" -version = "4.66.0.1" +version = "4.66.0.2" description = "Typing stubs for tqdm" optional = false python-versions = "*" files = [ - {file = "types-tqdm-4.66.0.1.tar.gz", hash = "sha256:6457c90f03cc5a0fe8dd11839c8cbf5572bf542b438b1af74233801728b5dfbc"}, - {file = "types_tqdm-4.66.0.1-py3-none-any.whl", hash = "sha256:6a1516788cbb33d725803439b79c25bfed7e8176b8d782020b5c24aedac1649b"}, + {file = "types-tqdm-4.66.0.2.tar.gz", hash = "sha256:9553a5e44c1d485fce19f505b8bd65c0c3e87e870678d1f2ed764ae59a55d45f"}, + {file = "types_tqdm-4.66.0.2-py3-none-any.whl", hash = "sha256:13dddd38908834abdf0acdc2b70cab7ac4bcc5ad7356ced450471662e58a0ffc"}, ] [[package]] @@ -997,13 +985,13 @@ files = [ [[package]] name = "wheel" -version = "0.41.1" +version = "0.41.2" description = "A built-package format for Python" optional = false python-versions = ">=3.7" files = [ - {file = "wheel-0.41.1-py3-none-any.whl", hash = "sha256:473219bd4cbedc62cea0cb309089b593e47c15c4a2531015f94e4e3b9a0f6981"}, - {file = "wheel-0.41.1.tar.gz", hash = "sha256:12b911f083e876e10c595779709f8a88a59f45aacc646492a67fe9ef796c1b47"}, + {file = "wheel-0.41.2-py3-none-any.whl", hash = "sha256:75909db2664838d015e3d9139004ee16711748a52c8f336b52882266540215d8"}, + {file = "wheel-0.41.2.tar.gz", hash = "sha256:0c5ac5ff2afb79ac23ab82bab027a0be7b5dbcf2e54dc50efe4bf507de1f7985"}, ] [package.extras] @@ -1096,4 +1084,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = "^3.10" -content-hash = "6b42e36364178f1670267137f73e8d2b2f3fc1d534a2b198d4ca3f65457d55c2" +content-hash = "a65a10595cd1536a6d09b3fcf6e95c29b03f7fab4574522f241dfdc8c6455b70" diff --git a/pyproject.toml b/pyproject.toml index fabe364..57c60c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,7 +11,6 @@ packages = [{include = "swr2_asr"}] python = "^3.10" torch = "2.0.0" torchaudio = "2.0.1" -audioloader = {git = "https://github.com/marvinborner/AudioLoader.git"} tqdm = "^4.66.1" numpy = "^1.25.2" mido = "^1.3.0" @@ -25,6 +24,15 @@ pylint = "^2.17.5" ruff = "^0.0.285" types-tqdm = "^4.66.0.1" +[tool.ruff] +select = ["E", "F", "B", "I"] +fixable = ["ALL"] +line-length = 120 +target-version = "py310" + +[tool.black] +line-length = 100 + [tool.poetry.scripts] train = "swr2_asr.train:run_cli" train-bpe-tokenizer = "swr2_asr.tokenizer:train_bpe_tokenizer" diff --git a/swr2_asr/inference_test.py b/swr2_asr/inference_test.py index a6b0010..96277fd 100644 --- a/swr2_asr/inference_test.py +++ b/swr2_asr/inference_test.py @@ -1,11 +1,12 @@ """Training script for the ASR model.""" -from AudioLoader.speech.mls import MultilingualLibriSpeech import torch import torchaudio import torchaudio.functional as F class GreedyCTCDecoder(torch.nn.Module): + """Greedy CTC decoder for the wav2vec2 model.""" + def __init__(self, labels, blank=0) -> None: super().__init__() self.labels = labels @@ -25,6 +26,8 @@ class GreedyCTCDecoder(torch.nn.Module): return "".join([self.labels[i] for i in indices]) +''' +Sorry marvin, Please fix this to use the new dataset def main() -> None: """Main function.""" # choose between cuda, cpu and mps devices @@ -44,9 +47,7 @@ def main() -> None: print(model.__class__) # only do all things for one single sample - dataset = MultilingualLibriSpeech( - "data", "mls_german_opus", split="train", download=True - ) + dataset = MultilingualLibriSpeech("data", "mls_german_opus", split="train", download=True) print(dataset[0]) @@ -68,7 +69,7 @@ def main() -> None: transcript = decoder(emission[0]) print(transcript) - +''' if __name__ == "__main__": main() diff --git a/swr2_asr/loss_scores.py b/swr2_asr/loss_scores.py index c49cc15..ef37b0a 100644 --- a/swr2_asr/loss_scores.py +++ b/swr2_asr/loss_scores.py @@ -54,9 +54,7 @@ def _levenshtein_distance(ref, hyp): return distance[len_ref % 2][len_hyp] -def word_errors( - reference: str, hypothesis: str, ignore_case: bool = False, delimiter: str = " " -): +def word_errors(reference: str, hypothesis: str, ignore_case: bool = False, delimiter: str = " "): """Compute the levenshtein distance between reference sequence and hypothesis sequence in word-level. :param reference: The reference sentence. @@ -176,9 +174,7 @@ def cer(reference, hypothesis, ignore_case=False, remove_space=False): :rtype: float :raises ValueError: If the reference length is zero. """ - edit_distance, ref_len = char_errors( - reference, hypothesis, ignore_case, remove_space - ) + edit_distance, ref_len = char_errors(reference, hypothesis, ignore_case, remove_space) if ref_len == 0: raise ValueError("Length of reference should be greater than 0.") diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py new file mode 100644 index 0000000..dd07ff9 --- /dev/null +++ b/swr2_asr/model_deep_speech.py @@ -0,0 +1,145 @@ +"""Main definition of model""" +import torch.nn.functional as F +from torch import nn + + +class CNNLayerNorm(nn.Module): + """Layer normalization built for cnns input""" + + def __init__(self, n_feats: int): + super().__init__() + self.layer_norm = nn.LayerNorm(n_feats) + + def forward(self, data): + """x (batch, channel, feature, time)""" + data = data.transpose(2, 3).contiguous() # (batch, channel, time, feature) + data = self.layer_norm(data) + return data.transpose(2, 3).contiguous() # (batch, channel, feature, time) + + +class ResidualCNN(nn.Module): + """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel: int, + stride: int, + dropout: float, + n_feats: int, + ): + super().__init__() + + self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel // 2) + self.cnn2 = nn.Conv2d( + out_channels, + out_channels, + kernel, + stride, + padding=kernel // 2, + ) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.layer_norm1 = CNNLayerNorm(n_feats) + self.layer_norm2 = CNNLayerNorm(n_feats) + + def forward(self, data): + """x (batch, channel, feature, time)""" + residual = data # (batch, channel, feature, time) + data = self.layer_norm1(data) + data = F.gelu(data) + data = self.dropout1(data) + data = self.cnn1(data) + data = self.layer_norm2(data) + data = F.gelu(data) + data = self.dropout2(data) + data = self.cnn2(data) + data += residual + return data # (batch, channel, feature, time) + + +class BidirectionalGRU(nn.Module): + """BIdirectional GRU with Layer Normalization and Dropout""" + + def __init__( + self, + rnn_dim: int, + hidden_size: int, + dropout: float, + batch_first: bool, + ): + super().__init__() + + self.bi_gru = nn.GRU( + input_size=rnn_dim, + hidden_size=hidden_size, + num_layers=1, + batch_first=batch_first, + bidirectional=True, + ) + self.layer_norm = nn.LayerNorm(rnn_dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, data): + """data (batch, time, feature)""" + data = self.layer_norm(data) + data = F.gelu(data) + data = self.dropout(data) + data, _ = self.bi_gru(data) + return data + + +class SpeechRecognitionModel(nn.Module): + """Speech Recognition Model Inspired by DeepSpeech 2""" + + def __init__( + self, + n_cnn_layers: int, + n_rnn_layers: int, + rnn_dim: int, + n_class: int, + n_feats: int, + stride: int = 2, + dropout: float = 0.1, + ): + super().__init__() + n_feats //= 2 + self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2) + # n residual cnn layers with filter size of 32 + self.rescnn_layers = nn.Sequential( + *[ + ResidualCNN(32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats) + for _ in range(n_cnn_layers) + ] + ) + self.fully_connected = nn.Linear(n_feats * 32, rnn_dim) + self.birnn_layers = nn.Sequential( + *[ + BidirectionalGRU( + rnn_dim=rnn_dim if i == 0 else rnn_dim * 2, + hidden_size=rnn_dim, + dropout=dropout, + batch_first=i == 0, + ) + for i in range(n_rnn_layers) + ] + ) + self.classifier = nn.Sequential( + nn.Linear(rnn_dim * 2, rnn_dim), # birnn returns rnn_dim*2 + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(rnn_dim, n_class), + ) + + def forward(self, data): + """data (batch, channel, feature, time)""" + data = self.cnn(data) + data = self.rescnn_layers(data) + sizes = data.size() + data = data.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # (batch, feature, time) + data = data.transpose(1, 2) # (batch, time, feature) + data = self.fully_connected(data) + data = self.birnn_layers(data) + data = self.classifier(data) + return data diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py index a665159..e4df93b 100644 --- a/swr2_asr/tokenizer.py +++ b/swr2_asr/tokenizer.py @@ -1,16 +1,60 @@ """Tokenizer for use with Multilingual Librispeech""" -from dataclasses import dataclass import json import os -import click -from tqdm import tqdm - -from AudioLoader.speech import MultilingualLibriSpeech +from dataclasses import dataclass +from typing import Type +import click from tokenizers import Tokenizer, normalizers from tokenizers.models import BPE -from tokenizers.trainers import BpeTrainer from tokenizers.pre_tokenizers import Whitespace +from tokenizers.trainers import BpeTrainer +from tqdm import tqdm + + +class TokenizerType: + """Base class for tokenizers. + + exposes the same interface as tokenizers from the huggingface library""" + + def encode(self, sequence: str) -> list[int]: + """Encode a sequence to a list of integer labels""" + raise NotImplementedError + + def decode(self, labels: list[int], remove_special_tokens: bool) -> str: + """Decode a list of integer labels to a sequence""" + raise NotImplementedError + + def decode_batch(self, labels: list[list[int]]) -> list[str]: + """Decode a batch of integer labels to a list of sequences""" + raise NotImplementedError + + def get_vocab_size(self) -> int: + """Get the size of the vocabulary""" + raise NotImplementedError + + def enable_padding( + self, + length: int = -1, + direction: str = "right", + pad_id: int = 0, + pad_type_id: int = 0, + pad_token: str = "[PAD]", + ) -> None: + """Enable padding for the tokenizer""" + raise NotImplementedError + + def save(self, path: str) -> None: + """Save the tokenizer to a file""" + raise NotImplementedError + + @staticmethod + def from_file(path: str) -> "TokenizerType": + """Load the tokenizer from a file""" + raise NotImplementedError + + +MyTokenizerType = Type[TokenizerType] @dataclass @@ -18,9 +62,10 @@ class Encoding: """Simple dataclass to represent an encoding""" ids: list[int] + tokens: list[str] -class CharTokenizer: +class CharTokenizer(TokenizerType): """Very simple tokenizer for use with Multilingual Librispeech Simply checks what characters are in the dataset and uses them as tokens. @@ -45,9 +90,7 @@ class CharTokenizer: self.char_map[token] = len(self.char_map) self.index_map[len(self.index_map)] = token - def train( - self, dataset_path: str, language: str, split: str, download: bool = True - ): + def train(self, dataset_path: str, language: str, split: str): """Train the tokenizer on the given dataset Args: @@ -65,13 +108,7 @@ class CharTokenizer: chars: set = set() for s_plit in splits: - transcript_path = os.path.join( - dataset_path, language, s_plit, "transcripts.txt" - ) - - # check if dataset is downloaded, download if not - if download and not os.path.exists(transcript_path): - MultilingualLibriSpeech(dataset_path, language, s_plit, download=True) + transcript_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt") with open( transcript_path, @@ -90,7 +127,7 @@ class CharTokenizer: self.char_map[char] = i self.index_map[i] = char - def encode(self, text: str): + def encode(self, sequence: str): """Use a character map and convert text to an integer sequence automatically maps spaces to <SPACE> and makes everything lowercase @@ -98,8 +135,8 @@ class CharTokenizer: """ int_sequence = [] - text = text.lower() - for char in text: + sequence = sequence.lower() + for char in sequence: if char == " ": mapped_char = self.char_map["<SPACE>"] elif char not in self.char_map: @@ -107,7 +144,7 @@ class CharTokenizer: else: mapped_char = self.char_map[char] int_sequence.append(mapped_char) - return Encoding(ids=int_sequence) + return Encoding(ids=int_sequence, tokens=list(sequence)) def decode(self, labels: list[int], remove_special_tokens: bool = True): """Use a character map and convert integer labels to an text sequence @@ -146,6 +183,7 @@ class CharTokenizer: def save(self, path: str): """Save the tokenizer to a file""" + os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "w", encoding="utf-8") as file: # save it in the following format: # {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}} @@ -155,31 +193,48 @@ class CharTokenizer: ensure_ascii=False, ) - def from_file(self, path: str): + @staticmethod + def from_file(path: str) -> "CharTokenizer": """Load the tokenizer from a file""" + char_tokenizer = CharTokenizer() with open(path, "r", encoding="utf-8") as file: # load it in the following format: # {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}} saved_file = json.load(file) - self.char_map = saved_file["char_map"] - self.index_map = saved_file["index_map"] + char_tokenizer.char_map = saved_file["char_map"] + char_tokenizer.index_map = saved_file["index_map"] + + return char_tokenizer @click.command() @click.option("--dataset_path", default="data", help="Path to the MLS dataset") @click.option("--language", default="mls_german_opus", help="Language to use") @click.option("--split", default="train", help="Split to use (including all)") -@click.option("--download", default=True, help="Whether to download the dataset") -@click.option( - "--out_path", default="tokenizer.json", help="Path to save the tokenizer to" -) +@click.option("--out_path", default="tokenizer.json", help="Path to save the tokenizer to") @click.option("--vocab_size", default=2000, help="Size of the vocabulary") +def train_bpe_tokenizer_cli( + dataset_path: str, + language: str, + split: str, + out_path: str, + vocab_size: int, +): + """Train a Byte-Pair Encoder tokenizer on the MLS dataset""" + train_bpe_tokenizer( + dataset_path, + language, + split, + out_path, + vocab_size, + ) + + def train_bpe_tokenizer( dataset_path: str, language: str, split: str, out_path: str, - download: bool, vocab_size: int, ): """Train a Byte-Pair Encoder tokenizer on the MLS dataset @@ -206,11 +261,12 @@ def train_bpe_tokenizer( lines = [] for s_plit in splits: - transcripts_path = os.path.join( - dataset_path, language, s_plit, "transcripts.txt" - ) - if download and not os.path.exists(transcripts_path): - MultilingualLibriSpeech(dataset_path, language, s_plit, download=True) + transcripts_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt") + if not os.path.exists(transcripts_path): + raise FileNotFoundError( + f"Could not find transcripts.txt in {transcripts_path}. " + "Please make sure that the dataset is downloaded." + ) with open( transcripts_path, @@ -226,6 +282,7 @@ def train_bpe_tokenizer( bpe_tokenizer = Tokenizer(BPE(unk_token="[UNK]")) initial_alphabet = [ + " ", "a", "b", "c", @@ -272,6 +329,7 @@ def train_bpe_tokenizer( "ü", ] + # TODO: add padding token / whitespace token / special tokens trainer = BpeTrainer( special_tokens=["[UNK]"], vocab_size=vocab_size, @@ -292,16 +350,22 @@ def train_bpe_tokenizer( @click.option("--dataset_path", default="data", help="Path to the MLS dataset") @click.option("--language", default="mls_german_opus", help="Language to use") @click.option("--split", default="train", help="Split to use") -@click.option( - "--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to" -) -@click.option("--download", default=True, help="Whether to download the dataset") +@click.option("--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to") +def train_char_tokenizer_cli( + dataset_path: str, + language: str, + split: str, + out_path: str, +): + """Train a Byte-Pair Encoder tokenizer on the MLS dataset""" + train_char_tokenizer(dataset_path, language, split, out_path) + + def train_char_tokenizer( dataset_path: str, language: str, split: str, out_path: str, - download: bool, ): """Train a Byte-Pair Encoder tokenizer on the MLS dataset @@ -317,7 +381,7 @@ def train_char_tokenizer( """ char_tokenizer = CharTokenizer() - char_tokenizer.train(dataset_path, language, split, download) + char_tokenizer.train(dataset_path, language, split) char_tokenizer.save(out_path) diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 6af1e80..63deb72 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -1,232 +1,57 @@ """Training script for the ASR model.""" import os +from typing import TypedDict + import click import torch import torch.nn.functional as F -import torchaudio -from AudioLoader.speech import MultilingualLibriSpeech from torch import nn, optim from torch.utils.data import DataLoader -from tokenizers import Tokenizer -from .tokenizer import CharTokenizer +from tqdm import tqdm + +from swr2_asr.model_deep_speech import SpeechRecognitionModel +from swr2_asr.tokenizer import CharTokenizer, train_char_tokenizer +from swr2_asr.utils import MLSDataset, Split, collate_fn from .loss_scores import cer, wer -train_audio_transforms = nn.Sequential( - torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), - torchaudio.transforms.FrequencyMasking(freq_mask_param=30), - torchaudio.transforms.TimeMasking(time_mask_param=100), -) +# TODO: improve naming of functions -valid_audio_transforms = torchaudio.transforms.MelSpectrogram() - -# text_transform = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") -text_transform = CharTokenizer() -text_transform.from_file("data/tokenizers/char_tokenizer_german.json") - - -def data_processing(data, data_type="train"): - """Return the spectrograms, labels, and their lengths.""" - spectrograms = [] - labels = [] - input_lengths = [] - label_lengths = [] - for sample in data: - if data_type == "train": - spec = train_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1) - elif data_type == "valid": - spec = valid_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1) - else: - raise ValueError("data_type should be train or valid") - spectrograms.append(spec) - label = torch.Tensor(text_transform.encode(sample["utterance"]).ids) - labels.append(label) - input_lengths.append(spec.shape[0] // 2) - label_lengths.append(len(label)) - - spectrograms = ( - nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) - .unsqueeze(1) - .transpose(2, 3) - ) - labels = nn.utils.rnn.pad_sequence(labels, batch_first=True) - return spectrograms, labels, input_lengths, label_lengths +class HParams(TypedDict): + """Type for the hyperparameters of the model.""" + n_cnn_layers: int + n_rnn_layers: int + rnn_dim: int + n_class: int + n_feats: int + stride: int + dropout: float + learning_rate: float + batch_size: int + epochs: int -def greedy_decoder( - output, labels, label_lengths, blank_label=28, collapse_repeated=True -): - # TODO: adopt to support both tokenizers + +def greedy_decoder(output, tokenizer, labels, label_lengths, collapse_repeated=True): """Greedily decode a sequence.""" + print("output shape", output.shape) arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member + blank_label = tokenizer.encode(" ").ids[0] decodes = [] targets = [] for i, args in enumerate(arg_maxes): decode = [] - targets.append( - text_transform.decode( - [int(x) for x in labels[i][: label_lengths[i]].tolist()] - ) - ) + targets.append(tokenizer.decode([int(x) for x in labels[i][: label_lengths[i]].tolist()])) for j, index in enumerate(args): if index != blank_label: if collapse_repeated and j != 0 and index == args[j - 1]: continue decode.append(index.item()) - decodes.append(text_transform.decode(decode)) + decodes.append(tokenizer.decode(decode)) return decodes, targets -# TODO: restructure into own file / class -class CNNLayerNorm(nn.Module): - """Layer normalization built for cnns input""" - - def __init__(self, n_feats: int): - super().__init__() - self.layer_norm = nn.LayerNorm(n_feats) - - def forward(self, data): - """x (batch, channel, feature, time)""" - data = data.transpose(2, 3).contiguous() # (batch, channel, time, feature) - data = self.layer_norm(data) - return data.transpose(2, 3).contiguous() # (batch, channel, feature, time) - - -class ResidualCNN(nn.Module): - """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf""" - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel: int, - stride: int, - dropout: float, - n_feats: int, - ): - super().__init__() - - self.cnn1 = nn.Conv2d( - in_channels, out_channels, kernel, stride, padding=kernel // 2 - ) - self.cnn2 = nn.Conv2d( - out_channels, - out_channels, - kernel, - stride, - padding=kernel // 2, - ) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - self.layer_norm1 = CNNLayerNorm(n_feats) - self.layer_norm2 = CNNLayerNorm(n_feats) - - def forward(self, data): - """x (batch, channel, feature, time)""" - residual = data # (batch, channel, feature, time) - data = self.layer_norm1(data) - data = F.gelu(data) - data = self.dropout1(data) - data = self.cnn1(data) - data = self.layer_norm2(data) - data = F.gelu(data) - data = self.dropout2(data) - data = self.cnn2(data) - data += residual - return data # (batch, channel, feature, time) - - -class BidirectionalGRU(nn.Module): - """BIdirectional GRU with Layer Normalization and Dropout""" - - def __init__( - self, - rnn_dim: int, - hidden_size: int, - dropout: float, - batch_first: bool, - ): - super().__init__() - - self.bi_gru = nn.GRU( - input_size=rnn_dim, - hidden_size=hidden_size, - num_layers=1, - batch_first=batch_first, - bidirectional=True, - ) - self.layer_norm = nn.LayerNorm(rnn_dim) - self.dropout = nn.Dropout(dropout) - - def forward(self, data): - """data (batch, time, feature)""" - data = self.layer_norm(data) - data = F.gelu(data) - data = self.dropout(data) - data, _ = self.bi_gru(data) - return data - - -class SpeechRecognitionModel(nn.Module): - """Speech Recognition Model Inspired by DeepSpeech 2""" - - def __init__( - self, - n_cnn_layers: int, - n_rnn_layers: int, - rnn_dim: int, - n_class: int, - n_feats: int, - stride: int = 2, - dropout: float = 0.1, - ): - super().__init__() - n_feats //= 2 - self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2) - # n residual cnn layers with filter size of 32 - self.rescnn_layers = nn.Sequential( - *[ - ResidualCNN( - 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats - ) - for _ in range(n_cnn_layers) - ] - ) - self.fully_connected = nn.Linear(n_feats * 32, rnn_dim) - self.birnn_layers = nn.Sequential( - *[ - BidirectionalGRU( - rnn_dim=rnn_dim if i == 0 else rnn_dim * 2, - hidden_size=rnn_dim, - dropout=dropout, - batch_first=i == 0, - ) - for i in range(n_rnn_layers) - ] - ) - self.classifier = nn.Sequential( - nn.Linear(rnn_dim * 2, rnn_dim), # birnn returns rnn_dim*2 - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(rnn_dim, n_class), - ) - - def forward(self, data): - """data (batch, channel, feature, time)""" - data = self.cnn(data) - data = self.rescnn_layers(data) - sizes = data.size() - data = data.view( - sizes[0], sizes[1] * sizes[2], sizes[3] - ) # (batch, feature, time) - data = data.transpose(1, 2) # (batch, time, feature) - data = self.fully_connected(data) - data = self.birnn_layers(data) - data = self.classifier(data) - return data - - class IterMeter: """keeps track of total iterations""" @@ -254,36 +79,30 @@ def train( ): """Train""" model.train() - data_len = len(train_loader.dataset) - for batch_idx, _data in enumerate(train_loader): - spectrograms, labels, input_lengths, label_lengths = _data - spectrograms, labels = spectrograms.to(device), labels.to(device) - + print(f"Epoch: {epoch}") + losses = [] + for _data in tqdm(train_loader, desc="batches"): + spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device) optimizer.zero_grad() output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) - loss = criterion(output, labels, input_lengths, label_lengths) + loss = criterion(output, labels, _data["input_length"], _data["utterance_length"]) loss.backward() optimizer.step() scheduler.step() iter_meter.step() - if batch_idx % 100 == 0 or batch_idx == data_len: - print( - f"Train Epoch: \ - {epoch} \ - [{batch_idx * len(spectrograms)}/{data_len} \ - ({100.0 * batch_idx / len(train_loader)}%)]\t \ - Loss: {loss.item()}" - ) - return loss.item() + losses.append(loss.item()) + + print(f"loss in epoch {epoch}: {sum(losses) / len(losses)}") + return sum(losses) / len(losses) -# TODO: check how dataloader can be made more efficient -def test(model, device, test_loader, criterion): + +def test(model, device, test_loader, criterion, tokenizer): """Test""" print("\nevaluating...") model.eval() @@ -291,18 +110,20 @@ def test(model, device, test_loader, criterion): test_cer, test_wer = [], [] with torch.no_grad(): for _data in test_loader: - spectrograms, labels, input_lengths, label_lengths = _data - spectrograms, labels = spectrograms.to(device), labels.to(device) + spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device) output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) - loss = criterion(output, labels, input_lengths, label_lengths) + loss = criterion(output, labels, _data["input_length"], _data["utterance_length"]) test_loss += loss.item() / len(test_loader) decoded_preds, decoded_targets = greedy_decoder( - output.transpose(0, 1), labels, label_lengths + output=output.transpose(0, 1), + labels=labels, + label_lengths=_data["utterance_length"], + tokenizer=tokenizer, ) for j, pred in enumerate(decoded_preds): test_cer.append(cer(decoded_targets[j], pred)) @@ -313,9 +134,11 @@ def test(model, device, test_loader, criterion): print( f"Test set: Average loss:\ - {test_loss}, Average CER: {avg_cer} Average WER: {avg_wer}\n" + {test_loss}, Average CER: {None} Average WER: {None}\n" ) + return test_loss, avg_cer, avg_wer + def run( learning_rate: float, @@ -324,46 +147,66 @@ def run( load: bool, path: str, dataset_path: str, + language: str, ) -> None: """Runs the training script.""" - hparams = { - "n_cnn_layers": 3, - "n_rnn_layers": 5, - "rnn_dim": 512, - "n_class": 36, # TODO: dynamically determine this from vocab size - "n_feats": 128, - "stride": 2, - "dropout": 0.1, - "learning_rate": learning_rate, - "batch_size": batch_size, - "epochs": epochs, - } - use_cuda = torch.cuda.is_available() torch.manual_seed(42) device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member # device = torch.device("mps") - download_dataset = not os.path.isdir(path) - train_dataset = MultilingualLibriSpeech( - dataset_path, "mls_german_opus", split="dev", download=download_dataset + # load dataset + train_dataset = MLSDataset( + dataset_path, language, Split.TRAIN, download=True, spectrogram_hparams=None ) - test_dataset = MultilingualLibriSpeech( - dataset_path, "mls_german_opus", split="test", download=False + valid_dataset = MLSDataset( + dataset_path, language, Split.VALID, download=True, spectrogram_hparams=None + ) + + # load tokenizer (bpe by default): + if not os.path.isfile("data/tokenizers/char_tokenizer_german.json"): + print("There is no tokenizer available. Do you want to train it on the dataset?") + input("Press Enter to continue...") + train_char_tokenizer( + dataset_path=dataset_path, + language=language, + split="all", + download=False, + out_path="data/tokenizers/char_tokenizer_german.json", + ) + + tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") + + train_dataset.set_tokenizer(tokenizer) # type: ignore + valid_dataset.set_tokenizer(tokenizer) # type: ignore + + print(f"Waveform shape: {train_dataset[0]['waveform'].shape}") + + hparams = HParams( + n_cnn_layers=3, + n_rnn_layers=5, + rnn_dim=512, + n_class=tokenizer.get_vocab_size(), + n_feats=128, + stride=2, + dropout=0.1, + learning_rate=learning_rate, + batch_size=batch_size, + epochs=epochs, ) train_loader = DataLoader( train_dataset, batch_size=hparams["batch_size"], shuffle=True, - collate_fn=lambda x: data_processing(x, "train"), + collate_fn=lambda x: collate_fn(x), ) - test_loader = DataLoader( - test_dataset, + valid_loader = DataLoader( + valid_dataset, batch_size=hparams["batch_size"], shuffle=True, - collate_fn=lambda x: data_processing(x, "train"), + collate_fn=lambda x: collate_fn(x), ) # enable flag to find the most compatible algorithms in advance @@ -379,12 +222,10 @@ def run( hparams["stride"], hparams["dropout"], ).to(device) - - print( - "Num Model Parameters", sum((param.nelement() for param in model.parameters())) - ) + print(tokenizer.encode(" ")) + print("Num Model Parameters", sum((param.nelement() for param in model.parameters()))) optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) - criterion = nn.CTCLoss(blank=28).to(device) + criterion = nn.CTCLoss(tokenizer.encode(" ").ids[0]).to(device) if load: checkpoint = torch.load(path) model.load_state_dict(checkpoint["model_state_dict"]) @@ -412,7 +253,13 @@ def run( iter_meter, ) - test(model=model, device=device, test_loader=test_loader, criterion=criterion) + test( + model=model, + device=device, + test_loader=valid_loader, + criterion=criterion, + tokenizer=tokenizer, + ) print("saving epoch", str(epoch)) torch.save( {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss}, @@ -452,4 +299,9 @@ def run_cli( load=load, path=path, dataset_path=dataset_path, + language="mls_german_opus", ) + + +if __name__ == "__main__": + run_cli() # pylint: disable=no-value-for-parameter diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py new file mode 100644 index 0000000..8a950ab --- /dev/null +++ b/swr2_asr/utils.py @@ -0,0 +1,316 @@ +"""Class containing utils for the ASR system.""" +import os +from enum import Enum +from typing import TypedDict + +import numpy as np +import torch +import torchaudio +from tokenizers import Tokenizer +from torch.utils.data import Dataset +from torchaudio.datasets.utils import _extract_tar as extract_archive + +from swr2_asr.tokenizer import TokenizerType + +train_audio_transforms = torch.nn.Sequential( + torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), + torchaudio.transforms.FrequencyMasking(freq_mask_param=30), + torchaudio.transforms.TimeMasking(time_mask_param=100), +) + + +# create enum specifiying dataset splits +class MLSSplit(str, Enum): + """Enum specifying dataset as they are defined in the + Multilingual LibriSpeech dataset""" + + TRAIN = "train" + TEST = "test" + DEV = "dev" + + +class Split(str, Enum): + """Extending the MLSSplit class to allow for a custom validatio split""" + + TRAIN = "train" + VALID = "valid" + TEST = "test" + DEV = "dev" + + +def split_to_mls_split(split_name: Split) -> MLSSplit: + """Converts the custom split to a MLSSplit""" + if split_name == Split.VALID: + return MLSSplit.TRAIN + else: + return split_name # type: ignore + + +class Sample(TypedDict): + """Type for a sample in the dataset""" + + waveform: torch.Tensor + spectrogram: torch.Tensor + input_length: int + utterance: torch.Tensor + utterance_length: int + sample_rate: int + speaker_id: str + book_id: str + chapter_id: str + + +class MLSDataset(Dataset): + """Custom Dataset for reading Multilingual LibriSpeech + + Attributes: + dataset_path (str): + path to the dataset + language (str): + language of the dataset + split (Split): + split of the dataset + mls_split (MLSSplit): + split of the dataset as defined in the Multilingual LibriSpeech dataset + dataset_lookup (list): + list of dicts containing the speakerid, bookid, chapterid and utterance + + directory structure: + <dataset_path> + ├── <language> + │ ├── train + │ │ ├── transcripts.txt + │ │ └── audio + │ │ └── <speakerid> + │ │ └── <bookid> + │ │ └── <speakerid>_<bookid>_<chapterid>.opus / .flac + + each line in transcripts.txt has the following format: + <speakerid>_<bookid>_<chapterid> <utterance> + """ + + def __init__( + self, + dataset_path: str, + language: str, + split: Split, + download: bool, + spectrogram_hparams: dict | None, + ): + """Initializes the dataset""" + self.dataset_path = dataset_path + self.language = language + self.file_ext = ".opus" if "opus" in language else ".flac" + self.mls_split: MLSSplit = split_to_mls_split(split) # split path on disk + self.split: Split = split # split used internally + + if spectrogram_hparams is None: + self.spectrogram_hparams = { + "sample_rate": 16000, + "n_fft": 400, + "win_length": 400, + "hop_length": 160, + "n_mels": 128, + "f_min": 0, + "f_max": 8000, + "power": 2.0, + } + else: + self.spectrogram_hparams = spectrogram_hparams + + self.dataset_lookup = [] + self.tokenizer: type[TokenizerType] + + self._handle_download_dataset(download) + self._validate_local_directory() + self.initialize() + + def initialize(self) -> None: + """Initializes the dataset + + Reads the transcripts.txt file and creates a lookup table + """ + transcripts_path = os.path.join( + self.dataset_path, self.language, self.mls_split, "transcripts.txt" + ) + + with open(transcripts_path, "r", encoding="utf-8") as script_file: + # read all lines in transcripts.txt + transcripts = script_file.readlines() + # split each line into (<speakerid>_<bookid>_<chapterid>, <utterance>) + transcripts = [line.strip().split("\t", 1) for line in transcripts] # type: ignore + utterances = [utterance.strip() for _, utterance in transcripts] # type: ignore + identifier = [identifier.strip() for identifier, _ in transcripts] # type: ignore + identifier = [path.split("_") for path in identifier] + + if self.split == Split.VALID: + np.random.seed(42) + indices = np.random.choice(len(utterances), int(len(utterances) * 0.2)) + utterances = [utterances[i] for i in indices] + identifier = [identifier[i] for i in indices] + elif self.split == Split.TRAIN: + np.random.seed(42) + indices = np.random.choice(len(utterances), int(len(utterances) * 0.8)) + utterances = [utterances[i] for i in indices] + identifier = [identifier[i] for i in indices] + + self.dataset_lookup = [ + { + "speakerid": path[0], + "bookid": path[1], + "chapterid": path[2], + "utterance": utterance, + } + for path, utterance in zip(identifier, utterances, strict=False) + ] + + def set_tokenizer(self, tokenizer: type[TokenizerType]): + """Sets the tokenizer""" + self.tokenizer = tokenizer + + def _handle_download_dataset(self, download: bool) -> None: + """Download the dataset""" + if not download: + print("Download flag not set, skipping download") + return + # zip exists: + if os.path.isfile(os.path.join(self.dataset_path, self.language) + ".tar.gz") and download: + print(f"Found dataset at {self.dataset_path}. Skipping download") + # zip does not exist: + else: + os.makedirs(self.dataset_path, exist_ok=True) + url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz" + + torch.hub.download_url_to_file( + url, os.path.join(self.dataset_path, self.language) + ".tar.gz" + ) + + # unzip the dataset + if not os.path.isdir(os.path.join(self.dataset_path, self.language)): + print( + f"Unzipping the dataset at {os.path.join(self.dataset_path, self.language) + '.tar.gz'}" + ) + extract_archive( + os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True + ) + else: + print("Dataset is already unzipped, validating it now") + return + + def _validate_local_directory(self): + # check if dataset_path exists + if not os.path.exists(self.dataset_path): + raise ValueError("Dataset path does not exist") + if not os.path.exists(os.path.join(self.dataset_path, self.language)): + raise ValueError("Language not downloaded!") + if not os.path.exists(os.path.join(self.dataset_path, self.language, self.mls_split)): + raise ValueError("Split not found in dataset") + + def __len__(self): + """Returns the length of the dataset""" + return len(self.dataset_lookup) + + def __getitem__(self, idx: int) -> Sample: + """One sample""" + if self.tokenizer is None: + raise ValueError("No tokenizer set") + # get the utterance + utterance = self.dataset_lookup[idx]["utterance"] + + # get the audio file + audio_path = os.path.join( + self.dataset_path, + self.language, + self.mls_split, + "audio", + self.dataset_lookup[idx]["speakerid"], + self.dataset_lookup[idx]["bookid"], + "_".join( + [ + self.dataset_lookup[idx]["speakerid"], + self.dataset_lookup[idx]["bookid"], + self.dataset_lookup[idx]["chapterid"], + ] + ) + + self.file_ext, + ) + + waveform, sample_rate = torchaudio.load(audio_path) # type: ignore + + # resample if necessary + if sample_rate != self.spectrogram_hparams["sample_rate"]: + resampler = torchaudio.transforms.Resample( + sample_rate, self.spectrogram_hparams["sample_rate"] + ) + waveform = resampler(waveform) + + spec = ( + torchaudio.transforms.MelSpectrogram(**self.spectrogram_hparams)(waveform) + .squeeze(0) + .transpose(0, 1) + ) + + input_length = spec.shape[0] // 2 + + utterance_length = len(utterance) + + utterance = self.tokenizer.encode(utterance) + + utterance = torch.LongTensor(utterance.ids) + + return Sample( + waveform=waveform, + spectrogram=spec, + input_length=input_length, + utterance=utterance, + utterance_length=utterance_length, + sample_rate=self.spectrogram_hparams["sample_rate"], + speaker_id=self.dataset_lookup[idx]["speakerid"], + book_id=self.dataset_lookup[idx]["bookid"], + chapter_id=self.dataset_lookup[idx]["chapterid"], + ) + + +def collate_fn(samples: list[Sample]) -> dict: + """Collate function for the dataloader + + pads all tensors within a batch to the same dimensions + """ + waveforms = [] + spectrograms = [] + labels = [] + input_lengths = [] + label_lengths = [] + + for sample in samples: + waveforms.append(sample["waveform"].transpose(0, 1)) + spectrograms.append(sample["spectrogram"]) + labels.append(sample["utterance"]) + input_lengths.append(sample["spectrogram"].shape[0] // 2) + label_lengths.append(len(sample["utterance"])) + + waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True) + spectrograms = ( + torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3) + ) + labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True) + + return { + "waveform": waveforms, + "spectrogram": spectrograms, + "input_length": input_lengths, + "utterance": labels, + "utterance_length": label_lengths, + } + + +if __name__ == "__main__": + DATASET_PATH = "/Volumes/pherkel/SWR2-ASR" + LANGUAGE = "mls_german_opus" + split = Split.TRAIN + DOWNLOAD = False + + dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, DOWNLOAD, None) + + tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") + dataset.set_tokenizer(tok) |