diff options
author | Pherkel | 2023-09-11 15:07:08 +0200 |
---|---|---|
committer | Pherkel | 2023-09-11 15:07:08 +0200 |
commit | effde1d9e71864a2c5bd8464db0958f5bf2d1733 (patch) | |
tree | 5125e57ca8454ff007207d11a3b4bde579b85755 /swr2_asr/utils | |
parent | 9dc3bc07424908dd7cf3f052708f506fd58b6e2c (diff) |
added small stuff to data utilities
Diffstat (limited to 'swr2_asr/utils')
-rw-r--r-- | swr2_asr/utils/data.py | 29 |
1 files changed, 25 insertions, 4 deletions
diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py index 93f4a9a..74d10c9 100644 --- a/swr2_asr/utils/data.py +++ b/swr2_asr/utils/data.py @@ -123,9 +123,9 @@ class MLSDataset(Dataset): self, dataset_path: str, language: str, - split: Split, - limited: bool, - download: bool, + split: Split, # pylint: disable=redefined-outer-name + limited: bool = False, + download: bool = True, size: float = 0.2, ): """Initializes the dataset""" @@ -365,7 +365,28 @@ class MLSDataset(Dataset): if __name__ == "__main__": + from torch.utils.data import DataLoader + DATASET_PATH = "/Volumes/pherkel/SWR2-ASR" LANGUAGE = "mls_german_opus" - split = Split.TRAIN + split = Split.DEV DOWNLOAD = False + + dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, download=DOWNLOAD) + + dataloader = DataLoader( + dataset, + batch_size=1, + shuffle=True, + collate_fn=DataProcessing( + "train", CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") + ), + ) + + for batch in dataloader: + print(batch) + break + + print(len(dataset)) + + print(dataset[0]) |