From 9dc3bc07424908dd7cf3f052708f506fd58b6e2c Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 11 Sep 2023 14:49:28 +0200 Subject: refactor utilities (data, vis, tokenizer) --- swr2_asr/tokenizer.py | 126 ------------ swr2_asr/utils.py | 417 ---------------------------------------- swr2_asr/utils/__init__.py | 0 swr2_asr/utils/data.py | 371 +++++++++++++++++++++++++++++++++++ swr2_asr/utils/decoder.py | 26 +++ swr2_asr/utils/tokenizer.py | 126 ++++++++++++ swr2_asr/utils/visualization.py | 22 +++ 7 files changed, 545 insertions(+), 543 deletions(-) delete mode 100644 swr2_asr/tokenizer.py delete mode 100644 swr2_asr/utils.py create mode 100644 swr2_asr/utils/__init__.py create mode 100644 swr2_asr/utils/data.py create mode 100644 swr2_asr/utils/decoder.py create mode 100644 swr2_asr/utils/tokenizer.py create mode 100644 swr2_asr/utils/visualization.py diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py deleted file mode 100644 index d92465a..0000000 --- a/swr2_asr/tokenizer.py +++ /dev/null @@ -1,126 +0,0 @@ -"""Tokenizer for Multilingual Librispeech datasets""" - - -class CharTokenizer: - """Maps characters to integers and vice versa""" - - def __init__(self): - char_map_str = """ - _ - - - - a - b - c - d - e - f - g - h - i - j - k - l - m - n - o - p - q - r - s - t - u - v - w - x - y - z - é - à - ä - ö - ß - ü - - - ' - - """ - - self.char_map = {} - self.index_map = {} - for idx, char in enumerate(char_map_str.strip().split("\n")): - char = char.strip() - self.char_map[char] = idx - self.index_map[idx] = char - self.index_map[1] = " " - - def encode(self, text: str) -> list[int]: - """Use a character map and convert text to an integer sequence""" - int_sequence = [] - for char in text: - if char == " ": - char = self.char_map[""] - elif char not in self.char_map: - char = self.char_map[""] - else: - char = self.char_map[char] - int_sequence.append(char) - return int_sequence - - def decode(self, labels: list[int]) -> str: - """Use a character map and convert integer labels to an text sequence""" - string = [] - for i in labels: - string.append(self.index_map[i]) - return "".join(string).replace("", " ") - - def get_vocab_size(self) -> int: - """Get the number of unique characters in the dataset""" - return len(self.char_map) - - def get_blank_token(self) -> int: - """Get the integer representation of the character""" - return self.char_map[""] - - def get_unk_token(self) -> int: - """Get the integer representation of the character""" - return self.char_map[""] - - def get_space_token(self) -> int: - """Get the integer representation of the character""" - return self.char_map[""] - - # TODO: add train function - - def save(self, path: str) -> None: - """Save the tokenizer to a file""" - with open(path, "w", encoding="utf-8") as file: - for char, index in self.char_map.items(): - file.write(f"{char} {index}\n") - - @staticmethod - def from_file(tokenizer_file: str) -> "CharTokenizer": - """Instantiate a CharTokenizer from a file""" - load_tokenizer = CharTokenizer() - with open(tokenizer_file, "r", encoding="utf-8") as file: - for line in file: - line = line.strip() - if line: - char, index = line.split() - tokenizer.char_map[char] = int(index) - tokenizer.index_map[int(index)] = char - return load_tokenizer - - -if __name__ == "__main__": - tokenizer = CharTokenizer() - tokenizer.save("data/tokenizers/char_tokenizer_german.json") - print(tokenizer.char_map) - print(tokenizer.index_map) - print(tokenizer.get_vocab_size()) - print(tokenizer.get_blank_token()) - print(tokenizer.get_unk_token()) - print(tokenizer.get_space_token()) - print(tokenizer.encode("hallo welt")) - print(tokenizer.decode([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])) diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py deleted file mode 100644 index a362b9e..0000000 --- a/swr2_asr/utils.py +++ /dev/null @@ -1,417 +0,0 @@ -"""Class containing utils for the ASR system.""" -import os -from enum import Enum -from typing import TypedDict - -import matplotlib.pyplot as plt -import numpy as np -import torch -import torchaudio -from tokenizers import Tokenizer -from torch.utils.data import Dataset -from torchaudio.datasets.utils import _extract_tar as extract_archive - -from swr2_asr.tokenizer import TokenizerType - -train_audio_transforms = torch.nn.Sequential( - torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), - torchaudio.transforms.FrequencyMasking(freq_mask_param=30), - torchaudio.transforms.TimeMasking(time_mask_param=100), -) - - -# create enum specifiying dataset splits -class MLSSplit(str, Enum): - """Enum specifying dataset as they are defined in the - Multilingual LibriSpeech dataset""" - - TRAIN = "train" - TEST = "test" - DEV = "dev" - - -class Split(str, Enum): - """Extending the MLSSplit class to allow for a custom validatio split""" - - TRAIN = "train" - VALID = "valid" - TEST = "test" - DEV = "dev" - - -def split_to_mls_split(split_name: Split) -> MLSSplit: - """Converts the custom split to a MLSSplit""" - if split_name == Split.VALID: - return MLSSplit.TRAIN - else: - return split_name # type: ignore - - -class Sample(TypedDict): - """Type for a sample in the dataset""" - - waveform: torch.Tensor - spectrogram: torch.Tensor - input_length: int - utterance: torch.Tensor - utterance_length: int - sample_rate: int - speaker_id: str - book_id: str - chapter_id: str - - -class MLSDataset(Dataset): - """Custom Dataset for reading Multilingual LibriSpeech - - Attributes: - dataset_path (str): - path to the dataset - language (str): - language of the dataset - split (Split): - split of the dataset - mls_split (MLSSplit): - split of the dataset as defined in the Multilingual LibriSpeech dataset - dataset_lookup (list): - list of dicts containing the speakerid, bookid, chapterid and utterance - - directory structure: - - ├── - │ ├── train - │ │ ├── transcripts.txt - │ │ └── audio - │ │ └── - │ │ └── - │ │ └── __.opus / .flac - - each line in transcripts.txt has the following format: - __ - """ - - def __init__( - self, - dataset_path: str, - language: str, - split: Split, - limited: bool, - download: bool, - spectrogram_hparams: dict | None, - ): - """Initializes the dataset""" - self.dataset_path = dataset_path - self.language = language - self.file_ext = ".opus" if "opus" in language else ".flac" - self.mls_split: MLSSplit = split_to_mls_split(split) # split path on disk - self.split: Split = split # split used internally - - if spectrogram_hparams is None: - self.spectrogram_hparams = { - "sample_rate": 16000, - "n_fft": 400, - "win_length": 400, - "hop_length": 160, - "n_mels": 128, - "f_min": 0, - "f_max": 8000, - "power": 2.0, - } - else: - self.spectrogram_hparams = spectrogram_hparams - - self.dataset_lookup = [] - self.tokenizer: type[TokenizerType] - - self._handle_download_dataset(download) - self._validate_local_directory() - if limited and (split == Split.TRAIN or split == Split.VALID): - self.initialize_limited() - else: - self.initialize() - - def initialize_limited(self) -> None: - """Initializes the limited supervision dataset""" - # get file handles - # get file paths - # get transcripts - # create train or validation split - - handles = set() - - train_root_path = os.path.join(self.dataset_path, self.language, "train") - - # get file handles for 9h - with open( - os.path.join(train_root_path, "limited_supervision", "9hr", "handles.txt"), - "r", - encoding="utf-8", - ) as file: - for line in file: - handles.add(line.strip()) - - # get file handles for 1h splits - for handle_path in os.listdir(os.path.join(train_root_path, "limited_supervision", "1hr")): - if handle_path not in range(0, 6): - continue - with open( - os.path.join( - train_root_path, "limited_supervision", "1hr", handle_path, "handles.txt" - ), - "r", - encoding="utf-8", - ) as file: - for line in file: - handles.add(line.strip()) - - # get file paths for handles - file_paths = [] - for handle in handles: - file_paths.append( - os.path.join( - train_root_path, - "audio", - handle.split("_")[0], - handle.split("_")[1], - handle + self.file_ext, - ) - ) - - # get transcripts for handles - transcripts = [] - with open(os.path.join(train_root_path, "transcripts.txt"), "r", encoding="utf-8") as file: - for line in file: - if line.split("\t")[0] in handles: - transcripts.append(line.strip()) - - # create train or valid split randomly with seed 42 - if self.split == Split.TRAIN: - np.random.seed(42) - indices = np.random.choice(len(file_paths), int(len(file_paths) * 0.8)) - file_paths = [file_paths[i] for i in indices] - transcripts = [transcripts[i] for i in indices] - elif self.split == Split.VALID: - np.random.seed(42) - indices = np.random.choice(len(file_paths), int(len(file_paths) * 0.2)) - file_paths = [file_paths[i] for i in indices] - transcripts = [transcripts[i] for i in indices] - - # create dataset lookup - self.dataset_lookup = [ - { - "speakerid": path.split("/")[-3], - "bookid": path.split("/")[-2], - "chapterid": path.split("/")[-1].split("_")[2].split(".")[0], - "utterance": utterance.split("\t")[1], - } - for path, utterance in zip(file_paths, transcripts, strict=False) - ] - - def initialize(self) -> None: - """Initializes the entire dataset - - Reads the transcripts.txt file and creates a lookup table - """ - transcripts_path = os.path.join( - self.dataset_path, self.language, self.mls_split, "transcripts.txt" - ) - - with open(transcripts_path, "r", encoding="utf-8") as script_file: - # read all lines in transcripts.txt - transcripts = script_file.readlines() - # split each line into (__, ) - transcripts = [line.strip().split("\t", 1) for line in transcripts] # type: ignore - utterances = [utterance.strip() for _, utterance in transcripts] # type: ignore - identifier = [identifier.strip() for identifier, _ in transcripts] # type: ignore - identifier = [path.split("_") for path in identifier] - - if self.split == Split.VALID: - np.random.seed(42) - indices = np.random.choice(len(utterances), int(len(utterances) * 0.2)) - utterances = [utterances[i] for i in indices] - identifier = [identifier[i] for i in indices] - elif self.split == Split.TRAIN: - np.random.seed(42) - indices = np.random.choice(len(utterances), int(len(utterances) * 0.8)) - utterances = [utterances[i] for i in indices] - identifier = [identifier[i] for i in indices] - - self.dataset_lookup = [ - { - "speakerid": path[0], - "bookid": path[1], - "chapterid": path[2], - "utterance": utterance, - } - for path, utterance in zip(identifier, utterances, strict=False) - ] - - def set_tokenizer(self, tokenizer: type[TokenizerType]): - """Sets the tokenizer""" - self.tokenizer = tokenizer - - def _handle_download_dataset(self, download: bool) -> None: - """Download the dataset""" - if not download: - print("Download flag not set, skipping download") - return - # zip exists: - if os.path.isfile(os.path.join(self.dataset_path, self.language) + ".tar.gz") and download: - print(f"Found dataset at {self.dataset_path}. Skipping download") - # zip does not exist: - else: - os.makedirs(self.dataset_path, exist_ok=True) - url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz" - - torch.hub.download_url_to_file( - url, os.path.join(self.dataset_path, self.language) + ".tar.gz" - ) - - # unzip the dataset - if not os.path.isdir(os.path.join(self.dataset_path, self.language)): - print( - f"Unzipping the dataset at \ - {os.path.join(self.dataset_path, self.language) + '.tar.gz'}" - ) - extract_archive( - os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True - ) - else: - print("Dataset is already unzipped, validating it now") - return - - def _validate_local_directory(self): - # check if dataset_path exists - if not os.path.exists(self.dataset_path): - raise ValueError("Dataset path does not exist") - if not os.path.exists(os.path.join(self.dataset_path, self.language)): - raise ValueError("Language not downloaded!") - if not os.path.exists(os.path.join(self.dataset_path, self.language, self.mls_split)): - raise ValueError("Split not found in dataset") - - def __len__(self): - """Returns the length of the dataset""" - return len(self.dataset_lookup) - - def __getitem__(self, idx: int) -> Sample: - """One sample""" - if self.tokenizer is None: - raise ValueError("No tokenizer set") - # get the utterance - utterance = self.dataset_lookup[idx]["utterance"] - - # get the audio file - audio_path = os.path.join( - self.dataset_path, - self.language, - self.mls_split, - "audio", - self.dataset_lookup[idx]["speakerid"], - self.dataset_lookup[idx]["bookid"], - "_".join( - [ - self.dataset_lookup[idx]["speakerid"], - self.dataset_lookup[idx]["bookid"], - self.dataset_lookup[idx]["chapterid"], - ] - ) - + self.file_ext, - ) - - waveform, sample_rate = torchaudio.load(audio_path) # pylint: disable=no-member - - # resample if necessary - if sample_rate != self.spectrogram_hparams["sample_rate"]: - resampler = torchaudio.transforms.Resample( - sample_rate, self.spectrogram_hparams["sample_rate"] - ) - waveform = resampler(waveform) - - spec = ( - torchaudio.transforms.MelSpectrogram(**self.spectrogram_hparams)(waveform) - .squeeze(0) - .transpose(0, 1) - ) - - input_length = spec.shape[0] // 2 - - utterance_length = len(utterance) - - utterance = self.tokenizer.encode(utterance) - - utterance = torch.LongTensor(utterance.ids) # pylint: disable=no-member - - return Sample( - waveform=waveform, - spectrogram=spec, - input_length=input_length, - utterance=utterance, - utterance_length=utterance_length, - sample_rate=self.spectrogram_hparams["sample_rate"], - speaker_id=self.dataset_lookup[idx]["speakerid"], - book_id=self.dataset_lookup[idx]["bookid"], - chapter_id=self.dataset_lookup[idx]["chapterid"], - ) - - -def collate_fn(samples: list[Sample]) -> dict: - """Collate function for the dataloader - - pads all tensors within a batch to the same dimensions - """ - waveforms = [] - spectrograms = [] - labels = [] - input_lengths = [] - label_lengths = [] - - for sample in samples: - waveforms.append(sample["waveform"].transpose(0, 1)) - spectrograms.append(sample["spectrogram"]) - labels.append(sample["utterance"]) - input_lengths.append(sample["spectrogram"].shape[0] // 2) - label_lengths.append(len(sample["utterance"])) - - waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True) - spectrograms = ( - torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3) - ) - labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True) - - return { - "waveform": waveforms, - "spectrogram": spectrograms, - "input_length": input_lengths, - "utterance": labels, - "utterance_length": label_lengths, - } - - -if __name__ == "__main__": - DATASET_PATH = "/Volumes/pherkel/SWR2-ASR" - LANGUAGE = "mls_german_opus" - split = Split.TRAIN - DOWNLOAD = False - - dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, False, DOWNLOAD, None) - - tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") - dataset.set_tokenizer(tok) - - -def plot(epochs, path): - """Plots the losses over the epochs""" - losses = list() - test_losses = list() - cers = list() - wers = list() - for epoch in range(1, epochs + 1): - current_state = torch.load(path + str(epoch)) - losses.append(current_state["loss"]) - test_losses.append(current_state["test_loss"]) - cers.append(current_state["avg_cer"]) - wers.append(current_state["avg_wer"]) - - plt.plot(losses) - plt.plot(test_losses) - plt.savefig("losses.svg") diff --git a/swr2_asr/utils/__init__.py b/swr2_asr/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py new file mode 100644 index 0000000..93f4a9a --- /dev/null +++ b/swr2_asr/utils/data.py @@ -0,0 +1,371 @@ +"""Class containing utils for the ASR system.""" +import os +from enum import Enum +from typing import TypedDict + +import numpy as np +import torch +import torchaudio +from torch import Tensor, nn +from torch.utils.data import Dataset +from torchaudio.datasets.utils import _extract_tar + +from swr2_asr.utils.tokenizer import CharTokenizer + + +class DataProcessing: + """Data processing class for the dataloader""" + + def __init__(self, data_type: str, tokenizer: CharTokenizer): + self.data_type = data_type + self.tokenizer = tokenizer + + if data_type == "train": + self.audio_transform = torch.nn.Sequential( + torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), + torchaudio.transforms.FrequencyMasking(freq_mask_param=30), + torchaudio.transforms.TimeMasking(time_mask_param=100), + ) + elif data_type == "valid": + self.audio_transform = torchaudio.transforms.MelSpectrogram() + + def __call__(self, data) -> tuple[Tensor, Tensor, list, list]: + spectrograms = [] + labels = [] + input_lengths = [] + label_lengths = [] + for waveform, _, utterance, _, _, _ in data: + spec = self.audio_transform(waveform).squeeze(0).transpose(0, 1) + spectrograms.append(spec) + label = torch.Tensor(self.tokenizer.encode(utterance.lower())) + labels.append(label) + input_lengths.append(spec.shape[0] // 2) + label_lengths.append(len(label)) + + spectrograms = ( + nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3) + ) + labels = nn.utils.rnn.pad_sequence(labels, batch_first=True) + + return spectrograms, labels, input_lengths, label_lengths + + +# create enum specifiying dataset splits +class MLSSplit(str, Enum): + """Enum specifying dataset as they are defined in the + Multilingual LibriSpeech dataset""" + + TRAIN = "train" + TEST = "test" + DEV = "dev" + + +class Split(str, Enum): + """Extending the MLSSplit class to allow for a custom validation split""" + + TRAIN = "train" + VALID = "valid" + TEST = "test" + DEV = "dev" + + +def split_to_mls_split(split_name: Split) -> MLSSplit: + """Converts the custom split to a MLSSplit""" + if split_name == Split.VALID: + return MLSSplit.TRAIN + return split_name # type: ignore + + +class Sample(TypedDict): + """Type for a sample in the dataset""" + + waveform: torch.Tensor + spectrogram: torch.Tensor + input_length: int + utterance: torch.Tensor + utterance_length: int + sample_rate: int + speaker_id: str + book_id: str + chapter_id: str + + +class MLSDataset(Dataset): + """Custom Dataset for reading Multilingual LibriSpeech + + Attributes: + dataset_path (str): + path to the dataset + language (str): + language of the dataset + split (Split): + split of the dataset + mls_split (MLSSplit): + split of the dataset as defined in the Multilingual LibriSpeech dataset + dataset_lookup (list): + list of dicts containing the speakerid, bookid, chapterid and utterance + + directory structure: + + ├── + │ ├── train + │ │ ├── transcripts.txt + │ │ └── audio + │ │ └── + │ │ └── + │ │ └── __.opus / .flac + + each line in transcripts.txt has the following format: + __ + """ + + def __init__( + self, + dataset_path: str, + language: str, + split: Split, + limited: bool, + download: bool, + size: float = 0.2, + ): + """Initializes the dataset""" + self.dataset_path = dataset_path + self.language = language + self.file_ext = ".opus" if "opus" in language else ".flac" + self.mls_split: MLSSplit = split_to_mls_split(split) # split path on disk + self.split: Split = split # split used internally + + self.dataset_lookup = [] + + self._handle_download_dataset(download) + self._validate_local_directory() + if limited and (split == Split.TRAIN or split == Split.VALID): + self.initialize_limited() + else: + self.initialize() + + self.dataset_lookup = self.dataset_lookup[: int(len(self.dataset_lookup) * size)] + + def initialize_limited(self) -> None: + """Initializes the limited supervision dataset""" + # get file handles + # get file paths + # get transcripts + # create train or validation split + + handles = set() + + train_root_path = os.path.join(self.dataset_path, self.language, "train") + + # get file handles for 9h + with open( + os.path.join(train_root_path, "limited_supervision", "9hr", "handles.txt"), + "r", + encoding="utf-8", + ) as file: + for line in file: + handles.add(line.strip()) + + # get file handles for 1h splits + for handle_path in os.listdir(os.path.join(train_root_path, "limited_supervision", "1hr")): + if handle_path not in range(0, 6): + continue + with open( + os.path.join( + train_root_path, "limited_supervision", "1hr", handle_path, "handles.txt" + ), + "r", + encoding="utf-8", + ) as file: + for line in file: + handles.add(line.strip()) + + # get file paths for handles + file_paths = [] + for handle in handles: + file_paths.append( + os.path.join( + train_root_path, + "audio", + handle.split("_")[0], + handle.split("_")[1], + handle + self.file_ext, + ) + ) + + # get transcripts for handles + transcripts = [] + with open(os.path.join(train_root_path, "transcripts.txt"), "r", encoding="utf-8") as file: + for line in file: + if line.split("\t")[0] in handles: + transcripts.append(line.strip()) + + # create train or valid split randomly with seed 42 + if self.split == Split.TRAIN: + np.random.seed(42) + indices = np.random.choice(len(file_paths), int(len(file_paths) * 0.8)) + file_paths = [file_paths[i] for i in indices] + transcripts = [transcripts[i] for i in indices] + elif self.split == Split.VALID: + np.random.seed(42) + indices = np.random.choice(len(file_paths), int(len(file_paths) * 0.2)) + file_paths = [file_paths[i] for i in indices] + transcripts = [transcripts[i] for i in indices] + + # create dataset lookup + self.dataset_lookup = [ + { + "speakerid": path.split("/")[-3], + "bookid": path.split("/")[-2], + "chapterid": path.split("/")[-1].split("_")[2].split(".")[0], + "utterance": utterance.split("\t")[1], + } + for path, utterance in zip(file_paths, transcripts, strict=False) + ] + + def initialize(self) -> None: + """Initializes the entire dataset + + Reads the transcripts.txt file and creates a lookup table + """ + transcripts_path = os.path.join( + self.dataset_path, self.language, self.mls_split, "transcripts.txt" + ) + + with open(transcripts_path, "r", encoding="utf-8") as script_file: + # read all lines in transcripts.txt + transcripts = script_file.readlines() + # split each line into (__, ) + transcripts = [line.strip().split("\t", 1) for line in transcripts] # type: ignore + utterances = [utterance.strip() for _, utterance in transcripts] # type: ignore + identifier = [identifier.strip() for identifier, _ in transcripts] # type: ignore + identifier = [path.split("_") for path in identifier] + + if self.split == Split.VALID: + np.random.seed(42) + indices = np.random.choice(len(utterances), int(len(utterances) * 0.2)) + utterances = [utterances[i] for i in indices] + identifier = [identifier[i] for i in indices] + elif self.split == Split.TRAIN: + np.random.seed(42) + indices = np.random.choice(len(utterances), int(len(utterances) * 0.8)) + utterances = [utterances[i] for i in indices] + identifier = [identifier[i] for i in indices] + + self.dataset_lookup = [ + { + "speakerid": path[0], + "bookid": path[1], + "chapterid": path[2], + "utterance": utterance, + } + for path, utterance in zip(identifier, utterances, strict=False) + ] + + def _handle_download_dataset(self, download: bool) -> None: + """Download the dataset""" + if not download: + print("Download flag not set, skipping download") + return + # zip exists: + if os.path.isfile(os.path.join(self.dataset_path, self.language) + ".tar.gz") and download: + print(f"Found dataset at {self.dataset_path}. Skipping download") + # path exists: + elif os.path.isdir(os.path.join(self.dataset_path, self.language)) and download: + return + else: + os.makedirs(self.dataset_path, exist_ok=True) + url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz" + + torch.hub.download_url_to_file( + url, os.path.join(self.dataset_path, self.language) + ".tar.gz" + ) + + # unzip the dataset + if not os.path.isdir(os.path.join(self.dataset_path, self.language)): + print( + f"Unzipping the dataset at \ + {os.path.join(self.dataset_path, self.language) + '.tar.gz'}" + ) + _extract_tar(os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True) + else: + print("Dataset is already unzipped, validating it now") + return + + def _validate_local_directory(self): + # check if dataset_path exists + if not os.path.exists(self.dataset_path): + raise ValueError("Dataset path does not exist") + if not os.path.exists(os.path.join(self.dataset_path, self.language)): + raise ValueError("Language not downloaded!") + if not os.path.exists(os.path.join(self.dataset_path, self.language, self.mls_split)): + raise ValueError("Split not found in dataset") + + def __len__(self): + """Returns the length of the dataset""" + return len(self.dataset_lookup) + + def __getitem__(self, idx: int) -> tuple[Tensor, int, str, int, int, int]: + """One sample + + Returns: + Tuple of the following items; + + Tensor: + Waveform + int: + Sample rate + str: + Transcript + int: + Speaker ID + int: + Chapter ID + int: + Utterance ID + """ + # get the utterance + dataset_lookup_entry = self.dataset_lookup[idx] + + utterance = dataset_lookup_entry["utterance"] + + # get the audio file + audio_path = os.path.join( + self.dataset_path, + self.language, + self.mls_split, + "audio", + self.dataset_lookup[idx]["speakerid"], + self.dataset_lookup[idx]["bookid"], + "_".join( + [ + self.dataset_lookup[idx]["speakerid"], + self.dataset_lookup[idx]["bookid"], + self.dataset_lookup[idx]["chapterid"], + ] + ) + + self.file_ext, + ) + + waveform, sample_rate = torchaudio.load(audio_path) # pylint: disable=no-member + + # resample if necessary + if sample_rate != 16000: + resampler = torchaudio.transforms.Resample(sample_rate, 16000) + waveform = resampler(waveform) + + return ( + waveform, + sample_rate, + utterance, + dataset_lookup_entry["speakerid"], + dataset_lookup_entry["chapterid"], + idx, + ) # type: ignore + + +if __name__ == "__main__": + DATASET_PATH = "/Volumes/pherkel/SWR2-ASR" + LANGUAGE = "mls_german_opus" + split = Split.TRAIN + DOWNLOAD = False diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py new file mode 100644 index 0000000..fcddb79 --- /dev/null +++ b/swr2_asr/utils/decoder.py @@ -0,0 +1,26 @@ +"""Decoder for CTC-based ASR.""" "" +import torch + +from swr2_asr.utils.tokenizer import CharTokenizer + + +# TODO: refactor to use torch CTC decoder class +def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True): + """Greedily decode a sequence.""" + blank_label = tokenizer.get_blank_token() + arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member + decodes = [] + targets = [] + for i, args in enumerate(arg_maxes): + decode = [] + targets.append(tokenizer.decode(labels[i][: label_lengths[i]].tolist())) + for j, index in enumerate(args): + if index != blank_label: + if collapse_repeated and j != 0 and index == args[j - 1]: + continue + decode.append(index.item()) + decodes.append(tokenizer.decode(decode)) + return decodes, targets + + +# TODO: add beam search decoder diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py new file mode 100644 index 0000000..d92465a --- /dev/null +++ b/swr2_asr/utils/tokenizer.py @@ -0,0 +1,126 @@ +"""Tokenizer for Multilingual Librispeech datasets""" + + +class CharTokenizer: + """Maps characters to integers and vice versa""" + + def __init__(self): + char_map_str = """ + _ + + + + a + b + c + d + e + f + g + h + i + j + k + l + m + n + o + p + q + r + s + t + u + v + w + x + y + z + é + à + ä + ö + ß + ü + - + ' + + """ + + self.char_map = {} + self.index_map = {} + for idx, char in enumerate(char_map_str.strip().split("\n")): + char = char.strip() + self.char_map[char] = idx + self.index_map[idx] = char + self.index_map[1] = " " + + def encode(self, text: str) -> list[int]: + """Use a character map and convert text to an integer sequence""" + int_sequence = [] + for char in text: + if char == " ": + char = self.char_map[""] + elif char not in self.char_map: + char = self.char_map[""] + else: + char = self.char_map[char] + int_sequence.append(char) + return int_sequence + + def decode(self, labels: list[int]) -> str: + """Use a character map and convert integer labels to an text sequence""" + string = [] + for i in labels: + string.append(self.index_map[i]) + return "".join(string).replace("", " ") + + def get_vocab_size(self) -> int: + """Get the number of unique characters in the dataset""" + return len(self.char_map) + + def get_blank_token(self) -> int: + """Get the integer representation of the character""" + return self.char_map[""] + + def get_unk_token(self) -> int: + """Get the integer representation of the character""" + return self.char_map[""] + + def get_space_token(self) -> int: + """Get the integer representation of the character""" + return self.char_map[""] + + # TODO: add train function + + def save(self, path: str) -> None: + """Save the tokenizer to a file""" + with open(path, "w", encoding="utf-8") as file: + for char, index in self.char_map.items(): + file.write(f"{char} {index}\n") + + @staticmethod + def from_file(tokenizer_file: str) -> "CharTokenizer": + """Instantiate a CharTokenizer from a file""" + load_tokenizer = CharTokenizer() + with open(tokenizer_file, "r", encoding="utf-8") as file: + for line in file: + line = line.strip() + if line: + char, index = line.split() + tokenizer.char_map[char] = int(index) + tokenizer.index_map[int(index)] = char + return load_tokenizer + + +if __name__ == "__main__": + tokenizer = CharTokenizer() + tokenizer.save("data/tokenizers/char_tokenizer_german.json") + print(tokenizer.char_map) + print(tokenizer.index_map) + print(tokenizer.get_vocab_size()) + print(tokenizer.get_blank_token()) + print(tokenizer.get_unk_token()) + print(tokenizer.get_space_token()) + print(tokenizer.encode("hallo welt")) + print(tokenizer.decode([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])) diff --git a/swr2_asr/utils/visualization.py b/swr2_asr/utils/visualization.py new file mode 100644 index 0000000..80f942a --- /dev/null +++ b/swr2_asr/utils/visualization.py @@ -0,0 +1,22 @@ +"""Utilities for visualizing the training process and results.""" + +import matplotlib.pyplot as plt +import torch + + +def plot(epochs, path): + """Plots the losses over the epochs""" + losses = list() + test_losses = list() + cers = list() + wers = list() + for epoch in range(1, epochs + 1): + current_state = torch.load(path + str(epoch)) + losses.append(current_state["loss"]) + test_losses.append(current_state["test_loss"]) + cers.append(current_state["avg_cer"]) + wers.append(current_state["avg_wer"]) + + plt.plot(losses) + plt.plot(test_losses) + plt.savefig("losses.svg") -- cgit v1.2.3