aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/utils.py')
-rw-r--r--swr2_asr/utils.py16
1 files changed, 12 insertions, 4 deletions
diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py
index efecb56..0330cd2 100644
--- a/swr2_asr/utils.py
+++ b/swr2_asr/utils.py
@@ -8,6 +8,7 @@ import torch
import torchaudio
from tokenizers import Tokenizer
from torch.utils.data import Dataset
+from torchaudio.datasets.utils import _extract_tar as extract_archive
from swr2_asr.tokenizer import TokenizerType
@@ -169,20 +170,27 @@ class MLSDataset(Dataset):
def _handle_download_dataset(self, download: bool):
"""Download the dataset"""
- if not os.path.exists(os.path.join(self.dataset_path, self.language)) and download:
+ if not download:
+ return
+ # zip exists:
+ if os.path.isfile(os.path.join(self.dataset_path, self.language) + ".tar.gz") and download:
+ print(f"Found dataset at {self.dataset_path}. Skipping download")
+ # zip does not exist:
+ else:
os.makedirs(self.dataset_path)
url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz"
torch.hub.download_url_to_file(url, self.dataset_path)
- elif not os.path.exists(os.path.join(self.dataset_path, self.language)) and not download:
- raise ValueError("Dataset not found. Set download to True to download it")
+
+ # unzip the dataset
+ extract_archive(os.path.join(self.dataset_path, self.language) + ".tar.gz")
def _validate_local_directory(self):
# check if dataset_path exists
if not os.path.exists(self.dataset_path):
raise ValueError("Dataset path does not exist")
if not os.path.exists(os.path.join(self.dataset_path, self.language)):
- raise ValueError("Language not found in dataset")
+ raise ValueError("Language not downloaded!")
if not os.path.exists(os.path.join(self.dataset_path, self.language, self.mls_split)):
raise ValueError("Split not found in dataset")