From 335b8a32f8bba5d37c00af6b4ecd1b9fc520f964 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Wed, 30 Aug 2023 17:11:51 +0200 Subject: wörks now!°!! --- swr2_asr/utils.py | 164 ++++++++++++++++++++++-------------------------------- 1 file changed, 65 insertions(+), 99 deletions(-) (limited to 'swr2_asr/utils.py') diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py index 404661d..4c751d5 100644 --- a/swr2_asr/utils.py +++ b/swr2_asr/utils.py @@ -1,7 +1,6 @@ """Class containing utils for the ASR system.""" import os from enum import Enum -from multiprocessing import Pool from typing import TypedDict import numpy as np @@ -10,9 +9,8 @@ import torchaudio from tokenizers import Tokenizer from torch.utils.data import Dataset from tqdm import tqdm -import audio_metadata -from swr2_asr.tokenizer import CharTokenizer, TokenizerType +from swr2_asr.tokenizer import TokenizerType train_audio_transforms = torch.nn.Sequential( torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), @@ -91,20 +89,42 @@ class MLSDataset(Dataset): __ """ - def __init__(self, dataset_path: str, language: str, split: Split, download: bool): + def __init__(self, dataset_path: str, language: str, split: Split, download: bool, spectrogram_hparams: dict | None): """Initializes the dataset""" self.dataset_path = dataset_path self.language = language self.file_ext = ".opus" if "opus" in language else ".flac" self.mls_split: MLSSplit = split_to_mls_split(split) # split path on disk self.split: Split = split # split used internally + + if spectrogram_hparams is None: + self.spectrogram_hparams = { + "sample_rate": 16000, + "n_fft": 400, + "win_length": 400, + "hop_length": 160, + "n_mels": 128, + "f_min": 0, + "f_max": 8000, + "power": 2.0, + } + else: + self.spectrogram_hparams = spectrogram_hparams + self.dataset_lookup = [] self.tokenizer: type[TokenizerType] self._handle_download_dataset(download) self._validate_local_directory() + self.initialize() - transcripts_path = os.path.join(dataset_path, language, self.mls_split, "transcripts.txt") + + def initialize(self) -> None: + """Initializes the dataset + + Reads the transcripts.txt file and creates a lookup table + """ + transcripts_path = os.path.join(self.dataset_path, self.language, self.mls_split, "transcripts.txt") with open(transcripts_path, "r", encoding="utf-8") as script_file: # read all lines in transcripts.txt @@ -136,13 +156,9 @@ class MLSDataset(Dataset): for path, utterance in zip(identifier, utterances, strict=False) ] - self.max_spec_length = 0 - self.max_utterance_length = 0 - def set_tokenizer(self, tokenizer: type[TokenizerType]): """Sets the tokenizer""" self.tokenizer = tokenizer - # self.calc_paddings() def _handle_download_dataset(self, download: bool): """Download the dataset""" @@ -163,80 +179,14 @@ class MLSDataset(Dataset): if not os.path.exists(os.path.join(self.dataset_path, self.language, self.mls_split)): raise ValueError("Split not found in dataset") - def _calculate_max_length(self, chunk): - """Calculates the maximum length of the spectrogram and the utterance - - to be called in a multiprocessing pool - """ - max_spec_length = 0 - max_utterance_length = 0 - - for sample in chunk: - audio_path = os.path.join( - self.dataset_path, - self.language, - self.mls_split, - "audio", - sample["speakerid"], - sample["bookid"], - "_".join( - [ - sample["speakerid"], - sample["bookid"], - sample["chapterid"], - ] - ) - + self.file_ext, - ) - metadata = audio_metadata.load(audio_path) - audio_duration = metadata.streaminfo.duration - sample_rate = metadata.streaminfo.sample_rate - - max_spec_length = int(max(max_spec_length, (audio_duration * sample_rate) // 200)) - max_utterance_length = max(max_utterance_length, len(self.tokenizer.encode(sample["utterance"]).ids)) - - return max_spec_length, max_utterance_length - - def calc_paddings(self) -> None: - """Sets the maximum length of the spectrogram and the utterance""" - # check if dataset has been loaded and tokenizer has been set - if not self.dataset_lookup: - raise ValueError("Dataset not loaded") - if not self.tokenizer: - raise ValueError("Tokenizer not set") - # check if paddings have been calculated already - if os.path.isfile(os.path.join(self.dataset_path, self.language, self.mls_split, "paddings.txt")): - print("Paddings already calculated") - with open(os.path.join(self.dataset_path, self.language, self.mls_split, "paddings.txt"), "r") as f: - self.max_spec_length, self.max_utterance_length = [int(line.strip()) for line in f.readlines()] - return - else: - print("Calculating paddings...") - - thread_count = os.cpu_count() - if thread_count is None: - thread_count = 4 - chunk_size = len(self.dataset_lookup) // thread_count - chunks = [self.dataset_lookup[i : i + chunk_size] for i in range(0, len(self.dataset_lookup), chunk_size)] - - with Pool(thread_count) as p: - results = list(p.imap(self._calculate_max_length, chunks)) - - for spec, utterance in results: - self.max_spec_length = max(self.max_spec_length, spec) - self.max_utterance_length = max(self.max_utterance_length, utterance) - - # write to file - with open(os.path.join(self.dataset_path, self.language, self.mls_split, "paddings.txt"), "w") as f: - f.write(f"{self.max_spec_length}\n") - f.write(f"{self.max_utterance_length}") - def __len__(self): """Returns the length of the dataset""" return len(self.dataset_lookup) def __getitem__(self, idx: int) -> Sample: """One sample""" + if self.tokenizer is None: + raise ValueError("No tokenizer set") # get the utterance utterance = self.dataset_lookup[idx]["utterance"] @@ -261,42 +211,62 @@ class MLSDataset(Dataset): waveform, sample_rate = torchaudio.load(audio_path) # type: ignore # resample if necessary - if sample_rate != 16000: - resampler = torchaudio.transforms.Resample(sample_rate, 16000) + if sample_rate != self.spectrogram_hparams["sample_rate"]: + resampler = torchaudio.transforms.Resample(sample_rate, self.spectrogram_hparams["sample_rate"]) waveform = resampler(waveform) - sample_rate = 16000 - spec = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128)(waveform).squeeze(0).transpose(0, 1) + spec = torchaudio.transforms.MelSpectrogram(**self.spectrogram_hparams)(waveform).squeeze(0).transpose(0, 1) input_length = spec.shape[0] // 2 + utterance_length = len(utterance) - self.tokenizer.enable_padding() - utterance = self.tokenizer.encode( - utterance, - ).ids + utterance = self.tokenizer.encode(utterance) - utterance = torch.Tensor(utterance) + utterance = torch.LongTensor(utterance.ids) return Sample( - # TODO: add flag to only return spectrogram or waveform or both waveform=waveform, spectrogram=spec, input_length=input_length, utterance=utterance, utterance_length=utterance_length, - sample_rate=sample_rate, + sample_rate=self.spectrogram_hparams["sample_rate"], speaker_id=self.dataset_lookup[idx]["speakerid"], book_id=self.dataset_lookup[idx]["bookid"], chapter_id=self.dataset_lookup[idx]["chapterid"], ) - def download(self, dataset_path: str, language: str): - """Download the dataset""" - os.makedirs(dataset_path) - url = f"https://dl.fbaipublicfiles.com/mls/{language}.tar.gz" - - torch.hub.download_url_to_file(url, dataset_path) +def collate_fn(samples: list[Sample]) -> dict: + """Collate function for the dataloader + + pads all tensors within a batch to the same dimensions + """ + waveforms = [] + spectrograms = [] + labels = [] + input_lengths = [] + label_lengths = [] + + for sample in samples: + waveforms.append(sample["waveform"].transpose(0, 1)) + spectrograms.append(sample["spectrogram"]) + labels.append(sample["utterance"]) + input_lengths.append(sample["spectrogram"].shape[0] // 2) + label_lengths.append(len(sample["utterance"])) + + waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True) + spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2,3) + labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True) + + return { + "waveform": waveforms, + "spectrogram": spectrograms, + "input_length": input_lengths, + "utterance": labels, + "utterance_length": label_lengths, + } + if __name__ == "__main__": @@ -305,11 +275,7 @@ if __name__ == "__main__": split = Split.train download = False - dataset = MLSDataset(dataset_path, language, split, download) + dataset = MLSDataset(dataset_path, language, split, download, None) tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") dataset.set_tokenizer(tok) - dataset.calc_paddings() - - print(f"Spectrogram shape: {dataset[41]['spectrogram'].shape}") - print(f"Utterance shape: {dataset[41]['utterance'].shape}") -- cgit v1.2.3