diff options
Diffstat (limited to 'swr2_asr/tokenizer.py')
-rw-r--r-- | swr2_asr/tokenizer.py | 144 |
1 files changed, 104 insertions, 40 deletions
diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py index a665159..e4df93b 100644 --- a/swr2_asr/tokenizer.py +++ b/swr2_asr/tokenizer.py @@ -1,16 +1,60 @@ """Tokenizer for use with Multilingual Librispeech""" -from dataclasses import dataclass import json import os -import click -from tqdm import tqdm - -from AudioLoader.speech import MultilingualLibriSpeech +from dataclasses import dataclass +from typing import Type +import click from tokenizers import Tokenizer, normalizers from tokenizers.models import BPE -from tokenizers.trainers import BpeTrainer from tokenizers.pre_tokenizers import Whitespace +from tokenizers.trainers import BpeTrainer +from tqdm import tqdm + + +class TokenizerType: + """Base class for tokenizers. + + exposes the same interface as tokenizers from the huggingface library""" + + def encode(self, sequence: str) -> list[int]: + """Encode a sequence to a list of integer labels""" + raise NotImplementedError + + def decode(self, labels: list[int], remove_special_tokens: bool) -> str: + """Decode a list of integer labels to a sequence""" + raise NotImplementedError + + def decode_batch(self, labels: list[list[int]]) -> list[str]: + """Decode a batch of integer labels to a list of sequences""" + raise NotImplementedError + + def get_vocab_size(self) -> int: + """Get the size of the vocabulary""" + raise NotImplementedError + + def enable_padding( + self, + length: int = -1, + direction: str = "right", + pad_id: int = 0, + pad_type_id: int = 0, + pad_token: str = "[PAD]", + ) -> None: + """Enable padding for the tokenizer""" + raise NotImplementedError + + def save(self, path: str) -> None: + """Save the tokenizer to a file""" + raise NotImplementedError + + @staticmethod + def from_file(path: str) -> "TokenizerType": + """Load the tokenizer from a file""" + raise NotImplementedError + + +MyTokenizerType = Type[TokenizerType] @dataclass @@ -18,9 +62,10 @@ class Encoding: """Simple dataclass to represent an encoding""" ids: list[int] + tokens: list[str] -class CharTokenizer: +class CharTokenizer(TokenizerType): """Very simple tokenizer for use with Multilingual Librispeech Simply checks what characters are in the dataset and uses them as tokens. @@ -45,9 +90,7 @@ class CharTokenizer: self.char_map[token] = len(self.char_map) self.index_map[len(self.index_map)] = token - def train( - self, dataset_path: str, language: str, split: str, download: bool = True - ): + def train(self, dataset_path: str, language: str, split: str): """Train the tokenizer on the given dataset Args: @@ -65,13 +108,7 @@ class CharTokenizer: chars: set = set() for s_plit in splits: - transcript_path = os.path.join( - dataset_path, language, s_plit, "transcripts.txt" - ) - - # check if dataset is downloaded, download if not - if download and not os.path.exists(transcript_path): - MultilingualLibriSpeech(dataset_path, language, s_plit, download=True) + transcript_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt") with open( transcript_path, @@ -90,7 +127,7 @@ class CharTokenizer: self.char_map[char] = i self.index_map[i] = char - def encode(self, text: str): + def encode(self, sequence: str): """Use a character map and convert text to an integer sequence automatically maps spaces to <SPACE> and makes everything lowercase @@ -98,8 +135,8 @@ class CharTokenizer: """ int_sequence = [] - text = text.lower() - for char in text: + sequence = sequence.lower() + for char in sequence: if char == " ": mapped_char = self.char_map["<SPACE>"] elif char not in self.char_map: @@ -107,7 +144,7 @@ class CharTokenizer: else: mapped_char = self.char_map[char] int_sequence.append(mapped_char) - return Encoding(ids=int_sequence) + return Encoding(ids=int_sequence, tokens=list(sequence)) def decode(self, labels: list[int], remove_special_tokens: bool = True): """Use a character map and convert integer labels to an text sequence @@ -146,6 +183,7 @@ class CharTokenizer: def save(self, path: str): """Save the tokenizer to a file""" + os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "w", encoding="utf-8") as file: # save it in the following format: # {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}} @@ -155,31 +193,48 @@ class CharTokenizer: ensure_ascii=False, ) - def from_file(self, path: str): + @staticmethod + def from_file(path: str) -> "CharTokenizer": """Load the tokenizer from a file""" + char_tokenizer = CharTokenizer() with open(path, "r", encoding="utf-8") as file: # load it in the following format: # {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}} saved_file = json.load(file) - self.char_map = saved_file["char_map"] - self.index_map = saved_file["index_map"] + char_tokenizer.char_map = saved_file["char_map"] + char_tokenizer.index_map = saved_file["index_map"] + + return char_tokenizer @click.command() @click.option("--dataset_path", default="data", help="Path to the MLS dataset") @click.option("--language", default="mls_german_opus", help="Language to use") @click.option("--split", default="train", help="Split to use (including all)") -@click.option("--download", default=True, help="Whether to download the dataset") -@click.option( - "--out_path", default="tokenizer.json", help="Path to save the tokenizer to" -) +@click.option("--out_path", default="tokenizer.json", help="Path to save the tokenizer to") @click.option("--vocab_size", default=2000, help="Size of the vocabulary") +def train_bpe_tokenizer_cli( + dataset_path: str, + language: str, + split: str, + out_path: str, + vocab_size: int, +): + """Train a Byte-Pair Encoder tokenizer on the MLS dataset""" + train_bpe_tokenizer( + dataset_path, + language, + split, + out_path, + vocab_size, + ) + + def train_bpe_tokenizer( dataset_path: str, language: str, split: str, out_path: str, - download: bool, vocab_size: int, ): """Train a Byte-Pair Encoder tokenizer on the MLS dataset @@ -206,11 +261,12 @@ def train_bpe_tokenizer( lines = [] for s_plit in splits: - transcripts_path = os.path.join( - dataset_path, language, s_plit, "transcripts.txt" - ) - if download and not os.path.exists(transcripts_path): - MultilingualLibriSpeech(dataset_path, language, s_plit, download=True) + transcripts_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt") + if not os.path.exists(transcripts_path): + raise FileNotFoundError( + f"Could not find transcripts.txt in {transcripts_path}. " + "Please make sure that the dataset is downloaded." + ) with open( transcripts_path, @@ -226,6 +282,7 @@ def train_bpe_tokenizer( bpe_tokenizer = Tokenizer(BPE(unk_token="[UNK]")) initial_alphabet = [ + " ", "a", "b", "c", @@ -272,6 +329,7 @@ def train_bpe_tokenizer( "ü", ] + # TODO: add padding token / whitespace token / special tokens trainer = BpeTrainer( special_tokens=["[UNK]"], vocab_size=vocab_size, @@ -292,16 +350,22 @@ def train_bpe_tokenizer( @click.option("--dataset_path", default="data", help="Path to the MLS dataset") @click.option("--language", default="mls_german_opus", help="Language to use") @click.option("--split", default="train", help="Split to use") -@click.option( - "--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to" -) -@click.option("--download", default=True, help="Whether to download the dataset") +@click.option("--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to") +def train_char_tokenizer_cli( + dataset_path: str, + language: str, + split: str, + out_path: str, +): + """Train a Byte-Pair Encoder tokenizer on the MLS dataset""" + train_char_tokenizer(dataset_path, language, split, out_path) + + def train_char_tokenizer( dataset_path: str, language: str, split: str, out_path: str, - download: bool, ): """Train a Byte-Pair Encoder tokenizer on the MLS dataset @@ -317,7 +381,7 @@ def train_char_tokenizer( """ char_tokenizer = CharTokenizer() - char_tokenizer.train(dataset_path, language, split, download) + char_tokenizer.train(dataset_path, language, split) char_tokenizer.save(out_path) |