diff options
Diffstat (limited to 'swr2_asr/utils.py')
-rw-r--r-- | swr2_asr/utils.py | 109 |
1 files changed, 96 insertions, 13 deletions
diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py index 87d4f82..a362b9e 100644 --- a/swr2_asr/utils.py +++ b/swr2_asr/utils.py @@ -3,8 +3,8 @@ import os from enum import Enum from typing import TypedDict -import numpy as np import matplotlib.pyplot as plt +import numpy as np import torch import torchaudio from tokenizers import Tokenizer @@ -95,6 +95,7 @@ class MLSDataset(Dataset): dataset_path: str, language: str, split: Split, + limited: bool, download: bool, spectrogram_hparams: dict | None, ): @@ -124,10 +125,90 @@ class MLSDataset(Dataset): self._handle_download_dataset(download) self._validate_local_directory() - self.initialize() + if limited and (split == Split.TRAIN or split == Split.VALID): + self.initialize_limited() + else: + self.initialize() + + def initialize_limited(self) -> None: + """Initializes the limited supervision dataset""" + # get file handles + # get file paths + # get transcripts + # create train or validation split + + handles = set() + + train_root_path = os.path.join(self.dataset_path, self.language, "train") + + # get file handles for 9h + with open( + os.path.join(train_root_path, "limited_supervision", "9hr", "handles.txt"), + "r", + encoding="utf-8", + ) as file: + for line in file: + handles.add(line.strip()) + + # get file handles for 1h splits + for handle_path in os.listdir(os.path.join(train_root_path, "limited_supervision", "1hr")): + if handle_path not in range(0, 6): + continue + with open( + os.path.join( + train_root_path, "limited_supervision", "1hr", handle_path, "handles.txt" + ), + "r", + encoding="utf-8", + ) as file: + for line in file: + handles.add(line.strip()) + + # get file paths for handles + file_paths = [] + for handle in handles: + file_paths.append( + os.path.join( + train_root_path, + "audio", + handle.split("_")[0], + handle.split("_")[1], + handle + self.file_ext, + ) + ) + + # get transcripts for handles + transcripts = [] + with open(os.path.join(train_root_path, "transcripts.txt"), "r", encoding="utf-8") as file: + for line in file: + if line.split("\t")[0] in handles: + transcripts.append(line.strip()) + + # create train or valid split randomly with seed 42 + if self.split == Split.TRAIN: + np.random.seed(42) + indices = np.random.choice(len(file_paths), int(len(file_paths) * 0.8)) + file_paths = [file_paths[i] for i in indices] + transcripts = [transcripts[i] for i in indices] + elif self.split == Split.VALID: + np.random.seed(42) + indices = np.random.choice(len(file_paths), int(len(file_paths) * 0.2)) + file_paths = [file_paths[i] for i in indices] + transcripts = [transcripts[i] for i in indices] + + # create dataset lookup + self.dataset_lookup = [ + { + "speakerid": path.split("/")[-3], + "bookid": path.split("/")[-2], + "chapterid": path.split("/")[-1].split("_")[2].split(".")[0], + "utterance": utterance.split("\t")[1], + } + for path, utterance in zip(file_paths, transcripts, strict=False) + ] def initialize(self) -> None: - """Initializes the dataset + """Initializes the entire dataset Reads the transcripts.txt file and creates a lookup table """ @@ -189,7 +270,8 @@ class MLSDataset(Dataset): # unzip the dataset if not os.path.isdir(os.path.join(self.dataset_path, self.language)): print( - f"Unzipping the dataset at {os.path.join(self.dataset_path, self.language) + '.tar.gz'}" + f"Unzipping the dataset at \ + {os.path.join(self.dataset_path, self.language) + '.tar.gz'}" ) extract_archive( os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True @@ -236,7 +318,7 @@ class MLSDataset(Dataset): + self.file_ext, ) - waveform, sample_rate = torchaudio.load(audio_path) # type: ignore + waveform, sample_rate = torchaudio.load(audio_path) # pylint: disable=no-member # resample if necessary if sample_rate != self.spectrogram_hparams["sample_rate"]: @@ -257,7 +339,7 @@ class MLSDataset(Dataset): utterance = self.tokenizer.encode(utterance) - utterance = torch.LongTensor(utterance.ids) + utterance = torch.LongTensor(utterance.ids) # pylint: disable=no-member return Sample( waveform=waveform, @@ -311,24 +393,25 @@ if __name__ == "__main__": split = Split.TRAIN DOWNLOAD = False - dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, DOWNLOAD, None) + dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, False, DOWNLOAD, None) tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") dataset.set_tokenizer(tok) -def plot(epochs,path): - losses = list() +def plot(epochs, path): + """Plots the losses over the epochs""" + losses = list() test_losses = list() cers = list() - wers =list() - for epoch in range(1, epochs +1): + wers = list() + for epoch in range(1, epochs + 1): current_state = torch.load(path + str(epoch)) losses.append(current_state["loss"]) test_losses.append(current_state["test_loss"]) cers.append(current_state["avg_cer"]) wers.append(current_state["avg_wer"]) - + plt.plot(losses) plt.plot(test_losses) - plt.savefig("losses.svg")
\ No newline at end of file + plt.savefig("losses.svg") |