diff options
Diffstat (limited to 'swr2_asr/tokenizer.py')
-rw-r--r-- | swr2_asr/tokenizer.py | 17 |
1 files changed, 15 insertions, 2 deletions
diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py index 8e3bf09..c8d3793 100644 --- a/swr2_asr/tokenizer.py +++ b/swr2_asr/tokenizer.py @@ -14,16 +14,24 @@ from tqdm import tqdm class TokenizerType: + """Base class for tokenizers. + + exposes the same interface as tokenizers from the huggingface library""" + def encode(self, sequence: str) -> list[int]: + """Encode a sequence to a list of integer labels""" raise NotImplementedError def decode(self, labels: list[int], remove_special_tokens: bool) -> str: + """Decode a list of integer labels to a sequence""" raise NotImplementedError def decode_batch(self, labels: list[list[int]]) -> list[str]: + """Decode a batch of integer labels to a list of sequences""" raise NotImplementedError def get_vocab_size(self) -> int: + """Get the size of the vocabulary""" raise NotImplementedError def enable_padding( @@ -34,17 +42,20 @@ class TokenizerType: pad_type_id: int = 0, pad_token: str = "[PAD]", ) -> None: + """Enable padding for the tokenizer""" raise NotImplementedError def save(self, path: str) -> None: + """Save the tokenizer to a file""" raise NotImplementedError @staticmethod def from_file(path: str) -> "TokenizerType": + """Load the tokenizer from a file""" raise NotImplementedError -tokenizer_type = Type[TokenizerType] +MyTokenizerType = Type[TokenizerType] @dataclass @@ -52,6 +63,7 @@ class Encoding: """Simple dataclass to represent an encoding""" ids: list[int] + tokens: list[str] class CharTokenizer(TokenizerType): @@ -137,7 +149,7 @@ class CharTokenizer(TokenizerType): else: mapped_char = self.char_map[char] int_sequence.append(mapped_char) - return Encoding(ids=int_sequence) + return Encoding(ids=int_sequence, tokens=list(sequence)) def decode(self, labels: list[int], remove_special_tokens: bool = True): """Use a character map and convert integer labels to an text sequence @@ -256,6 +268,7 @@ def train_bpe_tokenizer( bpe_tokenizer = Tokenizer(BPE(unk_token="[UNK]")) initial_alphabet = [ + " ", "a", "b", "c", |