diff options
author | Pherkel | 2023-09-11 20:45:32 +0200 |
---|---|---|
committer | Pherkel | 2023-09-11 20:45:32 +0200 |
commit | 8be140b38183b7465b5888a15b536a5f7fa66db6 (patch) | |
tree | 68737b56d9859c139eb8e998cf50813ec7c68bdf /swr2_asr/utils | |
parent | c078ce6789c134aa05607903d3bf9e4be64df45d (diff) |
added tokenizer to git and tokenizer training routing
Diffstat (limited to 'swr2_asr/utils')
-rw-r--r-- | swr2_asr/utils/tokenizer.py | 110 |
1 files changed, 60 insertions, 50 deletions
diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py index d92465a..5482bbe 100644 --- a/swr2_asr/utils/tokenizer.py +++ b/swr2_asr/utils/tokenizer.py @@ -1,59 +1,18 @@ """Tokenizer for Multilingual Librispeech datasets""" +from datetime import datetime +import os + +from tqdm.autonotebook import tqdm + + class CharTokenizer: """Maps characters to integers and vice versa""" def __init__(self): - char_map_str = """ - _ - <BLANK> - <UNK> - <SPACE> - a - b - c - d - e - f - g - h - i - j - k - l - m - n - o - p - q - r - s - t - u - v - w - x - y - z - é - à - ä - ö - ß - ü - - - ' - - """ - self.char_map = {} self.index_map = {} - for idx, char in enumerate(char_map_str.strip().split("\n")): - char = char.strip() - self.char_map[char] = idx - self.index_map[idx] = char - self.index_map[1] = " " def encode(self, text: str) -> list[int]: """Use a character map and convert text to an integer sequence""" @@ -91,7 +50,59 @@ class CharTokenizer: """Get the integer representation of the <SPACE> character""" return self.char_map["<SPACE>"] - # TODO: add train function + @staticmethod + def train(dataset_path: str, language: str) -> "CharTokenizer": + """Train the tokenizer on a dataset""" + chars = set() + root_path = os.path.join(dataset_path, language) + for split in os.listdir(root_path): + split_dir = os.path.join(root_path, split) + if os.path.isdir(split_dir): + transcript_path = os.path.join(split_dir, "transcripts.txt") + + with open(transcript_path, "r", encoding="utf-8") as transcrips: + lines = transcrips.readlines() + lines = [line.split(" ", 1)[1] for line in lines] + lines = [line.strip() for line in lines] + lines = [line.lower() for line in lines] + + for line in tqdm(lines, desc=f"Training tokenizer on {split_dir} split"): + chars.update(line) + + # sort chars + chars.remove(" ") + chars = sorted(chars) + + train_tokenizer = CharTokenizer() + + train_tokenizer.char_map["_"] = 0 + train_tokenizer.char_map["<BLANK>"] = 1 + train_tokenizer.char_map["<UNK>"] = 2 + train_tokenizer.char_map["<SPACE>"] = 3 + + train_tokenizer.index_map[0] = "_" + train_tokenizer.index_map[1] = "<BLANK>" + train_tokenizer.index_map[2] = "<UNK>" + train_tokenizer.index_map[3] = "<SPACE>" + + offset = 4 + + for idx, char in enumerate(chars): + idx += offset + train_tokenizer.char_map[char] = idx + train_tokenizer.index_map[idx] = char + + train_tokenizer_dir = os.path.join("data/tokenizers") + train_tokenizer_path = os.path.join( + train_tokenizer_dir, + f"char_tokenizer_{language}_{datetime.now().strftime('%Y-%m-%d_%H-%M')}.json", + ) + + if not os.path.exists(os.path.dirname(train_tokenizer_dir)): + os.makedirs(train_tokenizer_dir) + train_tokenizer.save(train_tokenizer_path) + + return train_tokenizer def save(self, path: str) -> None: """Save the tokenizer to a file""" @@ -114,8 +125,7 @@ class CharTokenizer: if __name__ == "__main__": - tokenizer = CharTokenizer() - tokenizer.save("data/tokenizers/char_tokenizer_german.json") + tokenizer = CharTokenizer.train("/Volumes/pherkel 1/SWR2-ASR", "mls_german_opus") print(tokenizer.char_map) print(tokenizer.index_map) print(tokenizer.get_vocab_size()) |