diff options
author | Pherkel | 2023-09-03 20:52:15 +0200 |
---|---|---|
committer | Pherkel | 2023-09-03 20:52:15 +0200 |
commit | acafe88a1a360832b727651b713806ce0404db3f (patch) | |
tree | 4aa801b011fb61e6cf5d08b1cea2ac48f969d3fb /swr2_asr/utils.py | |
parent | 6a73a1ed117d9ace8546aa5114daada91e69b069 (diff) |
fix not unzipping dataset
Diffstat (limited to 'swr2_asr/utils.py')
-rw-r--r-- | swr2_asr/utils.py | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py index efecb56..0330cd2 100644 --- a/swr2_asr/utils.py +++ b/swr2_asr/utils.py @@ -8,6 +8,7 @@ import torch import torchaudio from tokenizers import Tokenizer from torch.utils.data import Dataset +from torchaudio.datasets.utils import _extract_tar as extract_archive from swr2_asr.tokenizer import TokenizerType @@ -169,20 +170,27 @@ class MLSDataset(Dataset): def _handle_download_dataset(self, download: bool): """Download the dataset""" - if not os.path.exists(os.path.join(self.dataset_path, self.language)) and download: + if not download: + return + # zip exists: + if os.path.isfile(os.path.join(self.dataset_path, self.language) + ".tar.gz") and download: + print(f"Found dataset at {self.dataset_path}. Skipping download") + # zip does not exist: + else: os.makedirs(self.dataset_path) url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz" torch.hub.download_url_to_file(url, self.dataset_path) - elif not os.path.exists(os.path.join(self.dataset_path, self.language)) and not download: - raise ValueError("Dataset not found. Set download to True to download it") + + # unzip the dataset + extract_archive(os.path.join(self.dataset_path, self.language) + ".tar.gz") def _validate_local_directory(self): # check if dataset_path exists if not os.path.exists(self.dataset_path): raise ValueError("Dataset path does not exist") if not os.path.exists(os.path.join(self.dataset_path, self.language)): - raise ValueError("Language not found in dataset") + raise ValueError("Language not downloaded!") if not os.path.exists(os.path.join(self.dataset_path, self.language, self.mls_split)): raise ValueError("Split not found in dataset") |