diff options
-rw-r--r-- | swr2_asr/tokenizer.py | 466 |
1 files changed, 99 insertions, 367 deletions
diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py index 2e2fb57..69ced81 100644 --- a/swr2_asr/tokenizer.py +++ b/swr2_asr/tokenizer.py @@ -1,393 +1,125 @@ -"""Tokenizer for use with Multilingual Librispeech""" -import json -import os -from dataclasses import dataclass -from typing import Type +"""Tokenizer for Multilingual Librispeech datasets""" -import click -from tokenizers import Tokenizer, normalizers -from tokenizers.models import BPE -from tokenizers.pre_tokenizers import Whitespace -from tokenizers.trainers import BpeTrainer -from tqdm import tqdm - -class TokenizerType: - """Base class for tokenizers. - - exposes the same interface as tokenizers from the huggingface library""" - - def encode(self, sequence: str) -> list[int]: - """Encode a sequence to a list of integer labels""" - raise NotImplementedError - - def decode(self, labels: list[int], remove_special_tokens: bool) -> str: - """Decode a list of integer labels to a sequence""" - raise NotImplementedError - - def decode_batch(self, labels: list[list[int]]) -> list[str]: - """Decode a batch of integer labels to a list of sequences""" - raise NotImplementedError - - def get_vocab_size(self) -> int: - """Get the size of the vocabulary""" - raise NotImplementedError - - def enable_padding( - self, - length: int = -1, - direction: str = "right", - pad_id: int = 0, - pad_type_id: int = 0, - pad_token: str = "[PAD]", - ) -> None: - """Enable padding for the tokenizer""" - raise NotImplementedError - - def save(self, path: str) -> None: - """Save the tokenizer to a file""" - raise NotImplementedError - - @staticmethod - def from_file(path: str) -> "TokenizerType": - """Load the tokenizer from a file""" - raise NotImplementedError - - -MyTokenizerType = Type[TokenizerType] - - -@dataclass -class Encoding: - """Simple dataclass to represent an encoding""" - - ids: list[int] - tokens: list[str] - - -class CharTokenizer(TokenizerType): - """Very simple tokenizer for use with Multilingual Librispeech - - Simply checks what characters are in the dataset and uses them as tokens. - - Exposes the same interface as tokenizers from the huggingface library, i.e. - encode, decode, decode_batch, get_vocab_size, save, from_file and train. - """ +class CharTokenizer: + """Maps characters to integers and vice versa""" def __init__(self): - self.char_map = {} - self.index_map = {} - self.add_tokens(["<UNK>", "<SPACE>"]) - - def add_tokens(self, tokens: list[str]): - """Manually add tokens to the tokenizer - - Args: - tokens (list[str]): List of tokens to add - """ - for token in tokens: - if token not in self.char_map: - self.char_map[token] = len(self.char_map) - self.index_map[len(self.index_map)] = token - - def train(self, dataset_path: str, language: str, split: str): - """Train the tokenizer on the given dataset - - Args: - dataset_path (str): Path to the MLS dataset - language (str): Language to use - split (str): Split to use + char_map_str = """ + _ + <BLANK> + <UNK> + <SPACE> + a + b + c + d + e + f + g + h + i + j + k + l + m + n + o + p + q + r + s + t + u + v + w + x + y + z + é + à + ä + ö + ß + ü + - + ' """ - if split not in ["train", "dev", "test", "all"]: - raise ValueError("Split must be one of train, dev, test, all") - if split == "all": - splits = ["train", "dev", "test"] - else: - splits = [split] - - chars: set = set() - for s_plit in splits: - transcript_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt") - - with open( - transcript_path, - "r", - encoding="utf-8", - ) as file: - lines = file.readlines() - lines = [line.split(" ", 1)[1] for line in lines] - lines = [line.strip() for line in lines] - - for line in tqdm(lines, desc=f"Training tokenizer on {s_plit} split"): - chars.update(line) - offset = len(self.char_map) - for i, char in enumerate(chars): - i += offset - self.char_map[char] = i - self.index_map[i] = char - - def encode(self, sequence: str): - """Use a character map and convert text to an integer sequence - - automatically maps spaces to <SPACE> and makes everything lowercase - unknown characters are mapped to the <UNK> token - - """ + self.char_map = {} + self.index_map = {} + for idx, char in enumerate(char_map_str.strip().split("\n")): + char = char.strip() + self.char_map[char] = idx + self.index_map[idx] = char + self.index_map[1] = " " + + def encode(self, text: str) -> list[int]: + """Use a character map and convert text to an integer sequence""" int_sequence = [] - sequence = sequence.lower() - for char in sequence: + for char in text: if char == " ": - mapped_char = self.char_map["<SPACE>"] + char = self.char_map["<SPACE>"] elif char not in self.char_map: - mapped_char = self.char_map["<UNK>"] + char = self.char_map["<UNK>"] else: - mapped_char = self.char_map[char] - int_sequence.append(mapped_char) - return Encoding(ids=int_sequence, tokens=list(sequence)) + char = self.char_map[char] + int_sequence.append(char) + return int_sequence - def decode(self, labels: list[int], remove_special_tokens: bool = True): - """Use a character map and convert integer labels to an text sequence - - Args: - labels (list[int]): List of integer labels - remove_special_tokens (bool): Whether to remove special tokens. - Defaults to True. - """ + def decode(self, labels: list[int]) -> str: + """Use a character map and convert integer labels to an text sequence""" string = [] for i in labels: - if remove_special_tokens and self.index_map[f"{i}"] == "<UNK>": - continue - if remove_special_tokens and self.index_map[f"{i}"] == "<SPACE>": - string.append(" ") - string.append(self.index_map[f"{i}"]) + string.append(self.index_map[i]) return "".join(string).replace("<SPACE>", " ") - def decode_batch(self, labels: list[list[int]]): - """Use a character map and convert integer labels to an text sequence""" - strings = [] - for label in labels: - string = [] - for i in label: - if self.index_map[i] == "<UNK>": - continue - if self.index_map[i] == "<SPACE>": - string.append(" ") - string.append(self.index_map[i]) - strings.append("".join(string).replace("<SPACE>", " ")) - return strings - - def get_vocab_size(self): - """Get the size of the vocabulary""" + def get_vocab_size(self) -> int: + """Get the number of unique characters in the dataset""" return len(self.char_map) - def save(self, path: str): - """Save the tokenizer to a file""" - os.makedirs(os.path.dirname(path), exist_ok=True) - with open(path, "w", encoding="utf-8") as file: - # save it in the following format: - # {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}} - json.dump( - {"char_map": self.char_map, "index_map": self.index_map}, - file, - ensure_ascii=False, - ) - - @staticmethod - def from_file(path: str) -> "CharTokenizer": - """Load the tokenizer from a file""" - char_tokenizer = CharTokenizer() - with open(path, "r", encoding="utf-8") as file: - # load it in the following format: - # {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}} - saved_file = json.load(file) - char_tokenizer.char_map = saved_file["char_map"] - char_tokenizer.index_map = saved_file["index_map"] - - return char_tokenizer - - -@click.command() -@click.option("--dataset_path", default="data", help="Path to the MLS dataset") -@click.option("--language", default="mls_german_opus", help="Language to use") -@click.option("--split", default="train", help="Split to use (including all)") -@click.option("--out_path", default="tokenizer.json", help="Path to save the tokenizer to") -@click.option("--vocab_size", default=2000, help="Size of the vocabulary") -def train_bpe_tokenizer_cli( - dataset_path: str, - language: str, - split: str, - out_path: str, - vocab_size: int, -): - """Train a Byte-Pair Encoder tokenizer on the MLS dataset""" - train_bpe_tokenizer( - dataset_path, - language, - split, - out_path, - vocab_size, - ) - - -def train_bpe_tokenizer( - dataset_path: str, - language: str, - split: str, - out_path: str, - vocab_size: int, -): - """Train a Byte-Pair Encoder tokenizer on the MLS dataset - - Assumes that the MLS dataset is located in the dataset_path and there is a - transcripts.txt file in the split folder. - - Args: - dataset_path (str): Path to the MLS dataset - language (str): Language to use - split (str): Split to use - download (bool): Whether to download the dataset if it is not present - out_path (str): Path to save the tokenizer to - vocab_size (int): Size of the vocabulary - """ - if split not in ["train", "dev", "test", "all"]: - raise ValueError("Split must be one of train, dev, test, all") - - if split == "all": - splits = ["train", "dev", "test"] - else: - splits = [split] - - lines = [] + def get_blank_token(self) -> int: + """Get the integer representation of the <BLANK> character""" + return self.char_map["<BLANK>"] - for s_plit in splits: - transcripts_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt") - if not os.path.exists(transcripts_path): - raise FileNotFoundError( - f"Could not find transcripts.txt in {transcripts_path}. " - "Please make sure that the dataset is downloaded." - ) + def get_unk_token(self) -> int: + """Get the integer representation of the <UNK> character""" + return self.char_map["<UNK>"] - with open( - transcripts_path, - "r", - encoding="utf-8", - ) as file: - sp_lines = file.readlines() - sp_lines = [line.split(" ", 1)[1] for line in sp_lines] - sp_lines = [line.strip() for line in sp_lines] + def get_space_token(self) -> int: + """Get the integer representation of the <SPACE> character""" + return self.char_map["<SPACE>"] - lines.append(sp_lines) + # TODO: add train function - bpe_tokenizer = Tokenizer(BPE(unk_token="[UNK]")) - - initial_alphabet = [ - " ", - "a", - "b", - "c", - "d", - "e", - "f", - "g", - "h", - "i", - "j", - "k", - "l", - "m", - "n", - "o", - "p", - "q", - "r", - "s", - "t", - "u", - "v", - "w", - "x", - "y", - "z", - "ä", - "ö", - "ü", - "ß", - "-", - "é", - "è", - "à", - "ù", - "ç", - "â", - "ê", - "î", - "ô", - "û", - "ë", - "ï", - "ü", - ] - - - trainer = BpeTrainer( - special_tokens=["[UNK]"], - vocab_size=vocab_size, - initial_alphabet=initial_alphabet, - show_progress=True, - ) # type: ignore - - bpe_tokenizer.pre_tokenizer = Whitespace() # type: ignore - - bpe_tokenizer.normalizer = normalizers.Lowercase() # type: ignore - - bpe_tokenizer.train_from_iterator(lines, trainer=trainer) - - bpe_tokenizer.save(out_path) - - -@click.command() -@click.option("--dataset_path", default="data", help="Path to the MLS dataset") -@click.option("--language", default="mls_german_opus", help="Language to use") -@click.option("--split", default="train", help="Split to use") -@click.option("--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to") -def train_char_tokenizer_cli( - dataset_path: str, - language: str, - split: str, - out_path: str, -): - """Train a Byte-Pair Encoder tokenizer on the MLS dataset""" - train_char_tokenizer(dataset_path, language, split, out_path) - - -def train_char_tokenizer( - dataset_path: str, - language: str, - split: str, - out_path: str, -): - """Train a Byte-Pair Encoder tokenizer on the MLS dataset - - Assumes that the MLS dataset is located in the dataset_path and there is a - transcripts.txt file in the split folder. - - Args: - dataset_path (str): Path to the MLS dataset - language (str): Language to use - split (str): Split to use - download (bool): Whether to download the dataset if it is not present - out_path (str): Path to save the tokenizer to - """ - char_tokenizer = CharTokenizer() - - char_tokenizer.train(dataset_path, language, split) + def save(self, path: str) -> None: + """Save the tokenizer to a file""" + with open(path, "w", encoding="utf-8") as file: + for char, index in self.char_map.items(): + file.write(f"{char} {index}\n") - char_tokenizer.save(out_path) + @staticmethod + def from_file(tokenizer_file: str) -> "CharTokenizer": + """Instantiate a CharTokenizer from a file""" + load_tokenizer = CharTokenizer() + with open(tokenizer_file, "r", encoding="utf-8") as file: + for line in file: + line = line.strip() + if line: + char, index = line.split() + tokenizer.char_map[char] = int(index) + tokenizer.index_map[int(index)] = char + return load_tokenizer if __name__ == "__main__": tokenizer = CharTokenizer() - tokenizer.from_file("data/tokenizers/char_tokenizer_german.json") - - print(tokenizer.decode(tokenizer.encode("Fichier non trouvé").ids)) + tokenizer.save("data/tokenizers/char_tokenizer_german.json") + print(tokenizer.char_map) + print(tokenizer.index_map) + print(tokenizer.get_vocab_size()) + print(tokenizer.get_blank_token()) + print(tokenizer.get_unk_token()) + print(tokenizer.get_space_token()) + print(tokenizer.encode("hallo welt")) + print(tokenizer.decode([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])) |