diff options
Diffstat (limited to 'swr2_asr/utils.py')
-rw-r--r-- | swr2_asr/utils.py | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py index efecb56..0330cd2 100644 --- a/swr2_asr/utils.py +++ b/swr2_asr/utils.py @@ -8,6 +8,7 @@ import torch import torchaudio from tokenizers import Tokenizer from torch.utils.data import Dataset +from torchaudio.datasets.utils import _extract_tar as extract_archive from swr2_asr.tokenizer import TokenizerType @@ -169,20 +170,27 @@ class MLSDataset(Dataset): def _handle_download_dataset(self, download: bool): """Download the dataset""" - if not os.path.exists(os.path.join(self.dataset_path, self.language)) and download: + if not download: + return + # zip exists: + if os.path.isfile(os.path.join(self.dataset_path, self.language) + ".tar.gz") and download: + print(f"Found dataset at {self.dataset_path}. Skipping download") + # zip does not exist: + else: os.makedirs(self.dataset_path) url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz" torch.hub.download_url_to_file(url, self.dataset_path) - elif not os.path.exists(os.path.join(self.dataset_path, self.language)) and not download: - raise ValueError("Dataset not found. Set download to True to download it") + + # unzip the dataset + extract_archive(os.path.join(self.dataset_path, self.language) + ".tar.gz") def _validate_local_directory(self): # check if dataset_path exists if not os.path.exists(self.dataset_path): raise ValueError("Dataset path does not exist") if not os.path.exists(os.path.join(self.dataset_path, self.language)): - raise ValueError("Language not found in dataset") + raise ValueError("Language not downloaded!") if not os.path.exists(os.path.join(self.dataset_path, self.language, self.mls_split)): raise ValueError("Split not found in dataset") |