aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPherkel2023-09-05 22:26:27 +0200
committerPherkel2023-09-05 22:26:27 +0200
commit4bd118da7054f29e70e731ebcef7ad0310742235 (patch)
treedbf7c6a19f6ec2423cab994e645a42be4f77fe3a
parent46b23fd90f5ef9c3126ee66473b012fa715da008 (diff)
add limited supervision training (10hr)
-rw-r--r--swr2_asr/train.py25
-rw-r--r--swr2_asr/utils.py109
2 files changed, 108 insertions, 26 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 6f3bc6c..40626e7 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -15,8 +15,6 @@ from swr2_asr.utils import MLSDataset, Split, collate_fn
from .loss_scores import cer, wer
-# TODO: improve naming of functions
-
class HParams(TypedDict):
"""Type for the hyperparameters of the model."""
@@ -157,10 +155,10 @@ def run(
# load dataset
train_dataset = MLSDataset(
- dataset_path, language, Split.TRAIN, download=True, spectrogram_hparams=None
+ dataset_path, language, Split.TRAIN, download=True, spectrogram_hparams=None, limited=True
)
valid_dataset = MLSDataset(
- dataset_path, language, Split.VALID, download=True, spectrogram_hparams=None
+ dataset_path, language, Split.VALID, download=True, spectrogram_hparams=None, limited=True
)
# load tokenizer (bpe by default):
@@ -171,7 +169,6 @@ def run(
dataset_path=dataset_path,
language=language,
split="all",
- download=False,
out_path="data/tokenizers/char_tokenizer_german.json",
)
@@ -211,7 +208,7 @@ def run(
# enable flag to find the most compatible algorithms in advance
if use_cuda:
- torch.backends.cudnn.benchmark = True
+ torch.backends.cudnn.benchmark = True # pylance: disable=no-member
model = SpeechRecognitionModel(
hparams["n_cnn_layers"],
@@ -253,7 +250,7 @@ def run(
iter_meter,
)
- test_loss,avg_cer,avg_wer = test(
+ test_loss, avg_cer, avg_wer = test(
model=model,
device=device,
test_loader=valid_loader,
@@ -262,12 +259,14 @@ def run(
)
print("saving epoch", str(epoch))
torch.save(
- {"epoch": epoch,
- "model_state_dict": model.state_dict(),
- "loss": loss,
- "test_loss": test_loss,
- "avg_cer": avg_cer,
- "avg_wer": avg_wer},
+ {
+ "epoch": epoch,
+ "model_state_dict": model.state_dict(),
+ "loss": loss,
+ "test_loss": test_loss,
+ "avg_cer": avg_cer,
+ "avg_wer": avg_wer,
+ },
path + str(epoch),
)
diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py
index 87d4f82..a362b9e 100644
--- a/swr2_asr/utils.py
+++ b/swr2_asr/utils.py
@@ -3,8 +3,8 @@ import os
from enum import Enum
from typing import TypedDict
-import numpy as np
import matplotlib.pyplot as plt
+import numpy as np
import torch
import torchaudio
from tokenizers import Tokenizer
@@ -95,6 +95,7 @@ class MLSDataset(Dataset):
dataset_path: str,
language: str,
split: Split,
+ limited: bool,
download: bool,
spectrogram_hparams: dict | None,
):
@@ -124,10 +125,90 @@ class MLSDataset(Dataset):
self._handle_download_dataset(download)
self._validate_local_directory()
- self.initialize()
+ if limited and (split == Split.TRAIN or split == Split.VALID):
+ self.initialize_limited()
+ else:
+ self.initialize()
+
+ def initialize_limited(self) -> None:
+ """Initializes the limited supervision dataset"""
+ # get file handles
+ # get file paths
+ # get transcripts
+ # create train or validation split
+
+ handles = set()
+
+ train_root_path = os.path.join(self.dataset_path, self.language, "train")
+
+ # get file handles for 9h
+ with open(
+ os.path.join(train_root_path, "limited_supervision", "9hr", "handles.txt"),
+ "r",
+ encoding="utf-8",
+ ) as file:
+ for line in file:
+ handles.add(line.strip())
+
+ # get file handles for 1h splits
+ for handle_path in os.listdir(os.path.join(train_root_path, "limited_supervision", "1hr")):
+ if handle_path not in range(0, 6):
+ continue
+ with open(
+ os.path.join(
+ train_root_path, "limited_supervision", "1hr", handle_path, "handles.txt"
+ ),
+ "r",
+ encoding="utf-8",
+ ) as file:
+ for line in file:
+ handles.add(line.strip())
+
+ # get file paths for handles
+ file_paths = []
+ for handle in handles:
+ file_paths.append(
+ os.path.join(
+ train_root_path,
+ "audio",
+ handle.split("_")[0],
+ handle.split("_")[1],
+ handle + self.file_ext,
+ )
+ )
+
+ # get transcripts for handles
+ transcripts = []
+ with open(os.path.join(train_root_path, "transcripts.txt"), "r", encoding="utf-8") as file:
+ for line in file:
+ if line.split("\t")[0] in handles:
+ transcripts.append(line.strip())
+
+ # create train or valid split randomly with seed 42
+ if self.split == Split.TRAIN:
+ np.random.seed(42)
+ indices = np.random.choice(len(file_paths), int(len(file_paths) * 0.8))
+ file_paths = [file_paths[i] for i in indices]
+ transcripts = [transcripts[i] for i in indices]
+ elif self.split == Split.VALID:
+ np.random.seed(42)
+ indices = np.random.choice(len(file_paths), int(len(file_paths) * 0.2))
+ file_paths = [file_paths[i] for i in indices]
+ transcripts = [transcripts[i] for i in indices]
+
+ # create dataset lookup
+ self.dataset_lookup = [
+ {
+ "speakerid": path.split("/")[-3],
+ "bookid": path.split("/")[-2],
+ "chapterid": path.split("/")[-1].split("_")[2].split(".")[0],
+ "utterance": utterance.split("\t")[1],
+ }
+ for path, utterance in zip(file_paths, transcripts, strict=False)
+ ]
def initialize(self) -> None:
- """Initializes the dataset
+ """Initializes the entire dataset
Reads the transcripts.txt file and creates a lookup table
"""
@@ -189,7 +270,8 @@ class MLSDataset(Dataset):
# unzip the dataset
if not os.path.isdir(os.path.join(self.dataset_path, self.language)):
print(
- f"Unzipping the dataset at {os.path.join(self.dataset_path, self.language) + '.tar.gz'}"
+ f"Unzipping the dataset at \
+ {os.path.join(self.dataset_path, self.language) + '.tar.gz'}"
)
extract_archive(
os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True
@@ -236,7 +318,7 @@ class MLSDataset(Dataset):
+ self.file_ext,
)
- waveform, sample_rate = torchaudio.load(audio_path) # type: ignore
+ waveform, sample_rate = torchaudio.load(audio_path) # pylint: disable=no-member
# resample if necessary
if sample_rate != self.spectrogram_hparams["sample_rate"]:
@@ -257,7 +339,7 @@ class MLSDataset(Dataset):
utterance = self.tokenizer.encode(utterance)
- utterance = torch.LongTensor(utterance.ids)
+ utterance = torch.LongTensor(utterance.ids) # pylint: disable=no-member
return Sample(
waveform=waveform,
@@ -311,24 +393,25 @@ if __name__ == "__main__":
split = Split.TRAIN
DOWNLOAD = False
- dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, DOWNLOAD, None)
+ dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, False, DOWNLOAD, None)
tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
dataset.set_tokenizer(tok)
-def plot(epochs,path):
- losses = list()
+def plot(epochs, path):
+ """Plots the losses over the epochs"""
+ losses = list()
test_losses = list()
cers = list()
- wers =list()
- for epoch in range(1, epochs +1):
+ wers = list()
+ for epoch in range(1, epochs + 1):
current_state = torch.load(path + str(epoch))
losses.append(current_state["loss"])
test_losses.append(current_state["test_loss"])
cers.append(current_state["avg_cer"])
wers.append(current_state["avg_wer"])
-
+
plt.plot(losses)
plt.plot(test_losses)
- plt.savefig("losses.svg") \ No newline at end of file
+ plt.savefig("losses.svg")