diff options
Diffstat (limited to 'swr2_asr/utils.py')
-rw-r--r-- | swr2_asr/utils.py | 108 |
1 files changed, 80 insertions, 28 deletions
diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py index 786fbcf..404661d 100644 --- a/swr2_asr/utils.py +++ b/swr2_asr/utils.py @@ -1,6 +1,7 @@ """Class containing utils for the ASR system.""" import os from enum import Enum +from multiprocessing import Pool from typing import TypedDict import numpy as np @@ -8,6 +9,8 @@ import torch import torchaudio from tokenizers import Tokenizer from torch.utils.data import Dataset +from tqdm import tqdm +import audio_metadata from swr2_asr.tokenizer import CharTokenizer, TokenizerType @@ -133,11 +136,13 @@ class MLSDataset(Dataset): for path, utterance in zip(identifier, utterances, strict=False) ] + self.max_spec_length = 0 + self.max_utterance_length = 0 + def set_tokenizer(self, tokenizer: type[TokenizerType]): """Sets the tokenizer""" self.tokenizer = tokenizer - - self.calc_paddings() + # self.calc_paddings() def _handle_download_dataset(self, download: bool): """Download the dataset""" @@ -158,27 +163,73 @@ class MLSDataset(Dataset): if not os.path.exists(os.path.join(self.dataset_path, self.language, self.mls_split)): raise ValueError("Split not found in dataset") - def calc_paddings(self): - """Sets the maximum length of the spectrogram""" + def _calculate_max_length(self, chunk): + """Calculates the maximum length of the spectrogram and the utterance + + to be called in a multiprocessing pool + """ + max_spec_length = 0 + max_utterance_length = 0 + + for sample in chunk: + audio_path = os.path.join( + self.dataset_path, + self.language, + self.mls_split, + "audio", + sample["speakerid"], + sample["bookid"], + "_".join( + [ + sample["speakerid"], + sample["bookid"], + sample["chapterid"], + ] + ) + + self.file_ext, + ) + metadata = audio_metadata.load(audio_path) + audio_duration = metadata.streaminfo.duration + sample_rate = metadata.streaminfo.sample_rate + + max_spec_length = int(max(max_spec_length, (audio_duration * sample_rate) // 200)) + max_utterance_length = max(max_utterance_length, len(self.tokenizer.encode(sample["utterance"]).ids)) + + return max_spec_length, max_utterance_length + + def calc_paddings(self) -> None: + """Sets the maximum length of the spectrogram and the utterance""" # check if dataset has been loaded and tokenizer has been set if not self.dataset_lookup: raise ValueError("Dataset not loaded") if not self.tokenizer: raise ValueError("Tokenizer not set") - - max_spec_length = 0 - max_uterance_length = 0 - for sample in self.dataset_lookup: - spec_length = sample["spectrogram"].shape[0] - if spec_length > max_spec_length: - max_spec_length = spec_length - - utterance_length = sample["utterance"].shape[0] - if utterance_length > max_uterance_length: - max_uterance_length = utterance_length - - self.max_spec_length = max_spec_length - self.max_utterance_length = max_uterance_length + # check if paddings have been calculated already + if os.path.isfile(os.path.join(self.dataset_path, self.language, self.mls_split, "paddings.txt")): + print("Paddings already calculated") + with open(os.path.join(self.dataset_path, self.language, self.mls_split, "paddings.txt"), "r") as f: + self.max_spec_length, self.max_utterance_length = [int(line.strip()) for line in f.readlines()] + return + else: + print("Calculating paddings...") + + thread_count = os.cpu_count() + if thread_count is None: + thread_count = 4 + chunk_size = len(self.dataset_lookup) // thread_count + chunks = [self.dataset_lookup[i : i + chunk_size] for i in range(0, len(self.dataset_lookup), chunk_size)] + + with Pool(thread_count) as p: + results = list(p.imap(self._calculate_max_length, chunks)) + + for spec, utterance in results: + self.max_spec_length = max(self.max_spec_length, spec) + self.max_utterance_length = max(self.max_utterance_length, utterance) + + # write to file + with open(os.path.join(self.dataset_path, self.language, self.mls_split, "paddings.txt"), "w") as f: + f.write(f"{self.max_spec_length}\n") + f.write(f"{self.max_utterance_length}") def __len__(self): """Returns the length of the dataset""" @@ -208,18 +259,18 @@ class MLSDataset(Dataset): ) waveform, sample_rate = torchaudio.load(audio_path) # type: ignore - # TODO: figure out if we have to resample or not - # TODO: pad correctly (manually) + + # resample if necessary + if sample_rate != 16000: + resampler = torchaudio.transforms.Resample(sample_rate, 16000) + waveform = resampler(waveform) + sample_rate = 16000 + spec = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128)(waveform).squeeze(0).transpose(0, 1) - print(f"spec.shape: {spec.shape}") - input_length = spec.shape[0] // 2 - spec = ( - torch.nn.functional.pad(spec, pad=(0, self.max_spec_length), mode="constant", value=0) - .unsqueeze(1) - .transpose(2, 3) - ) + input_length = spec.shape[0] // 2 utterance_length = len(utterance) + self.tokenizer.enable_padding() utterance = self.tokenizer.encode( utterance, @@ -260,4 +311,5 @@ if __name__ == "__main__": dataset.set_tokenizer(tok) dataset.calc_paddings() - print(dataset[41]["spectrogram"].shape) + print(f"Spectrogram shape: {dataset[41]['spectrogram'].shape}") + print(f"Utterance shape: {dataset[41]['utterance'].shape}") |