aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--poetry.lock46
-rw-r--r--pyproject.toml10
-rw-r--r--swr2_asr/inference_test.py11
-rw-r--r--swr2_asr/loss_scores.py8
-rw-r--r--swr2_asr/model_deep_speech.py145
-rw-r--r--swr2_asr/tokenizer.py144
-rw-r--r--swr2_asr/train.py354
-rw-r--r--swr2_asr/utils.py316
8 files changed, 702 insertions, 332 deletions
diff --git a/poetry.lock b/poetry.lock
index 1f3609a..c322398 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1,4 +1,4 @@
-# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand.
+# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
[[package]]
name = "astroid"
@@ -20,21 +20,6 @@ wrapt = [
]
[[package]]
-name = "AudioLoader"
-version = "0.1.4"
-description = "A collection of PyTorch audio datasets for speech and music applications"
-optional = false
-python-versions = ">=3.6"
-files = []
-develop = false
-
-[package.source]
-type = "git"
-url = "https://github.com/marvinborner/AudioLoader.git"
-reference = "HEAD"
-resolved_reference = "8fb829bf7fb98f26f8456dc22ef0fe2c7bb38ac2"
-
-[[package]]
name = "black"
version = "23.7.0"
description = "The uncompromising code formatter."
@@ -149,18 +134,21 @@ graph = ["objgraph (>=1.7.2)"]
[[package]]
name = "filelock"
-version = "3.12.2"
+version = "3.12.3"
description = "A platform independent file lock."
optional = false
-python-versions = ">=3.7"
+python-versions = ">=3.8"
files = [
- {file = "filelock-3.12.2-py3-none-any.whl", hash = "sha256:cbb791cdea2a72f23da6ac5b5269ab0a0d161e9ef0100e653b69049a7706d1ec"},
- {file = "filelock-3.12.2.tar.gz", hash = "sha256:002740518d8aa59a26b0c76e10fb8c6e15eae825d34b6fdf670333fd7b938d81"},
+ {file = "filelock-3.12.3-py3-none-any.whl", hash = "sha256:f067e40ccc40f2b48395a80fcbd4728262fab54e232e090a4063ab804179efeb"},
+ {file = "filelock-3.12.3.tar.gz", hash = "sha256:0ecc1dd2ec4672a10c8550a8182f1bd0c0a5088470ecd5a125e45f49472fac3d"},
]
+[package.dependencies]
+typing-extensions = {version = ">=4.7.1", markers = "python_version < \"3.11\""}
+
[package.extras]
-docs = ["furo (>=2023.5.20)", "sphinx (>=7.0.1)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"]
-testing = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "diff-cover (>=7.5)", "pytest (>=7.3.1)", "pytest-cov (>=4.1)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"]
+docs = ["furo (>=2023.7.26)", "sphinx (>=7.1.2)", "sphinx-autodoc-typehints (>=1.24)"]
+testing = ["covdefaults (>=2.3)", "coverage (>=7.3)", "diff-cover (>=7.7)", "pytest (>=7.4)", "pytest-cov (>=4.1)", "pytest-mock (>=3.11.1)", "pytest-timeout (>=2.1)"]
[[package]]
name = "isort"
@@ -975,13 +963,13 @@ tutorials = ["matplotlib", "pandas", "tabulate"]
[[package]]
name = "types-tqdm"
-version = "4.66.0.1"
+version = "4.66.0.2"
description = "Typing stubs for tqdm"
optional = false
python-versions = "*"
files = [
- {file = "types-tqdm-4.66.0.1.tar.gz", hash = "sha256:6457c90f03cc5a0fe8dd11839c8cbf5572bf542b438b1af74233801728b5dfbc"},
- {file = "types_tqdm-4.66.0.1-py3-none-any.whl", hash = "sha256:6a1516788cbb33d725803439b79c25bfed7e8176b8d782020b5c24aedac1649b"},
+ {file = "types-tqdm-4.66.0.2.tar.gz", hash = "sha256:9553a5e44c1d485fce19f505b8bd65c0c3e87e870678d1f2ed764ae59a55d45f"},
+ {file = "types_tqdm-4.66.0.2-py3-none-any.whl", hash = "sha256:13dddd38908834abdf0acdc2b70cab7ac4bcc5ad7356ced450471662e58a0ffc"},
]
[[package]]
@@ -997,13 +985,13 @@ files = [
[[package]]
name = "wheel"
-version = "0.41.1"
+version = "0.41.2"
description = "A built-package format for Python"
optional = false
python-versions = ">=3.7"
files = [
- {file = "wheel-0.41.1-py3-none-any.whl", hash = "sha256:473219bd4cbedc62cea0cb309089b593e47c15c4a2531015f94e4e3b9a0f6981"},
- {file = "wheel-0.41.1.tar.gz", hash = "sha256:12b911f083e876e10c595779709f8a88a59f45aacc646492a67fe9ef796c1b47"},
+ {file = "wheel-0.41.2-py3-none-any.whl", hash = "sha256:75909db2664838d015e3d9139004ee16711748a52c8f336b52882266540215d8"},
+ {file = "wheel-0.41.2.tar.gz", hash = "sha256:0c5ac5ff2afb79ac23ab82bab027a0be7b5dbcf2e54dc50efe4bf507de1f7985"},
]
[package.extras]
@@ -1096,4 +1084,4 @@ files = [
[metadata]
lock-version = "2.0"
python-versions = "^3.10"
-content-hash = "6b42e36364178f1670267137f73e8d2b2f3fc1d534a2b198d4ca3f65457d55c2"
+content-hash = "a65a10595cd1536a6d09b3fcf6e95c29b03f7fab4574522f241dfdc8c6455b70"
diff --git a/pyproject.toml b/pyproject.toml
index fabe364..57c60c9 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -11,7 +11,6 @@ packages = [{include = "swr2_asr"}]
python = "^3.10"
torch = "2.0.0"
torchaudio = "2.0.1"
-audioloader = {git = "https://github.com/marvinborner/AudioLoader.git"}
tqdm = "^4.66.1"
numpy = "^1.25.2"
mido = "^1.3.0"
@@ -25,6 +24,15 @@ pylint = "^2.17.5"
ruff = "^0.0.285"
types-tqdm = "^4.66.0.1"
+[tool.ruff]
+select = ["E", "F", "B", "I"]
+fixable = ["ALL"]
+line-length = 120
+target-version = "py310"
+
+[tool.black]
+line-length = 100
+
[tool.poetry.scripts]
train = "swr2_asr.train:run_cli"
train-bpe-tokenizer = "swr2_asr.tokenizer:train_bpe_tokenizer"
diff --git a/swr2_asr/inference_test.py b/swr2_asr/inference_test.py
index a6b0010..96277fd 100644
--- a/swr2_asr/inference_test.py
+++ b/swr2_asr/inference_test.py
@@ -1,11 +1,12 @@
"""Training script for the ASR model."""
-from AudioLoader.speech.mls import MultilingualLibriSpeech
import torch
import torchaudio
import torchaudio.functional as F
class GreedyCTCDecoder(torch.nn.Module):
+ """Greedy CTC decoder for the wav2vec2 model."""
+
def __init__(self, labels, blank=0) -> None:
super().__init__()
self.labels = labels
@@ -25,6 +26,8 @@ class GreedyCTCDecoder(torch.nn.Module):
return "".join([self.labels[i] for i in indices])
+'''
+Sorry marvin, Please fix this to use the new dataset
def main() -> None:
"""Main function."""
# choose between cuda, cpu and mps devices
@@ -44,9 +47,7 @@ def main() -> None:
print(model.__class__)
# only do all things for one single sample
- dataset = MultilingualLibriSpeech(
- "data", "mls_german_opus", split="train", download=True
- )
+ dataset = MultilingualLibriSpeech("data", "mls_german_opus", split="train", download=True)
print(dataset[0])
@@ -68,7 +69,7 @@ def main() -> None:
transcript = decoder(emission[0])
print(transcript)
-
+'''
if __name__ == "__main__":
main()
diff --git a/swr2_asr/loss_scores.py b/swr2_asr/loss_scores.py
index c49cc15..ef37b0a 100644
--- a/swr2_asr/loss_scores.py
+++ b/swr2_asr/loss_scores.py
@@ -54,9 +54,7 @@ def _levenshtein_distance(ref, hyp):
return distance[len_ref % 2][len_hyp]
-def word_errors(
- reference: str, hypothesis: str, ignore_case: bool = False, delimiter: str = " "
-):
+def word_errors(reference: str, hypothesis: str, ignore_case: bool = False, delimiter: str = " "):
"""Compute the levenshtein distance between reference sequence and
hypothesis sequence in word-level.
:param reference: The reference sentence.
@@ -176,9 +174,7 @@ def cer(reference, hypothesis, ignore_case=False, remove_space=False):
:rtype: float
:raises ValueError: If the reference length is zero.
"""
- edit_distance, ref_len = char_errors(
- reference, hypothesis, ignore_case, remove_space
- )
+ edit_distance, ref_len = char_errors(reference, hypothesis, ignore_case, remove_space)
if ref_len == 0:
raise ValueError("Length of reference should be greater than 0.")
diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py
new file mode 100644
index 0000000..dd07ff9
--- /dev/null
+++ b/swr2_asr/model_deep_speech.py
@@ -0,0 +1,145 @@
+"""Main definition of model"""
+import torch.nn.functional as F
+from torch import nn
+
+
+class CNNLayerNorm(nn.Module):
+ """Layer normalization built for cnns input"""
+
+ def __init__(self, n_feats: int):
+ super().__init__()
+ self.layer_norm = nn.LayerNorm(n_feats)
+
+ def forward(self, data):
+ """x (batch, channel, feature, time)"""
+ data = data.transpose(2, 3).contiguous() # (batch, channel, time, feature)
+ data = self.layer_norm(data)
+ return data.transpose(2, 3).contiguous() # (batch, channel, feature, time)
+
+
+class ResidualCNN(nn.Module):
+ """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf"""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel: int,
+ stride: int,
+ dropout: float,
+ n_feats: int,
+ ):
+ super().__init__()
+
+ self.cnn1 = nn.Conv2d(in_channels, out_channels, kernel, stride, padding=kernel // 2)
+ self.cnn2 = nn.Conv2d(
+ out_channels,
+ out_channels,
+ kernel,
+ stride,
+ padding=kernel // 2,
+ )
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.layer_norm1 = CNNLayerNorm(n_feats)
+ self.layer_norm2 = CNNLayerNorm(n_feats)
+
+ def forward(self, data):
+ """x (batch, channel, feature, time)"""
+ residual = data # (batch, channel, feature, time)
+ data = self.layer_norm1(data)
+ data = F.gelu(data)
+ data = self.dropout1(data)
+ data = self.cnn1(data)
+ data = self.layer_norm2(data)
+ data = F.gelu(data)
+ data = self.dropout2(data)
+ data = self.cnn2(data)
+ data += residual
+ return data # (batch, channel, feature, time)
+
+
+class BidirectionalGRU(nn.Module):
+ """BIdirectional GRU with Layer Normalization and Dropout"""
+
+ def __init__(
+ self,
+ rnn_dim: int,
+ hidden_size: int,
+ dropout: float,
+ batch_first: bool,
+ ):
+ super().__init__()
+
+ self.bi_gru = nn.GRU(
+ input_size=rnn_dim,
+ hidden_size=hidden_size,
+ num_layers=1,
+ batch_first=batch_first,
+ bidirectional=True,
+ )
+ self.layer_norm = nn.LayerNorm(rnn_dim)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, data):
+ """data (batch, time, feature)"""
+ data = self.layer_norm(data)
+ data = F.gelu(data)
+ data = self.dropout(data)
+ data, _ = self.bi_gru(data)
+ return data
+
+
+class SpeechRecognitionModel(nn.Module):
+ """Speech Recognition Model Inspired by DeepSpeech 2"""
+
+ def __init__(
+ self,
+ n_cnn_layers: int,
+ n_rnn_layers: int,
+ rnn_dim: int,
+ n_class: int,
+ n_feats: int,
+ stride: int = 2,
+ dropout: float = 0.1,
+ ):
+ super().__init__()
+ n_feats //= 2
+ self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2)
+ # n residual cnn layers with filter size of 32
+ self.rescnn_layers = nn.Sequential(
+ *[
+ ResidualCNN(32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats)
+ for _ in range(n_cnn_layers)
+ ]
+ )
+ self.fully_connected = nn.Linear(n_feats * 32, rnn_dim)
+ self.birnn_layers = nn.Sequential(
+ *[
+ BidirectionalGRU(
+ rnn_dim=rnn_dim if i == 0 else rnn_dim * 2,
+ hidden_size=rnn_dim,
+ dropout=dropout,
+ batch_first=i == 0,
+ )
+ for i in range(n_rnn_layers)
+ ]
+ )
+ self.classifier = nn.Sequential(
+ nn.Linear(rnn_dim * 2, rnn_dim), # birnn returns rnn_dim*2
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(rnn_dim, n_class),
+ )
+
+ def forward(self, data):
+ """data (batch, channel, feature, time)"""
+ data = self.cnn(data)
+ data = self.rescnn_layers(data)
+ sizes = data.size()
+ data = data.view(sizes[0], sizes[1] * sizes[2], sizes[3]) # (batch, feature, time)
+ data = data.transpose(1, 2) # (batch, time, feature)
+ data = self.fully_connected(data)
+ data = self.birnn_layers(data)
+ data = self.classifier(data)
+ return data
diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py
index a665159..e4df93b 100644
--- a/swr2_asr/tokenizer.py
+++ b/swr2_asr/tokenizer.py
@@ -1,16 +1,60 @@
"""Tokenizer for use with Multilingual Librispeech"""
-from dataclasses import dataclass
import json
import os
-import click
-from tqdm import tqdm
-
-from AudioLoader.speech import MultilingualLibriSpeech
+from dataclasses import dataclass
+from typing import Type
+import click
from tokenizers import Tokenizer, normalizers
from tokenizers.models import BPE
-from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
+from tokenizers.trainers import BpeTrainer
+from tqdm import tqdm
+
+
+class TokenizerType:
+ """Base class for tokenizers.
+
+ exposes the same interface as tokenizers from the huggingface library"""
+
+ def encode(self, sequence: str) -> list[int]:
+ """Encode a sequence to a list of integer labels"""
+ raise NotImplementedError
+
+ def decode(self, labels: list[int], remove_special_tokens: bool) -> str:
+ """Decode a list of integer labels to a sequence"""
+ raise NotImplementedError
+
+ def decode_batch(self, labels: list[list[int]]) -> list[str]:
+ """Decode a batch of integer labels to a list of sequences"""
+ raise NotImplementedError
+
+ def get_vocab_size(self) -> int:
+ """Get the size of the vocabulary"""
+ raise NotImplementedError
+
+ def enable_padding(
+ self,
+ length: int = -1,
+ direction: str = "right",
+ pad_id: int = 0,
+ pad_type_id: int = 0,
+ pad_token: str = "[PAD]",
+ ) -> None:
+ """Enable padding for the tokenizer"""
+ raise NotImplementedError
+
+ def save(self, path: str) -> None:
+ """Save the tokenizer to a file"""
+ raise NotImplementedError
+
+ @staticmethod
+ def from_file(path: str) -> "TokenizerType":
+ """Load the tokenizer from a file"""
+ raise NotImplementedError
+
+
+MyTokenizerType = Type[TokenizerType]
@dataclass
@@ -18,9 +62,10 @@ class Encoding:
"""Simple dataclass to represent an encoding"""
ids: list[int]
+ tokens: list[str]
-class CharTokenizer:
+class CharTokenizer(TokenizerType):
"""Very simple tokenizer for use with Multilingual Librispeech
Simply checks what characters are in the dataset and uses them as tokens.
@@ -45,9 +90,7 @@ class CharTokenizer:
self.char_map[token] = len(self.char_map)
self.index_map[len(self.index_map)] = token
- def train(
- self, dataset_path: str, language: str, split: str, download: bool = True
- ):
+ def train(self, dataset_path: str, language: str, split: str):
"""Train the tokenizer on the given dataset
Args:
@@ -65,13 +108,7 @@ class CharTokenizer:
chars: set = set()
for s_plit in splits:
- transcript_path = os.path.join(
- dataset_path, language, s_plit, "transcripts.txt"
- )
-
- # check if dataset is downloaded, download if not
- if download and not os.path.exists(transcript_path):
- MultilingualLibriSpeech(dataset_path, language, s_plit, download=True)
+ transcript_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt")
with open(
transcript_path,
@@ -90,7 +127,7 @@ class CharTokenizer:
self.char_map[char] = i
self.index_map[i] = char
- def encode(self, text: str):
+ def encode(self, sequence: str):
"""Use a character map and convert text to an integer sequence
automatically maps spaces to <SPACE> and makes everything lowercase
@@ -98,8 +135,8 @@ class CharTokenizer:
"""
int_sequence = []
- text = text.lower()
- for char in text:
+ sequence = sequence.lower()
+ for char in sequence:
if char == " ":
mapped_char = self.char_map["<SPACE>"]
elif char not in self.char_map:
@@ -107,7 +144,7 @@ class CharTokenizer:
else:
mapped_char = self.char_map[char]
int_sequence.append(mapped_char)
- return Encoding(ids=int_sequence)
+ return Encoding(ids=int_sequence, tokens=list(sequence))
def decode(self, labels: list[int], remove_special_tokens: bool = True):
"""Use a character map and convert integer labels to an text sequence
@@ -146,6 +183,7 @@ class CharTokenizer:
def save(self, path: str):
"""Save the tokenizer to a file"""
+ os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w", encoding="utf-8") as file:
# save it in the following format:
# {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}}
@@ -155,31 +193,48 @@ class CharTokenizer:
ensure_ascii=False,
)
- def from_file(self, path: str):
+ @staticmethod
+ def from_file(path: str) -> "CharTokenizer":
"""Load the tokenizer from a file"""
+ char_tokenizer = CharTokenizer()
with open(path, "r", encoding="utf-8") as file:
# load it in the following format:
# {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}}
saved_file = json.load(file)
- self.char_map = saved_file["char_map"]
- self.index_map = saved_file["index_map"]
+ char_tokenizer.char_map = saved_file["char_map"]
+ char_tokenizer.index_map = saved_file["index_map"]
+
+ return char_tokenizer
@click.command()
@click.option("--dataset_path", default="data", help="Path to the MLS dataset")
@click.option("--language", default="mls_german_opus", help="Language to use")
@click.option("--split", default="train", help="Split to use (including all)")
-@click.option("--download", default=True, help="Whether to download the dataset")
-@click.option(
- "--out_path", default="tokenizer.json", help="Path to save the tokenizer to"
-)
+@click.option("--out_path", default="tokenizer.json", help="Path to save the tokenizer to")
@click.option("--vocab_size", default=2000, help="Size of the vocabulary")
+def train_bpe_tokenizer_cli(
+ dataset_path: str,
+ language: str,
+ split: str,
+ out_path: str,
+ vocab_size: int,
+):
+ """Train a Byte-Pair Encoder tokenizer on the MLS dataset"""
+ train_bpe_tokenizer(
+ dataset_path,
+ language,
+ split,
+ out_path,
+ vocab_size,
+ )
+
+
def train_bpe_tokenizer(
dataset_path: str,
language: str,
split: str,
out_path: str,
- download: bool,
vocab_size: int,
):
"""Train a Byte-Pair Encoder tokenizer on the MLS dataset
@@ -206,11 +261,12 @@ def train_bpe_tokenizer(
lines = []
for s_plit in splits:
- transcripts_path = os.path.join(
- dataset_path, language, s_plit, "transcripts.txt"
- )
- if download and not os.path.exists(transcripts_path):
- MultilingualLibriSpeech(dataset_path, language, s_plit, download=True)
+ transcripts_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt")
+ if not os.path.exists(transcripts_path):
+ raise FileNotFoundError(
+ f"Could not find transcripts.txt in {transcripts_path}. "
+ "Please make sure that the dataset is downloaded."
+ )
with open(
transcripts_path,
@@ -226,6 +282,7 @@ def train_bpe_tokenizer(
bpe_tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
initial_alphabet = [
+ " ",
"a",
"b",
"c",
@@ -272,6 +329,7 @@ def train_bpe_tokenizer(
"ü",
]
+ # TODO: add padding token / whitespace token / special tokens
trainer = BpeTrainer(
special_tokens=["[UNK]"],
vocab_size=vocab_size,
@@ -292,16 +350,22 @@ def train_bpe_tokenizer(
@click.option("--dataset_path", default="data", help="Path to the MLS dataset")
@click.option("--language", default="mls_german_opus", help="Language to use")
@click.option("--split", default="train", help="Split to use")
-@click.option(
- "--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to"
-)
-@click.option("--download", default=True, help="Whether to download the dataset")
+@click.option("--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to")
+def train_char_tokenizer_cli(
+ dataset_path: str,
+ language: str,
+ split: str,
+ out_path: str,
+):
+ """Train a Byte-Pair Encoder tokenizer on the MLS dataset"""
+ train_char_tokenizer(dataset_path, language, split, out_path)
+
+
def train_char_tokenizer(
dataset_path: str,
language: str,
split: str,
out_path: str,
- download: bool,
):
"""Train a Byte-Pair Encoder tokenizer on the MLS dataset
@@ -317,7 +381,7 @@ def train_char_tokenizer(
"""
char_tokenizer = CharTokenizer()
- char_tokenizer.train(dataset_path, language, split, download)
+ char_tokenizer.train(dataset_path, language, split)
char_tokenizer.save(out_path)
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 6af1e80..63deb72 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -1,232 +1,57 @@
"""Training script for the ASR model."""
import os
+from typing import TypedDict
+
import click
import torch
import torch.nn.functional as F
-import torchaudio
-from AudioLoader.speech import MultilingualLibriSpeech
from torch import nn, optim
from torch.utils.data import DataLoader
-from tokenizers import Tokenizer
-from .tokenizer import CharTokenizer
+from tqdm import tqdm
+
+from swr2_asr.model_deep_speech import SpeechRecognitionModel
+from swr2_asr.tokenizer import CharTokenizer, train_char_tokenizer
+from swr2_asr.utils import MLSDataset, Split, collate_fn
from .loss_scores import cer, wer
-train_audio_transforms = nn.Sequential(
- torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
- torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
- torchaudio.transforms.TimeMasking(time_mask_param=100),
-)
+# TODO: improve naming of functions
-valid_audio_transforms = torchaudio.transforms.MelSpectrogram()
-
-# text_transform = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
-text_transform = CharTokenizer()
-text_transform.from_file("data/tokenizers/char_tokenizer_german.json")
-
-
-def data_processing(data, data_type="train"):
- """Return the spectrograms, labels, and their lengths."""
- spectrograms = []
- labels = []
- input_lengths = []
- label_lengths = []
- for sample in data:
- if data_type == "train":
- spec = train_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1)
- elif data_type == "valid":
- spec = valid_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1)
- else:
- raise ValueError("data_type should be train or valid")
- spectrograms.append(spec)
- label = torch.Tensor(text_transform.encode(sample["utterance"]).ids)
- labels.append(label)
- input_lengths.append(spec.shape[0] // 2)
- label_lengths.append(len(label))
-
- spectrograms = (
- nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
- .unsqueeze(1)
- .transpose(2, 3)
- )
- labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)
- return spectrograms, labels, input_lengths, label_lengths
+class HParams(TypedDict):
+ """Type for the hyperparameters of the model."""
+ n_cnn_layers: int
+ n_rnn_layers: int
+ rnn_dim: int
+ n_class: int
+ n_feats: int
+ stride: int
+ dropout: float
+ learning_rate: float
+ batch_size: int
+ epochs: int
-def greedy_decoder(
- output, labels, label_lengths, blank_label=28, collapse_repeated=True
-):
- # TODO: adopt to support both tokenizers
+
+def greedy_decoder(output, tokenizer, labels, label_lengths, collapse_repeated=True):
"""Greedily decode a sequence."""
+ print("output shape", output.shape)
arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
+ blank_label = tokenizer.encode(" ").ids[0]
decodes = []
targets = []
for i, args in enumerate(arg_maxes):
decode = []
- targets.append(
- text_transform.decode(
- [int(x) for x in labels[i][: label_lengths[i]].tolist()]
- )
- )
+ targets.append(tokenizer.decode([int(x) for x in labels[i][: label_lengths[i]].tolist()]))
for j, index in enumerate(args):
if index != blank_label:
if collapse_repeated and j != 0 and index == args[j - 1]:
continue
decode.append(index.item())
- decodes.append(text_transform.decode(decode))
+ decodes.append(tokenizer.decode(decode))
return decodes, targets
-# TODO: restructure into own file / class
-class CNNLayerNorm(nn.Module):
- """Layer normalization built for cnns input"""
-
- def __init__(self, n_feats: int):
- super().__init__()
- self.layer_norm = nn.LayerNorm(n_feats)
-
- def forward(self, data):
- """x (batch, channel, feature, time)"""
- data = data.transpose(2, 3).contiguous() # (batch, channel, time, feature)
- data = self.layer_norm(data)
- return data.transpose(2, 3).contiguous() # (batch, channel, feature, time)
-
-
-class ResidualCNN(nn.Module):
- """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf"""
-
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel: int,
- stride: int,
- dropout: float,
- n_feats: int,
- ):
- super().__init__()
-
- self.cnn1 = nn.Conv2d(
- in_channels, out_channels, kernel, stride, padding=kernel // 2
- )
- self.cnn2 = nn.Conv2d(
- out_channels,
- out_channels,
- kernel,
- stride,
- padding=kernel // 2,
- )
- self.dropout1 = nn.Dropout(dropout)
- self.dropout2 = nn.Dropout(dropout)
- self.layer_norm1 = CNNLayerNorm(n_feats)
- self.layer_norm2 = CNNLayerNorm(n_feats)
-
- def forward(self, data):
- """x (batch, channel, feature, time)"""
- residual = data # (batch, channel, feature, time)
- data = self.layer_norm1(data)
- data = F.gelu(data)
- data = self.dropout1(data)
- data = self.cnn1(data)
- data = self.layer_norm2(data)
- data = F.gelu(data)
- data = self.dropout2(data)
- data = self.cnn2(data)
- data += residual
- return data # (batch, channel, feature, time)
-
-
-class BidirectionalGRU(nn.Module):
- """BIdirectional GRU with Layer Normalization and Dropout"""
-
- def __init__(
- self,
- rnn_dim: int,
- hidden_size: int,
- dropout: float,
- batch_first: bool,
- ):
- super().__init__()
-
- self.bi_gru = nn.GRU(
- input_size=rnn_dim,
- hidden_size=hidden_size,
- num_layers=1,
- batch_first=batch_first,
- bidirectional=True,
- )
- self.layer_norm = nn.LayerNorm(rnn_dim)
- self.dropout = nn.Dropout(dropout)
-
- def forward(self, data):
- """data (batch, time, feature)"""
- data = self.layer_norm(data)
- data = F.gelu(data)
- data = self.dropout(data)
- data, _ = self.bi_gru(data)
- return data
-
-
-class SpeechRecognitionModel(nn.Module):
- """Speech Recognition Model Inspired by DeepSpeech 2"""
-
- def __init__(
- self,
- n_cnn_layers: int,
- n_rnn_layers: int,
- rnn_dim: int,
- n_class: int,
- n_feats: int,
- stride: int = 2,
- dropout: float = 0.1,
- ):
- super().__init__()
- n_feats //= 2
- self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2)
- # n residual cnn layers with filter size of 32
- self.rescnn_layers = nn.Sequential(
- *[
- ResidualCNN(
- 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats
- )
- for _ in range(n_cnn_layers)
- ]
- )
- self.fully_connected = nn.Linear(n_feats * 32, rnn_dim)
- self.birnn_layers = nn.Sequential(
- *[
- BidirectionalGRU(
- rnn_dim=rnn_dim if i == 0 else rnn_dim * 2,
- hidden_size=rnn_dim,
- dropout=dropout,
- batch_first=i == 0,
- )
- for i in range(n_rnn_layers)
- ]
- )
- self.classifier = nn.Sequential(
- nn.Linear(rnn_dim * 2, rnn_dim), # birnn returns rnn_dim*2
- nn.GELU(),
- nn.Dropout(dropout),
- nn.Linear(rnn_dim, n_class),
- )
-
- def forward(self, data):
- """data (batch, channel, feature, time)"""
- data = self.cnn(data)
- data = self.rescnn_layers(data)
- sizes = data.size()
- data = data.view(
- sizes[0], sizes[1] * sizes[2], sizes[3]
- ) # (batch, feature, time)
- data = data.transpose(1, 2) # (batch, time, feature)
- data = self.fully_connected(data)
- data = self.birnn_layers(data)
- data = self.classifier(data)
- return data
-
-
class IterMeter:
"""keeps track of total iterations"""
@@ -254,36 +79,30 @@ def train(
):
"""Train"""
model.train()
- data_len = len(train_loader.dataset)
- for batch_idx, _data in enumerate(train_loader):
- spectrograms, labels, input_lengths, label_lengths = _data
- spectrograms, labels = spectrograms.to(device), labels.to(device)
-
+ print(f"Epoch: {epoch}")
+ losses = []
+ for _data in tqdm(train_loader, desc="batches"):
+ spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device)
optimizer.zero_grad()
output = model(spectrograms) # (batch, time, n_class)
output = F.log_softmax(output, dim=2)
output = output.transpose(0, 1) # (time, batch, n_class)
- loss = criterion(output, labels, input_lengths, label_lengths)
+ loss = criterion(output, labels, _data["input_length"], _data["utterance_length"])
loss.backward()
optimizer.step()
scheduler.step()
iter_meter.step()
- if batch_idx % 100 == 0 or batch_idx == data_len:
- print(
- f"Train Epoch: \
- {epoch} \
- [{batch_idx * len(spectrograms)}/{data_len} \
- ({100.0 * batch_idx / len(train_loader)}%)]\t \
- Loss: {loss.item()}"
- )
- return loss.item()
+ losses.append(loss.item())
+
+ print(f"loss in epoch {epoch}: {sum(losses) / len(losses)}")
+ return sum(losses) / len(losses)
-# TODO: check how dataloader can be made more efficient
-def test(model, device, test_loader, criterion):
+
+def test(model, device, test_loader, criterion, tokenizer):
"""Test"""
print("\nevaluating...")
model.eval()
@@ -291,18 +110,20 @@ def test(model, device, test_loader, criterion):
test_cer, test_wer = [], []
with torch.no_grad():
for _data in test_loader:
- spectrograms, labels, input_lengths, label_lengths = _data
- spectrograms, labels = spectrograms.to(device), labels.to(device)
+ spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device)
output = model(spectrograms) # (batch, time, n_class)
output = F.log_softmax(output, dim=2)
output = output.transpose(0, 1) # (time, batch, n_class)
- loss = criterion(output, labels, input_lengths, label_lengths)
+ loss = criterion(output, labels, _data["input_length"], _data["utterance_length"])
test_loss += loss.item() / len(test_loader)
decoded_preds, decoded_targets = greedy_decoder(
- output.transpose(0, 1), labels, label_lengths
+ output=output.transpose(0, 1),
+ labels=labels,
+ label_lengths=_data["utterance_length"],
+ tokenizer=tokenizer,
)
for j, pred in enumerate(decoded_preds):
test_cer.append(cer(decoded_targets[j], pred))
@@ -313,9 +134,11 @@ def test(model, device, test_loader, criterion):
print(
f"Test set: Average loss:\
- {test_loss}, Average CER: {avg_cer} Average WER: {avg_wer}\n"
+ {test_loss}, Average CER: {None} Average WER: {None}\n"
)
+ return test_loss, avg_cer, avg_wer
+
def run(
learning_rate: float,
@@ -324,46 +147,66 @@ def run(
load: bool,
path: str,
dataset_path: str,
+ language: str,
) -> None:
"""Runs the training script."""
- hparams = {
- "n_cnn_layers": 3,
- "n_rnn_layers": 5,
- "rnn_dim": 512,
- "n_class": 36, # TODO: dynamically determine this from vocab size
- "n_feats": 128,
- "stride": 2,
- "dropout": 0.1,
- "learning_rate": learning_rate,
- "batch_size": batch_size,
- "epochs": epochs,
- }
-
use_cuda = torch.cuda.is_available()
torch.manual_seed(42)
device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member
# device = torch.device("mps")
- download_dataset = not os.path.isdir(path)
- train_dataset = MultilingualLibriSpeech(
- dataset_path, "mls_german_opus", split="dev", download=download_dataset
+ # load dataset
+ train_dataset = MLSDataset(
+ dataset_path, language, Split.TRAIN, download=True, spectrogram_hparams=None
)
- test_dataset = MultilingualLibriSpeech(
- dataset_path, "mls_german_opus", split="test", download=False
+ valid_dataset = MLSDataset(
+ dataset_path, language, Split.VALID, download=True, spectrogram_hparams=None
+ )
+
+ # load tokenizer (bpe by default):
+ if not os.path.isfile("data/tokenizers/char_tokenizer_german.json"):
+ print("There is no tokenizer available. Do you want to train it on the dataset?")
+ input("Press Enter to continue...")
+ train_char_tokenizer(
+ dataset_path=dataset_path,
+ language=language,
+ split="all",
+ download=False,
+ out_path="data/tokenizers/char_tokenizer_german.json",
+ )
+
+ tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json")
+
+ train_dataset.set_tokenizer(tokenizer) # type: ignore
+ valid_dataset.set_tokenizer(tokenizer) # type: ignore
+
+ print(f"Waveform shape: {train_dataset[0]['waveform'].shape}")
+
+ hparams = HParams(
+ n_cnn_layers=3,
+ n_rnn_layers=5,
+ rnn_dim=512,
+ n_class=tokenizer.get_vocab_size(),
+ n_feats=128,
+ stride=2,
+ dropout=0.1,
+ learning_rate=learning_rate,
+ batch_size=batch_size,
+ epochs=epochs,
)
train_loader = DataLoader(
train_dataset,
batch_size=hparams["batch_size"],
shuffle=True,
- collate_fn=lambda x: data_processing(x, "train"),
+ collate_fn=lambda x: collate_fn(x),
)
- test_loader = DataLoader(
- test_dataset,
+ valid_loader = DataLoader(
+ valid_dataset,
batch_size=hparams["batch_size"],
shuffle=True,
- collate_fn=lambda x: data_processing(x, "train"),
+ collate_fn=lambda x: collate_fn(x),
)
# enable flag to find the most compatible algorithms in advance
@@ -379,12 +222,10 @@ def run(
hparams["stride"],
hparams["dropout"],
).to(device)
-
- print(
- "Num Model Parameters", sum((param.nelement() for param in model.parameters()))
- )
+ print(tokenizer.encode(" "))
+ print("Num Model Parameters", sum((param.nelement() for param in model.parameters())))
optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"])
- criterion = nn.CTCLoss(blank=28).to(device)
+ criterion = nn.CTCLoss(tokenizer.encode(" ").ids[0]).to(device)
if load:
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
@@ -412,7 +253,13 @@ def run(
iter_meter,
)
- test(model=model, device=device, test_loader=test_loader, criterion=criterion)
+ test(
+ model=model,
+ device=device,
+ test_loader=valid_loader,
+ criterion=criterion,
+ tokenizer=tokenizer,
+ )
print("saving epoch", str(epoch))
torch.save(
{"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss},
@@ -452,4 +299,9 @@ def run_cli(
load=load,
path=path,
dataset_path=dataset_path,
+ language="mls_german_opus",
)
+
+
+if __name__ == "__main__":
+ run_cli() # pylint: disable=no-value-for-parameter
diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py
new file mode 100644
index 0000000..8a950ab
--- /dev/null
+++ b/swr2_asr/utils.py
@@ -0,0 +1,316 @@
+"""Class containing utils for the ASR system."""
+import os
+from enum import Enum
+from typing import TypedDict
+
+import numpy as np
+import torch
+import torchaudio
+from tokenizers import Tokenizer
+from torch.utils.data import Dataset
+from torchaudio.datasets.utils import _extract_tar as extract_archive
+
+from swr2_asr.tokenizer import TokenizerType
+
+train_audio_transforms = torch.nn.Sequential(
+ torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
+ torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
+ torchaudio.transforms.TimeMasking(time_mask_param=100),
+)
+
+
+# create enum specifiying dataset splits
+class MLSSplit(str, Enum):
+ """Enum specifying dataset as they are defined in the
+ Multilingual LibriSpeech dataset"""
+
+ TRAIN = "train"
+ TEST = "test"
+ DEV = "dev"
+
+
+class Split(str, Enum):
+ """Extending the MLSSplit class to allow for a custom validatio split"""
+
+ TRAIN = "train"
+ VALID = "valid"
+ TEST = "test"
+ DEV = "dev"
+
+
+def split_to_mls_split(split_name: Split) -> MLSSplit:
+ """Converts the custom split to a MLSSplit"""
+ if split_name == Split.VALID:
+ return MLSSplit.TRAIN
+ else:
+ return split_name # type: ignore
+
+
+class Sample(TypedDict):
+ """Type for a sample in the dataset"""
+
+ waveform: torch.Tensor
+ spectrogram: torch.Tensor
+ input_length: int
+ utterance: torch.Tensor
+ utterance_length: int
+ sample_rate: int
+ speaker_id: str
+ book_id: str
+ chapter_id: str
+
+
+class MLSDataset(Dataset):
+ """Custom Dataset for reading Multilingual LibriSpeech
+
+ Attributes:
+ dataset_path (str):
+ path to the dataset
+ language (str):
+ language of the dataset
+ split (Split):
+ split of the dataset
+ mls_split (MLSSplit):
+ split of the dataset as defined in the Multilingual LibriSpeech dataset
+ dataset_lookup (list):
+ list of dicts containing the speakerid, bookid, chapterid and utterance
+
+ directory structure:
+ <dataset_path>
+ ├── <language>
+ │ ├── train
+ │ │ ├── transcripts.txt
+ │ │ └── audio
+ │ │ └── <speakerid>
+ │ │ └── <bookid>
+ │ │ └── <speakerid>_<bookid>_<chapterid>.opus / .flac
+
+ each line in transcripts.txt has the following format:
+ <speakerid>_<bookid>_<chapterid> <utterance>
+ """
+
+ def __init__(
+ self,
+ dataset_path: str,
+ language: str,
+ split: Split,
+ download: bool,
+ spectrogram_hparams: dict | None,
+ ):
+ """Initializes the dataset"""
+ self.dataset_path = dataset_path
+ self.language = language
+ self.file_ext = ".opus" if "opus" in language else ".flac"
+ self.mls_split: MLSSplit = split_to_mls_split(split) # split path on disk
+ self.split: Split = split # split used internally
+
+ if spectrogram_hparams is None:
+ self.spectrogram_hparams = {
+ "sample_rate": 16000,
+ "n_fft": 400,
+ "win_length": 400,
+ "hop_length": 160,
+ "n_mels": 128,
+ "f_min": 0,
+ "f_max": 8000,
+ "power": 2.0,
+ }
+ else:
+ self.spectrogram_hparams = spectrogram_hparams
+
+ self.dataset_lookup = []
+ self.tokenizer: type[TokenizerType]
+
+ self._handle_download_dataset(download)
+ self._validate_local_directory()
+ self.initialize()
+
+ def initialize(self) -> None:
+ """Initializes the dataset
+
+ Reads the transcripts.txt file and creates a lookup table
+ """
+ transcripts_path = os.path.join(
+ self.dataset_path, self.language, self.mls_split, "transcripts.txt"
+ )
+
+ with open(transcripts_path, "r", encoding="utf-8") as script_file:
+ # read all lines in transcripts.txt
+ transcripts = script_file.readlines()
+ # split each line into (<speakerid>_<bookid>_<chapterid>, <utterance>)
+ transcripts = [line.strip().split("\t", 1) for line in transcripts] # type: ignore
+ utterances = [utterance.strip() for _, utterance in transcripts] # type: ignore
+ identifier = [identifier.strip() for identifier, _ in transcripts] # type: ignore
+ identifier = [path.split("_") for path in identifier]
+
+ if self.split == Split.VALID:
+ np.random.seed(42)
+ indices = np.random.choice(len(utterances), int(len(utterances) * 0.2))
+ utterances = [utterances[i] for i in indices]
+ identifier = [identifier[i] for i in indices]
+ elif self.split == Split.TRAIN:
+ np.random.seed(42)
+ indices = np.random.choice(len(utterances), int(len(utterances) * 0.8))
+ utterances = [utterances[i] for i in indices]
+ identifier = [identifier[i] for i in indices]
+
+ self.dataset_lookup = [
+ {
+ "speakerid": path[0],
+ "bookid": path[1],
+ "chapterid": path[2],
+ "utterance": utterance,
+ }
+ for path, utterance in zip(identifier, utterances, strict=False)
+ ]
+
+ def set_tokenizer(self, tokenizer: type[TokenizerType]):
+ """Sets the tokenizer"""
+ self.tokenizer = tokenizer
+
+ def _handle_download_dataset(self, download: bool) -> None:
+ """Download the dataset"""
+ if not download:
+ print("Download flag not set, skipping download")
+ return
+ # zip exists:
+ if os.path.isfile(os.path.join(self.dataset_path, self.language) + ".tar.gz") and download:
+ print(f"Found dataset at {self.dataset_path}. Skipping download")
+ # zip does not exist:
+ else:
+ os.makedirs(self.dataset_path, exist_ok=True)
+ url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz"
+
+ torch.hub.download_url_to_file(
+ url, os.path.join(self.dataset_path, self.language) + ".tar.gz"
+ )
+
+ # unzip the dataset
+ if not os.path.isdir(os.path.join(self.dataset_path, self.language)):
+ print(
+ f"Unzipping the dataset at {os.path.join(self.dataset_path, self.language) + '.tar.gz'}"
+ )
+ extract_archive(
+ os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True
+ )
+ else:
+ print("Dataset is already unzipped, validating it now")
+ return
+
+ def _validate_local_directory(self):
+ # check if dataset_path exists
+ if not os.path.exists(self.dataset_path):
+ raise ValueError("Dataset path does not exist")
+ if not os.path.exists(os.path.join(self.dataset_path, self.language)):
+ raise ValueError("Language not downloaded!")
+ if not os.path.exists(os.path.join(self.dataset_path, self.language, self.mls_split)):
+ raise ValueError("Split not found in dataset")
+
+ def __len__(self):
+ """Returns the length of the dataset"""
+ return len(self.dataset_lookup)
+
+ def __getitem__(self, idx: int) -> Sample:
+ """One sample"""
+ if self.tokenizer is None:
+ raise ValueError("No tokenizer set")
+ # get the utterance
+ utterance = self.dataset_lookup[idx]["utterance"]
+
+ # get the audio file
+ audio_path = os.path.join(
+ self.dataset_path,
+ self.language,
+ self.mls_split,
+ "audio",
+ self.dataset_lookup[idx]["speakerid"],
+ self.dataset_lookup[idx]["bookid"],
+ "_".join(
+ [
+ self.dataset_lookup[idx]["speakerid"],
+ self.dataset_lookup[idx]["bookid"],
+ self.dataset_lookup[idx]["chapterid"],
+ ]
+ )
+ + self.file_ext,
+ )
+
+ waveform, sample_rate = torchaudio.load(audio_path) # type: ignore
+
+ # resample if necessary
+ if sample_rate != self.spectrogram_hparams["sample_rate"]:
+ resampler = torchaudio.transforms.Resample(
+ sample_rate, self.spectrogram_hparams["sample_rate"]
+ )
+ waveform = resampler(waveform)
+
+ spec = (
+ torchaudio.transforms.MelSpectrogram(**self.spectrogram_hparams)(waveform)
+ .squeeze(0)
+ .transpose(0, 1)
+ )
+
+ input_length = spec.shape[0] // 2
+
+ utterance_length = len(utterance)
+
+ utterance = self.tokenizer.encode(utterance)
+
+ utterance = torch.LongTensor(utterance.ids)
+
+ return Sample(
+ waveform=waveform,
+ spectrogram=spec,
+ input_length=input_length,
+ utterance=utterance,
+ utterance_length=utterance_length,
+ sample_rate=self.spectrogram_hparams["sample_rate"],
+ speaker_id=self.dataset_lookup[idx]["speakerid"],
+ book_id=self.dataset_lookup[idx]["bookid"],
+ chapter_id=self.dataset_lookup[idx]["chapterid"],
+ )
+
+
+def collate_fn(samples: list[Sample]) -> dict:
+ """Collate function for the dataloader
+
+ pads all tensors within a batch to the same dimensions
+ """
+ waveforms = []
+ spectrograms = []
+ labels = []
+ input_lengths = []
+ label_lengths = []
+
+ for sample in samples:
+ waveforms.append(sample["waveform"].transpose(0, 1))
+ spectrograms.append(sample["spectrogram"])
+ labels.append(sample["utterance"])
+ input_lengths.append(sample["spectrogram"].shape[0] // 2)
+ label_lengths.append(len(sample["utterance"]))
+
+ waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True)
+ spectrograms = (
+ torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3)
+ )
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True)
+
+ return {
+ "waveform": waveforms,
+ "spectrogram": spectrograms,
+ "input_length": input_lengths,
+ "utterance": labels,
+ "utterance_length": label_lengths,
+ }
+
+
+if __name__ == "__main__":
+ DATASET_PATH = "/Volumes/pherkel/SWR2-ASR"
+ LANGUAGE = "mls_german_opus"
+ split = Split.TRAIN
+ DOWNLOAD = False
+
+ dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, DOWNLOAD, None)
+
+ tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
+ dataset.set_tokenizer(tok)