aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/tokenizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/tokenizer.py')
-rw-r--r--swr2_asr/tokenizer.py41
1 files changed, 27 insertions, 14 deletions
diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py
index d32e60d..d9cd622 100644
--- a/swr2_asr/tokenizer.py
+++ b/swr2_asr/tokenizer.py
@@ -1,4 +1,6 @@
"""Tokenizer for use with Multilingual Librispeech"""
+from dataclasses import dataclass
+import json
import os
import click
from tqdm import tqdm
@@ -11,6 +13,13 @@ from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
+@dataclass
+class Encoding:
+ """Simple dataclass to represent an encoding"""
+
+ ids: list[int]
+
+
class CharTokenizer:
"""Very simple tokenizer for use with Multilingual Librispeech
@@ -98,7 +107,7 @@ class CharTokenizer:
else:
mapped_char = self.char_map[char]
int_sequence.append(mapped_char)
- return int_sequence
+ return Encoding(ids=int_sequence)
def decode(self, labels: list[int], remove_special_tokens: bool = True):
"""Use a character map and convert integer labels to an text sequence
@@ -110,11 +119,11 @@ class CharTokenizer:
"""
string = []
for i in labels:
- if remove_special_tokens and self.index_map[i] == "<UNK>":
+ if remove_special_tokens and self.index_map[f"{i}"] == "<UNK>":
continue
- if remove_special_tokens and self.index_map[i] == "<SPACE>":
+ if remove_special_tokens and self.index_map[f"{i}"] == "<SPACE>":
string.append(" ")
- string.append(self.index_map[i])
+ string.append(self.index_map[f"{i}"])
return "".join(string).replace("<SPACE>", " ")
def decode_batch(self, labels: list[list[int]]):
@@ -134,16 +143,22 @@ class CharTokenizer:
def save(self, path: str):
"""Save the tokenizer to a file"""
with open(path, "w", encoding="utf-8") as file:
- for char, index in self.char_map.items():
- file.write(f"{char} {index}\n")
+ # save it in the following format:
+ # {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}}
+ json.dump(
+ {"char_map": self.char_map, "index_map": self.index_map},
+ file,
+ ensure_ascii=False,
+ )
def from_file(self, path: str):
"""Load the tokenizer from a file"""
with open(path, "r", encoding="utf-8") as file:
- for line in file.readlines():
- char, index = line.split(" ")
- self.char_map[char] = int(index)
- self.index_map[int(index)] = char
+ # load it in the following format:
+ # {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}}
+ saved_file = json.load(file)
+ self.char_map = saved_file["char_map"]
+ self.index_map = saved_file["index_map"]
@click.command()
@@ -303,8 +318,6 @@ def train_char_tokenizer(
if __name__ == "__main__":
tokenizer = CharTokenizer()
- tokenizer.train("/Volumes/pherkel 2/SWR2-ASR", "mls_german_opus", "all")
-
- print(tokenizer.decode(tokenizer.encode("Fichier non trouvé")))
+ tokenizer.from_file("data/tokenizers/char_tokenizer_german.json")
- tokenizer.save("tokenizer_chars.txt")
+ print(tokenizer.decode(tokenizer.encode("Fichier non trouvé").ids))