diff options
Diffstat (limited to 'swr2_asr/tokenizer.py')
-rw-r--r-- | swr2_asr/tokenizer.py | 41 |
1 files changed, 27 insertions, 14 deletions
diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py index d32e60d..d9cd622 100644 --- a/swr2_asr/tokenizer.py +++ b/swr2_asr/tokenizer.py @@ -1,4 +1,6 @@ """Tokenizer for use with Multilingual Librispeech""" +from dataclasses import dataclass +import json import os import click from tqdm import tqdm @@ -11,6 +13,13 @@ from tokenizers.trainers import BpeTrainer from tokenizers.pre_tokenizers import Whitespace +@dataclass +class Encoding: + """Simple dataclass to represent an encoding""" + + ids: list[int] + + class CharTokenizer: """Very simple tokenizer for use with Multilingual Librispeech @@ -98,7 +107,7 @@ class CharTokenizer: else: mapped_char = self.char_map[char] int_sequence.append(mapped_char) - return int_sequence + return Encoding(ids=int_sequence) def decode(self, labels: list[int], remove_special_tokens: bool = True): """Use a character map and convert integer labels to an text sequence @@ -110,11 +119,11 @@ class CharTokenizer: """ string = [] for i in labels: - if remove_special_tokens and self.index_map[i] == "<UNK>": + if remove_special_tokens and self.index_map[f"{i}"] == "<UNK>": continue - if remove_special_tokens and self.index_map[i] == "<SPACE>": + if remove_special_tokens and self.index_map[f"{i}"] == "<SPACE>": string.append(" ") - string.append(self.index_map[i]) + string.append(self.index_map[f"{i}"]) return "".join(string).replace("<SPACE>", " ") def decode_batch(self, labels: list[list[int]]): @@ -134,16 +143,22 @@ class CharTokenizer: def save(self, path: str): """Save the tokenizer to a file""" with open(path, "w", encoding="utf-8") as file: - for char, index in self.char_map.items(): - file.write(f"{char} {index}\n") + # save it in the following format: + # {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}} + json.dump( + {"char_map": self.char_map, "index_map": self.index_map}, + file, + ensure_ascii=False, + ) def from_file(self, path: str): """Load the tokenizer from a file""" with open(path, "r", encoding="utf-8") as file: - for line in file.readlines(): - char, index = line.split(" ") - self.char_map[char] = int(index) - self.index_map[int(index)] = char + # load it in the following format: + # {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}} + saved_file = json.load(file) + self.char_map = saved_file["char_map"] + self.index_map = saved_file["index_map"] @click.command() @@ -303,8 +318,6 @@ def train_char_tokenizer( if __name__ == "__main__": tokenizer = CharTokenizer() - tokenizer.train("/Volumes/pherkel 2/SWR2-ASR", "mls_german_opus", "all") - - print(tokenizer.decode(tokenizer.encode("Fichier non trouvé"))) + tokenizer.from_file("data/tokenizers/char_tokenizer_german.json") - tokenizer.save("tokenizer_chars.txt") + print(tokenizer.decode(tokenizer.encode("Fichier non trouvé").ids)) |