aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/utils.py
diff options
context:
space:
mode:
authorPherkel2023-09-04 22:55:27 +0200
committerGitHub2023-09-04 22:55:27 +0200
commit93e49e708fa59613406249069d31c2f6c8f2d2ab (patch)
tree9f413c226283db990116c9559ffffb9124b911d8 /swr2_asr/utils.py
parent14ceeb5ad36beea2f05214aa26260cdd1d86590b (diff)
parent0d70a19e1fea6eda3f7b16ad0084591613f2de72 (diff)
Merge pull request #27 from Algo-Boys/refactor_modularize
Refactor modularize
Diffstat (limited to 'swr2_asr/utils.py')
-rw-r--r--swr2_asr/utils.py316
1 files changed, 316 insertions, 0 deletions
diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py
new file mode 100644
index 0000000..8a950ab
--- /dev/null
+++ b/swr2_asr/utils.py
@@ -0,0 +1,316 @@
+"""Class containing utils for the ASR system."""
+import os
+from enum import Enum
+from typing import TypedDict
+
+import numpy as np
+import torch
+import torchaudio
+from tokenizers import Tokenizer
+from torch.utils.data import Dataset
+from torchaudio.datasets.utils import _extract_tar as extract_archive
+
+from swr2_asr.tokenizer import TokenizerType
+
+train_audio_transforms = torch.nn.Sequential(
+ torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
+ torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
+ torchaudio.transforms.TimeMasking(time_mask_param=100),
+)
+
+
+# create enum specifiying dataset splits
+class MLSSplit(str, Enum):
+ """Enum specifying dataset as they are defined in the
+ Multilingual LibriSpeech dataset"""
+
+ TRAIN = "train"
+ TEST = "test"
+ DEV = "dev"
+
+
+class Split(str, Enum):
+ """Extending the MLSSplit class to allow for a custom validatio split"""
+
+ TRAIN = "train"
+ VALID = "valid"
+ TEST = "test"
+ DEV = "dev"
+
+
+def split_to_mls_split(split_name: Split) -> MLSSplit:
+ """Converts the custom split to a MLSSplit"""
+ if split_name == Split.VALID:
+ return MLSSplit.TRAIN
+ else:
+ return split_name # type: ignore
+
+
+class Sample(TypedDict):
+ """Type for a sample in the dataset"""
+
+ waveform: torch.Tensor
+ spectrogram: torch.Tensor
+ input_length: int
+ utterance: torch.Tensor
+ utterance_length: int
+ sample_rate: int
+ speaker_id: str
+ book_id: str
+ chapter_id: str
+
+
+class MLSDataset(Dataset):
+ """Custom Dataset for reading Multilingual LibriSpeech
+
+ Attributes:
+ dataset_path (str):
+ path to the dataset
+ language (str):
+ language of the dataset
+ split (Split):
+ split of the dataset
+ mls_split (MLSSplit):
+ split of the dataset as defined in the Multilingual LibriSpeech dataset
+ dataset_lookup (list):
+ list of dicts containing the speakerid, bookid, chapterid and utterance
+
+ directory structure:
+ <dataset_path>
+ ├── <language>
+ │ ├── train
+ │ │ ├── transcripts.txt
+ │ │ └── audio
+ │ │ └── <speakerid>
+ │ │ └── <bookid>
+ │ │ └── <speakerid>_<bookid>_<chapterid>.opus / .flac
+
+ each line in transcripts.txt has the following format:
+ <speakerid>_<bookid>_<chapterid> <utterance>
+ """
+
+ def __init__(
+ self,
+ dataset_path: str,
+ language: str,
+ split: Split,
+ download: bool,
+ spectrogram_hparams: dict | None,
+ ):
+ """Initializes the dataset"""
+ self.dataset_path = dataset_path
+ self.language = language
+ self.file_ext = ".opus" if "opus" in language else ".flac"
+ self.mls_split: MLSSplit = split_to_mls_split(split) # split path on disk
+ self.split: Split = split # split used internally
+
+ if spectrogram_hparams is None:
+ self.spectrogram_hparams = {
+ "sample_rate": 16000,
+ "n_fft": 400,
+ "win_length": 400,
+ "hop_length": 160,
+ "n_mels": 128,
+ "f_min": 0,
+ "f_max": 8000,
+ "power": 2.0,
+ }
+ else:
+ self.spectrogram_hparams = spectrogram_hparams
+
+ self.dataset_lookup = []
+ self.tokenizer: type[TokenizerType]
+
+ self._handle_download_dataset(download)
+ self._validate_local_directory()
+ self.initialize()
+
+ def initialize(self) -> None:
+ """Initializes the dataset
+
+ Reads the transcripts.txt file and creates a lookup table
+ """
+ transcripts_path = os.path.join(
+ self.dataset_path, self.language, self.mls_split, "transcripts.txt"
+ )
+
+ with open(transcripts_path, "r", encoding="utf-8") as script_file:
+ # read all lines in transcripts.txt
+ transcripts = script_file.readlines()
+ # split each line into (<speakerid>_<bookid>_<chapterid>, <utterance>)
+ transcripts = [line.strip().split("\t", 1) for line in transcripts] # type: ignore
+ utterances = [utterance.strip() for _, utterance in transcripts] # type: ignore
+ identifier = [identifier.strip() for identifier, _ in transcripts] # type: ignore
+ identifier = [path.split("_") for path in identifier]
+
+ if self.split == Split.VALID:
+ np.random.seed(42)
+ indices = np.random.choice(len(utterances), int(len(utterances) * 0.2))
+ utterances = [utterances[i] for i in indices]
+ identifier = [identifier[i] for i in indices]
+ elif self.split == Split.TRAIN:
+ np.random.seed(42)
+ indices = np.random.choice(len(utterances), int(len(utterances) * 0.8))
+ utterances = [utterances[i] for i in indices]
+ identifier = [identifier[i] for i in indices]
+
+ self.dataset_lookup = [
+ {
+ "speakerid": path[0],
+ "bookid": path[1],
+ "chapterid": path[2],
+ "utterance": utterance,
+ }
+ for path, utterance in zip(identifier, utterances, strict=False)
+ ]
+
+ def set_tokenizer(self, tokenizer: type[TokenizerType]):
+ """Sets the tokenizer"""
+ self.tokenizer = tokenizer
+
+ def _handle_download_dataset(self, download: bool) -> None:
+ """Download the dataset"""
+ if not download:
+ print("Download flag not set, skipping download")
+ return
+ # zip exists:
+ if os.path.isfile(os.path.join(self.dataset_path, self.language) + ".tar.gz") and download:
+ print(f"Found dataset at {self.dataset_path}. Skipping download")
+ # zip does not exist:
+ else:
+ os.makedirs(self.dataset_path, exist_ok=True)
+ url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz"
+
+ torch.hub.download_url_to_file(
+ url, os.path.join(self.dataset_path, self.language) + ".tar.gz"
+ )
+
+ # unzip the dataset
+ if not os.path.isdir(os.path.join(self.dataset_path, self.language)):
+ print(
+ f"Unzipping the dataset at {os.path.join(self.dataset_path, self.language) + '.tar.gz'}"
+ )
+ extract_archive(
+ os.path.join(self.dataset_path, self.language) + ".tar.gz", overwrite=True
+ )
+ else:
+ print("Dataset is already unzipped, validating it now")
+ return
+
+ def _validate_local_directory(self):
+ # check if dataset_path exists
+ if not os.path.exists(self.dataset_path):
+ raise ValueError("Dataset path does not exist")
+ if not os.path.exists(os.path.join(self.dataset_path, self.language)):
+ raise ValueError("Language not downloaded!")
+ if not os.path.exists(os.path.join(self.dataset_path, self.language, self.mls_split)):
+ raise ValueError("Split not found in dataset")
+
+ def __len__(self):
+ """Returns the length of the dataset"""
+ return len(self.dataset_lookup)
+
+ def __getitem__(self, idx: int) -> Sample:
+ """One sample"""
+ if self.tokenizer is None:
+ raise ValueError("No tokenizer set")
+ # get the utterance
+ utterance = self.dataset_lookup[idx]["utterance"]
+
+ # get the audio file
+ audio_path = os.path.join(
+ self.dataset_path,
+ self.language,
+ self.mls_split,
+ "audio",
+ self.dataset_lookup[idx]["speakerid"],
+ self.dataset_lookup[idx]["bookid"],
+ "_".join(
+ [
+ self.dataset_lookup[idx]["speakerid"],
+ self.dataset_lookup[idx]["bookid"],
+ self.dataset_lookup[idx]["chapterid"],
+ ]
+ )
+ + self.file_ext,
+ )
+
+ waveform, sample_rate = torchaudio.load(audio_path) # type: ignore
+
+ # resample if necessary
+ if sample_rate != self.spectrogram_hparams["sample_rate"]:
+ resampler = torchaudio.transforms.Resample(
+ sample_rate, self.spectrogram_hparams["sample_rate"]
+ )
+ waveform = resampler(waveform)
+
+ spec = (
+ torchaudio.transforms.MelSpectrogram(**self.spectrogram_hparams)(waveform)
+ .squeeze(0)
+ .transpose(0, 1)
+ )
+
+ input_length = spec.shape[0] // 2
+
+ utterance_length = len(utterance)
+
+ utterance = self.tokenizer.encode(utterance)
+
+ utterance = torch.LongTensor(utterance.ids)
+
+ return Sample(
+ waveform=waveform,
+ spectrogram=spec,
+ input_length=input_length,
+ utterance=utterance,
+ utterance_length=utterance_length,
+ sample_rate=self.spectrogram_hparams["sample_rate"],
+ speaker_id=self.dataset_lookup[idx]["speakerid"],
+ book_id=self.dataset_lookup[idx]["bookid"],
+ chapter_id=self.dataset_lookup[idx]["chapterid"],
+ )
+
+
+def collate_fn(samples: list[Sample]) -> dict:
+ """Collate function for the dataloader
+
+ pads all tensors within a batch to the same dimensions
+ """
+ waveforms = []
+ spectrograms = []
+ labels = []
+ input_lengths = []
+ label_lengths = []
+
+ for sample in samples:
+ waveforms.append(sample["waveform"].transpose(0, 1))
+ spectrograms.append(sample["spectrogram"])
+ labels.append(sample["utterance"])
+ input_lengths.append(sample["spectrogram"].shape[0] // 2)
+ label_lengths.append(len(sample["utterance"]))
+
+ waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True)
+ spectrograms = (
+ torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2, 3)
+ )
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True)
+
+ return {
+ "waveform": waveforms,
+ "spectrogram": spectrograms,
+ "input_length": input_lengths,
+ "utterance": labels,
+ "utterance_length": label_lengths,
+ }
+
+
+if __name__ == "__main__":
+ DATASET_PATH = "/Volumes/pherkel/SWR2-ASR"
+ LANGUAGE = "mls_german_opus"
+ split = Split.TRAIN
+ DOWNLOAD = False
+
+ dataset = MLSDataset(DATASET_PATH, LANGUAGE, split, DOWNLOAD, None)
+
+ tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
+ dataset.set_tokenizer(tok)