From cd15a49ccee83c21ada481d6815d004f134147fe Mon Sep 17 00:00:00 2001 From: Philipp Merkel Date: Mon, 4 Sep 2023 14:07:54 +0000 Subject: applied fixes to download and tokenizers --- swr2_asr/utils.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) (limited to 'swr2_asr/utils.py') diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py index 0330cd2..1c755a6 100644 --- a/swr2_asr/utils.py +++ b/swr2_asr/utils.py @@ -168,22 +168,28 @@ class MLSDataset(Dataset): """Sets the tokenizer""" self.tokenizer = tokenizer - def _handle_download_dataset(self, download: bool): + def _handle_download_dataset(self, download: bool) -> None: """Download the dataset""" if not download: + print("Download flag not set, skipping download") return # zip exists: if os.path.isfile(os.path.join(self.dataset_path, self.language) + ".tar.gz") and download: print(f"Found dataset at {self.dataset_path}. Skipping download") # zip does not exist: else: - os.makedirs(self.dataset_path) + os.makedirs(self.dataset_path, exist_ok=True) url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz" - torch.hub.download_url_to_file(url, self.dataset_path) + torch.hub.download_url_to_file(url, os.path.join(self.dataset_path, self.language) + ".tar.gz") # unzip the dataset - extract_archive(os.path.join(self.dataset_path, self.language) + ".tar.gz") + if not os.path.isdir(os.path.join(self.dataset_path, self.language)): + print(f"Unzipping the dataset at {os.path.join(self.dataset_path, self.language) + '.tar.gz'}") + extract_archive(os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True) + else: + print("Dataset is already unzipped, validating it now") + return def _validate_local_directory(self): # check if dataset_path exists -- cgit v1.2.3