diff options
-rw-r--r-- | swr2_asr/tokenizer.py | 28 | ||||
-rw-r--r-- | swr2_asr/train.py | 3 | ||||
-rw-r--r-- | swr2_asr/utils.py | 14 |
3 files changed, 39 insertions, 6 deletions
diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py index f02d4f5..64227a4 100644 --- a/swr2_asr/tokenizer.py +++ b/swr2_asr/tokenizer.py @@ -187,6 +187,7 @@ class CharTokenizer(TokenizerType): def save(self, path: str): """Save the tokenizer to a file""" + os.makedirs(os.path.dirname(path), exist_ok=True) with open(path, "w", encoding="utf-8") as file: # save it in the following format: # {"char_map": {"a": 0, "b": 1, ...}, "index_map": {0: "a", 1: "b", ...}} @@ -217,6 +218,23 @@ class CharTokenizer(TokenizerType): @click.option("--download", default=True, help="Whether to download the dataset") @click.option("--out_path", default="tokenizer.json", help="Path to save the tokenizer to") @click.option("--vocab_size", default=2000, help="Size of the vocabulary") +def train_bpe_tokenizer_cli( + dataset_path: str, + language: str, + split: str, + out_path: str, + download: bool, + vocab_size: int, +): + train_bpe_tokenizer( + dataset_path, + language, + split, + out_path, + download, + vocab_size, +) + def train_bpe_tokenizer( dataset_path: str, language: str, @@ -251,6 +269,7 @@ def train_bpe_tokenizer( for s_plit in splits: transcripts_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt") if download and not os.path.exists(transcripts_path): + # TODO: move to own dataset MultilingualLibriSpeech(dataset_path, language, s_plit, download=True) with open( @@ -337,6 +356,15 @@ def train_bpe_tokenizer( @click.option("--split", default="train", help="Split to use") @click.option("--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to") @click.option("--download", default=True, help="Whether to download the dataset") +def train_char_tokenizer_cli( + dataset_path: str, + language: str, + split: str, + out_path: str, + download: bool, + ): + train_char_tokenizer(dataset_path, language, split, out_path, download) + def train_char_tokenizer( dataset_path: str, language: str, diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 8fc0b78..aea99e0 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -173,7 +173,6 @@ def run( split="all", download=False, out_path="data/tokenizers/char_tokenizer_german.json", - vocab_size=3000, ) tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") @@ -305,4 +304,4 @@ def run_cli( if __name__ == "__main__": - run(1e-3, 10, 1, False, "", "/Volumes/pherkel/SWR2-ASR", "mls_german_opus") + run_cli() diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py index 0330cd2..1c755a6 100644 --- a/swr2_asr/utils.py +++ b/swr2_asr/utils.py @@ -168,22 +168,28 @@ class MLSDataset(Dataset): """Sets the tokenizer""" self.tokenizer = tokenizer - def _handle_download_dataset(self, download: bool): + def _handle_download_dataset(self, download: bool) -> None: """Download the dataset""" if not download: + print("Download flag not set, skipping download") return # zip exists: if os.path.isfile(os.path.join(self.dataset_path, self.language) + ".tar.gz") and download: print(f"Found dataset at {self.dataset_path}. Skipping download") # zip does not exist: else: - os.makedirs(self.dataset_path) + os.makedirs(self.dataset_path, exist_ok=True) url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz" - torch.hub.download_url_to_file(url, self.dataset_path) + torch.hub.download_url_to_file(url, os.path.join(self.dataset_path, self.language) + ".tar.gz") # unzip the dataset - extract_archive(os.path.join(self.dataset_path, self.language) + ".tar.gz") + if not os.path.isdir(os.path.join(self.dataset_path, self.language)): + print(f"Unzipping the dataset at {os.path.join(self.dataset_path, self.language) + '.tar.gz'}") + extract_archive(os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True) + else: + print("Dataset is already unzipped, validating it now") + return def _validate_local_directory(self): # check if dataset_path exists |