aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/utils
diff options
context:
space:
mode:
authorPherkel2023-09-11 15:07:08 +0200
committerPherkel2023-09-11 15:07:08 +0200
commiteffde1d9e71864a2c5bd8464db0958f5bf2d1733 (patch)
tree5125e57ca8454ff007207d11a3b4bde579b85755 /swr2_asr/utils
parent9dc3bc07424908dd7cf3f052708f506fd58b6e2c (diff)
added small stuff to data utilities
Diffstat (limited to 'swr2_asr/utils')
-rw-r--r--swr2_asr/utils/data.py29
1 files changed, 25 insertions, 4 deletions
diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py
index 93f4a9a..74d10c9 100644
--- a/swr2_asr/utils/data.py
+++ b/swr2_asr/utils/data.py
@@ -123,9 +123,9 @@ class MLSDataset(Dataset):
self,
dataset_path: str,
language: str,
- split: Split,
- limited: bool,
- download: bool,
+ split: Split, # pylint: disable=redefined-outer-name
+ limited: bool = False,
+ download: bool = True,
size: float = 0.2,
):
"""Initializes the dataset"""
@@ -365,7 +365,28 @@ class MLSDataset(Dataset):
if __name__ == "__main__":
+ from torch.utils.data import DataLoader
+
DATASET_PATH = "/Volumes/pherkel/SWR2-ASR"
LANGUAGE = "mls_german_opus"
- split = Split.TRAIN
+ split = Split.DEV
DOWNLOAD = False
+
+ dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, download=DOWNLOAD)
+
+ dataloader = DataLoader(
+ dataset,
+ batch_size=1,
+ shuffle=True,
+ collate_fn=DataProcessing(
+ "train", CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json")
+ ),
+ )
+
+ for batch in dataloader:
+ print(batch)
+ break
+
+ print(len(dataset))
+
+ print(dataset[0])