aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/tokenizer.py
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/tokenizer.py')
-rw-r--r--swr2_asr/tokenizer.py17
1 files changed, 15 insertions, 2 deletions
diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py
index 8e3bf09..c8d3793 100644
--- a/swr2_asr/tokenizer.py
+++ b/swr2_asr/tokenizer.py
@@ -14,16 +14,24 @@ from tqdm import tqdm
class TokenizerType:
+ """Base class for tokenizers.
+
+ exposes the same interface as tokenizers from the huggingface library"""
+
def encode(self, sequence: str) -> list[int]:
+ """Encode a sequence to a list of integer labels"""
raise NotImplementedError
def decode(self, labels: list[int], remove_special_tokens: bool) -> str:
+ """Decode a list of integer labels to a sequence"""
raise NotImplementedError
def decode_batch(self, labels: list[list[int]]) -> list[str]:
+ """Decode a batch of integer labels to a list of sequences"""
raise NotImplementedError
def get_vocab_size(self) -> int:
+ """Get the size of the vocabulary"""
raise NotImplementedError
def enable_padding(
@@ -34,17 +42,20 @@ class TokenizerType:
pad_type_id: int = 0,
pad_token: str = "[PAD]",
) -> None:
+ """Enable padding for the tokenizer"""
raise NotImplementedError
def save(self, path: str) -> None:
+ """Save the tokenizer to a file"""
raise NotImplementedError
@staticmethod
def from_file(path: str) -> "TokenizerType":
+ """Load the tokenizer from a file"""
raise NotImplementedError
-tokenizer_type = Type[TokenizerType]
+MyTokenizerType = Type[TokenizerType]
@dataclass
@@ -52,6 +63,7 @@ class Encoding:
"""Simple dataclass to represent an encoding"""
ids: list[int]
+ tokens: list[str]
class CharTokenizer(TokenizerType):
@@ -137,7 +149,7 @@ class CharTokenizer(TokenizerType):
else:
mapped_char = self.char_map[char]
int_sequence.append(mapped_char)
- return Encoding(ids=int_sequence)
+ return Encoding(ids=int_sequence, tokens=list(sequence))
def decode(self, labels: list[int], remove_special_tokens: bool = True):
"""Use a character map and convert integer labels to an text sequence
@@ -256,6 +268,7 @@ def train_bpe_tokenizer(
bpe_tokenizer = Tokenizer(BPE(unk_token="[UNK]"))
initial_alphabet = [
+ " ",
"a",
"b",
"c",