aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/utils/tokenizer.py
blob: 5482bbe6116226b14bd41aebba34b2a2437f5f5d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""Tokenizer for Multilingual Librispeech datasets"""


from datetime import datetime
import os

from tqdm.autonotebook import tqdm


class CharTokenizer:
    """Maps characters to integers and vice versa"""

    def __init__(self):
        self.char_map = {}
        self.index_map = {}

    def encode(self, text: str) -> list[int]:
        """Use a character map and convert text to an integer sequence"""
        int_sequence = []
        for char in text:
            if char == " ":
                char = self.char_map["<SPACE>"]
            elif char not in self.char_map:
                char = self.char_map["<UNK>"]
            else:
                char = self.char_map[char]
            int_sequence.append(char)
        return int_sequence

    def decode(self, labels: list[int]) -> str:
        """Use a character map and convert integer labels to an text sequence"""
        string = []
        for i in labels:
            string.append(self.index_map[i])
        return "".join(string).replace("<SPACE>", " ")

    def get_vocab_size(self) -> int:
        """Get the number of unique characters in the dataset"""
        return len(self.char_map)

    def get_blank_token(self) -> int:
        """Get the integer representation of the <BLANK> character"""
        return self.char_map["<BLANK>"]

    def get_unk_token(self) -> int:
        """Get the integer representation of the <UNK> character"""
        return self.char_map["<UNK>"]

    def get_space_token(self) -> int:
        """Get the integer representation of the <SPACE> character"""
        return self.char_map["<SPACE>"]

    @staticmethod
    def train(dataset_path: str, language: str) -> "CharTokenizer":
        """Train the tokenizer on a dataset"""
        chars = set()
        root_path = os.path.join(dataset_path, language)
        for split in os.listdir(root_path):
            split_dir = os.path.join(root_path, split)
            if os.path.isdir(split_dir):
                transcript_path = os.path.join(split_dir, "transcripts.txt")

                with open(transcript_path, "r", encoding="utf-8") as transcrips:
                    lines = transcrips.readlines()
                lines = [line.split(" ", 1)[1] for line in lines]
                lines = [line.strip() for line in lines]
                lines = [line.lower() for line in lines]

                for line in tqdm(lines, desc=f"Training tokenizer on {split_dir} split"):
                    chars.update(line)

        # sort chars
        chars.remove(" ")
        chars = sorted(chars)

        train_tokenizer = CharTokenizer()

        train_tokenizer.char_map["_"] = 0
        train_tokenizer.char_map["<BLANK>"] = 1
        train_tokenizer.char_map["<UNK>"] = 2
        train_tokenizer.char_map["<SPACE>"] = 3

        train_tokenizer.index_map[0] = "_"
        train_tokenizer.index_map[1] = "<BLANK>"
        train_tokenizer.index_map[2] = "<UNK>"
        train_tokenizer.index_map[3] = "<SPACE>"

        offset = 4

        for idx, char in enumerate(chars):
            idx += offset
            train_tokenizer.char_map[char] = idx
            train_tokenizer.index_map[idx] = char

        train_tokenizer_dir = os.path.join("data/tokenizers")
        train_tokenizer_path = os.path.join(
            train_tokenizer_dir,
            f"char_tokenizer_{language}_{datetime.now().strftime('%Y-%m-%d_%H-%M')}.json",
        )

        if not os.path.exists(os.path.dirname(train_tokenizer_dir)):
            os.makedirs(train_tokenizer_dir)
        train_tokenizer.save(train_tokenizer_path)

        return train_tokenizer

    def save(self, path: str) -> None:
        """Save the tokenizer to a file"""
        with open(path, "w", encoding="utf-8") as file:
            for char, index in self.char_map.items():
                file.write(f"{char} {index}\n")

    @staticmethod
    def from_file(tokenizer_file: str) -> "CharTokenizer":
        """Instantiate a CharTokenizer from a file"""
        load_tokenizer = CharTokenizer()
        with open(tokenizer_file, "r", encoding="utf-8") as file:
            for line in file:
                line = line.strip()
                if line:
                    char, index = line.split()
                    tokenizer.char_map[char] = int(index)
                    tokenizer.index_map[int(index)] = char
        return load_tokenizer


if __name__ == "__main__":
    tokenizer = CharTokenizer.train("/Volumes/pherkel 1/SWR2-ASR", "mls_german_opus")
    print(tokenizer.char_map)
    print(tokenizer.index_map)
    print(tokenizer.get_vocab_size())
    print(tokenizer.get_blank_token())
    print(tokenizer.get_unk_token())
    print(tokenizer.get_space_token())
    print(tokenizer.encode("hallo welt"))
    print(tokenizer.decode([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]))