diff options
author | Pherkel | 2023-09-05 22:26:27 +0200 |
---|---|---|
committer | Pherkel | 2023-09-05 22:26:27 +0200 |
commit | 4bd118da7054f29e70e731ebcef7ad0310742235 (patch) | |
tree | dbf7c6a19f6ec2423cab994e645a42be4f77fe3a | |
parent | 46b23fd90f5ef9c3126ee66473b012fa715da008 (diff) |
add limited supervision training (10hr)
-rw-r--r-- | swr2_asr/train.py | 25 | ||||
-rw-r--r-- | swr2_asr/utils.py | 109 |
2 files changed, 108 insertions, 26 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 6f3bc6c..40626e7 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -15,8 +15,6 @@ from swr2_asr.utils import MLSDataset, Split, collate_fn from .loss_scores import cer, wer -# TODO: improve naming of functions - class HParams(TypedDict): """Type for the hyperparameters of the model.""" @@ -157,10 +155,10 @@ def run( # load dataset train_dataset = MLSDataset( - dataset_path, language, Split.TRAIN, download=True, spectrogram_hparams=None + dataset_path, language, Split.TRAIN, download=True, spectrogram_hparams=None, limited=True ) valid_dataset = MLSDataset( - dataset_path, language, Split.VALID, download=True, spectrogram_hparams=None + dataset_path, language, Split.VALID, download=True, spectrogram_hparams=None, limited=True ) # load tokenizer (bpe by default): @@ -171,7 +169,6 @@ def run( dataset_path=dataset_path, language=language, split="all", - download=False, out_path="data/tokenizers/char_tokenizer_german.json", ) @@ -211,7 +208,7 @@ def run( # enable flag to find the most compatible algorithms in advance if use_cuda: - torch.backends.cudnn.benchmark = True + torch.backends.cudnn.benchmark = True # pylance: disable=no-member model = SpeechRecognitionModel( hparams["n_cnn_layers"], @@ -253,7 +250,7 @@ def run( iter_meter, ) - test_loss,avg_cer,avg_wer = test( + test_loss, avg_cer, avg_wer = test( model=model, device=device, test_loader=valid_loader, @@ -262,12 +259,14 @@ def run( ) print("saving epoch", str(epoch)) torch.save( - {"epoch": epoch, - "model_state_dict": model.state_dict(), - "loss": loss, - "test_loss": test_loss, - "avg_cer": avg_cer, - "avg_wer": avg_wer}, + { + "epoch": epoch, + "model_state_dict": model.state_dict(), + "loss": loss, + "test_loss": test_loss, + "avg_cer": avg_cer, + "avg_wer": avg_wer, + }, path + str(epoch), ) diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py index 87d4f82..a362b9e 100644 --- a/swr2_asr/utils.py +++ b/swr2_asr/utils.py @@ -3,8 +3,8 @@ import os from enum import Enum from typing import TypedDict -import numpy as np import matplotlib.pyplot as plt +import numpy as np import torch import torchaudio from tokenizers import Tokenizer @@ -95,6 +95,7 @@ class MLSDataset(Dataset): dataset_path: str, language: str, split: Split, + limited: bool, download: bool, spectrogram_hparams: dict | None, ): @@ -124,10 +125,90 @@ class MLSDataset(Dataset): self._handle_download_dataset(download) self._validate_local_directory() - self.initialize() + if limited and (split == Split.TRAIN or split == Split.VALID): + self.initialize_limited() + else: + self.initialize() + + def initialize_limited(self) -> None: + """Initializes the limited supervision dataset""" + # get file handles + # get file paths + # get transcripts + # create train or validation split + + handles = set() + + train_root_path = os.path.join(self.dataset_path, self.language, "train") + + # get file handles for 9h + with open( + os.path.join(train_root_path, "limited_supervision", "9hr", "handles.txt"), + "r", + encoding="utf-8", + ) as file: + for line in file: + handles.add(line.strip()) + + # get file handles for 1h splits + for handle_path in os.listdir(os.path.join(train_root_path, "limited_supervision", "1hr")): + if handle_path not in range(0, 6): + continue + with open( + os.path.join( + train_root_path, "limited_supervision", "1hr", handle_path, "handles.txt" + ), + "r", + encoding="utf-8", + ) as file: + for line in file: + handles.add(line.strip()) + + # get file paths for handles + file_paths = [] + for handle in handles: + file_paths.append( + os.path.join( + train_root_path, + "audio", + handle.split("_")[0], + handle.split("_")[1], + handle + self.file_ext, + ) + ) + + # get transcripts for handles + transcripts = [] + with open(os.path.join(train_root_path, "transcripts.txt"), "r", encoding="utf-8") as file: + for line in file: + if line.split("\t")[0] in handles: + transcripts.append(line.strip()) + + # create train or valid split randomly with seed 42 + if self.split == Split.TRAIN: + np.random.seed(42) + indices = np.random.choice(len(file_paths), int(len(file_paths) * 0.8)) + file_paths = [file_paths[i] for i in indices] + transcripts = [transcripts[i] for i in indices] + elif self.split == Split.VALID: + np.random.seed(42) + indices = np.random.choice(len(file_paths), int(len(file_paths) * 0.2)) + file_paths = [file_paths[i] for i in indices] + transcripts = [transcripts[i] for i in indices] + + # create dataset lookup + self.dataset_lookup = [ + { + "speakerid": path.split("/")[-3], + "bookid": path.split("/")[-2], + "chapterid": path.split("/")[-1].split("_")[2].split(".")[0], + "utterance": utterance.split("\t")[1], + } + for path, utterance in zip(file_paths, transcripts, strict=False) + ] def initialize(self) -> None: - """Initializes the dataset + """Initializes the entire dataset Reads the transcripts.txt file and creates a lookup table """ @@ -189,7 +270,8 @@ class MLSDataset(Dataset): # unzip the dataset if not os.path.isdir(os.path.join(self.dataset_path, self.language)): print( - f"Unzipping the dataset at {os.path.join(self.dataset_path, self.language) + '.tar.gz'}" + f"Unzipping the dataset at \ + {os.path.join(self.dataset_path, self.language) + '.tar.gz'}" ) extract_archive( os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True @@ -236,7 +318,7 @@ class MLSDataset(Dataset): + self.file_ext, ) - waveform, sample_rate = torchaudio.load(audio_path) # type: ignore + waveform, sample_rate = torchaudio.load(audio_path) # pylint: disable=no-member # resample if necessary if sample_rate != self.spectrogram_hparams["sample_rate"]: @@ -257,7 +339,7 @@ class MLSDataset(Dataset): utterance = self.tokenizer.encode(utterance) - utterance = torch.LongTensor(utterance.ids) + utterance = torch.LongTensor(utterance.ids) # pylint: disable=no-member return Sample( waveform=waveform, @@ -311,24 +393,25 @@ if __name__ == "__main__": split = Split.TRAIN DOWNLOAD = False - dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, DOWNLOAD, None) + dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, False, DOWNLOAD, None) tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") dataset.set_tokenizer(tok) -def plot(epochs,path): - losses = list() +def plot(epochs, path): + """Plots the losses over the epochs""" + losses = list() test_losses = list() cers = list() - wers =list() - for epoch in range(1, epochs +1): + wers = list() + for epoch in range(1, epochs + 1): current_state = torch.load(path + str(epoch)) losses.append(current_state["loss"]) test_losses.append(current_state["test_loss"]) cers.append(current_state["avg_cer"]) wers.append(current_state["avg_wer"]) - + plt.plot(losses) plt.plot(test_losses) - plt.savefig("losses.svg")
\ No newline at end of file + plt.savefig("losses.svg") |