diff options
Diffstat (limited to 'swr2_asr/tokenizer.py')
-rw-r--r-- | swr2_asr/tokenizer.py | 72 |
1 files changed, 48 insertions, 24 deletions
diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py index 4dbb386..5758da7 100644 --- a/swr2_asr/tokenizer.py +++ b/swr2_asr/tokenizer.py @@ -1,16 +1,50 @@ """Tokenizer for use with Multilingual Librispeech""" -from dataclasses import dataclass import json import os -import click -from tqdm import tqdm +from dataclasses import dataclass +from typing import Type +import click from AudioLoader.speech import MultilingualLibriSpeech - from tokenizers import Tokenizer, normalizers from tokenizers.models import BPE -from tokenizers.trainers import BpeTrainer from tokenizers.pre_tokenizers import Whitespace +from tokenizers.trainers import BpeTrainer +from tqdm import tqdm + + +class TokenizerType: + def encode(self, sequence: str) -> list[int]: + raise NotImplementedError + + def decode(self, labels: list[int], remove_special_tokens: bool) -> str: + raise NotImplementedError + + def decode_batch(self, labels: list[list[int]]) -> list[str]: + raise NotImplementedError + + def get_vocab_size(self) -> int: + raise NotImplementedError + + def enable_padding( + self, + length: int = -1, + direction: str = "right", + pad_id: int = 0, + pad_type_id: int = 0, + pad_token: str = "[PAD]", + ) -> None: + raise NotImplementedError + + def save(self, path: str) -> None: + raise NotImplementedError + + @staticmethod + def from_file(path: str) -> "TokenizerType": + raise NotImplementedError + + +tokenizer_type = Type[TokenizerType] @dataclass @@ -20,7 +54,7 @@ class Encoding: ids: list[int] -class CharTokenizer: +class CharTokenizer(TokenizerType): """Very simple tokenizer for use with Multilingual Librispeech Simply checks what characters are in the dataset and uses them as tokens. @@ -45,9 +79,7 @@ class CharTokenizer: self.char_map[token] = len(self.char_map) self.index_map[len(self.index_map)] = token - def train( - self, dataset_path: str, language: str, split: str, download: bool = True - ): + def train(self, dataset_path: str, language: str, split: str, download: bool = True): """Train the tokenizer on the given dataset Args: @@ -65,9 +97,7 @@ class CharTokenizer: chars: set = set() for s_plit in splits: - transcript_path = os.path.join( - dataset_path, language, s_plit, "transcripts.txt" - ) + transcript_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt") # check if dataset is downloaded, download if not if download and not os.path.exists(transcript_path): @@ -90,7 +120,7 @@ class CharTokenizer: self.char_map[char] = i self.index_map[i] = char - def encode(self, text: str): + def encode(self, sequence: str): """Use a character map and convert text to an integer sequence automatically maps spaces to <SPACE> and makes everything lowercase @@ -98,8 +128,8 @@ class CharTokenizer: """ int_sequence = [] - text = text.lower() - for char in text: + sequence = sequence.lower() + for char in sequence: if char == " ": mapped_char = self.char_map["<SPACE>"] elif char not in self.char_map: @@ -174,9 +204,7 @@ class CharTokenizer: @click.option("--language", default="mls_german_opus", help="Language to use") @click.option("--split", default="train", help="Split to use (including all)") @click.option("--download", default=True, help="Whether to download the dataset") -@click.option( - "--out_path", default="tokenizer.json", help="Path to save the tokenizer to" -) +@click.option("--out_path", default="tokenizer.json", help="Path to save the tokenizer to") @click.option("--vocab_size", default=2000, help="Size of the vocabulary") def train_bpe_tokenizer( dataset_path: str, @@ -210,9 +238,7 @@ def train_bpe_tokenizer( lines = [] for s_plit in splits: - transcripts_path = os.path.join( - dataset_path, language, s_plit, "transcripts.txt" - ) + transcripts_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt") if download and not os.path.exists(transcripts_path): MultilingualLibriSpeech(dataset_path, language, s_plit, download=True) @@ -296,9 +322,7 @@ def train_bpe_tokenizer( @click.option("--dataset_path", default="data", help="Path to the MLS dataset") @click.option("--language", default="mls_german_opus", help="Language to use") @click.option("--split", default="train", help="Split to use") -@click.option( - "--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to" -) +@click.option("--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to") @click.option("--download", default=True, help="Whether to download the dataset") def train_char_tokenizer( dataset_path: str, |