aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/utils
diff options
context:
space:
mode:
authorPherkel2023-09-11 20:45:32 +0200
committerPherkel2023-09-11 20:45:32 +0200
commit8be140b38183b7465b5888a15b536a5f7fa66db6 (patch)
tree68737b56d9859c139eb8e998cf50813ec7c68bdf /swr2_asr/utils
parentc078ce6789c134aa05607903d3bf9e4be64df45d (diff)
added tokenizer to git and tokenizer training routing
Diffstat (limited to 'swr2_asr/utils')
-rw-r--r--swr2_asr/utils/tokenizer.py110
1 files changed, 60 insertions, 50 deletions
diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py
index d92465a..5482bbe 100644
--- a/swr2_asr/utils/tokenizer.py
+++ b/swr2_asr/utils/tokenizer.py
@@ -1,59 +1,18 @@
"""Tokenizer for Multilingual Librispeech datasets"""
+from datetime import datetime
+import os
+
+from tqdm.autonotebook import tqdm
+
+
class CharTokenizer:
"""Maps characters to integers and vice versa"""
def __init__(self):
- char_map_str = """
- _
- <BLANK>
- <UNK>
- <SPACE>
- a
- b
- c
- d
- e
- f
- g
- h
- i
- j
- k
- l
- m
- n
- o
- p
- q
- r
- s
- t
- u
- v
- w
- x
- y
- z
- é
- à
- ä
- ö
- ß
- ü
- -
- '
-
- """
-
self.char_map = {}
self.index_map = {}
- for idx, char in enumerate(char_map_str.strip().split("\n")):
- char = char.strip()
- self.char_map[char] = idx
- self.index_map[idx] = char
- self.index_map[1] = " "
def encode(self, text: str) -> list[int]:
"""Use a character map and convert text to an integer sequence"""
@@ -91,7 +50,59 @@ class CharTokenizer:
"""Get the integer representation of the <SPACE> character"""
return self.char_map["<SPACE>"]
- # TODO: add train function
+ @staticmethod
+ def train(dataset_path: str, language: str) -> "CharTokenizer":
+ """Train the tokenizer on a dataset"""
+ chars = set()
+ root_path = os.path.join(dataset_path, language)
+ for split in os.listdir(root_path):
+ split_dir = os.path.join(root_path, split)
+ if os.path.isdir(split_dir):
+ transcript_path = os.path.join(split_dir, "transcripts.txt")
+
+ with open(transcript_path, "r", encoding="utf-8") as transcrips:
+ lines = transcrips.readlines()
+ lines = [line.split(" ", 1)[1] for line in lines]
+ lines = [line.strip() for line in lines]
+ lines = [line.lower() for line in lines]
+
+ for line in tqdm(lines, desc=f"Training tokenizer on {split_dir} split"):
+ chars.update(line)
+
+ # sort chars
+ chars.remove(" ")
+ chars = sorted(chars)
+
+ train_tokenizer = CharTokenizer()
+
+ train_tokenizer.char_map["_"] = 0
+ train_tokenizer.char_map["<BLANK>"] = 1
+ train_tokenizer.char_map["<UNK>"] = 2
+ train_tokenizer.char_map["<SPACE>"] = 3
+
+ train_tokenizer.index_map[0] = "_"
+ train_tokenizer.index_map[1] = "<BLANK>"
+ train_tokenizer.index_map[2] = "<UNK>"
+ train_tokenizer.index_map[3] = "<SPACE>"
+
+ offset = 4
+
+ for idx, char in enumerate(chars):
+ idx += offset
+ train_tokenizer.char_map[char] = idx
+ train_tokenizer.index_map[idx] = char
+
+ train_tokenizer_dir = os.path.join("data/tokenizers")
+ train_tokenizer_path = os.path.join(
+ train_tokenizer_dir,
+ f"char_tokenizer_{language}_{datetime.now().strftime('%Y-%m-%d_%H-%M')}.json",
+ )
+
+ if not os.path.exists(os.path.dirname(train_tokenizer_dir)):
+ os.makedirs(train_tokenizer_dir)
+ train_tokenizer.save(train_tokenizer_path)
+
+ return train_tokenizer
def save(self, path: str) -> None:
"""Save the tokenizer to a file"""
@@ -114,8 +125,7 @@ class CharTokenizer:
if __name__ == "__main__":
- tokenizer = CharTokenizer()
- tokenizer.save("data/tokenizers/char_tokenizer_german.json")
+ tokenizer = CharTokenizer.train("/Volumes/pherkel 1/SWR2-ASR", "mls_german_opus")
print(tokenizer.char_map)
print(tokenizer.index_map)
print(tokenizer.get_vocab_size())