aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/utils.py')
-rw-r--r--swr2_asr/utils.py14
1 files changed, 10 insertions, 4 deletions
diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py
index 0330cd2..1c755a6 100644
--- a/swr2_asr/utils.py
+++ b/swr2_asr/utils.py
@@ -168,22 +168,28 @@ class MLSDataset(Dataset):
"""Sets the tokenizer"""
self.tokenizer = tokenizer
- def _handle_download_dataset(self, download: bool):
+ def _handle_download_dataset(self, download: bool) -> None:
"""Download the dataset"""
if not download:
+ print("Download flag not set, skipping download")
return
# zip exists:
if os.path.isfile(os.path.join(self.dataset_path, self.language) + ".tar.gz") and download:
print(f"Found dataset at {self.dataset_path}. Skipping download")
# zip does not exist:
else:
- os.makedirs(self.dataset_path)
+ os.makedirs(self.dataset_path, exist_ok=True)
url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz"
- torch.hub.download_url_to_file(url, self.dataset_path)
+ torch.hub.download_url_to_file(url, os.path.join(self.dataset_path, self.language) + ".tar.gz")
# unzip the dataset
- extract_archive(os.path.join(self.dataset_path, self.language) + ".tar.gz")
+ if not os.path.isdir(os.path.join(self.dataset_path, self.language)):
+ print(f"Unzipping the dataset at {os.path.join(self.dataset_path, self.language) + '.tar.gz'}")
+ extract_archive(os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True)
+ else:
+ print("Dataset is already unzipped, validating it now")
+ return
def _validate_local_directory(self):
# check if dataset_path exists