aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/tokenizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/tokenizer.py')
-rw-r--r--swr2_asr/tokenizer.py310
1 files changed, 310 insertions, 0 deletions
diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py
new file mode 100644
index 0000000..d32e60d
--- /dev/null
+++ b/swr2_asr/tokenizer.py
@@ -0,0 +1,310 @@
+"""Tokenizer for use with Multilingual Librispeech"""
+import os
+import click
+from tqdm import tqdm
+
+from AudioLoader.speech import MultilingualLibriSpeech
+
+from tokenizers import Tokenizer, normalizers
+from tokenizers.models import BPE
+from tokenizers.trainers import BpeTrainer
+from tokenizers.pre_tokenizers import Whitespace
+
+
+class CharTokenizer:
+ """Very simple tokenizer for use with Multilingual Librispeech
+
+ Simply checks what characters are in the dataset and uses them as tokens.
+
+ Exposes the same interface as tokenizers from the huggingface library, i.e.
+ encode, decode, decode_batch, save, from_file and train.
+ """
+
+ def __init__(self):
+ self.char_map = {}
+ self.index_map = {}
+ self.add_tokens(["<UNK>", "<SPACE>"])
+
+ def add_tokens(self, tokens: list[str]):
+ """Manually add tokens to the tokenizer
+
+ Args:
+ tokens (list[str]): List of tokens to add
+ """
+ for token in tokens:
+ if token not in self.char_map:
+ self.char_map[token] = len(self.char_map)
+ self.index_map[len(self.index_map)] = token
+
+ def train(
+ self, dataset_path: str, language: str, split: str, download: bool = True
+ ):
+ """Train the tokenizer on the given dataset
+
+ Args:
+ dataset_path (str): Path to the MLS dataset
+ language (str): Language to use
+ split (str): Split to use
+ """
+ if split not in ["train", "dev", "test", "all"]:
+ raise ValueError("Split must be one of train, dev, test, all")
+
+ if split == "all":
+ splits = ["train", "dev", "test"]
+ else:
+ splits = [split]
+
+ chars = set()
+ for sp in splits:
+ transcript_path = os.path.join(
+ dataset_path, language, sp, "transcripts.txt"
+ )
+
+ # check if dataset is downloaded, download if not
+ if download and not os.path.exists(transcript_path):
+ MultilingualLibriSpeech(dataset_path, language, sp, download=True)
+
+ with open(
+ transcript_path,
+ "r",
+ encoding="utf-8",
+ ) as file:
+ lines = file.readlines()
+ lines = [line.split(" ", 1)[1] for line in lines]
+ lines = [line.strip() for line in lines]
+
+ for line in tqdm(lines, desc=f"Training tokenizer on {sp} split"):
+ chars.update(line)
+ offset = len(self.char_map)
+ for i, char in enumerate(chars):
+ i += offset
+ self.char_map[char] = i
+ self.index_map[i] = char
+
+ def encode(self, text: str):
+ """Use a character map and convert text to an integer sequence
+
+ automatically maps spaces to <SPACE> and makes everything lowercase
+ unknown characters are mapped to the <UNK> token
+
+ """
+ int_sequence = []
+ text = text.lower()
+ for char in text:
+ if char == " ":
+ mapped_char = self.char_map["<SPACE>"]
+ elif char not in self.char_map:
+ mapped_char = self.char_map["<UNK>"]
+ else:
+ mapped_char = self.char_map[char]
+ int_sequence.append(mapped_char)
+ return int_sequence
+
+ def decode(self, labels: list[int], remove_special_tokens: bool = True):
+ """Use a character map and convert integer labels to an text sequence
+
+ Args:
+ labels (list[int]): List of integer labels
+ remove_special_tokens (bool): Whether to remove special tokens.
+ Defaults to True.
+ """
+ string = []
+ for i in labels:
+ if remove_special_tokens and self.index_map[i] == "<UNK>":
+ continue
+ if remove_special_tokens and self.index_map[i] == "<SPACE>":
+ string.append(" ")
+ string.append(self.index_map[i])
+ return "".join(string).replace("<SPACE>", " ")
+
+ def decode_batch(self, labels: list[list[int]]):
+ """Use a character map and convert integer labels to an text sequence"""
+ strings = []
+ for label in labels:
+ string = []
+ for i in label:
+ if self.index_map[i] == "<UNK>":
+ continue
+ if self.index_map[i] == "<SPACE>":
+ string.append(" ")
+ string.append(self.index_map[i])
+ strings.append("".join(string).replace("<SPACE>", " "))
+ return strings
+
+ def save(self, path: str):
+ """Save the tokenizer to a file"""
+ with open(path, "w", encoding="utf-8") as file:
+ for char, index in self.char_map.items():
+ file.write(f"{char} {index}\n")
+
+ def from_file(self, path: str):
+ """Load the tokenizer from a file"""
+ with open(path, "r", encoding="utf-8") as file:
+ for line in file.readlines():
+ char, index = line.split(" ")
+ self.char_map[char] = int(index)
+ self.index_map[int(index)] = char
+
+
+@click.command()
+@click.option("--dataset_path", default="data", help="Path to the MLS dataset")
+@click.option("--language", default="mls_german_opus", help="Language to use")
+@click.option("--split", default="train", help="Split to use (including all)")
+@click.option("--download", default=True, help="Whether to download the dataset")
+@click.option(
+ "--out_path", default="tokenizer.json", help="Path to save the tokenizer to"
+)
+@click.option("--vocab_size", default=2000, help="Size of the vocabulary")
+def train_bpe_tokenizer(
+ dataset_path: str,
+ language: str,
+ split: str,
+ out_path: str,
+ download: bool,
+ vocab_size: int,
+):
+ """Train a Byte-Pair Encoder tokenizer on the MLS dataset
+
+ Assumes that the MLS dataset is located in the dataset_path and there is a
+ transcripts.txt file in the split folder.
+
+ Args:
+ dataset_path (str): Path to the MLS dataset
+ language (str): Language to use
+ split (str): Split to use
+ download (bool): Whether to download the dataset if it is not present
+ out_path (str): Path to save the tokenizer to
+ vocab_size (int): Size of the vocabulary
+ """
+ if split not in ["train", "dev", "test", "all"]:
+ raise ValueError("Split must be one of train, dev, test, all")
+
+ if split == "all":
+ splits = ["train", "dev", "test"]
+ else:
+ splits = [split]
+
+ lines = []
+
+ for sp in splits:
+ transcripts_path = os.path.join(dataset_path, language, sp, "transcripts.txt")
+ if download and not os.path.exists(transcripts_path):
+ MultilingualLibriSpeech(dataset_path, language, sp, download=True)
+
+ with open(
+ transcripts_path,
+ "r",
+ encoding="utf-8",
+ ) as file:
+ sp_lines = file.readlines()
+ sp_lines = [line.split(" ", 1)[1] for line in sp_lines]
+ sp_lines = [line.strip() for line in sp_lines]
+
+ lines.append(sp_lines)
+
+ bpe_tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
+
+ initial_alphabet = [
+ "a",
+ "b",
+ "c",
+ "d",
+ "e",
+ "f",
+ "g",
+ "h",
+ "i",
+ "j",
+ "k",
+ "l",
+ "m",
+ "n",
+ "o",
+ "p",
+ "q",
+ "r",
+ "s",
+ "t",
+ "u",
+ "v",
+ "w",
+ "x",
+ "y",
+ "z",
+ "ä",
+ "ö",
+ "ü",
+ "ß",
+ "-",
+ "é",
+ "è",
+ "à",
+ "ù",
+ "ç",
+ "â",
+ "ê",
+ "î",
+ "ô",
+ "û",
+ "ë",
+ "ï",
+ "ü",
+ ]
+
+ trainer = BpeTrainer(
+ special_tokens=["[UNK]"],
+ vocab_size=vocab_size,
+ initial_alphabet=initial_alphabet,
+ show_progress=True,
+ ) # type: ignore
+
+ bpe_tokenizer.pre_tokenizer = Whitespace() # type: ignore
+
+ bpe_tokenizer.normalizer = normalizers.Lowercase() # type: ignore
+
+ bpe_tokenizer.train_from_iterator(lines, trainer=trainer)
+
+ bpe_tokenizer.save(out_path)
+
+
+@click.command()
+@click.option("--dataset_path", default="data", help="Path to the MLS dataset")
+@click.option("--language", default="mls_german_opus", help="Language to use")
+@click.option("--split", default="train", help="Split to use")
+@click.option(
+ "--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to"
+)
+@click.option("--download", default=True, help="Whether to download the dataset")
+def train_char_tokenizer(
+ dataset_path: str,
+ language: str,
+ split: str,
+ out_path: str,
+ download: bool,
+):
+ """Train a Byte-Pair Encoder tokenizer on the MLS dataset
+
+ Assumes that the MLS dataset is located in the dataset_path and there is a
+ transcripts.txt file in the split folder.
+
+ Args:
+ dataset_path (str): Path to the MLS dataset
+ language (str): Language to use
+ split (str): Split to use
+ download (bool): Whether to download the dataset if it is not present
+ out_path (str): Path to save the tokenizer to
+ """
+ char_tokenizer = CharTokenizer()
+
+ char_tokenizer.train(dataset_path, language, split, download)
+
+ char_tokenizer.save(out_path)
+
+
+if __name__ == "__main__":
+ tokenizer = CharTokenizer()
+ tokenizer.train("/Volumes/pherkel 2/SWR2-ASR", "mls_german_opus", "all")
+
+ print(tokenizer.decode(tokenizer.encode("Fichier non trouvé")))
+
+ tokenizer.save("tokenizer_chars.txt")