aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPherkel2023-09-11 14:49:28 +0200
committerPherkel2023-09-11 14:49:28 +0200
commit9dc3bc07424908dd7cf3f052708f506fd58b6e2c (patch)
treecd45dc9b70977530669c271c09025246ebbb9fef
parent01fae2b5e395e84db6a7e9819b6f98777c46e845 (diff)
refactor utilities (data, vis, tokenizer)
-rw-r--r--swr2_asr/utils/__init__.py0
-rw-r--r--swr2_asr/utils/data.py (renamed from swr2_asr/utils.py)206
-rw-r--r--swr2_asr/utils/decoder.py26
-rw-r--r--swr2_asr/utils/tokenizer.py (renamed from swr2_asr/tokenizer.py)0
-rw-r--r--swr2_asr/utils/visualization.py22
5 files changed, 128 insertions, 126 deletions
diff --git a/swr2_asr/utils/__init__.py b/swr2_asr/utils/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/swr2_asr/utils/__init__.py
diff --git a/swr2_asr/utils.py b/swr2_asr/utils/data.py
index a362b9e..93f4a9a 100644
--- a/swr2_asr/utils.py
+++ b/swr2_asr/utils/data.py
@@ -3,21 +3,51 @@ import os
from enum import Enum
from typing import TypedDict
-import matplotlib.pyplot as plt
import numpy as np
import torch
import torchaudio
-from tokenizers import Tokenizer
+from torch import Tensor, nn
from torch.utils.data import Dataset
-from torchaudio.datasets.utils import _extract_tar as extract_archive
+from torchaudio.datasets.utils import _extract_tar
-from swr2_asr.tokenizer import TokenizerType
+from swr2_asr.utils.tokenizer import CharTokenizer
-train_audio_transforms = torch.nn.Sequential(
- torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
- torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
- torchaudio.transforms.TimeMasking(time_mask_param=100),
-)
+
+class DataProcessing:
+ """Data processing class for the dataloader"""
+
+ def __init__(self, data_type: str, tokenizer: CharTokenizer):
+ self.data_type = data_type
+ self.tokenizer = tokenizer
+
+ if data_type == "train":
+ self.audio_transform = torch.nn.Sequential(
+ torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
+ torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
+ torchaudio.transforms.TimeMasking(time_mask_param=100),
+ )
+ elif data_type == "valid":
+ self.audio_transform = torchaudio.transforms.MelSpectrogram()
+
+ def __call__(self, data) -> tuple[Tensor, Tensor, list, list]:
+ spectrograms = []
+ labels = []
+ input_lengths = []
+ label_lengths = []
+ for waveform, _, utterance, _, _, _ in data:
+ spec = self.audio_transform(waveform).squeeze(0).transpose(0, 1)
+ spectrograms.append(spec)
+ label = torch.Tensor(self.tokenizer.encode(utterance.lower()))
+ labels.append(label)
+ input_lengths.append(spec.shape[0] // 2)
+ label_lengths.append(len(label))
+
+ spectrograms = (
+ nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3)
+ )
+ labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)
+
+ return spectrograms, labels, input_lengths, label_lengths
# create enum specifiying dataset splits
@@ -31,7 +61,7 @@ class MLSSplit(str, Enum):
class Split(str, Enum):
- """Extending the MLSSplit class to allow for a custom validatio split"""
+ """Extending the MLSSplit class to allow for a custom validation split"""
TRAIN = "train"
VALID = "valid"
@@ -43,8 +73,7 @@ def split_to_mls_split(split_name: Split) -> MLSSplit:
"""Converts the custom split to a MLSSplit"""
if split_name == Split.VALID:
return MLSSplit.TRAIN
- else:
- return split_name # type: ignore
+ return split_name # type: ignore
class Sample(TypedDict):
@@ -97,7 +126,7 @@ class MLSDataset(Dataset):
split: Split,
limited: bool,
download: bool,
- spectrogram_hparams: dict | None,
+ size: float = 0.2,
):
"""Initializes the dataset"""
self.dataset_path = dataset_path
@@ -106,22 +135,7 @@ class MLSDataset(Dataset):
self.mls_split: MLSSplit = split_to_mls_split(split) # split path on disk
self.split: Split = split # split used internally
- if spectrogram_hparams is None:
- self.spectrogram_hparams = {
- "sample_rate": 16000,
- "n_fft": 400,
- "win_length": 400,
- "hop_length": 160,
- "n_mels": 128,
- "f_min": 0,
- "f_max": 8000,
- "power": 2.0,
- }
- else:
- self.spectrogram_hparams = spectrogram_hparams
-
self.dataset_lookup = []
- self.tokenizer: type[TokenizerType]
self._handle_download_dataset(download)
self._validate_local_directory()
@@ -130,6 +144,8 @@ class MLSDataset(Dataset):
else:
self.initialize()
+ self.dataset_lookup = self.dataset_lookup[: int(len(self.dataset_lookup) * size)]
+
def initialize_limited(self) -> None:
"""Initializes the limited supervision dataset"""
# get file handles
@@ -246,10 +262,6 @@ class MLSDataset(Dataset):
for path, utterance in zip(identifier, utterances, strict=False)
]
- def set_tokenizer(self, tokenizer: type[TokenizerType]):
- """Sets the tokenizer"""
- self.tokenizer = tokenizer
-
def _handle_download_dataset(self, download: bool) -> None:
"""Download the dataset"""
if not download:
@@ -258,7 +270,9 @@ class MLSDataset(Dataset):
# zip exists:
if os.path.isfile(os.path.join(self.dataset_path, self.language) + ".tar.gz") and download:
print(f"Found dataset at {self.dataset_path}. Skipping download")
- # zip does not exist:
+ # path exists:
+ elif os.path.isdir(os.path.join(self.dataset_path, self.language)) and download:
+ return
else:
os.makedirs(self.dataset_path, exist_ok=True)
url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz"
@@ -273,9 +287,7 @@ class MLSDataset(Dataset):
f"Unzipping the dataset at \
{os.path.join(self.dataset_path, self.language) + '.tar.gz'}"
)
- extract_archive(
- os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True
- )
+ _extract_tar(os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True)
else:
print("Dataset is already unzipped, validating it now")
return
@@ -293,12 +305,29 @@ class MLSDataset(Dataset):
"""Returns the length of the dataset"""
return len(self.dataset_lookup)
- def __getitem__(self, idx: int) -> Sample:
- """One sample"""
- if self.tokenizer is None:
- raise ValueError("No tokenizer set")
+ def __getitem__(self, idx: int) -> tuple[Tensor, int, str, int, int, int]:
+ """One sample
+
+ Returns:
+ Tuple of the following items;
+
+ Tensor:
+ Waveform
+ int:
+ Sample rate
+ str:
+ Transcript
+ int:
+ Speaker ID
+ int:
+ Chapter ID
+ int:
+ Utterance ID
+ """
# get the utterance
- utterance = self.dataset_lookup[idx]["utterance"]
+ dataset_lookup_entry = self.dataset_lookup[idx]
+
+ utterance = dataset_lookup_entry["utterance"]
# get the audio file
audio_path = os.path.join(
@@ -321,70 +350,18 @@ class MLSDataset(Dataset):
waveform, sample_rate = torchaudio.load(audio_path) # pylint: disable=no-member
# resample if necessary
- if sample_rate != self.spectrogram_hparams["sample_rate"]:
- resampler = torchaudio.transforms.Resample(
- sample_rate, self.spectrogram_hparams["sample_rate"]
- )
+ if sample_rate != 16000:
+ resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(waveform)
- spec = (
- torchaudio.transforms.MelSpectrogram(**self.spectrogram_hparams)(waveform)
- .squeeze(0)
- .transpose(0, 1)
- )
-
- input_length = spec.shape[0] // 2
-
- utterance_length = len(utterance)
-
- utterance = self.tokenizer.encode(utterance)
-
- utterance = torch.LongTensor(utterance.ids) # pylint: disable=no-member
-
- return Sample(
- waveform=waveform,
- spectrogram=spec,
- input_length=input_length,
- utterance=utterance,
- utterance_length=utterance_length,
- sample_rate=self.spectrogram_hparams["sample_rate"],
- speaker_id=self.dataset_lookup[idx]["speakerid"],
- book_id=self.dataset_lookup[idx]["bookid"],
- chapter_id=self.dataset_lookup[idx]["chapterid"],
- )
-
-
-def collate_fn(samples: list[Sample]) -> dict:
- """Collate function for the dataloader
-
- pads all tensors within a batch to the same dimensions
- """
- waveforms = []
- spectrograms = []
- labels = []
- input_lengths = []
- label_lengths = []
-
- for sample in samples:
- waveforms.append(sample["waveform"].transpose(0, 1))
- spectrograms.append(sample["spectrogram"])
- labels.append(sample["utterance"])
- input_lengths.append(sample["spectrogram"].shape[0] // 2)
- label_lengths.append(len(sample["utterance"]))
-
- waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True)
- spectrograms = (
- torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3)
- )
- labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True)
-
- return {
- "waveform": waveforms,
- "spectrogram": spectrograms,
- "input_length": input_lengths,
- "utterance": labels,
- "utterance_length": label_lengths,
- }
+ return (
+ waveform,
+ sample_rate,
+ utterance,
+ dataset_lookup_entry["speakerid"],
+ dataset_lookup_entry["chapterid"],
+ idx,
+ ) # type: ignore
if __name__ == "__main__":
@@ -392,26 +369,3 @@ if __name__ == "__main__":
LANGUAGE = "mls_german_opus"
split = Split.TRAIN
DOWNLOAD = False
-
- dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, False, DOWNLOAD, None)
-
- tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
- dataset.set_tokenizer(tok)
-
-
-def plot(epochs, path):
- """Plots the losses over the epochs"""
- losses = list()
- test_losses = list()
- cers = list()
- wers = list()
- for epoch in range(1, epochs + 1):
- current_state = torch.load(path + str(epoch))
- losses.append(current_state["loss"])
- test_losses.append(current_state["test_loss"])
- cers.append(current_state["avg_cer"])
- wers.append(current_state["avg_wer"])
-
- plt.plot(losses)
- plt.plot(test_losses)
- plt.savefig("losses.svg")
diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py
new file mode 100644
index 0000000..fcddb79
--- /dev/null
+++ b/swr2_asr/utils/decoder.py
@@ -0,0 +1,26 @@
+"""Decoder for CTC-based ASR.""" ""
+import torch
+
+from swr2_asr.utils.tokenizer import CharTokenizer
+
+
+# TODO: refactor to use torch CTC decoder class
+def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True):
+ """Greedily decode a sequence."""
+ blank_label = tokenizer.get_blank_token()
+ arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
+ decodes = []
+ targets = []
+ for i, args in enumerate(arg_maxes):
+ decode = []
+ targets.append(tokenizer.decode(labels[i][: label_lengths[i]].tolist()))
+ for j, index in enumerate(args):
+ if index != blank_label:
+ if collapse_repeated and j != 0 and index == args[j - 1]:
+ continue
+ decode.append(index.item())
+ decodes.append(tokenizer.decode(decode))
+ return decodes, targets
+
+
+# TODO: add beam search decoder
diff --git a/swr2_asr/tokenizer.py b/swr2_asr/utils/tokenizer.py
index d92465a..d92465a 100644
--- a/swr2_asr/tokenizer.py
+++ b/swr2_asr/utils/tokenizer.py
diff --git a/swr2_asr/utils/visualization.py b/swr2_asr/utils/visualization.py
new file mode 100644
index 0000000..80f942a
--- /dev/null
+++ b/swr2_asr/utils/visualization.py
@@ -0,0 +1,22 @@
+"""Utilities for visualizing the training process and results."""
+
+import matplotlib.pyplot as plt
+import torch
+
+
+def plot(epochs, path):
+ """Plots the losses over the epochs"""
+ losses = list()
+ test_losses = list()
+ cers = list()
+ wers = list()
+ for epoch in range(1, epochs + 1):
+ current_state = torch.load(path + str(epoch))
+ losses.append(current_state["loss"])
+ test_losses.append(current_state["test_loss"])
+ cers.append(current_state["avg_cer"])
+ wers.append(current_state["avg_wer"])
+
+ plt.plot(losses)
+ plt.plot(test_losses)
+ plt.savefig("losses.svg")