aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/tokenizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/tokenizer.py')
-rw-r--r--swr2_asr/tokenizer.py72
1 files changed, 48 insertions, 24 deletions
diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py
index 4dbb386..5758da7 100644
--- a/swr2_asr/tokenizer.py
+++ b/swr2_asr/tokenizer.py
@@ -1,16 +1,50 @@
"""Tokenizer for use with Multilingual Librispeech"""
-from dataclasses import dataclass
import json
import os
-import click
-from tqdm import tqdm
+from dataclasses import dataclass
+from typing import Type
+import click
from AudioLoader.speech import MultilingualLibriSpeech
-
from tokenizers import Tokenizer, normalizers
from tokenizers.models import BPE
-from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
+from tokenizers.trainers import BpeTrainer
+from tqdm import tqdm
+
+
+class TokenizerType:
+ def encode(self, sequence: str) -> list[int]:
+ raise NotImplementedError
+
+ def decode(self, labels: list[int], remove_special_tokens: bool) -> str:
+ raise NotImplementedError
+
+ def decode_batch(self, labels: list[list[int]]) -> list[str]:
+ raise NotImplementedError
+
+ def get_vocab_size(self) -> int:
+ raise NotImplementedError
+
+ def enable_padding(
+ self,
+ length: int = -1,
+ direction: str = "right",
+ pad_id: int = 0,
+ pad_type_id: int = 0,
+ pad_token: str = "[PAD]",
+ ) -> None:
+ raise NotImplementedError
+
+ def save(self, path: str) -> None:
+ raise NotImplementedError
+
+ @staticmethod
+ def from_file(path: str) -> "TokenizerType":
+ raise NotImplementedError
+
+
+tokenizer_type = Type[TokenizerType]
@dataclass
@@ -20,7 +54,7 @@ class Encoding:
ids: list[int]
-class CharTokenizer:
+class CharTokenizer(TokenizerType):
"""Very simple tokenizer for use with Multilingual Librispeech
Simply checks what characters are in the dataset and uses them as tokens.
@@ -45,9 +79,7 @@ class CharTokenizer:
self.char_map[token] = len(self.char_map)
self.index_map[len(self.index_map)] = token
- def train(
- self, dataset_path: str, language: str, split: str, download: bool = True
- ):
+ def train(self, dataset_path: str, language: str, split: str, download: bool = True):
"""Train the tokenizer on the given dataset
Args:
@@ -65,9 +97,7 @@ class CharTokenizer:
chars: set = set()
for s_plit in splits:
- transcript_path = os.path.join(
- dataset_path, language, s_plit, "transcripts.txt"
- )
+ transcript_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt")
# check if dataset is downloaded, download if not
if download and not os.path.exists(transcript_path):
@@ -90,7 +120,7 @@ class CharTokenizer:
self.char_map[char] = i
self.index_map[i] = char
- def encode(self, text: str):
+ def encode(self, sequence: str):
"""Use a character map and convert text to an integer sequence
automatically maps spaces to <SPACE> and makes everything lowercase
@@ -98,8 +128,8 @@ class CharTokenizer:
"""
int_sequence = []
- text = text.lower()
- for char in text:
+ sequence = sequence.lower()
+ for char in sequence:
if char == " ":
mapped_char = self.char_map["<SPACE>"]
elif char not in self.char_map:
@@ -174,9 +204,7 @@ class CharTokenizer:
@click.option("--language", default="mls_german_opus", help="Language to use")
@click.option("--split", default="train", help="Split to use (including all)")
@click.option("--download", default=True, help="Whether to download the dataset")
-@click.option(
- "--out_path", default="tokenizer.json", help="Path to save the tokenizer to"
-)
+@click.option("--out_path", default="tokenizer.json", help="Path to save the tokenizer to")
@click.option("--vocab_size", default=2000, help="Size of the vocabulary")
def train_bpe_tokenizer(
dataset_path: str,
@@ -210,9 +238,7 @@ def train_bpe_tokenizer(
lines = []
for s_plit in splits:
- transcripts_path = os.path.join(
- dataset_path, language, s_plit, "transcripts.txt"
- )
+ transcripts_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt")
if download and not os.path.exists(transcripts_path):
MultilingualLibriSpeech(dataset_path, language, s_plit, download=True)
@@ -296,9 +322,7 @@ def train_bpe_tokenizer(
@click.option("--dataset_path", default="data", help="Path to the MLS dataset")
@click.option("--language", default="mls_german_opus", help="Language to use")
@click.option("--split", default="train", help="Split to use")
-@click.option(
- "--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to"
-)
+@click.option("--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to")
@click.option("--download", default=True, help="Whether to download the dataset")
def train_char_tokenizer(
dataset_path: str,