diff options
author | Pherkel | 2023-09-11 14:49:28 +0200 |
---|---|---|
committer | Pherkel | 2023-09-11 14:49:28 +0200 |
commit | 9dc3bc07424908dd7cf3f052708f506fd58b6e2c (patch) | |
tree | cd45dc9b70977530669c271c09025246ebbb9fef | |
parent | 01fae2b5e395e84db6a7e9819b6f98777c46e845 (diff) |
refactor utilities (data, vis, tokenizer)
-rw-r--r-- | swr2_asr/utils/__init__.py | 0 | ||||
-rw-r--r-- | swr2_asr/utils/data.py (renamed from swr2_asr/utils.py) | 206 | ||||
-rw-r--r-- | swr2_asr/utils/decoder.py | 26 | ||||
-rw-r--r-- | swr2_asr/utils/tokenizer.py (renamed from swr2_asr/tokenizer.py) | 0 | ||||
-rw-r--r-- | swr2_asr/utils/visualization.py | 22 |
5 files changed, 128 insertions, 126 deletions
diff --git a/swr2_asr/utils/__init__.py b/swr2_asr/utils/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/swr2_asr/utils/__init__.py diff --git a/swr2_asr/utils.py b/swr2_asr/utils/data.py index a362b9e..93f4a9a 100644 --- a/swr2_asr/utils.py +++ b/swr2_asr/utils/data.py @@ -3,21 +3,51 @@ import os from enum import Enum from typing import TypedDict -import matplotlib.pyplot as plt import numpy as np import torch import torchaudio -from tokenizers import Tokenizer +from torch import Tensor, nn from torch.utils.data import Dataset -from torchaudio.datasets.utils import _extract_tar as extract_archive +from torchaudio.datasets.utils import _extract_tar -from swr2_asr.tokenizer import TokenizerType +from swr2_asr.utils.tokenizer import CharTokenizer -train_audio_transforms = torch.nn.Sequential( - torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), - torchaudio.transforms.FrequencyMasking(freq_mask_param=30), - torchaudio.transforms.TimeMasking(time_mask_param=100), -) + +class DataProcessing: + """Data processing class for the dataloader""" + + def __init__(self, data_type: str, tokenizer: CharTokenizer): + self.data_type = data_type + self.tokenizer = tokenizer + + if data_type == "train": + self.audio_transform = torch.nn.Sequential( + torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), + torchaudio.transforms.FrequencyMasking(freq_mask_param=30), + torchaudio.transforms.TimeMasking(time_mask_param=100), + ) + elif data_type == "valid": + self.audio_transform = torchaudio.transforms.MelSpectrogram() + + def __call__(self, data) -> tuple[Tensor, Tensor, list, list]: + spectrograms = [] + labels = [] + input_lengths = [] + label_lengths = [] + for waveform, _, utterance, _, _, _ in data: + spec = self.audio_transform(waveform).squeeze(0).transpose(0, 1) + spectrograms.append(spec) + label = torch.Tensor(self.tokenizer.encode(utterance.lower())) + labels.append(label) + input_lengths.append(spec.shape[0] // 2) + label_lengths.append(len(label)) + + spectrograms = ( + nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3) + ) + labels = nn.utils.rnn.pad_sequence(labels, batch_first=True) + + return spectrograms, labels, input_lengths, label_lengths # create enum specifiying dataset splits @@ -31,7 +61,7 @@ class MLSSplit(str, Enum): class Split(str, Enum): - """Extending the MLSSplit class to allow for a custom validatio split""" + """Extending the MLSSplit class to allow for a custom validation split""" TRAIN = "train" VALID = "valid" @@ -43,8 +73,7 @@ def split_to_mls_split(split_name: Split) -> MLSSplit: """Converts the custom split to a MLSSplit""" if split_name == Split.VALID: return MLSSplit.TRAIN - else: - return split_name # type: ignore + return split_name # type: ignore class Sample(TypedDict): @@ -97,7 +126,7 @@ class MLSDataset(Dataset): split: Split, limited: bool, download: bool, - spectrogram_hparams: dict | None, + size: float = 0.2, ): """Initializes the dataset""" self.dataset_path = dataset_path @@ -106,22 +135,7 @@ class MLSDataset(Dataset): self.mls_split: MLSSplit = split_to_mls_split(split) # split path on disk self.split: Split = split # split used internally - if spectrogram_hparams is None: - self.spectrogram_hparams = { - "sample_rate": 16000, - "n_fft": 400, - "win_length": 400, - "hop_length": 160, - "n_mels": 128, - "f_min": 0, - "f_max": 8000, - "power": 2.0, - } - else: - self.spectrogram_hparams = spectrogram_hparams - self.dataset_lookup = [] - self.tokenizer: type[TokenizerType] self._handle_download_dataset(download) self._validate_local_directory() @@ -130,6 +144,8 @@ class MLSDataset(Dataset): else: self.initialize() + self.dataset_lookup = self.dataset_lookup[: int(len(self.dataset_lookup) * size)] + def initialize_limited(self) -> None: """Initializes the limited supervision dataset""" # get file handles @@ -246,10 +262,6 @@ class MLSDataset(Dataset): for path, utterance in zip(identifier, utterances, strict=False) ] - def set_tokenizer(self, tokenizer: type[TokenizerType]): - """Sets the tokenizer""" - self.tokenizer = tokenizer - def _handle_download_dataset(self, download: bool) -> None: """Download the dataset""" if not download: @@ -258,7 +270,9 @@ class MLSDataset(Dataset): # zip exists: if os.path.isfile(os.path.join(self.dataset_path, self.language) + ".tar.gz") and download: print(f"Found dataset at {self.dataset_path}. Skipping download") - # zip does not exist: + # path exists: + elif os.path.isdir(os.path.join(self.dataset_path, self.language)) and download: + return else: os.makedirs(self.dataset_path, exist_ok=True) url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz" @@ -273,9 +287,7 @@ class MLSDataset(Dataset): f"Unzipping the dataset at \ {os.path.join(self.dataset_path, self.language) + '.tar.gz'}" ) - extract_archive( - os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True - ) + _extract_tar(os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True) else: print("Dataset is already unzipped, validating it now") return @@ -293,12 +305,29 @@ class MLSDataset(Dataset): """Returns the length of the dataset""" return len(self.dataset_lookup) - def __getitem__(self, idx: int) -> Sample: - """One sample""" - if self.tokenizer is None: - raise ValueError("No tokenizer set") + def __getitem__(self, idx: int) -> tuple[Tensor, int, str, int, int, int]: + """One sample + + Returns: + Tuple of the following items; + + Tensor: + Waveform + int: + Sample rate + str: + Transcript + int: + Speaker ID + int: + Chapter ID + int: + Utterance ID + """ # get the utterance - utterance = self.dataset_lookup[idx]["utterance"] + dataset_lookup_entry = self.dataset_lookup[idx] + + utterance = dataset_lookup_entry["utterance"] # get the audio file audio_path = os.path.join( @@ -321,70 +350,18 @@ class MLSDataset(Dataset): waveform, sample_rate = torchaudio.load(audio_path) # pylint: disable=no-member # resample if necessary - if sample_rate != self.spectrogram_hparams["sample_rate"]: - resampler = torchaudio.transforms.Resample( - sample_rate, self.spectrogram_hparams["sample_rate"] - ) + if sample_rate != 16000: + resampler = torchaudio.transforms.Resample(sample_rate, 16000) waveform = resampler(waveform) - spec = ( - torchaudio.transforms.MelSpectrogram(**self.spectrogram_hparams)(waveform) - .squeeze(0) - .transpose(0, 1) - ) - - input_length = spec.shape[0] // 2 - - utterance_length = len(utterance) - - utterance = self.tokenizer.encode(utterance) - - utterance = torch.LongTensor(utterance.ids) # pylint: disable=no-member - - return Sample( - waveform=waveform, - spectrogram=spec, - input_length=input_length, - utterance=utterance, - utterance_length=utterance_length, - sample_rate=self.spectrogram_hparams["sample_rate"], - speaker_id=self.dataset_lookup[idx]["speakerid"], - book_id=self.dataset_lookup[idx]["bookid"], - chapter_id=self.dataset_lookup[idx]["chapterid"], - ) - - -def collate_fn(samples: list[Sample]) -> dict: - """Collate function for the dataloader - - pads all tensors within a batch to the same dimensions - """ - waveforms = [] - spectrograms = [] - labels = [] - input_lengths = [] - label_lengths = [] - - for sample in samples: - waveforms.append(sample["waveform"].transpose(0, 1)) - spectrograms.append(sample["spectrogram"]) - labels.append(sample["utterance"]) - input_lengths.append(sample["spectrogram"].shape[0] // 2) - label_lengths.append(len(sample["utterance"])) - - waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True) - spectrograms = ( - torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3) - ) - labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True) - - return { - "waveform": waveforms, - "spectrogram": spectrograms, - "input_length": input_lengths, - "utterance": labels, - "utterance_length": label_lengths, - } + return ( + waveform, + sample_rate, + utterance, + dataset_lookup_entry["speakerid"], + dataset_lookup_entry["chapterid"], + idx, + ) # type: ignore if __name__ == "__main__": @@ -392,26 +369,3 @@ if __name__ == "__main__": LANGUAGE = "mls_german_opus" split = Split.TRAIN DOWNLOAD = False - - dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, False, DOWNLOAD, None) - - tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") - dataset.set_tokenizer(tok) - - -def plot(epochs, path): - """Plots the losses over the epochs""" - losses = list() - test_losses = list() - cers = list() - wers = list() - for epoch in range(1, epochs + 1): - current_state = torch.load(path + str(epoch)) - losses.append(current_state["loss"]) - test_losses.append(current_state["test_loss"]) - cers.append(current_state["avg_cer"]) - wers.append(current_state["avg_wer"]) - - plt.plot(losses) - plt.plot(test_losses) - plt.savefig("losses.svg") diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py new file mode 100644 index 0000000..fcddb79 --- /dev/null +++ b/swr2_asr/utils/decoder.py @@ -0,0 +1,26 @@ +"""Decoder for CTC-based ASR.""" "" +import torch + +from swr2_asr.utils.tokenizer import CharTokenizer + + +# TODO: refactor to use torch CTC decoder class +def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True): + """Greedily decode a sequence.""" + blank_label = tokenizer.get_blank_token() + arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member + decodes = [] + targets = [] + for i, args in enumerate(arg_maxes): + decode = [] + targets.append(tokenizer.decode(labels[i][: label_lengths[i]].tolist())) + for j, index in enumerate(args): + if index != blank_label: + if collapse_repeated and j != 0 and index == args[j - 1]: + continue + decode.append(index.item()) + decodes.append(tokenizer.decode(decode)) + return decodes, targets + + +# TODO: add beam search decoder diff --git a/swr2_asr/tokenizer.py b/swr2_asr/utils/tokenizer.py index d92465a..d92465a 100644 --- a/swr2_asr/tokenizer.py +++ b/swr2_asr/utils/tokenizer.py diff --git a/swr2_asr/utils/visualization.py b/swr2_asr/utils/visualization.py new file mode 100644 index 0000000..80f942a --- /dev/null +++ b/swr2_asr/utils/visualization.py @@ -0,0 +1,22 @@ +"""Utilities for visualizing the training process and results.""" + +import matplotlib.pyplot as plt +import torch + + +def plot(epochs, path): + """Plots the losses over the epochs""" + losses = list() + test_losses = list() + cers = list() + wers = list() + for epoch in range(1, epochs + 1): + current_state = torch.load(path + str(epoch)) + losses.append(current_state["loss"]) + test_losses.append(current_state["test_loss"]) + cers.append(current_state["avg_cer"]) + wers.append(current_state["avg_wer"]) + + plt.plot(losses) + plt.plot(test_losses) + plt.savefig("losses.svg") |