aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/utils.py')
-rw-r--r--swr2_asr/utils.py85
1 files changed, 50 insertions, 35 deletions
diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py
index 3b9b3ca..efecb56 100644
--- a/swr2_asr/utils.py
+++ b/swr2_asr/utils.py
@@ -8,7 +8,6 @@ import torch
import torchaudio
from tokenizers import Tokenizer
from torch.utils.data import Dataset
-from tqdm import tqdm
from swr2_asr.tokenizer import TokenizerType
@@ -24,26 +23,26 @@ class MLSSplit(str, Enum):
"""Enum specifying dataset as they are defined in the
Multilingual LibriSpeech dataset"""
- train = "train"
- test = "test"
- dev = "dev"
+ TRAIN = "train"
+ TEST = "test"
+ DEV = "dev"
class Split(str, Enum):
"""Extending the MLSSplit class to allow for a custom validatio split"""
- train = "train"
- valid = "valid"
- test = "test"
- dev = "dev"
+ TRAIN = "train"
+ VALID = "valid"
+ TEST = "test"
+ DEV = "dev"
-def split_to_mls_split(split: Split) -> MLSSplit:
+def split_to_mls_split(split_name: Split) -> MLSSplit:
"""Converts the custom split to a MLSSplit"""
- if split == Split.valid:
- return MLSSplit.train
+ if split_name == Split.VALID:
+ return MLSSplit.TRAIN
else:
- return split # type: ignore
+ return split_name # type: ignore
class Sample(TypedDict):
@@ -89,14 +88,21 @@ class MLSDataset(Dataset):
<speakerid>_<bookid>_<chapterid> <utterance>
"""
- def __init__(self, dataset_path: str, language: str, split: Split, download: bool, spectrogram_hparams: dict | None):
+ def __init__(
+ self,
+ dataset_path: str,
+ language: str,
+ split: Split,
+ download: bool,
+ spectrogram_hparams: dict | None,
+ ):
"""Initializes the dataset"""
self.dataset_path = dataset_path
self.language = language
self.file_ext = ".opus" if "opus" in language else ".flac"
self.mls_split: MLSSplit = split_to_mls_split(split) # split path on disk
self.split: Split = split # split used internally
-
+
if spectrogram_hparams is None:
self.spectrogram_hparams = {
"sample_rate": 16000,
@@ -110,7 +116,7 @@ class MLSDataset(Dataset):
}
else:
self.spectrogram_hparams = spectrogram_hparams
-
+
self.dataset_lookup = []
self.tokenizer: type[TokenizerType]
@@ -118,13 +124,14 @@ class MLSDataset(Dataset):
self._validate_local_directory()
self.initialize()
-
def initialize(self) -> None:
"""Initializes the dataset
-
+
Reads the transcripts.txt file and creates a lookup table
"""
- transcripts_path = os.path.join(self.dataset_path, self.language, self.mls_split, "transcripts.txt")
+ transcripts_path = os.path.join(
+ self.dataset_path, self.language, self.mls_split, "transcripts.txt"
+ )
with open(transcripts_path, "r", encoding="utf-8") as script_file:
# read all lines in transcripts.txt
@@ -135,12 +142,12 @@ class MLSDataset(Dataset):
identifier = [identifier.strip() for identifier, _ in transcripts] # type: ignore
identifier = [path.split("_") for path in identifier]
- if self.split == Split.valid:
+ if self.split == Split.VALID:
np.random.seed(42)
indices = np.random.choice(len(utterances), int(len(utterances) * 0.2))
utterances = [utterances[i] for i in indices]
identifier = [identifier[i] for i in indices]
- elif self.split == Split.train:
+ elif self.split == Split.TRAIN:
np.random.seed(42)
indices = np.random.choice(len(utterances), int(len(utterances) * 0.8))
utterances = [utterances[i] for i in indices]
@@ -212,13 +219,19 @@ class MLSDataset(Dataset):
# resample if necessary
if sample_rate != self.spectrogram_hparams["sample_rate"]:
- resampler = torchaudio.transforms.Resample(sample_rate, self.spectrogram_hparams["sample_rate"])
+ resampler = torchaudio.transforms.Resample(
+ sample_rate, self.spectrogram_hparams["sample_rate"]
+ )
waveform = resampler(waveform)
- spec = torchaudio.transforms.MelSpectrogram(**self.spectrogram_hparams)(waveform).squeeze(0).transpose(0, 1)
+ spec = (
+ torchaudio.transforms.MelSpectrogram(**self.spectrogram_hparams)(waveform)
+ .squeeze(0)
+ .transpose(0, 1)
+ )
input_length = spec.shape[0] // 2
-
+
utterance_length = len(utterance)
utterance = self.tokenizer.encode(utterance)
@@ -236,11 +249,11 @@ class MLSDataset(Dataset):
book_id=self.dataset_lookup[idx]["bookid"],
chapter_id=self.dataset_lookup[idx]["chapterid"],
)
-
+
def collate_fn(samples: list[Sample]) -> dict:
"""Collate function for the dataloader
-
+
pads all tensors within a batch to the same dimensions
"""
waveforms = []
@@ -248,18 +261,20 @@ def collate_fn(samples: list[Sample]) -> dict:
labels = []
input_lengths = []
label_lengths = []
-
+
for sample in samples:
waveforms.append(sample["waveform"].transpose(0, 1))
spectrograms.append(sample["spectrogram"])
labels.append(sample["utterance"])
input_lengths.append(sample["spectrogram"].shape[0] // 2)
label_lengths.append(len(sample["utterance"]))
-
+
waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True)
- spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2,3)
+ spectrograms = (
+ torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3)
+ )
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True)
-
+
return {
"waveform": waveforms,
"spectrogram": spectrograms,
@@ -267,15 +282,15 @@ def collate_fn(samples: list[Sample]) -> dict:
"utterance": labels,
"utterance_length": label_lengths,
}
-
+
if __name__ == "__main__":
- dataset_path = "/Volumes/pherkel/SWR2-ASR"
- language = "mls_german_opus"
- split = Split.train
- download = False
+ DATASET_PATH = "/Volumes/pherkel/SWR2-ASR"
+ LANGUAGE = "mls_german_opus"
+ split = Split.TRAIN
+ DOWNLOAD = False
- dataset = MLSDataset(dataset_path, language, split, download, None)
+ dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, DOWNLOAD, None)
tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
dataset.set_tokenizer(tok)