diff options
-rw-r--r-- | swr2_asr/tokenizer.py | 45 | ||||
-rw-r--r-- | swr2_asr/train.py | 2 | ||||
-rw-r--r-- | swr2_asr/utils.py | 16 |
3 files changed, 32 insertions, 31 deletions
diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py index 64227a4..e4df93b 100644 --- a/swr2_asr/tokenizer.py +++ b/swr2_asr/tokenizer.py @@ -90,7 +90,7 @@ class CharTokenizer(TokenizerType): self.char_map[token] = len(self.char_map) self.index_map[len(self.index_map)] = token - def train(self, dataset_path: str, language: str, split: str, download: bool = True): + def train(self, dataset_path: str, language: str, split: str): """Train the tokenizer on the given dataset Args: @@ -110,10 +110,6 @@ class CharTokenizer(TokenizerType): for s_plit in splits: transcript_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt") - # check if dataset is downloaded, download if not - if download and not os.path.exists(transcript_path): - MultilingualLibriSpeech(dataset_path, language, s_plit, download=True) - with open( transcript_path, "r", @@ -215,7 +211,6 @@ class CharTokenizer(TokenizerType): @click.option("--dataset_path", default="data", help="Path to the MLS dataset") @click.option("--language", default="mls_german_opus", help="Language to use") @click.option("--split", default="train", help="Split to use (including all)") -@click.option("--download", default=True, help="Whether to download the dataset") @click.option("--out_path", default="tokenizer.json", help="Path to save the tokenizer to") @click.option("--vocab_size", default=2000, help="Size of the vocabulary") def train_bpe_tokenizer_cli( @@ -223,24 +218,23 @@ def train_bpe_tokenizer_cli( language: str, split: str, out_path: str, - download: bool, vocab_size: int, ): - train_bpe_tokenizer( - dataset_path, - language, - split, - out_path, - download, - vocab_size, -) + """Train a Byte-Pair Encoder tokenizer on the MLS dataset""" + train_bpe_tokenizer( + dataset_path, + language, + split, + out_path, + vocab_size, + ) + def train_bpe_tokenizer( dataset_path: str, language: str, split: str, out_path: str, - download: bool, vocab_size: int, ): """Train a Byte-Pair Encoder tokenizer on the MLS dataset @@ -268,9 +262,11 @@ def train_bpe_tokenizer( for s_plit in splits: transcripts_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt") - if download and not os.path.exists(transcripts_path): - # TODO: move to own dataset - MultilingualLibriSpeech(dataset_path, language, s_plit, download=True) + if not os.path.exists(transcripts_path): + raise FileNotFoundError( + f"Could not find transcripts.txt in {transcripts_path}. " + "Please make sure that the dataset is downloaded." + ) with open( transcripts_path, @@ -355,22 +351,21 @@ def train_bpe_tokenizer( @click.option("--language", default="mls_german_opus", help="Language to use") @click.option("--split", default="train", help="Split to use") @click.option("--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to") -@click.option("--download", default=True, help="Whether to download the dataset") def train_char_tokenizer_cli( dataset_path: str, language: str, split: str, out_path: str, - download: bool, - ): - train_char_tokenizer(dataset_path, language, split, out_path, download) +): + """Train a Byte-Pair Encoder tokenizer on the MLS dataset""" + train_char_tokenizer(dataset_path, language, split, out_path) + def train_char_tokenizer( dataset_path: str, language: str, split: str, out_path: str, - download: bool, ): """Train a Byte-Pair Encoder tokenizer on the MLS dataset @@ -386,7 +381,7 @@ def train_char_tokenizer( """ char_tokenizer = CharTokenizer() - char_tokenizer.train(dataset_path, language, split, download) + char_tokenizer.train(dataset_path, language, split) char_tokenizer.save(out_path) diff --git a/swr2_asr/train.py b/swr2_asr/train.py index aea99e0..63deb72 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -304,4 +304,4 @@ def run_cli( if __name__ == "__main__": - run_cli() + run_cli() # pylint: disable=no-value-for-parameter diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py index 1c755a6..8a950ab 100644 --- a/swr2_asr/utils.py +++ b/swr2_asr/utils.py @@ -181,15 +181,21 @@ class MLSDataset(Dataset): os.makedirs(self.dataset_path, exist_ok=True) url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz" - torch.hub.download_url_to_file(url, os.path.join(self.dataset_path, self.language) + ".tar.gz") + torch.hub.download_url_to_file( + url, os.path.join(self.dataset_path, self.language) + ".tar.gz" + ) # unzip the dataset if not os.path.isdir(os.path.join(self.dataset_path, self.language)): - print(f"Unzipping the dataset at {os.path.join(self.dataset_path, self.language) + '.tar.gz'}") - extract_archive(os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True) + print( + f"Unzipping the dataset at {os.path.join(self.dataset_path, self.language) + '.tar.gz'}" + ) + extract_archive( + os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True + ) else: - print("Dataset is already unzipped, validating it now") - return + print("Dataset is already unzipped, validating it now") + return def _validate_local_directory(self): # check if dataset_path exists |