aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.vscode/settings.json8
-rw-r--r--poetry.lock82
-rw-r--r--pyproject.toml4
-rw-r--r--readme.md11
-rw-r--r--swr2_asr/tokenizer.py323
-rw-r--r--swr2_asr/train.py104
6 files changed, 439 insertions, 93 deletions
diff --git a/.vscode/settings.json b/.vscode/settings.json
index bd8762b..0054bca 100644
--- a/.vscode/settings.json
+++ b/.vscode/settings.json
@@ -3,10 +3,12 @@
"editor.formatOnType": true,
"editor.defaultFormatter": "ms-python.black-formatter",
"editor.formatOnSave": true,
- "editor.rulers": [80, 120],
+ "editor.rulers": [88, 120],
},
"black-formatter.importStrategy": "fromEnvironment",
"python.analysis.typeCheckingMode": "basic",
- "python.linting.pylintEnabled": true,
- "python.linting.enabled": true,
+ "ruff.organizeImports": true,
+ "ruff.importStrategy": "fromEnvironment",
+ "ruff.fixAll": true,
+ "ruff.run": "onType"
} \ No newline at end of file
diff --git a/poetry.lock b/poetry.lock
index d5a40f1..49d37d1 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -692,6 +692,32 @@ spelling = ["pyenchant (>=3.2,<4.0)"]
testutils = ["gitpython (>3)"]
[[package]]
+name = "ruff"
+version = "0.0.285"
+description = "An extremely fast Python linter, written in Rust."
+optional = false
+python-versions = ">=3.7"
+files = [
+ {file = "ruff-0.0.285-py3-none-macosx_10_7_x86_64.whl", hash = "sha256:72a3a0936369b986b0e959f9090206ed3c18f9e5e439ea5b8e6867c6707aded5"},
+ {file = "ruff-0.0.285-py3-none-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:0d9ab6ad16742eb78919e0fba09f914f042409df40ad63423c34bb20d350162a"},
+ {file = "ruff-0.0.285-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c48926156288b8ac005eb1db5e77c15e8a37309ae49d9fb6771d5cf5f777590"},
+ {file = "ruff-0.0.285-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:1d2a60c102e7a5e147b58fc2cbea12a563c565383effc527c987ea2086a05742"},
+ {file = "ruff-0.0.285-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b02aae62f922d088bb01943e1dbd861688ada13d735b78b8348a7d90121fd292"},
+ {file = "ruff-0.0.285-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:f572c4296d8c7ddd22c3204de4031965be524fdd1fdaaef273945932912b28c5"},
+ {file = "ruff-0.0.285-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:80effdf4fe69763d69eb4ab9443e186fd09e668b59fe70ba4b49f4c077d15a1b"},
+ {file = "ruff-0.0.285-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5977ce304da35c263f5e082901bd7ac0bd2be845a8fcfd1a29e4d6680cddb307"},
+ {file = "ruff-0.0.285-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:72a087712d474fa17b915d7cb9ef807e1256182b12ddfafb105eb00aeee48d1a"},
+ {file = "ruff-0.0.285-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:7ce67736cd8dfe97162d1e7adfc2d9a1bac0efb9aaaff32e4042c7cde079f54b"},
+ {file = "ruff-0.0.285-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:5473a4c6cac34f583bff08c5f63b8def5599a0ea4dc96c0302fbd2cc0b3ecbad"},
+ {file = "ruff-0.0.285-py3-none-musllinux_1_2_i686.whl", hash = "sha256:e6b1c961d608d373a032f047a20bf3c55ad05f56c32e7b96dcca0830a2a72348"},
+ {file = "ruff-0.0.285-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:2933cc9631f453305399c7b8fb72b113ad76b49ae1d7103cc4afd3a423bed164"},
+ {file = "ruff-0.0.285-py3-none-win32.whl", hash = "sha256:770c5eb6376de024111443022cda534fb28980a9dd3b4abc83992a8770167ba6"},
+ {file = "ruff-0.0.285-py3-none-win_amd64.whl", hash = "sha256:a8c6ad6b9cd77489bf6d1510950cbbe47a843aa234adff0960bae64bd06c3b6d"},
+ {file = "ruff-0.0.285-py3-none-win_arm64.whl", hash = "sha256:de44fbc6c3b25fccee473ddf851416fd4e246fc6027b2197c395b1b3b3897921"},
+ {file = "ruff-0.0.285.tar.gz", hash = "sha256:45866048d1dcdcc80855998cb26c4b2b05881f9e043d2e3bfe1aa36d9a2e8f28"},
+]
+
+[[package]]
name = "setuptools"
version = "68.1.2"
description = "Easily download, build, install, upgrade, and uninstall Python packages"
@@ -722,6 +748,60 @@ files = [
mpmath = ">=0.19"
[[package]]
+name = "tokenizers"
+version = "0.13.3"
+description = "Fast and Customizable Tokenizers"
+optional = false
+python-versions = "*"
+files = [
+ {file = "tokenizers-0.13.3-cp310-cp310-macosx_10_11_x86_64.whl", hash = "sha256:f3835c5be51de8c0a092058a4d4380cb9244fb34681fd0a295fbf0a52a5fdf33"},
+ {file = "tokenizers-0.13.3-cp310-cp310-macosx_12_0_arm64.whl", hash = "sha256:4ef4c3e821730f2692489e926b184321e887f34fb8a6b80b8096b966ba663d07"},
+ {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c5fd1a6a25353e9aa762e2aae5a1e63883cad9f4e997c447ec39d071020459bc"},
+ {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ee0b1b311d65beab83d7a41c56a1e46ab732a9eed4460648e8eb0bd69fc2d059"},
+ {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ef4215284df1277dadbcc5e17d4882bda19f770d02348e73523f7e7d8b8d396"},
+ {file = "tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a4d53976079cff8a033f778fb9adca2d9d69d009c02fa2d71a878b5f3963ed30"},
+ {file = "tokenizers-0.13.3-cp310-cp310-win32.whl", hash = "sha256:1f0e3b4c2ea2cd13238ce43548959c118069db7579e5d40ec270ad77da5833ce"},
+ {file = "tokenizers-0.13.3-cp310-cp310-win_amd64.whl", hash = "sha256:89649c00d0d7211e8186f7a75dfa1db6996f65edce4b84821817eadcc2d3c79e"},
+ {file = "tokenizers-0.13.3-cp311-cp311-macosx_10_11_universal2.whl", hash = "sha256:56b726e0d2bbc9243872b0144515ba684af5b8d8cd112fb83ee1365e26ec74c8"},
+ {file = "tokenizers-0.13.3-cp311-cp311-macosx_12_0_arm64.whl", hash = "sha256:cc5c022ce692e1f499d745af293ab9ee6f5d92538ed2faf73f9708c89ee59ce6"},
+ {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f55c981ac44ba87c93e847c333e58c12abcbb377a0c2f2ef96e1a266e4184ff2"},
+ {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f247eae99800ef821a91f47c5280e9e9afaeed9980fc444208d5aa6ba69ff148"},
+ {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4b3e3215d048e94f40f1c95802e45dcc37c5b05eb46280fc2ccc8cd351bff839"},
+ {file = "tokenizers-0.13.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ba2b0bf01777c9b9bc94b53764d6684554ce98551fec496f71bc5be3a03e98b"},
+ {file = "tokenizers-0.13.3-cp311-cp311-win32.whl", hash = "sha256:cc78d77f597d1c458bf0ea7c2a64b6aa06941c7a99cb135b5969b0278824d808"},
+ {file = "tokenizers-0.13.3-cp311-cp311-win_amd64.whl", hash = "sha256:ecf182bf59bd541a8876deccf0360f5ae60496fd50b58510048020751cf1724c"},
+ {file = "tokenizers-0.13.3-cp37-cp37m-macosx_10_11_x86_64.whl", hash = "sha256:0527dc5436a1f6bf2c0327da3145687d3bcfbeab91fed8458920093de3901b44"},
+ {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:07cbb2c307627dc99b44b22ef05ff4473aa7c7cc1fec8f0a8b37d8a64b1a16d2"},
+ {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4560dbdeaae5b7ee0d4e493027e3de6d53c991b5002d7ff95083c99e11dd5ac0"},
+ {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:64064bd0322405c9374305ab9b4c07152a1474370327499911937fd4a76d004b"},
+ {file = "tokenizers-0.13.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b8c6e2ab0f2e3d939ca66aa1d596602105fe33b505cd2854a4c1717f704c51de"},
+ {file = "tokenizers-0.13.3-cp37-cp37m-win32.whl", hash = "sha256:6cc29d410768f960db8677221e497226e545eaaea01aa3613fa0fdf2cc96cff4"},
+ {file = "tokenizers-0.13.3-cp37-cp37m-win_amd64.whl", hash = "sha256:fc2a7fdf864554a0dacf09d32e17c0caa9afe72baf9dd7ddedc61973bae352d8"},
+ {file = "tokenizers-0.13.3-cp38-cp38-macosx_10_11_x86_64.whl", hash = "sha256:8791dedba834c1fc55e5f1521be325ea3dafb381964be20684b92fdac95d79b7"},
+ {file = "tokenizers-0.13.3-cp38-cp38-macosx_12_0_arm64.whl", hash = "sha256:d607a6a13718aeb20507bdf2b96162ead5145bbbfa26788d6b833f98b31b26e1"},
+ {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3791338f809cd1bf8e4fee6b540b36822434d0c6c6bc47162448deee3f77d425"},
+ {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c2f35f30e39e6aab8716f07790f646bdc6e4a853816cc49a95ef2a9016bf9ce6"},
+ {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:310204dfed5aa797128b65d63538a9837cbdd15da2a29a77d67eefa489edda26"},
+ {file = "tokenizers-0.13.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a0f9b92ea052305166559f38498b3b0cae159caea712646648aaa272f7160963"},
+ {file = "tokenizers-0.13.3-cp38-cp38-win32.whl", hash = "sha256:9a3fa134896c3c1f0da6e762d15141fbff30d094067c8f1157b9fdca593b5806"},
+ {file = "tokenizers-0.13.3-cp38-cp38-win_amd64.whl", hash = "sha256:8e7b0cdeace87fa9e760e6a605e0ae8fc14b7d72e9fc19c578116f7287bb873d"},
+ {file = "tokenizers-0.13.3-cp39-cp39-macosx_10_11_x86_64.whl", hash = "sha256:00cee1e0859d55507e693a48fa4aef07060c4bb6bd93d80120e18fea9371c66d"},
+ {file = "tokenizers-0.13.3-cp39-cp39-macosx_12_0_arm64.whl", hash = "sha256:a23ff602d0797cea1d0506ce69b27523b07e70f6dda982ab8cf82402de839088"},
+ {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70ce07445050b537d2696022dafb115307abdffd2a5c106f029490f84501ef97"},
+ {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:280ffe95f50eaaf655b3a1dc7ff1d9cf4777029dbbc3e63a74e65a056594abc3"},
+ {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:97acfcec592f7e9de8cadcdcda50a7134423ac8455c0166b28c9ff04d227b371"},
+ {file = "tokenizers-0.13.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd7730c98a3010cd4f523465867ff95cd9d6430db46676ce79358f65ae39797b"},
+ {file = "tokenizers-0.13.3-cp39-cp39-win32.whl", hash = "sha256:48625a108029cb1ddf42e17a81b5a3230ba6888a70c9dc14e81bc319e812652d"},
+ {file = "tokenizers-0.13.3-cp39-cp39-win_amd64.whl", hash = "sha256:bc0a6f1ba036e482db6453571c9e3e60ecd5489980ffd95d11dc9f960483d783"},
+ {file = "tokenizers-0.13.3.tar.gz", hash = "sha256:2e546dbb68b623008a5442353137fbb0123d311a6d7ba52f2667c8862a75af2e"},
+]
+
+[package.extras]
+dev = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"]
+docs = ["setuptools-rust", "sphinx", "sphinx-rtd-theme"]
+testing = ["black (==22.3)", "datasets", "numpy", "pytest", "requests"]
+
+[[package]]
name = "tomli"
version = "2.0.1"
description = "A lil' TOML parser"
@@ -999,4 +1079,4 @@ files = [
[metadata]
lock-version = "2.0"
python-versions = "~3.10"
-content-hash = "db7fb92ac025d1ec61cee65ac267365851f2974063c79c6ba7dcca964c1ba2e8"
+content-hash = "a72b4e5791a6216b58b53a72bf68d97dbdbc95978b3974fddd9e5f9b76e36321"
diff --git a/pyproject.toml b/pyproject.toml
index 1c29b7c..eb17479 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -15,15 +15,19 @@ audioloader = {git = "https://github.com/marvinborner/AudioLoader.git"}
tqdm = "^4.66.1"
numpy = "^1.25.2"
mido = "^1.3.0"
+tokenizers = "^0.13.3"
click = "^8.1.7"
[tool.poetry.group.dev.dependencies]
black = "^23.7.0"
mypy = "^1.5.1"
pylint = "^2.17.5"
+ruff = "^0.0.285"
[tool.poetry.scripts]
train = "swr2_asr.train:run_cli"
+train-bpe-tokenizer = "swr2_asr.tokenizer:train_bpe_tokenizer"
+train-char-tokenizer = "swr2_asr.tokenizer:train_char_tokenizer"
[build-system]
requires = ["poetry-core"]
diff --git a/readme.md b/readme.md
index 47d9a31..795283b 100644
--- a/readme.md
+++ b/readme.md
@@ -10,6 +10,17 @@ poetry install
# Usage
+## Training the tokenizer
+We use a byte pair encoding tokenizer. To train the tokenizer, run
+```
+poetry run train-bpe-tokenizer --dataset_path="DATA_PATH" --language=mls_german_opus --split=all --out_path="data/tokenizers/bpe_tokenizer_german_3000.json" --vocab_size=3000
+```
+with the desired values for `DATA_PATH` and `vocab_size`.
+
+You can also use a character level tokenizer, which can be trained with
+```
+poetry run train-char-tokenizer --dataset_path="DATA_PATH" --language=mls_german_opus --split=all --out_path="data/tokenizers/char_tokenizer_german.txt"
+```
## Training
Train using the provided train script:
diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py
new file mode 100644
index 0000000..d9cd622
--- /dev/null
+++ b/swr2_asr/tokenizer.py
@@ -0,0 +1,323 @@
+"""Tokenizer for use with Multilingual Librispeech"""
+from dataclasses import dataclass
+import json
+import os
+import click
+from tqdm import tqdm
+
+from AudioLoader.speech import MultilingualLibriSpeech
+
+from tokenizers import Tokenizer, normalizers
+from tokenizers.models import BPE
+from tokenizers.trainers import BpeTrainer
+from tokenizers.pre_tokenizers import Whitespace
+
+
+@dataclass
+class Encoding:
+ """Simple dataclass to represent an encoding"""
+
+ ids: list[int]
+
+
+class CharTokenizer:
+ """Very simple tokenizer for use with Multilingual Librispeech
+
+ Simply checks what characters are in the dataset and uses them as tokens.
+
+ Exposes the same interface as tokenizers from the huggingface library, i.e.
+ encode, decode, decode_batch, save, from_file and train.
+ """
+
+ def __init__(self):
+ self.char_map = {}
+ self.index_map = {}
+ self.add_tokens(["<UNK>", "<SPACE>"])
+
+ def add_tokens(self, tokens: list[str]):
+ """Manually add tokens to the tokenizer
+
+ Args:
+ tokens (list[str]): List of tokens to add
+ """
+ for token in tokens:
+ if token not in self.char_map:
+ self.char_map[token] = len(self.char_map)
+ self.index_map[len(self.index_map)] = token
+
+ def train(
+ self, dataset_path: str, language: str, split: str, download: bool = True
+ ):
+ """Train the tokenizer on the given dataset
+
+ Args:
+ dataset_path (str): Path to the MLS dataset
+ language (str): Language to use
+ split (str): Split to use
+ """
+ if split not in ["train", "dev", "test", "all"]:
+ raise ValueError("Split must be one of train, dev, test, all")
+
+ if split == "all":
+ splits = ["train", "dev", "test"]
+ else:
+ splits = [split]
+
+ chars = set()
+ for sp in splits:
+ transcript_path = os.path.join(
+ dataset_path, language, sp, "transcripts.txt"
+ )
+
+ # check if dataset is downloaded, download if not
+ if download and not os.path.exists(transcript_path):
+ MultilingualLibriSpeech(dataset_path, language, sp, download=True)
+
+ with open(
+ transcript_path,
+ "r",
+ encoding="utf-8",
+ ) as file:
+ lines = file.readlines()
+ lines = [line.split(" ", 1)[1] for line in lines]
+ lines = [line.strip() for line in lines]
+
+ for line in tqdm(lines, desc=f"Training tokenizer on {sp} split"):
+ chars.update(line)
+ offset = len(self.char_map)
+ for i, char in enumerate(chars):
+ i += offset
+ self.char_map[char] = i
+ self.index_map[i] = char
+
+ def encode(self, text: str):
+ """Use a character map and convert text to an integer sequence
+
+ automatically maps spaces to <SPACE> and makes everything lowercase
+ unknown characters are mapped to the <UNK> token
+
+ """
+ int_sequence = []
+ text = text.lower()
+ for char in text:
+ if char == " ":
+ mapped_char = self.char_map["<SPACE>"]
+ elif char not in self.char_map:
+ mapped_char = self.char_map["<UNK>"]
+ else:
+ mapped_char = self.char_map[char]
+ int_sequence.append(mapped_char)
+ return Encoding(ids=int_sequence)
+
+ def decode(self, labels: list[int], remove_special_tokens: bool = True):
+ """Use a character map and convert integer labels to an text sequence
+
+ Args:
+ labels (list[int]): List of integer labels
+ remove_special_tokens (bool): Whether to remove special tokens.
+ Defaults to True.
+ """
+ string = []
+ for i in labels:
+ if remove_special_tokens and self.index_map[f"{i}"] == "<UNK>":
+ continue
+ if remove_special_tokens and self.index_map[f"{i}"] == "<SPACE>":
+ string.append(" ")
+ string.append(self.index_map[f"{i}"])
+ return "".join(string).replace("<SPACE>", " ")
+
+ def decode_batch(self, labels: list[list[int]]):
+ """Use a character map and convert integer labels to an text sequence"""
+ strings = []
+ for label in labels:
+ string = []
+ for i in label:
+ if self.index_map[i] == "<UNK>":
+ continue
+ if self.index_map[i] == "<SPACE>":
+ string.append(" ")
+ string.append(self.index_map[i])
+ strings.append("".join(string).replace("<SPACE>", " "))
+ return strings
+
+ def save(self, path: str):
+ """Save the tokenizer to a file"""
+ with open(path, "w", encoding="utf-8") as file:
+ # save it in the following format:
+ # {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}}
+ json.dump(
+ {"char_map": self.char_map, "index_map": self.index_map},
+ file,
+ ensure_ascii=False,
+ )
+
+ def from_file(self, path: str):
+ """Load the tokenizer from a file"""
+ with open(path, "r", encoding="utf-8") as file:
+ # load it in the following format:
+ # {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}}
+ saved_file = json.load(file)
+ self.char_map = saved_file["char_map"]
+ self.index_map = saved_file["index_map"]
+
+
+@click.command()
+@click.option("--dataset_path", default="data", help="Path to the MLS dataset")
+@click.option("--language", default="mls_german_opus", help="Language to use")
+@click.option("--split", default="train", help="Split to use (including all)")
+@click.option("--download", default=True, help="Whether to download the dataset")
+@click.option(
+ "--out_path", default="tokenizer.json", help="Path to save the tokenizer to"
+)
+@click.option("--vocab_size", default=2000, help="Size of the vocabulary")
+def train_bpe_tokenizer(
+ dataset_path: str,
+ language: str,
+ split: str,
+ out_path: str,
+ download: bool,
+ vocab_size: int,
+):
+ """Train a Byte-Pair Encoder tokenizer on the MLS dataset
+
+ Assumes that the MLS dataset is located in the dataset_path and there is a
+ transcripts.txt file in the split folder.
+
+ Args:
+ dataset_path (str): Path to the MLS dataset
+ language (str): Language to use
+ split (str): Split to use
+ download (bool): Whether to download the dataset if it is not present
+ out_path (str): Path to save the tokenizer to
+ vocab_size (int): Size of the vocabulary
+ """
+ if split not in ["train", "dev", "test", "all"]:
+ raise ValueError("Split must be one of train, dev, test, all")
+
+ if split == "all":
+ splits = ["train", "dev", "test"]
+ else:
+ splits = [split]
+
+ lines = []
+
+ for sp in splits:
+ transcripts_path = os.path.join(dataset_path, language, sp, "transcripts.txt")
+ if download and not os.path.exists(transcripts_path):
+ MultilingualLibriSpeech(dataset_path, language, sp, download=True)
+
+ with open(
+ transcripts_path,
+ "r",
+ encoding="utf-8",
+ ) as file:
+ sp_lines = file.readlines()
+ sp_lines = [line.split(" ", 1)[1] for line in sp_lines]
+ sp_lines = [line.strip() for line in sp_lines]
+
+ lines.append(sp_lines)
+
+ bpe_tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
+
+ initial_alphabet = [
+ "a",
+ "b",
+ "c",
+ "d",
+ "e",
+ "f",
+ "g",
+ "h",
+ "i",
+ "j",
+ "k",
+ "l",
+ "m",
+ "n",
+ "o",
+ "p",
+ "q",
+ "r",
+ "s",
+ "t",
+ "u",
+ "v",
+ "w",
+ "x",
+ "y",
+ "z",
+ "ä",
+ "ö",
+ "ü",
+ "ß",
+ "-",
+ "é",
+ "è",
+ "à",
+ "ù",
+ "ç",
+ "â",
+ "ê",
+ "î",
+ "ô",
+ "û",
+ "ë",
+ "ï",
+ "ü",
+ ]
+
+ trainer = BpeTrainer(
+ special_tokens=["[UNK]"],
+ vocab_size=vocab_size,
+ initial_alphabet=initial_alphabet,
+ show_progress=True,
+ ) # type: ignore
+
+ bpe_tokenizer.pre_tokenizer = Whitespace() # type: ignore
+
+ bpe_tokenizer.normalizer = normalizers.Lowercase() # type: ignore
+
+ bpe_tokenizer.train_from_iterator(lines, trainer=trainer)
+
+ bpe_tokenizer.save(out_path)
+
+
+@click.command()
+@click.option("--dataset_path", default="data", help="Path to the MLS dataset")
+@click.option("--language", default="mls_german_opus", help="Language to use")
+@click.option("--split", default="train", help="Split to use")
+@click.option(
+ "--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to"
+)
+@click.option("--download", default=True, help="Whether to download the dataset")
+def train_char_tokenizer(
+ dataset_path: str,
+ language: str,
+ split: str,
+ out_path: str,
+ download: bool,
+):
+ """Train a Byte-Pair Encoder tokenizer on the MLS dataset
+
+ Assumes that the MLS dataset is located in the dataset_path and there is a
+ transcripts.txt file in the split folder.
+
+ Args:
+ dataset_path (str): Path to the MLS dataset
+ language (str): Language to use
+ split (str): Split to use
+ download (bool): Whether to download the dataset if it is not present
+ out_path (str): Path to save the tokenizer to
+ """
+ char_tokenizer = CharTokenizer()
+
+ char_tokenizer.train(dataset_path, language, split, download)
+
+ char_tokenizer.save(out_path)
+
+
+if __name__ == "__main__":
+ tokenizer = CharTokenizer()
+ tokenizer.from_file("data/tokenizers/char_tokenizer_german.json")
+
+ print(tokenizer.decode(tokenizer.encode("Fichier non trouvé").ids))
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index ad8c9e9..d13683f 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -1,94 +1,16 @@
"""Training script for the ASR model."""
-from AudioLoader.speech import MultilingualLibriSpeech
import os
import click
import torch
-import torch.nn as nn
-import torch.optim as optim
import torch.nn.functional as F
-from torch.utils.data import DataLoader
import torchaudio
-from .loss_scores import cer, wer
-
-
-class TextTransform:
- """Maps characters to integers and vice versa"""
-
- def __init__(self):
- char_map_str = """
- ' 0
- <SPACE> 1
- a 2
- b 3
- c 4
- d 5
- e 6
- f 7
- g 8
- h 9
- i 10
- j 11
- k 12
- l 13
- m 14
- n 15
- o 16
- p 17
- q 18
- r 19
- s 20
- t 21
- u 22
- v 23
- w 24
- x 25
- y 26
- z 27
- ä 28
- ö 29
- ü 30
- ß 31
- - 32
- é 33
- è 34
- à 35
- ù 36
- ç 37
- â 38
- ê 39
- î 40
- ô 41
- û 42
- ë 43
- ï 44
- ü 45
- """
- self.char_map = {}
- self.index_map = {}
- for line in char_map_str.strip().split("\n"):
- char, index = line.split()
- self.char_map[char] = int(index)
- self.index_map[int(index)] = char
- self.index_map[1] = " "
-
- def text_to_int(self, text):
- """Use a character map and convert text to an integer sequence"""
- int_sequence = []
- for char in text:
- if char == " ":
- mapped_char = self.char_map["<SPACE>"]
- else:
- mapped_char = self.char_map[char]
- int_sequence.append(mapped_char)
- return int_sequence
-
- def int_to_text(self, labels):
- """Use a character map and convert integer labels to an text sequence"""
- string = []
- for i in labels:
- string.append(self.index_map[i])
- return "".join(string).replace("<SPACE>", " ")
+from AudioLoader.speech import MultilingualLibriSpeech
+from torch import nn, optim
+from torch.utils.data import DataLoader
+from tokenizers import Tokenizer
+from .tokenizer import CharTokenizer
+from .loss_scores import cer, wer
train_audio_transforms = nn.Sequential(
torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
@@ -98,7 +20,9 @@ train_audio_transforms = nn.Sequential(
valid_audio_transforms = torchaudio.transforms.MelSpectrogram()
-text_transform = TextTransform()
+# text_transform = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
+text_transform = CharTokenizer()
+text_transform.from_file("data/tokenizers/char_tokenizer_german.json")
def data_processing(data, data_type="train"):
@@ -115,7 +39,7 @@ def data_processing(data, data_type="train"):
else:
raise ValueError("data_type should be train or valid")
spectrograms.append(spec)
- label = torch.Tensor(text_transform.text_to_int(sample["utterance"].lower()))
+ label = torch.Tensor(text_transform.encode(sample["utterance"]).ids)
labels.append(label)
input_lengths.append(spec.shape[0] // 2)
label_lengths.append(len(label))
@@ -140,14 +64,16 @@ def greedy_decoder(
for i, args in enumerate(arg_maxes):
decode = []
targets.append(
- text_transform.int_to_text(labels[i][: label_lengths[i]].tolist())
+ text_transform.decode(
+ [int(x) for x in labels[i][: label_lengths[i]].tolist()]
+ )
)
for j, index in enumerate(args):
if index != blank_label:
if collapse_repeated and j != 0 and index == args[j - 1]:
continue
decode.append(index.item())
- decodes.append(text_transform.int_to_text(decode))
+ decodes.append(text_transform.decode(decode))
return decodes, targets
@@ -401,7 +327,7 @@ def run(
"n_cnn_layers": 3,
"n_rnn_layers": 5,
"rnn_dim": 512,
- "n_class": 46,
+ "n_class": 36,
"n_feats": 128,
"stride": 2,
"dropout": 0.1,