From 33f09080aee10bddb4797a557d676ee1f7b8de31 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Sun, 3 Sep 2023 19:30:33 +0200 Subject: idk, hopefully this works --- swr2_asr/utils.py | 85 ++++++++++++++++++++++++++++++++----------------------- 1 file changed, 50 insertions(+), 35 deletions(-) (limited to 'swr2_asr/utils.py') diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py index 3b9b3ca..efecb56 100644 --- a/swr2_asr/utils.py +++ b/swr2_asr/utils.py @@ -8,7 +8,6 @@ import torch import torchaudio from tokenizers import Tokenizer from torch.utils.data import Dataset -from tqdm import tqdm from swr2_asr.tokenizer import TokenizerType @@ -24,26 +23,26 @@ class MLSSplit(str, Enum): """Enum specifying dataset as they are defined in the Multilingual LibriSpeech dataset""" - train = "train" - test = "test" - dev = "dev" + TRAIN = "train" + TEST = "test" + DEV = "dev" class Split(str, Enum): """Extending the MLSSplit class to allow for a custom validatio split""" - train = "train" - valid = "valid" - test = "test" - dev = "dev" + TRAIN = "train" + VALID = "valid" + TEST = "test" + DEV = "dev" -def split_to_mls_split(split: Split) -> MLSSplit: +def split_to_mls_split(split_name: Split) -> MLSSplit: """Converts the custom split to a MLSSplit""" - if split == Split.valid: - return MLSSplit.train + if split_name == Split.VALID: + return MLSSplit.TRAIN else: - return split # type: ignore + return split_name # type: ignore class Sample(TypedDict): @@ -89,14 +88,21 @@ class MLSDataset(Dataset): __ """ - def __init__(self, dataset_path: str, language: str, split: Split, download: bool, spectrogram_hparams: dict | None): + def __init__( + self, + dataset_path: str, + language: str, + split: Split, + download: bool, + spectrogram_hparams: dict | None, + ): """Initializes the dataset""" self.dataset_path = dataset_path self.language = language self.file_ext = ".opus" if "opus" in language else ".flac" self.mls_split: MLSSplit = split_to_mls_split(split) # split path on disk self.split: Split = split # split used internally - + if spectrogram_hparams is None: self.spectrogram_hparams = { "sample_rate": 16000, @@ -110,7 +116,7 @@ class MLSDataset(Dataset): } else: self.spectrogram_hparams = spectrogram_hparams - + self.dataset_lookup = [] self.tokenizer: type[TokenizerType] @@ -118,13 +124,14 @@ class MLSDataset(Dataset): self._validate_local_directory() self.initialize() - def initialize(self) -> None: """Initializes the dataset - + Reads the transcripts.txt file and creates a lookup table """ - transcripts_path = os.path.join(self.dataset_path, self.language, self.mls_split, "transcripts.txt") + transcripts_path = os.path.join( + self.dataset_path, self.language, self.mls_split, "transcripts.txt" + ) with open(transcripts_path, "r", encoding="utf-8") as script_file: # read all lines in transcripts.txt @@ -135,12 +142,12 @@ class MLSDataset(Dataset): identifier = [identifier.strip() for identifier, _ in transcripts] # type: ignore identifier = [path.split("_") for path in identifier] - if self.split == Split.valid: + if self.split == Split.VALID: np.random.seed(42) indices = np.random.choice(len(utterances), int(len(utterances) * 0.2)) utterances = [utterances[i] for i in indices] identifier = [identifier[i] for i in indices] - elif self.split == Split.train: + elif self.split == Split.TRAIN: np.random.seed(42) indices = np.random.choice(len(utterances), int(len(utterances) * 0.8)) utterances = [utterances[i] for i in indices] @@ -212,13 +219,19 @@ class MLSDataset(Dataset): # resample if necessary if sample_rate != self.spectrogram_hparams["sample_rate"]: - resampler = torchaudio.transforms.Resample(sample_rate, self.spectrogram_hparams["sample_rate"]) + resampler = torchaudio.transforms.Resample( + sample_rate, self.spectrogram_hparams["sample_rate"] + ) waveform = resampler(waveform) - spec = torchaudio.transforms.MelSpectrogram(**self.spectrogram_hparams)(waveform).squeeze(0).transpose(0, 1) + spec = ( + torchaudio.transforms.MelSpectrogram(**self.spectrogram_hparams)(waveform) + .squeeze(0) + .transpose(0, 1) + ) input_length = spec.shape[0] // 2 - + utterance_length = len(utterance) utterance = self.tokenizer.encode(utterance) @@ -236,11 +249,11 @@ class MLSDataset(Dataset): book_id=self.dataset_lookup[idx]["bookid"], chapter_id=self.dataset_lookup[idx]["chapterid"], ) - + def collate_fn(samples: list[Sample]) -> dict: """Collate function for the dataloader - + pads all tensors within a batch to the same dimensions """ waveforms = [] @@ -248,18 +261,20 @@ def collate_fn(samples: list[Sample]) -> dict: labels = [] input_lengths = [] label_lengths = [] - + for sample in samples: waveforms.append(sample["waveform"].transpose(0, 1)) spectrograms.append(sample["spectrogram"]) labels.append(sample["utterance"]) input_lengths.append(sample["spectrogram"].shape[0] // 2) label_lengths.append(len(sample["utterance"])) - + waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True) - spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2,3) + spectrograms = ( + torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3) + ) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True) - + return { "waveform": waveforms, "spectrogram": spectrograms, @@ -267,15 +282,15 @@ def collate_fn(samples: list[Sample]) -> dict: "utterance": labels, "utterance_length": label_lengths, } - + if __name__ == "__main__": - dataset_path = "/Volumes/pherkel/SWR2-ASR" - language = "mls_german_opus" - split = Split.train - download = False + DATASET_PATH = "/Volumes/pherkel/SWR2-ASR" + LANGUAGE = "mls_german_opus" + split = Split.TRAIN + DOWNLOAD = False - dataset = MLSDataset(dataset_path, language, split, download, None) + dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, DOWNLOAD, None) tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") dataset.set_tokenizer(tok) -- cgit v1.2.3