aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr')
-rw-r--r--swr2_asr/tokenizer.py466
1 files changed, 99 insertions, 367 deletions
diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py
index 2e2fb57..69ced81 100644
--- a/swr2_asr/tokenizer.py
+++ b/swr2_asr/tokenizer.py
@@ -1,393 +1,125 @@
-"""Tokenizer for use with Multilingual Librispeech"""
-import json
-import os
-from dataclasses import dataclass
-from typing import Type
+"""Tokenizer for Multilingual Librispeech datasets"""
-import click
-from tokenizers import Tokenizer, normalizers
-from tokenizers.models import BPE
-from tokenizers.pre_tokenizers import Whitespace
-from tokenizers.trainers import BpeTrainer
-from tqdm import tqdm
-
-class TokenizerType:
- """Base class for tokenizers.
-
- exposes the same interface as tokenizers from the huggingface library"""
-
- def encode(self, sequence: str) -> list[int]:
- """Encode a sequence to a list of integer labels"""
- raise NotImplementedError
-
- def decode(self, labels: list[int], remove_special_tokens: bool) -> str:
- """Decode a list of integer labels to a sequence"""
- raise NotImplementedError
-
- def decode_batch(self, labels: list[list[int]]) -> list[str]:
- """Decode a batch of integer labels to a list of sequences"""
- raise NotImplementedError
-
- def get_vocab_size(self) -> int:
- """Get the size of the vocabulary"""
- raise NotImplementedError
-
- def enable_padding(
- self,
- length: int = -1,
- direction: str = "right",
- pad_id: int = 0,
- pad_type_id: int = 0,
- pad_token: str = "[PAD]",
- ) -> None:
- """Enable padding for the tokenizer"""
- raise NotImplementedError
-
- def save(self, path: str) -> None:
- """Save the tokenizer to a file"""
- raise NotImplementedError
-
- @staticmethod
- def from_file(path: str) -> "TokenizerType":
- """Load the tokenizer from a file"""
- raise NotImplementedError
-
-
-MyTokenizerType = Type[TokenizerType]
-
-
-@dataclass
-class Encoding:
- """Simple dataclass to represent an encoding"""
-
- ids: list[int]
- tokens: list[str]
-
-
-class CharTokenizer(TokenizerType):
- """Very simple tokenizer for use with Multilingual Librispeech
-
- Simply checks what characters are in the dataset and uses them as tokens.
-
- Exposes the same interface as tokenizers from the huggingface library, i.e.
- encode, decode, decode_batch, get_vocab_size, save, from_file and train.
- """
+class CharTokenizer:
+ """Maps characters to integers and vice versa"""
def __init__(self):
- self.char_map = {}
- self.index_map = {}
- self.add_tokens(["<UNK>", "<SPACE>"])
-
- def add_tokens(self, tokens: list[str]):
- """Manually add tokens to the tokenizer
-
- Args:
- tokens (list[str]): List of tokens to add
- """
- for token in tokens:
- if token not in self.char_map:
- self.char_map[token] = len(self.char_map)
- self.index_map[len(self.index_map)] = token
-
- def train(self, dataset_path: str, language: str, split: str):
- """Train the tokenizer on the given dataset
-
- Args:
- dataset_path (str): Path to the MLS dataset
- language (str): Language to use
- split (str): Split to use
+ char_map_str = """
+ _
+ <BLANK>
+ <UNK>
+ <SPACE>
+ a
+ b
+ c
+ d
+ e
+ f
+ g
+ h
+ i
+ j
+ k
+ l
+ m
+ n
+ o
+ p
+ q
+ r
+ s
+ t
+ u
+ v
+ w
+ x
+ y
+ z
+ é
+ à
+ ä
+ ö
+ ß
+ ü
+ -
+ '
"""
- if split not in ["train", "dev", "test", "all"]:
- raise ValueError("Split must be one of train, dev, test, all")
- if split == "all":
- splits = ["train", "dev", "test"]
- else:
- splits = [split]
-
- chars: set = set()
- for s_plit in splits:
- transcript_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt")
-
- with open(
- transcript_path,
- "r",
- encoding="utf-8",
- ) as file:
- lines = file.readlines()
- lines = [line.split(" ", 1)[1] for line in lines]
- lines = [line.strip() for line in lines]
-
- for line in tqdm(lines, desc=f"Training tokenizer on {s_plit} split"):
- chars.update(line)
- offset = len(self.char_map)
- for i, char in enumerate(chars):
- i += offset
- self.char_map[char] = i
- self.index_map[i] = char
-
- def encode(self, sequence: str):
- """Use a character map and convert text to an integer sequence
-
- automatically maps spaces to <SPACE> and makes everything lowercase
- unknown characters are mapped to the <UNK> token
-
- """
+ self.char_map = {}
+ self.index_map = {}
+ for idx, char in enumerate(char_map_str.strip().split("\n")):
+ char = char.strip()
+ self.char_map[char] = idx
+ self.index_map[idx] = char
+ self.index_map[1] = " "
+
+ def encode(self, text: str) -> list[int]:
+ """Use a character map and convert text to an integer sequence"""
int_sequence = []
- sequence = sequence.lower()
- for char in sequence:
+ for char in text:
if char == " ":
- mapped_char = self.char_map["<SPACE>"]
+ char = self.char_map["<SPACE>"]
elif char not in self.char_map:
- mapped_char = self.char_map["<UNK>"]
+ char = self.char_map["<UNK>"]
else:
- mapped_char = self.char_map[char]
- int_sequence.append(mapped_char)
- return Encoding(ids=int_sequence, tokens=list(sequence))
+ char = self.char_map[char]
+ int_sequence.append(char)
+ return int_sequence
- def decode(self, labels: list[int], remove_special_tokens: bool = True):
- """Use a character map and convert integer labels to an text sequence
-
- Args:
- labels (list[int]): List of integer labels
- remove_special_tokens (bool): Whether to remove special tokens.
- Defaults to True.
- """
+ def decode(self, labels: list[int]) -> str:
+ """Use a character map and convert integer labels to an text sequence"""
string = []
for i in labels:
- if remove_special_tokens and self.index_map[f"{i}"] == "<UNK>":
- continue
- if remove_special_tokens and self.index_map[f"{i}"] == "<SPACE>":
- string.append(" ")
- string.append(self.index_map[f"{i}"])
+ string.append(self.index_map[i])
return "".join(string).replace("<SPACE>", " ")
- def decode_batch(self, labels: list[list[int]]):
- """Use a character map and convert integer labels to an text sequence"""
- strings = []
- for label in labels:
- string = []
- for i in label:
- if self.index_map[i] == "<UNK>":
- continue
- if self.index_map[i] == "<SPACE>":
- string.append(" ")
- string.append(self.index_map[i])
- strings.append("".join(string).replace("<SPACE>", " "))
- return strings
-
- def get_vocab_size(self):
- """Get the size of the vocabulary"""
+ def get_vocab_size(self) -> int:
+ """Get the number of unique characters in the dataset"""
return len(self.char_map)
- def save(self, path: str):
- """Save the tokenizer to a file"""
- os.makedirs(os.path.dirname(path), exist_ok=True)
- with open(path, "w", encoding="utf-8") as file:
- # save it in the following format:
- # {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}}
- json.dump(
- {"char_map": self.char_map, "index_map": self.index_map},
- file,
- ensure_ascii=False,
- )
-
- @staticmethod
- def from_file(path: str) -> "CharTokenizer":
- """Load the tokenizer from a file"""
- char_tokenizer = CharTokenizer()
- with open(path, "r", encoding="utf-8") as file:
- # load it in the following format:
- # {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}}
- saved_file = json.load(file)
- char_tokenizer.char_map = saved_file["char_map"]
- char_tokenizer.index_map = saved_file["index_map"]
-
- return char_tokenizer
-
-
-@click.command()
-@click.option("--dataset_path", default="data", help="Path to the MLS dataset")
-@click.option("--language", default="mls_german_opus", help="Language to use")
-@click.option("--split", default="train", help="Split to use (including all)")
-@click.option("--out_path", default="tokenizer.json", help="Path to save the tokenizer to")
-@click.option("--vocab_size", default=2000, help="Size of the vocabulary")
-def train_bpe_tokenizer_cli(
- dataset_path: str,
- language: str,
- split: str,
- out_path: str,
- vocab_size: int,
-):
- """Train a Byte-Pair Encoder tokenizer on the MLS dataset"""
- train_bpe_tokenizer(
- dataset_path,
- language,
- split,
- out_path,
- vocab_size,
- )
-
-
-def train_bpe_tokenizer(
- dataset_path: str,
- language: str,
- split: str,
- out_path: str,
- vocab_size: int,
-):
- """Train a Byte-Pair Encoder tokenizer on the MLS dataset
-
- Assumes that the MLS dataset is located in the dataset_path and there is a
- transcripts.txt file in the split folder.
-
- Args:
- dataset_path (str): Path to the MLS dataset
- language (str): Language to use
- split (str): Split to use
- download (bool): Whether to download the dataset if it is not present
- out_path (str): Path to save the tokenizer to
- vocab_size (int): Size of the vocabulary
- """
- if split not in ["train", "dev", "test", "all"]:
- raise ValueError("Split must be one of train, dev, test, all")
-
- if split == "all":
- splits = ["train", "dev", "test"]
- else:
- splits = [split]
-
- lines = []
+ def get_blank_token(self) -> int:
+ """Get the integer representation of the <BLANK> character"""
+ return self.char_map["<BLANK>"]
- for s_plit in splits:
- transcripts_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt")
- if not os.path.exists(transcripts_path):
- raise FileNotFoundError(
- f"Could not find transcripts.txt in {transcripts_path}. "
- "Please make sure that the dataset is downloaded."
- )
+ def get_unk_token(self) -> int:
+ """Get the integer representation of the <UNK> character"""
+ return self.char_map["<UNK>"]
- with open(
- transcripts_path,
- "r",
- encoding="utf-8",
- ) as file:
- sp_lines = file.readlines()
- sp_lines = [line.split(" ", 1)[1] for line in sp_lines]
- sp_lines = [line.strip() for line in sp_lines]
+ def get_space_token(self) -> int:
+ """Get the integer representation of the <SPACE> character"""
+ return self.char_map["<SPACE>"]
- lines.append(sp_lines)
+ # TODO: add train function
- bpe_tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
-
- initial_alphabet = [
- " ",
- "a",
- "b",
- "c",
- "d",
- "e",
- "f",
- "g",
- "h",
- "i",
- "j",
- "k",
- "l",
- "m",
- "n",
- "o",
- "p",
- "q",
- "r",
- "s",
- "t",
- "u",
- "v",
- "w",
- "x",
- "y",
- "z",
- "ä",
- "ö",
- "ü",
- "ß",
- "-",
- "é",
- "è",
- "à",
- "ù",
- "ç",
- "â",
- "ê",
- "î",
- "ô",
- "û",
- "ë",
- "ï",
- "ü",
- ]
-
-
- trainer = BpeTrainer(
- special_tokens=["[UNK]"],
- vocab_size=vocab_size,
- initial_alphabet=initial_alphabet,
- show_progress=True,
- ) # type: ignore
-
- bpe_tokenizer.pre_tokenizer = Whitespace() # type: ignore
-
- bpe_tokenizer.normalizer = normalizers.Lowercase() # type: ignore
-
- bpe_tokenizer.train_from_iterator(lines, trainer=trainer)
-
- bpe_tokenizer.save(out_path)
-
-
-@click.command()
-@click.option("--dataset_path", default="data", help="Path to the MLS dataset")
-@click.option("--language", default="mls_german_opus", help="Language to use")
-@click.option("--split", default="train", help="Split to use")
-@click.option("--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to")
-def train_char_tokenizer_cli(
- dataset_path: str,
- language: str,
- split: str,
- out_path: str,
-):
- """Train a Byte-Pair Encoder tokenizer on the MLS dataset"""
- train_char_tokenizer(dataset_path, language, split, out_path)
-
-
-def train_char_tokenizer(
- dataset_path: str,
- language: str,
- split: str,
- out_path: str,
-):
- """Train a Byte-Pair Encoder tokenizer on the MLS dataset
-
- Assumes that the MLS dataset is located in the dataset_path and there is a
- transcripts.txt file in the split folder.
-
- Args:
- dataset_path (str): Path to the MLS dataset
- language (str): Language to use
- split (str): Split to use
- download (bool): Whether to download the dataset if it is not present
- out_path (str): Path to save the tokenizer to
- """
- char_tokenizer = CharTokenizer()
-
- char_tokenizer.train(dataset_path, language, split)
+ def save(self, path: str) -> None:
+ """Save the tokenizer to a file"""
+ with open(path, "w", encoding="utf-8") as file:
+ for char, index in self.char_map.items():
+ file.write(f"{char} {index}\n")
- char_tokenizer.save(out_path)
+ @staticmethod
+ def from_file(tokenizer_file: str) -> "CharTokenizer":
+ """Instantiate a CharTokenizer from a file"""
+ load_tokenizer = CharTokenizer()
+ with open(tokenizer_file, "r", encoding="utf-8") as file:
+ for line in file:
+ line = line.strip()
+ if line:
+ char, index = line.split()
+ tokenizer.char_map[char] = int(index)
+ tokenizer.index_map[int(index)] = char
+ return load_tokenizer
if __name__ == "__main__":
tokenizer = CharTokenizer()
- tokenizer.from_file("data/tokenizers/char_tokenizer_german.json")
-
- print(tokenizer.decode(tokenizer.encode("Fichier non trouvé").ids))
+ tokenizer.save("data/tokenizers/char_tokenizer_german.json")
+ print(tokenizer.char_map)
+ print(tokenizer.index_map)
+ print(tokenizer.get_vocab_size())
+ print(tokenizer.get_blank_token())
+ print(tokenizer.get_unk_token())
+ print(tokenizer.get_space_token())
+ print(tokenizer.encode("hallo welt"))
+ print(tokenizer.decode([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))