aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/tokenizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/tokenizer.py')
-rw-r--r--swr2_asr/tokenizer.py144
1 files changed, 104 insertions, 40 deletions
diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py
index a665159..e4df93b 100644
--- a/swr2_asr/tokenizer.py
+++ b/swr2_asr/tokenizer.py
@@ -1,16 +1,60 @@
"""Tokenizer for use with Multilingual Librispeech"""
-from dataclasses import dataclass
import json
import os
-import click
-from tqdm import tqdm
-
-from AudioLoader.speech import MultilingualLibriSpeech
+from dataclasses import dataclass
+from typing import Type
+import click
from tokenizers import Tokenizer, normalizers
from tokenizers.models import BPE
-from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
+from tokenizers.trainers import BpeTrainer
+from tqdm import tqdm
+
+
+class TokenizerType:
+ """Base class for tokenizers.
+
+ exposes the same interface as tokenizers from the huggingface library"""
+
+ def encode(self, sequence: str) -> list[int]:
+ """Encode a sequence to a list of integer labels"""
+ raise NotImplementedError
+
+ def decode(self, labels: list[int], remove_special_tokens: bool) -> str:
+ """Decode a list of integer labels to a sequence"""
+ raise NotImplementedError
+
+ def decode_batch(self, labels: list[list[int]]) -> list[str]:
+ """Decode a batch of integer labels to a list of sequences"""
+ raise NotImplementedError
+
+ def get_vocab_size(self) -> int:
+ """Get the size of the vocabulary"""
+ raise NotImplementedError
+
+ def enable_padding(
+ self,
+ length: int = -1,
+ direction: str = "right",
+ pad_id: int = 0,
+ pad_type_id: int = 0,
+ pad_token: str = "[PAD]",
+ ) -> None:
+ """Enable padding for the tokenizer"""
+ raise NotImplementedError
+
+ def save(self, path: str) -> None:
+ """Save the tokenizer to a file"""
+ raise NotImplementedError
+
+ @staticmethod
+ def from_file(path: str) -> "TokenizerType":
+ """Load the tokenizer from a file"""
+ raise NotImplementedError
+
+
+MyTokenizerType = Type[TokenizerType]
@dataclass
@@ -18,9 +62,10 @@ class Encoding:
"""Simple dataclass to represent an encoding"""
ids: list[int]
+ tokens: list[str]
-class CharTokenizer:
+class CharTokenizer(TokenizerType):
"""Very simple tokenizer for use with Multilingual Librispeech
Simply checks what characters are in the dataset and uses them as tokens.
@@ -45,9 +90,7 @@ class CharTokenizer:
self.char_map[token] = len(self.char_map)
self.index_map[len(self.index_map)] = token
- def train(
- self, dataset_path: str, language: str, split: str, download: bool = True
- ):
+ def train(self, dataset_path: str, language: str, split: str):
"""Train the tokenizer on the given dataset
Args:
@@ -65,13 +108,7 @@ class CharTokenizer:
chars: set = set()
for s_plit in splits:
- transcript_path = os.path.join(
- dataset_path, language, s_plit, "transcripts.txt"
- )
-
- # check if dataset is downloaded, download if not
- if download and not os.path.exists(transcript_path):
- MultilingualLibriSpeech(dataset_path, language, s_plit, download=True)
+ transcript_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt")
with open(
transcript_path,
@@ -90,7 +127,7 @@ class CharTokenizer:
self.char_map[char] = i
self.index_map[i] = char
- def encode(self, text: str):
+ def encode(self, sequence: str):
"""Use a character map and convert text to an integer sequence
automatically maps spaces to <SPACE> and makes everything lowercase
@@ -98,8 +135,8 @@ class CharTokenizer:
"""
int_sequence = []
- text = text.lower()
- for char in text:
+ sequence = sequence.lower()
+ for char in sequence:
if char == " ":
mapped_char = self.char_map["<SPACE>"]
elif char not in self.char_map:
@@ -107,7 +144,7 @@ class CharTokenizer:
else:
mapped_char = self.char_map[char]
int_sequence.append(mapped_char)
- return Encoding(ids=int_sequence)
+ return Encoding(ids=int_sequence, tokens=list(sequence))
def decode(self, labels: list[int], remove_special_tokens: bool = True):
"""Use a character map and convert integer labels to an text sequence
@@ -146,6 +183,7 @@ class CharTokenizer:
def save(self, path: str):
"""Save the tokenizer to a file"""
+ os.makedirs(os.path.dirname(path), exist_ok=True)
with open(path, "w", encoding="utf-8") as file:
# save it in the following format:
# {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}}
@@ -155,31 +193,48 @@ class CharTokenizer:
ensure_ascii=False,
)
- def from_file(self, path: str):
+ @staticmethod
+ def from_file(path: str) -> "CharTokenizer":
"""Load the tokenizer from a file"""
+ char_tokenizer = CharTokenizer()
with open(path, "r", encoding="utf-8") as file:
# load it in the following format:
# {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}}
saved_file = json.load(file)
- self.char_map = saved_file["char_map"]
- self.index_map = saved_file["index_map"]
+ char_tokenizer.char_map = saved_file["char_map"]
+ char_tokenizer.index_map = saved_file["index_map"]
+
+ return char_tokenizer
@click.command()
@click.option("--dataset_path", default="data", help="Path to the MLS dataset")
@click.option("--language", default="mls_german_opus", help="Language to use")
@click.option("--split", default="train", help="Split to use (including all)")
-@click.option("--download", default=True, help="Whether to download the dataset")
-@click.option(
- "--out_path", default="tokenizer.json", help="Path to save the tokenizer to"
-)
+@click.option("--out_path", default="tokenizer.json", help="Path to save the tokenizer to")
@click.option("--vocab_size", default=2000, help="Size of the vocabulary")
+def train_bpe_tokenizer_cli(
+ dataset_path: str,
+ language: str,
+ split: str,
+ out_path: str,
+ vocab_size: int,
+):
+ """Train a Byte-Pair Encoder tokenizer on the MLS dataset"""
+ train_bpe_tokenizer(
+ dataset_path,
+ language,
+ split,
+ out_path,
+ vocab_size,
+ )
+
+
def train_bpe_tokenizer(
dataset_path: str,
language: str,
split: str,
out_path: str,
- download: bool,
vocab_size: int,
):
"""Train a Byte-Pair Encoder tokenizer on the MLS dataset
@@ -206,11 +261,12 @@ def train_bpe_tokenizer(
lines = []
for s_plit in splits:
- transcripts_path = os.path.join(
- dataset_path, language, s_plit, "transcripts.txt"
- )
- if download and not os.path.exists(transcripts_path):
- MultilingualLibriSpeech(dataset_path, language, s_plit, download=True)
+ transcripts_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt")
+ if not os.path.exists(transcripts_path):
+ raise FileNotFoundError(
+ f"Could not find transcripts.txt in {transcripts_path}. "
+ "Please make sure that the dataset is downloaded."
+ )
with open(
transcripts_path,
@@ -226,6 +282,7 @@ def train_bpe_tokenizer(
bpe_tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
initial_alphabet = [
+ " ",
"a",
"b",
"c",
@@ -272,6 +329,7 @@ def train_bpe_tokenizer(
"ü",
]
+ # TODO: add padding token / whitespace token / special tokens
trainer = BpeTrainer(
special_tokens=["[UNK]"],
vocab_size=vocab_size,
@@ -292,16 +350,22 @@ def train_bpe_tokenizer(
@click.option("--dataset_path", default="data", help="Path to the MLS dataset")
@click.option("--language", default="mls_german_opus", help="Language to use")
@click.option("--split", default="train", help="Split to use")
-@click.option(
- "--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to"
-)
-@click.option("--download", default=True, help="Whether to download the dataset")
+@click.option("--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to")
+def train_char_tokenizer_cli(
+ dataset_path: str,
+ language: str,
+ split: str,
+ out_path: str,
+):
+ """Train a Byte-Pair Encoder tokenizer on the MLS dataset"""
+ train_char_tokenizer(dataset_path, language, split, out_path)
+
+
def train_char_tokenizer(
dataset_path: str,
language: str,
split: str,
out_path: str,
- download: bool,
):
"""Train a Byte-Pair Encoder tokenizer on the MLS dataset
@@ -317,7 +381,7 @@ def train_char_tokenizer(
"""
char_tokenizer = CharTokenizer()
- char_tokenizer.train(dataset_path, language, split, download)
+ char_tokenizer.train(dataset_path, language, split)
char_tokenizer.save(out_path)