diff options
author | Pherkel | 2023-08-24 00:03:56 +0200 |
---|---|---|
committer | Pherkel | 2023-08-24 00:03:56 +0200 |
commit | 403472ca4e65e8ed404e8a73fb9b3fbafe3f2a53 (patch) | |
tree | e5bfaca7d1982f1fadb1abe1da023d2020151363 /swr2_asr/utils.py | |
parent | d65e728575e07a54cec52ccb57af3cafedaac1a2 (diff) |
wip: commit before going on vacation :)
Diffstat (limited to 'swr2_asr/utils.py')
-rw-r--r-- | swr2_asr/utils.py | 207 |
1 files changed, 93 insertions, 114 deletions
diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py index c4aeb0b..786fbcf 100644 --- a/swr2_asr/utils.py +++ b/swr2_asr/utils.py @@ -1,16 +1,21 @@ """Class containing utils for the ASR system.""" -from dataclasses import dataclass import os -from AudioLoader.speech import MultilingualLibriSpeech +from enum import Enum +from typing import TypedDict + import numpy as np import torch import torchaudio -from torch import nn -from torch.utils.data import Dataset, DataLoader -from enum import Enum - from tokenizers import Tokenizer -from swr2_asr.tokenizer import CharTokenizer +from torch.utils.data import Dataset + +from swr2_asr.tokenizer import CharTokenizer, TokenizerType + +train_audio_transforms = torch.nn.Sequential( + torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), + torchaudio.transforms.FrequencyMasking(freq_mask_param=30), + torchaudio.transforms.TimeMasking(time_mask_param=100), +) # create enum specifiying dataset splits @@ -40,34 +45,20 @@ def split_to_mls_split(split: Split) -> MLSSplit: return split # type: ignore -@dataclass -class Sample: - """Dataclass for a sample in the dataset""" +class Sample(TypedDict): + """Type for a sample in the dataset""" waveform: torch.Tensor spectrogram: torch.Tensor - utterance: str + input_length: int + utterance: torch.Tensor + utterance_length: int sample_rate: int speaker_id: str book_id: str chapter_id: str -def tokenizer_factory(tokenizer_path: str, tokenizer_type: str = "BPE"): - """Factory for Tokenizer class - - Args: - tokenizer_type (str, optional): Type of tokenizer to use. Defaults to "BPE". - - Returns: - nn.Module: Tokenizer class - """ - if tokenizer_type == "BPE": - return Tokenizer.from_file(tokenizer_path) - elif tokenizer_type == "char": - return CharTokenizer.from_file(tokenizer_path) - - class MLSDataset(Dataset): """Custom Dataset for reading Multilingual LibriSpeech @@ -105,23 +96,33 @@ class MLSDataset(Dataset): self.mls_split: MLSSplit = split_to_mls_split(split) # split path on disk self.split: Split = split # split used internally self.dataset_lookup = [] + self.tokenizer: type[TokenizerType] self._handle_download_dataset(download) self._validate_local_directory() - transcripts_path = os.path.join( - dataset_path, language, self.mls_split, "transcripts.txt" - ) + transcripts_path = os.path.join(dataset_path, language, self.mls_split, "transcripts.txt") with open(transcripts_path, "r", encoding="utf-8") as script_file: # read all lines in transcripts.txt transcripts = script_file.readlines() # split each line into (<speakerid>_<bookid>_<chapterid>, <utterance>) - transcripts = [line.strip().split("\t", 1) for line in transcripts] - utterances = [utterance.strip() for _, utterance in transcripts] - identifier = [identifier.strip() for identifier, _ in transcripts] + transcripts = [line.strip().split("\t", 1) for line in transcripts] # type: ignore + utterances = [utterance.strip() for _, utterance in transcripts] # type: ignore + identifier = [identifier.strip() for identifier, _ in transcripts] # type: ignore identifier = [path.split("_") for path in identifier] + if self.split == Split.valid: + np.random.seed(42) + indices = np.random.choice(len(utterances), int(len(utterances) * 0.2)) + utterances = [utterances[i] for i in indices] + identifier = [identifier[i] for i in indices] + elif self.split == Split.train: + np.random.seed(42) + indices = np.random.choice(len(utterances), int(len(utterances) * 0.8)) + utterances = [utterances[i] for i in indices] + identifier = [identifier[i] for i in indices] + self.dataset_lookup = [ { "speakerid": path[0], @@ -129,27 +130,23 @@ class MLSDataset(Dataset): "chapterid": path[2], "utterance": utterance, } - for path, utterance in zip(identifier, utterances) + for path, utterance in zip(identifier, utterances, strict=False) ] - # save dataset_lookup as list of dicts, where each dict contains - # the speakerid, bookid and chapterid, as well as the utterance - # we can then use this to map the utterance to the audio file + def set_tokenizer(self, tokenizer: type[TokenizerType]): + """Sets the tokenizer""" + self.tokenizer = tokenizer + + self.calc_paddings() def _handle_download_dataset(self, download: bool): """Download the dataset""" - if ( - not os.path.exists(os.path.join(self.dataset_path, self.language)) - and download - ): + if not os.path.exists(os.path.join(self.dataset_path, self.language)) and download: os.makedirs(self.dataset_path) url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz" torch.hub.download_url_to_file(url, self.dataset_path) - elif ( - not os.path.exists(os.path.join(self.dataset_path, self.language)) - and not download - ): + elif not os.path.exists(os.path.join(self.dataset_path, self.language)) and not download: raise ValueError("Dataset not found. Set download to True to download it") def _validate_local_directory(self): @@ -158,18 +155,32 @@ class MLSDataset(Dataset): raise ValueError("Dataset path does not exist") if not os.path.exists(os.path.join(self.dataset_path, self.language)): raise ValueError("Language not found in dataset") - if not os.path.exists( - os.path.join(self.dataset_path, self.language, self.mls_split) - ): + if not os.path.exists(os.path.join(self.dataset_path, self.language, self.mls_split)): raise ValueError("Split not found in dataset") - # checks if the transcripts.txt file exists - if not os.path.exists( - os.path.join(dataset_path, language, split, "transcripts.txt") - ): - raise ValueError("transcripts.txt not found in dataset") - - def __get_len__(self): + def calc_paddings(self): + """Sets the maximum length of the spectrogram""" + # check if dataset has been loaded and tokenizer has been set + if not self.dataset_lookup: + raise ValueError("Dataset not loaded") + if not self.tokenizer: + raise ValueError("Tokenizer not set") + + max_spec_length = 0 + max_uterance_length = 0 + for sample in self.dataset_lookup: + spec_length = sample["spectrogram"].shape[0] + if spec_length > max_spec_length: + max_spec_length = spec_length + + utterance_length = sample["utterance"].shape[0] + if utterance_length > max_uterance_length: + max_uterance_length = utterance_length + + self.max_spec_length = max_spec_length + self.max_utterance_length = max_uterance_length + + def __len__(self): """Returns the length of the dataset""" return len(self.dataset_lookup) @@ -197,13 +208,32 @@ class MLSDataset(Dataset): ) waveform, sample_rate = torchaudio.load(audio_path) # type: ignore + # TODO: figure out if we have to resample or not + # TODO: pad correctly (manually) + spec = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128)(waveform).squeeze(0).transpose(0, 1) + print(f"spec.shape: {spec.shape}") + input_length = spec.shape[0] // 2 + spec = ( + torch.nn.functional.pad(spec, pad=(0, self.max_spec_length), mode="constant", value=0) + .unsqueeze(1) + .transpose(2, 3) + ) + + utterance_length = len(utterance) + self.tokenizer.enable_padding() + utterance = self.tokenizer.encode( + utterance, + ).ids + + utterance = torch.Tensor(utterance) return Sample( + # TODO: add flag to only return spectrogram or waveform or both waveform=waveform, - spectrogram=torchaudio.transforms.MelSpectrogram( - sample_rate=16000, n_mels=128 - )(waveform), + spectrogram=spec, + input_length=input_length, utterance=utterance, + utterance_length=utterance_length, sample_rate=sample_rate, speaker_id=self.dataset_lookup[idx]["speakerid"], book_id=self.dataset_lookup[idx]["bookid"], @@ -218,62 +248,6 @@ class MLSDataset(Dataset): torch.hub.download_url_to_file(url, dataset_path) -class DataProcessor: - """Factory for DataProcessingclass - - Transforms the dataset into spectrograms and labels, as well as a tokenizer - """ - - def __init__( - self, - dataset: MultilingualLibriSpeech, - tokenizer_path: str, - data_type: str = "train", - tokenizer_type: str = "BPE", - ): - self.dataset = dataset - self.data_type = data_type - - self.train_audio_transforms = nn.Sequential( - torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), - torchaudio.transforms.FrequencyMasking(freq_mask_param=30), - torchaudio.transforms.TimeMasking(time_mask_param=100), - ) - - self.valid_audio_transforms = torchaudio.transforms.MelSpectrogram() - self.tokenizer = tokenizer_factory( - tokenizer_path=tokenizer_path, tokenizer_type=tokenizer_type - ) - - def __call__(self) -> tuple[np.ndarray, np.ndarray, int, int]: - """Returns spectrograms, labels and their lenghts""" - for sample in self.dataset: - if self.data_type == "train": - spec = ( - self.train_audio_transforms(sample["waveform"]) - .squeeze(0) - .transpose(0, 1) - ) - elif self.data_type == "valid": - spec = ( - self.valid_audio_transforms(sample["waveform"]) - .squeeze(0) - .transpose(0, 1) - ) - else: - raise ValueError("data_type should be train or valid") - label = torch.Tensor(text_transform.encode(sample["utterance"]).ids) - - spectrograms = ( - nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) - .unsqueeze(1) - .transpose(2, 3) - ) - labels = nn.utils.rnn.pad_sequence(labels, batch_first=True) - - yield spec, label, spec.shape[0] // 2, len(labels) - - if __name__ == "__main__": dataset_path = "/Volumes/pherkel/SWR2-ASR" language = "mls_german_opus" @@ -281,4 +255,9 @@ if __name__ == "__main__": download = False dataset = MLSDataset(dataset_path, language, split, download) - print(dataset[0]) + + tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") + dataset.set_tokenizer(tok) + dataset.calc_paddings() + + print(dataset[41]["spectrogram"].shape) |