diff options
author | Pherkel | 2023-09-04 22:48:40 +0200 |
---|---|---|
committer | Pherkel | 2023-09-04 22:48:40 +0200 |
commit | 0d70a19e1fea6eda3f7b16ad0084591613f2de72 (patch) | |
tree | 9f413c226283db990116c9559ffffb9124b911d8 /swr2_asr/tokenizer.py | |
parent | cd15a49ccee83c21ada481d6815d004f134147fe (diff) |
please the linter
Diffstat (limited to 'swr2_asr/tokenizer.py')
-rw-r--r-- | swr2_asr/tokenizer.py | 45 |
1 files changed, 20 insertions, 25 deletions
diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py index 64227a4..e4df93b 100644 --- a/swr2_asr/tokenizer.py +++ b/swr2_asr/tokenizer.py @@ -90,7 +90,7 @@ class CharTokenizer(TokenizerType): self.char_map[token] = len(self.char_map) self.index_map[len(self.index_map)] = token - def train(self, dataset_path: str, language: str, split: str, download: bool = True): + def train(self, dataset_path: str, language: str, split: str): """Train the tokenizer on the given dataset Args: @@ -110,10 +110,6 @@ class CharTokenizer(TokenizerType): for s_plit in splits: transcript_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt") - # check if dataset is downloaded, download if not - if download and not os.path.exists(transcript_path): - MultilingualLibriSpeech(dataset_path, language, s_plit, download=True) - with open( transcript_path, "r", @@ -215,7 +211,6 @@ class CharTokenizer(TokenizerType): @click.option("--dataset_path", default="data", help="Path to the MLS dataset") @click.option("--language", default="mls_german_opus", help="Language to use") @click.option("--split", default="train", help="Split to use (including all)") -@click.option("--download", default=True, help="Whether to download the dataset") @click.option("--out_path", default="tokenizer.json", help="Path to save the tokenizer to") @click.option("--vocab_size", default=2000, help="Size of the vocabulary") def train_bpe_tokenizer_cli( @@ -223,24 +218,23 @@ def train_bpe_tokenizer_cli( language: str, split: str, out_path: str, - download: bool, vocab_size: int, ): - train_bpe_tokenizer( - dataset_path, - language, - split, - out_path, - download, - vocab_size, -) + """Train a Byte-Pair Encoder tokenizer on the MLS dataset""" + train_bpe_tokenizer( + dataset_path, + language, + split, + out_path, + vocab_size, + ) + def train_bpe_tokenizer( dataset_path: str, language: str, split: str, out_path: str, - download: bool, vocab_size: int, ): """Train a Byte-Pair Encoder tokenizer on the MLS dataset @@ -268,9 +262,11 @@ def train_bpe_tokenizer( for s_plit in splits: transcripts_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt") - if download and not os.path.exists(transcripts_path): - # TODO: move to own dataset - MultilingualLibriSpeech(dataset_path, language, s_plit, download=True) + if not os.path.exists(transcripts_path): + raise FileNotFoundError( + f"Could not find transcripts.txt in {transcripts_path}. " + "Please make sure that the dataset is downloaded." + ) with open( transcripts_path, @@ -355,22 +351,21 @@ def train_bpe_tokenizer( @click.option("--language", default="mls_german_opus", help="Language to use") @click.option("--split", default="train", help="Split to use") @click.option("--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to") -@click.option("--download", default=True, help="Whether to download the dataset") def train_char_tokenizer_cli( dataset_path: str, language: str, split: str, out_path: str, - download: bool, - ): - train_char_tokenizer(dataset_path, language, split, out_path, download) +): + """Train a Byte-Pair Encoder tokenizer on the MLS dataset""" + train_char_tokenizer(dataset_path, language, split, out_path) + def train_char_tokenizer( dataset_path: str, language: str, split: str, out_path: str, - download: bool, ): """Train a Byte-Pair Encoder tokenizer on the MLS dataset @@ -386,7 +381,7 @@ def train_char_tokenizer( """ char_tokenizer = CharTokenizer() - char_tokenizer.train(dataset_path, language, split, download) + char_tokenizer.train(dataset_path, language, split) char_tokenizer.save(out_path) |