aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/utils.py')
-rw-r--r--swr2_asr/utils.py108
1 files changed, 80 insertions, 28 deletions
diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py
index 786fbcf..404661d 100644
--- a/swr2_asr/utils.py
+++ b/swr2_asr/utils.py
@@ -1,6 +1,7 @@
"""Class containing utils for the ASR system."""
import os
from enum import Enum
+from multiprocessing import Pool
from typing import TypedDict
import numpy as np
@@ -8,6 +9,8 @@ import torch
import torchaudio
from tokenizers import Tokenizer
from torch.utils.data import Dataset
+from tqdm import tqdm
+import audio_metadata
from swr2_asr.tokenizer import CharTokenizer, TokenizerType
@@ -133,11 +136,13 @@ class MLSDataset(Dataset):
for path, utterance in zip(identifier, utterances, strict=False)
]
+ self.max_spec_length = 0
+ self.max_utterance_length = 0
+
def set_tokenizer(self, tokenizer: type[TokenizerType]):
"""Sets the tokenizer"""
self.tokenizer = tokenizer
-
- self.calc_paddings()
+ # self.calc_paddings()
def _handle_download_dataset(self, download: bool):
"""Download the dataset"""
@@ -158,27 +163,73 @@ class MLSDataset(Dataset):
if not os.path.exists(os.path.join(self.dataset_path, self.language, self.mls_split)):
raise ValueError("Split not found in dataset")
- def calc_paddings(self):
- """Sets the maximum length of the spectrogram"""
+ def _calculate_max_length(self, chunk):
+ """Calculates the maximum length of the spectrogram and the utterance
+
+ to be called in a multiprocessing pool
+ """
+ max_spec_length = 0
+ max_utterance_length = 0
+
+ for sample in chunk:
+ audio_path = os.path.join(
+ self.dataset_path,
+ self.language,
+ self.mls_split,
+ "audio",
+ sample["speakerid"],
+ sample["bookid"],
+ "_".join(
+ [
+ sample["speakerid"],
+ sample["bookid"],
+ sample["chapterid"],
+ ]
+ )
+ + self.file_ext,
+ )
+ metadata = audio_metadata.load(audio_path)
+ audio_duration = metadata.streaminfo.duration
+ sample_rate = metadata.streaminfo.sample_rate
+
+ max_spec_length = int(max(max_spec_length, (audio_duration * sample_rate) // 200))
+ max_utterance_length = max(max_utterance_length, len(self.tokenizer.encode(sample["utterance"]).ids))
+
+ return max_spec_length, max_utterance_length
+
+ def calc_paddings(self) -> None:
+ """Sets the maximum length of the spectrogram and the utterance"""
# check if dataset has been loaded and tokenizer has been set
if not self.dataset_lookup:
raise ValueError("Dataset not loaded")
if not self.tokenizer:
raise ValueError("Tokenizer not set")
-
- max_spec_length = 0
- max_uterance_length = 0
- for sample in self.dataset_lookup:
- spec_length = sample["spectrogram"].shape[0]
- if spec_length > max_spec_length:
- max_spec_length = spec_length
-
- utterance_length = sample["utterance"].shape[0]
- if utterance_length > max_uterance_length:
- max_uterance_length = utterance_length
-
- self.max_spec_length = max_spec_length
- self.max_utterance_length = max_uterance_length
+ # check if paddings have been calculated already
+ if os.path.isfile(os.path.join(self.dataset_path, self.language, self.mls_split, "paddings.txt")):
+ print("Paddings already calculated")
+ with open(os.path.join(self.dataset_path, self.language, self.mls_split, "paddings.txt"), "r") as f:
+ self.max_spec_length, self.max_utterance_length = [int(line.strip()) for line in f.readlines()]
+ return
+ else:
+ print("Calculating paddings...")
+
+ thread_count = os.cpu_count()
+ if thread_count is None:
+ thread_count = 4
+ chunk_size = len(self.dataset_lookup) // thread_count
+ chunks = [self.dataset_lookup[i : i + chunk_size] for i in range(0, len(self.dataset_lookup), chunk_size)]
+
+ with Pool(thread_count) as p:
+ results = list(p.imap(self._calculate_max_length, chunks))
+
+ for spec, utterance in results:
+ self.max_spec_length = max(self.max_spec_length, spec)
+ self.max_utterance_length = max(self.max_utterance_length, utterance)
+
+ # write to file
+ with open(os.path.join(self.dataset_path, self.language, self.mls_split, "paddings.txt"), "w") as f:
+ f.write(f"{self.max_spec_length}\n")
+ f.write(f"{self.max_utterance_length}")
def __len__(self):
"""Returns the length of the dataset"""
@@ -208,18 +259,18 @@ class MLSDataset(Dataset):
)
waveform, sample_rate = torchaudio.load(audio_path) # type: ignore
- # TODO: figure out if we have to resample or not
- # TODO: pad correctly (manually)
+
+ # resample if necessary
+ if sample_rate != 16000:
+ resampler = torchaudio.transforms.Resample(sample_rate, 16000)
+ waveform = resampler(waveform)
+ sample_rate = 16000
+
spec = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128)(waveform).squeeze(0).transpose(0, 1)
- print(f"spec.shape: {spec.shape}")
- input_length = spec.shape[0] // 2
- spec = (
- torch.nn.functional.pad(spec, pad=(0, self.max_spec_length), mode="constant", value=0)
- .unsqueeze(1)
- .transpose(2, 3)
- )
+ input_length = spec.shape[0] // 2
utterance_length = len(utterance)
+
self.tokenizer.enable_padding()
utterance = self.tokenizer.encode(
utterance,
@@ -260,4 +311,5 @@ if __name__ == "__main__":
dataset.set_tokenizer(tok)
dataset.calc_paddings()
- print(dataset[41]["spectrogram"].shape)
+ print(f"Spectrogram shape: {dataset[41]['spectrogram'].shape}")
+ print(f"Utterance shape: {dataset[41]['utterance'].shape}")