aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/utils.py
diff options
context:
space:
mode:
authorPherkel2023-08-24 00:03:56 +0200
committerPherkel2023-08-24 00:03:56 +0200
commit403472ca4e65e8ed404e8a73fb9b3fbafe3f2a53 (patch)
treee5bfaca7d1982f1fadb1abe1da023d2020151363 /swr2_asr/utils.py
parentd65e728575e07a54cec52ccb57af3cafedaac1a2 (diff)
wip: commit before going on vacation :)
Diffstat (limited to 'swr2_asr/utils.py')
-rw-r--r--swr2_asr/utils.py207
1 files changed, 93 insertions, 114 deletions
diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py
index c4aeb0b..786fbcf 100644
--- a/swr2_asr/utils.py
+++ b/swr2_asr/utils.py
@@ -1,16 +1,21 @@
"""Class containing utils for the ASR system."""
-from dataclasses import dataclass
import os
-from AudioLoader.speech import MultilingualLibriSpeech
+from enum import Enum
+from typing import TypedDict
+
import numpy as np
import torch
import torchaudio
-from torch import nn
-from torch.utils.data import Dataset, DataLoader
-from enum import Enum
-
from tokenizers import Tokenizer
-from swr2_asr.tokenizer import CharTokenizer
+from torch.utils.data import Dataset
+
+from swr2_asr.tokenizer import CharTokenizer, TokenizerType
+
+train_audio_transforms = torch.nn.Sequential(
+ torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
+ torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
+ torchaudio.transforms.TimeMasking(time_mask_param=100),
+)
# create enum specifiying dataset splits
@@ -40,34 +45,20 @@ def split_to_mls_split(split: Split) -> MLSSplit:
return split # type: ignore
-@dataclass
-class Sample:
- """Dataclass for a sample in the dataset"""
+class Sample(TypedDict):
+ """Type for a sample in the dataset"""
waveform: torch.Tensor
spectrogram: torch.Tensor
- utterance: str
+ input_length: int
+ utterance: torch.Tensor
+ utterance_length: int
sample_rate: int
speaker_id: str
book_id: str
chapter_id: str
-def tokenizer_factory(tokenizer_path: str, tokenizer_type: str = "BPE"):
- """Factory for Tokenizer class
-
- Args:
- tokenizer_type (str, optional): Type of tokenizer to use. Defaults to "BPE".
-
- Returns:
- nn.Module: Tokenizer class
- """
- if tokenizer_type == "BPE":
- return Tokenizer.from_file(tokenizer_path)
- elif tokenizer_type == "char":
- return CharTokenizer.from_file(tokenizer_path)
-
-
class MLSDataset(Dataset):
"""Custom Dataset for reading Multilingual LibriSpeech
@@ -105,23 +96,33 @@ class MLSDataset(Dataset):
self.mls_split: MLSSplit = split_to_mls_split(split) # split path on disk
self.split: Split = split # split used internally
self.dataset_lookup = []
+ self.tokenizer: type[TokenizerType]
self._handle_download_dataset(download)
self._validate_local_directory()
- transcripts_path = os.path.join(
- dataset_path, language, self.mls_split, "transcripts.txt"
- )
+ transcripts_path = os.path.join(dataset_path, language, self.mls_split, "transcripts.txt")
with open(transcripts_path, "r", encoding="utf-8") as script_file:
# read all lines in transcripts.txt
transcripts = script_file.readlines()
# split each line into (<speakerid>_<bookid>_<chapterid>, <utterance>)
- transcripts = [line.strip().split("\t", 1) for line in transcripts]
- utterances = [utterance.strip() for _, utterance in transcripts]
- identifier = [identifier.strip() for identifier, _ in transcripts]
+ transcripts = [line.strip().split("\t", 1) for line in transcripts] # type: ignore
+ utterances = [utterance.strip() for _, utterance in transcripts] # type: ignore
+ identifier = [identifier.strip() for identifier, _ in transcripts] # type: ignore
identifier = [path.split("_") for path in identifier]
+ if self.split == Split.valid:
+ np.random.seed(42)
+ indices = np.random.choice(len(utterances), int(len(utterances) * 0.2))
+ utterances = [utterances[i] for i in indices]
+ identifier = [identifier[i] for i in indices]
+ elif self.split == Split.train:
+ np.random.seed(42)
+ indices = np.random.choice(len(utterances), int(len(utterances) * 0.8))
+ utterances = [utterances[i] for i in indices]
+ identifier = [identifier[i] for i in indices]
+
self.dataset_lookup = [
{
"speakerid": path[0],
@@ -129,27 +130,23 @@ class MLSDataset(Dataset):
"chapterid": path[2],
"utterance": utterance,
}
- for path, utterance in zip(identifier, utterances)
+ for path, utterance in zip(identifier, utterances, strict=False)
]
- # save dataset_lookup as list of dicts, where each dict contains
- # the speakerid, bookid and chapterid, as well as the utterance
- # we can then use this to map the utterance to the audio file
+ def set_tokenizer(self, tokenizer: type[TokenizerType]):
+ """Sets the tokenizer"""
+ self.tokenizer = tokenizer
+
+ self.calc_paddings()
def _handle_download_dataset(self, download: bool):
"""Download the dataset"""
- if (
- not os.path.exists(os.path.join(self.dataset_path, self.language))
- and download
- ):
+ if not os.path.exists(os.path.join(self.dataset_path, self.language)) and download:
os.makedirs(self.dataset_path)
url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz"
torch.hub.download_url_to_file(url, self.dataset_path)
- elif (
- not os.path.exists(os.path.join(self.dataset_path, self.language))
- and not download
- ):
+ elif not os.path.exists(os.path.join(self.dataset_path, self.language)) and not download:
raise ValueError("Dataset not found. Set download to True to download it")
def _validate_local_directory(self):
@@ -158,18 +155,32 @@ class MLSDataset(Dataset):
raise ValueError("Dataset path does not exist")
if not os.path.exists(os.path.join(self.dataset_path, self.language)):
raise ValueError("Language not found in dataset")
- if not os.path.exists(
- os.path.join(self.dataset_path, self.language, self.mls_split)
- ):
+ if not os.path.exists(os.path.join(self.dataset_path, self.language, self.mls_split)):
raise ValueError("Split not found in dataset")
- # checks if the transcripts.txt file exists
- if not os.path.exists(
- os.path.join(dataset_path, language, split, "transcripts.txt")
- ):
- raise ValueError("transcripts.txt not found in dataset")
-
- def __get_len__(self):
+ def calc_paddings(self):
+ """Sets the maximum length of the spectrogram"""
+ # check if dataset has been loaded and tokenizer has been set
+ if not self.dataset_lookup:
+ raise ValueError("Dataset not loaded")
+ if not self.tokenizer:
+ raise ValueError("Tokenizer not set")
+
+ max_spec_length = 0
+ max_uterance_length = 0
+ for sample in self.dataset_lookup:
+ spec_length = sample["spectrogram"].shape[0]
+ if spec_length > max_spec_length:
+ max_spec_length = spec_length
+
+ utterance_length = sample["utterance"].shape[0]
+ if utterance_length > max_uterance_length:
+ max_uterance_length = utterance_length
+
+ self.max_spec_length = max_spec_length
+ self.max_utterance_length = max_uterance_length
+
+ def __len__(self):
"""Returns the length of the dataset"""
return len(self.dataset_lookup)
@@ -197,13 +208,32 @@ class MLSDataset(Dataset):
)
waveform, sample_rate = torchaudio.load(audio_path) # type: ignore
+ # TODO: figure out if we have to resample or not
+ # TODO: pad correctly (manually)
+ spec = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128)(waveform).squeeze(0).transpose(0, 1)
+ print(f"spec.shape: {spec.shape}")
+ input_length = spec.shape[0] // 2
+ spec = (
+ torch.nn.functional.pad(spec, pad=(0, self.max_spec_length), mode="constant", value=0)
+ .unsqueeze(1)
+ .transpose(2, 3)
+ )
+
+ utterance_length = len(utterance)
+ self.tokenizer.enable_padding()
+ utterance = self.tokenizer.encode(
+ utterance,
+ ).ids
+
+ utterance = torch.Tensor(utterance)
return Sample(
+ # TODO: add flag to only return spectrogram or waveform or both
waveform=waveform,
- spectrogram=torchaudio.transforms.MelSpectrogram(
- sample_rate=16000, n_mels=128
- )(waveform),
+ spectrogram=spec,
+ input_length=input_length,
utterance=utterance,
+ utterance_length=utterance_length,
sample_rate=sample_rate,
speaker_id=self.dataset_lookup[idx]["speakerid"],
book_id=self.dataset_lookup[idx]["bookid"],
@@ -218,62 +248,6 @@ class MLSDataset(Dataset):
torch.hub.download_url_to_file(url, dataset_path)
-class DataProcessor:
- """Factory for DataProcessingclass
-
- Transforms the dataset into spectrograms and labels, as well as a tokenizer
- """
-
- def __init__(
- self,
- dataset: MultilingualLibriSpeech,
- tokenizer_path: str,
- data_type: str = "train",
- tokenizer_type: str = "BPE",
- ):
- self.dataset = dataset
- self.data_type = data_type
-
- self.train_audio_transforms = nn.Sequential(
- torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
- torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
- torchaudio.transforms.TimeMasking(time_mask_param=100),
- )
-
- self.valid_audio_transforms = torchaudio.transforms.MelSpectrogram()
- self.tokenizer = tokenizer_factory(
- tokenizer_path=tokenizer_path, tokenizer_type=tokenizer_type
- )
-
- def __call__(self) -> tuple[np.ndarray, np.ndarray, int, int]:
- """Returns spectrograms, labels and their lenghts"""
- for sample in self.dataset:
- if self.data_type == "train":
- spec = (
- self.train_audio_transforms(sample["waveform"])
- .squeeze(0)
- .transpose(0, 1)
- )
- elif self.data_type == "valid":
- spec = (
- self.valid_audio_transforms(sample["waveform"])
- .squeeze(0)
- .transpose(0, 1)
- )
- else:
- raise ValueError("data_type should be train or valid")
- label = torch.Tensor(text_transform.encode(sample["utterance"]).ids)
-
- spectrograms = (
- nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
- .unsqueeze(1)
- .transpose(2, 3)
- )
- labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)
-
- yield spec, label, spec.shape[0] // 2, len(labels)
-
-
if __name__ == "__main__":
dataset_path = "/Volumes/pherkel/SWR2-ASR"
language = "mls_german_opus"
@@ -281,4 +255,9 @@ if __name__ == "__main__":
download = False
dataset = MLSDataset(dataset_path, language, split, download)
- print(dataset[0])
+
+ tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
+ dataset.set_tokenizer(tok)
+ dataset.calc_paddings()
+
+ print(dataset[41]["spectrogram"].shape)