aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/utils.py
diff options
context:
space:
mode:
authorPherkel2023-08-30 17:11:51 +0200
committerPherkel2023-08-30 17:11:51 +0200
commit335b8a32f8bba5d37c00af6b4ecd1b9fc520f964 (patch)
treea972bdf941714bb3281116503acb19aa37ed37af /swr2_asr/utils.py
parent9b450c685e9ec4e7e74688de3d6dbb719b19fcf6 (diff)
wörks now!°!!
Diffstat (limited to 'swr2_asr/utils.py')
-rw-r--r--swr2_asr/utils.py164
1 files changed, 65 insertions, 99 deletions
diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py
index 404661d..4c751d5 100644
--- a/swr2_asr/utils.py
+++ b/swr2_asr/utils.py
@@ -1,7 +1,6 @@
"""Class containing utils for the ASR system."""
import os
from enum import Enum
-from multiprocessing import Pool
from typing import TypedDict
import numpy as np
@@ -10,9 +9,8 @@ import torchaudio
from tokenizers import Tokenizer
from torch.utils.data import Dataset
from tqdm import tqdm
-import audio_metadata
-from swr2_asr.tokenizer import CharTokenizer, TokenizerType
+from swr2_asr.tokenizer import TokenizerType
train_audio_transforms = torch.nn.Sequential(
torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
@@ -91,20 +89,42 @@ class MLSDataset(Dataset):
<speakerid>_<bookid>_<chapterid> <utterance>
"""
- def __init__(self, dataset_path: str, language: str, split: Split, download: bool):
+ def __init__(self, dataset_path: str, language: str, split: Split, download: bool, spectrogram_hparams: dict | None):
"""Initializes the dataset"""
self.dataset_path = dataset_path
self.language = language
self.file_ext = ".opus" if "opus" in language else ".flac"
self.mls_split: MLSSplit = split_to_mls_split(split) # split path on disk
self.split: Split = split # split used internally
+
+ if spectrogram_hparams is None:
+ self.spectrogram_hparams = {
+ "sample_rate": 16000,
+ "n_fft": 400,
+ "win_length": 400,
+ "hop_length": 160,
+ "n_mels": 128,
+ "f_min": 0,
+ "f_max": 8000,
+ "power": 2.0,
+ }
+ else:
+ self.spectrogram_hparams = spectrogram_hparams
+
self.dataset_lookup = []
self.tokenizer: type[TokenizerType]
self._handle_download_dataset(download)
self._validate_local_directory()
+ self.initialize()
- transcripts_path = os.path.join(dataset_path, language, self.mls_split, "transcripts.txt")
+
+ def initialize(self) -> None:
+ """Initializes the dataset
+
+ Reads the transcripts.txt file and creates a lookup table
+ """
+ transcripts_path = os.path.join(self.dataset_path, self.language, self.mls_split, "transcripts.txt")
with open(transcripts_path, "r", encoding="utf-8") as script_file:
# read all lines in transcripts.txt
@@ -136,13 +156,9 @@ class MLSDataset(Dataset):
for path, utterance in zip(identifier, utterances, strict=False)
]
- self.max_spec_length = 0
- self.max_utterance_length = 0
-
def set_tokenizer(self, tokenizer: type[TokenizerType]):
"""Sets the tokenizer"""
self.tokenizer = tokenizer
- # self.calc_paddings()
def _handle_download_dataset(self, download: bool):
"""Download the dataset"""
@@ -163,80 +179,14 @@ class MLSDataset(Dataset):
if not os.path.exists(os.path.join(self.dataset_path, self.language, self.mls_split)):
raise ValueError("Split not found in dataset")
- def _calculate_max_length(self, chunk):
- """Calculates the maximum length of the spectrogram and the utterance
-
- to be called in a multiprocessing pool
- """
- max_spec_length = 0
- max_utterance_length = 0
-
- for sample in chunk:
- audio_path = os.path.join(
- self.dataset_path,
- self.language,
- self.mls_split,
- "audio",
- sample["speakerid"],
- sample["bookid"],
- "_".join(
- [
- sample["speakerid"],
- sample["bookid"],
- sample["chapterid"],
- ]
- )
- + self.file_ext,
- )
- metadata = audio_metadata.load(audio_path)
- audio_duration = metadata.streaminfo.duration
- sample_rate = metadata.streaminfo.sample_rate
-
- max_spec_length = int(max(max_spec_length, (audio_duration * sample_rate) // 200))
- max_utterance_length = max(max_utterance_length, len(self.tokenizer.encode(sample["utterance"]).ids))
-
- return max_spec_length, max_utterance_length
-
- def calc_paddings(self) -> None:
- """Sets the maximum length of the spectrogram and the utterance"""
- # check if dataset has been loaded and tokenizer has been set
- if not self.dataset_lookup:
- raise ValueError("Dataset not loaded")
- if not self.tokenizer:
- raise ValueError("Tokenizer not set")
- # check if paddings have been calculated already
- if os.path.isfile(os.path.join(self.dataset_path, self.language, self.mls_split, "paddings.txt")):
- print("Paddings already calculated")
- with open(os.path.join(self.dataset_path, self.language, self.mls_split, "paddings.txt"), "r") as f:
- self.max_spec_length, self.max_utterance_length = [int(line.strip()) for line in f.readlines()]
- return
- else:
- print("Calculating paddings...")
-
- thread_count = os.cpu_count()
- if thread_count is None:
- thread_count = 4
- chunk_size = len(self.dataset_lookup) // thread_count
- chunks = [self.dataset_lookup[i : i + chunk_size] for i in range(0, len(self.dataset_lookup), chunk_size)]
-
- with Pool(thread_count) as p:
- results = list(p.imap(self._calculate_max_length, chunks))
-
- for spec, utterance in results:
- self.max_spec_length = max(self.max_spec_length, spec)
- self.max_utterance_length = max(self.max_utterance_length, utterance)
-
- # write to file
- with open(os.path.join(self.dataset_path, self.language, self.mls_split, "paddings.txt"), "w") as f:
- f.write(f"{self.max_spec_length}\n")
- f.write(f"{self.max_utterance_length}")
-
def __len__(self):
"""Returns the length of the dataset"""
return len(self.dataset_lookup)
def __getitem__(self, idx: int) -> Sample:
"""One sample"""
+ if self.tokenizer is None:
+ raise ValueError("No tokenizer set")
# get the utterance
utterance = self.dataset_lookup[idx]["utterance"]
@@ -261,42 +211,62 @@ class MLSDataset(Dataset):
waveform, sample_rate = torchaudio.load(audio_path) # type: ignore
# resample if necessary
- if sample_rate != 16000:
- resampler = torchaudio.transforms.Resample(sample_rate, 16000)
+ if sample_rate != self.spectrogram_hparams["sample_rate"]:
+ resampler = torchaudio.transforms.Resample(sample_rate, self.spectrogram_hparams["sample_rate"])
waveform = resampler(waveform)
- sample_rate = 16000
- spec = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128)(waveform).squeeze(0).transpose(0, 1)
+ spec = torchaudio.transforms.MelSpectrogram(**self.spectrogram_hparams)(waveform).squeeze(0).transpose(0, 1)
input_length = spec.shape[0] // 2
+
utterance_length = len(utterance)
- self.tokenizer.enable_padding()
- utterance = self.tokenizer.encode(
- utterance,
- ).ids
+ utterance = self.tokenizer.encode(utterance)
- utterance = torch.Tensor(utterance)
+ utterance = torch.LongTensor(utterance.ids)
return Sample(
- # TODO: add flag to only return spectrogram or waveform or both
waveform=waveform,
spectrogram=spec,
input_length=input_length,
utterance=utterance,
utterance_length=utterance_length,
- sample_rate=sample_rate,
+ sample_rate=self.spectrogram_hparams["sample_rate"],
speaker_id=self.dataset_lookup[idx]["speakerid"],
book_id=self.dataset_lookup[idx]["bookid"],
chapter_id=self.dataset_lookup[idx]["chapterid"],
)
- def download(self, dataset_path: str, language: str):
- """Download the dataset"""
- os.makedirs(dataset_path)
- url = f"https://dl.fbaipublicfiles.com/mls/{language}.tar.gz"
-
- torch.hub.download_url_to_file(url, dataset_path)
+def collate_fn(samples: list[Sample]) -> dict:
+ """Collate function for the dataloader
+
+ pads all tensors within a batch to the same dimensions
+ """
+ waveforms = []
+ spectrograms = []
+ labels = []
+ input_lengths = []
+ label_lengths = []
+
+ for sample in samples:
+ waveforms.append(sample["waveform"].transpose(0, 1))
+ spectrograms.append(sample["spectrogram"])
+ labels.append(sample["utterance"])
+ input_lengths.append(sample["spectrogram"].shape[0] // 2)
+ label_lengths.append(len(sample["utterance"]))
+
+ waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True)
+ spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2,3)
+ labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True)
+
+ return {
+ "waveform": waveforms,
+ "spectrogram": spectrograms,
+ "input_length": input_lengths,
+ "utterance": labels,
+ "utterance_length": label_lengths,
+ }
+
if __name__ == "__main__":
@@ -305,11 +275,7 @@ if __name__ == "__main__":
split = Split.train
download = False
- dataset = MLSDataset(dataset_path, language, split, download)
+ dataset = MLSDataset(dataset_path, language, split, download, None)
tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
dataset.set_tokenizer(tok)
- dataset.calc_paddings()
-
- print(f"Spectrogram shape: {dataset[41]['spectrogram'].shape}")
- print(f"Utterance shape: {dataset[41]['utterance'].shape}")