From effde1d9e71864a2c5bd8464db0958f5bf2d1733 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 11 Sep 2023 15:07:08 +0200 Subject: added small stuff to data utilities --- swr2_asr/utils/data.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) (limited to 'swr2_asr/utils') diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index 93f4a9a..74d10c9 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -123,9 +123,9 @@ class MLSDataset(Dataset): self, dataset_path: str, language: str, - split: Split, - limited: bool, - download: bool, + split: Split, # pylint: disable=redefined-outer-name + limited: bool = False, + download: bool = True, size: float = 0.2, ): """Initializes the dataset""" @@ -365,7 +365,28 @@ class MLSDataset(Dataset): if __name__ == "__main__": + from torch.utils.data import DataLoader + DATASET_PATH = "/Volumes/pherkel/SWR2-ASR" LANGUAGE = "mls_german_opus" - split = Split.TRAIN + split = Split.DEV DOWNLOAD = False + + dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, download=DOWNLOAD) + + dataloader = DataLoader( + dataset, + batch_size=1, + shuffle=True, + collate_fn=DataProcessing( + "train", CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") + ), + ) + + for batch in dataloader: + print(batch) + break + + print(len(dataset)) + + print(dataset[0]) -- cgit v1.2.3