diff options
-rw-r--r-- | swr2_asr/model_deep_speech.py | 2 | ||||
-rw-r--r-- | swr2_asr/train.py | 49 | ||||
-rw-r--r-- | swr2_asr/utils.py | 164 |
3 files changed, 95 insertions, 120 deletions
diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py index ea0b667..f00ebd4 100644 --- a/swr2_asr/model_deep_speech.py +++ b/swr2_asr/model_deep_speech.py @@ -1,5 +1,5 @@ -from torch import nn import torch.nn.functional as F +from torch import nn class CNNLayerNorm(nn.Module): diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 53cdac1..bae8c7c 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -5,13 +5,14 @@ from typing import TypedDict import click import torch import torch.nn.functional as F +from AudioLoader.speech import MultilingualLibriSpeech from tokenizers import Tokenizer from torch import nn, optim from torch.utils.data import DataLoader from swr2_asr.model_deep_speech import SpeechRecognitionModel from swr2_asr.tokenizer import train_bpe_tokenizer -from swr2_asr.utils import MLSDataset, Split +from swr2_asr.utils import MLSDataset, Split, collate_fn from .loss_scores import cer, wer @@ -31,20 +32,20 @@ class HParams(TypedDict): epochs: int -def greedy_decoder(output, labels, label_lengths, blank_label=28, collapse_repeated=True): +def greedy_decoder(output, tokenizer, labels, label_lengths, blank_label=28, collapse_repeated=True): """Greedily decode a sequence.""" arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member decodes = [] targets = [] for i, args in enumerate(arg_maxes): decode = [] - targets.append(text_transform.decode([int(x) for x in labels[i][: label_lengths[i]].tolist()])) + targets.append(tokenizer.decode([int(x) for x in labels[i][: label_lengths[i]].tolist()])) for j, index in enumerate(args): if index != blank_label: if collapse_repeated and j != 0 and index == args[j - 1]: continue decode.append(index.item()) - decodes.append(text_transform.decode(decode)) + decodes.append(tokenizer.decode(decode)) return decodes, targets @@ -77,15 +78,14 @@ def train( model.train() data_len = len(train_loader.dataset) for batch_idx, _data in enumerate(train_loader): - _, spectrograms, input_lengths, labels, label_lengths, *_ = _data - spectrograms, labels = spectrograms.to(device), labels.to(device) + spectrograms, labels = _data['spectrogram'].to(device), _data['utterance'].to(device) optimizer.zero_grad() output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) - loss = criterion(output, labels, input_lengths, label_lengths) + loss = criterion(output, labels, _data['input_length'], _data["utterance_length"]) loss.backward() optimizer.step() @@ -102,7 +102,7 @@ def train( return loss.item() -def test(model, device, test_loader, criterion): +def test(model, device, test_loader, criterion, tokenizer): """Test""" print("\nevaluating...") model.eval() @@ -110,17 +110,20 @@ def test(model, device, test_loader, criterion): test_cer, test_wer = [], [] with torch.no_grad(): for _data in test_loader: - spectrograms, labels, input_lengths, label_lengths = _data - spectrograms, labels = spectrograms.to(device), labels.to(device) + spectrograms, labels = _data['spectrogram'].to(device), _data['utterance'].to(device) output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) - loss = criterion(output, labels, input_lengths, label_lengths) + loss = criterion(output, labels, _data['input_length'], _data["utterance_length"]) test_loss += loss.item() / len(test_loader) - decoded_preds, decoded_targets = greedy_decoder(output.transpose(0, 1), labels, label_lengths) + decoded_preds, decoded_targets = greedy_decoder( + output = output.transpose(0, 1), + labels = labels, + label_lengths= _data["utterance_length"], + tokenizer=tokenizer) for j, pred in enumerate(decoded_preds): test_cer.append(cer(decoded_targets[j], pred)) test_wer.append(wer(decoded_targets[j], pred)) @@ -150,12 +153,11 @@ def run( # device = torch.device("mps") # load dataset - train_dataset = MLSDataset(dataset_path, language, Split.train, download=True) - valid_dataset = MLSDataset(dataset_path, language, Split.valid, download=True) - test_dataset = MLSDataset(dataset_path, language, Split.test, download=True) + train_dataset = MLSDataset(dataset_path, language, Split.train, download=True, spectrogram_hparams=None) + valid_dataset = MLSDataset(dataset_path, language, Split.valid, download=True, spectrogram_hparams=None) + test_dataset = MLSDataset(dataset_path, language, Split.test, download=True, spectrogram_hparams=None) - # TODO: add flag to choose tokenizer - # load tokenizer (bpe by default): + # load tokenizer (bpe by default): if not os.path.isfile("data/tokenizers/bpe_tokenizer_german_3000.json"): print("There is no tokenizer available. Do you want to train it on the dataset?") input("Press Enter to continue...") @@ -167,12 +169,14 @@ def run( out_path="data/tokenizers/bpe_tokenizer_german_3000.json", vocab_size=3000, ) - + tokenizer = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") - + train_dataset.set_tokenizer(tokenizer) valid_dataset.set_tokenizer(tokenizer) test_dataset.set_tokenizer(tokenizer) + + print(f"Waveform shape: {train_dataset[0]['waveform'].shape}") hparams = HParams( n_cnn_layers=3, @@ -191,12 +195,14 @@ def run( train_dataset, batch_size=hparams["batch_size"], shuffle=True, + collate_fn=lambda x: collate_fn(x), ) valid_loader = DataLoader( valid_dataset, batch_size=hparams["batch_size"], shuffle=True, + collate_fn=lambda x: collate_fn(x), ) # enable flag to find the most compatible algorithms in advance @@ -243,7 +249,7 @@ def run( iter_meter, ) - test(model=model, device=device, test_loader=valid_loader, criterion=criterion) + test(model=model, device=device, test_loader=valid_loader, criterion=criterion, tokenizer = tokenizer) print("saving epoch", str(epoch)) torch.save( {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss}, @@ -285,3 +291,6 @@ def run_cli( dataset_path=dataset_path, language="mls_german_opus", ) + +if __name__ == "__main__": + run(1e-3, 10, 1, False, "", "/Volumes/pherkel/SWR2-ASR", "mls_german_opus")
\ No newline at end of file diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py index 404661d..4c751d5 100644 --- a/swr2_asr/utils.py +++ b/swr2_asr/utils.py @@ -1,7 +1,6 @@ """Class containing utils for the ASR system.""" import os from enum import Enum -from multiprocessing import Pool from typing import TypedDict import numpy as np @@ -10,9 +9,8 @@ import torchaudio from tokenizers import Tokenizer from torch.utils.data import Dataset from tqdm import tqdm -import audio_metadata -from swr2_asr.tokenizer import CharTokenizer, TokenizerType +from swr2_asr.tokenizer import TokenizerType train_audio_transforms = torch.nn.Sequential( torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), @@ -91,20 +89,42 @@ class MLSDataset(Dataset): <speakerid>_<bookid>_<chapterid> <utterance> """ - def __init__(self, dataset_path: str, language: str, split: Split, download: bool): + def __init__(self, dataset_path: str, language: str, split: Split, download: bool, spectrogram_hparams: dict | None): """Initializes the dataset""" self.dataset_path = dataset_path self.language = language self.file_ext = ".opus" if "opus" in language else ".flac" self.mls_split: MLSSplit = split_to_mls_split(split) # split path on disk self.split: Split = split # split used internally + + if spectrogram_hparams is None: + self.spectrogram_hparams = { + "sample_rate": 16000, + "n_fft": 400, + "win_length": 400, + "hop_length": 160, + "n_mels": 128, + "f_min": 0, + "f_max": 8000, + "power": 2.0, + } + else: + self.spectrogram_hparams = spectrogram_hparams + self.dataset_lookup = [] self.tokenizer: type[TokenizerType] self._handle_download_dataset(download) self._validate_local_directory() + self.initialize() - transcripts_path = os.path.join(dataset_path, language, self.mls_split, "transcripts.txt") + + def initialize(self) -> None: + """Initializes the dataset + + Reads the transcripts.txt file and creates a lookup table + """ + transcripts_path = os.path.join(self.dataset_path, self.language, self.mls_split, "transcripts.txt") with open(transcripts_path, "r", encoding="utf-8") as script_file: # read all lines in transcripts.txt @@ -136,13 +156,9 @@ class MLSDataset(Dataset): for path, utterance in zip(identifier, utterances, strict=False) ] - self.max_spec_length = 0 - self.max_utterance_length = 0 - def set_tokenizer(self, tokenizer: type[TokenizerType]): """Sets the tokenizer""" self.tokenizer = tokenizer - # self.calc_paddings() def _handle_download_dataset(self, download: bool): """Download the dataset""" @@ -163,80 +179,14 @@ class MLSDataset(Dataset): if not os.path.exists(os.path.join(self.dataset_path, self.language, self.mls_split)): raise ValueError("Split not found in dataset") - def _calculate_max_length(self, chunk): - """Calculates the maximum length of the spectrogram and the utterance - - to be called in a multiprocessing pool - """ - max_spec_length = 0 - max_utterance_length = 0 - - for sample in chunk: - audio_path = os.path.join( - self.dataset_path, - self.language, - self.mls_split, - "audio", - sample["speakerid"], - sample["bookid"], - "_".join( - [ - sample["speakerid"], - sample["bookid"], - sample["chapterid"], - ] - ) - + self.file_ext, - ) - metadata = audio_metadata.load(audio_path) - audio_duration = metadata.streaminfo.duration - sample_rate = metadata.streaminfo.sample_rate - - max_spec_length = int(max(max_spec_length, (audio_duration * sample_rate) // 200)) - max_utterance_length = max(max_utterance_length, len(self.tokenizer.encode(sample["utterance"]).ids)) - - return max_spec_length, max_utterance_length - - def calc_paddings(self) -> None: - """Sets the maximum length of the spectrogram and the utterance""" - # check if dataset has been loaded and tokenizer has been set - if not self.dataset_lookup: - raise ValueError("Dataset not loaded") - if not self.tokenizer: - raise ValueError("Tokenizer not set") - # check if paddings have been calculated already - if os.path.isfile(os.path.join(self.dataset_path, self.language, self.mls_split, "paddings.txt")): - print("Paddings already calculated") - with open(os.path.join(self.dataset_path, self.language, self.mls_split, "paddings.txt"), "r") as f: - self.max_spec_length, self.max_utterance_length = [int(line.strip()) for line in f.readlines()] - return - else: - print("Calculating paddings...") - - thread_count = os.cpu_count() - if thread_count is None: - thread_count = 4 - chunk_size = len(self.dataset_lookup) // thread_count - chunks = [self.dataset_lookup[i : i + chunk_size] for i in range(0, len(self.dataset_lookup), chunk_size)] - - with Pool(thread_count) as p: - results = list(p.imap(self._calculate_max_length, chunks)) - - for spec, utterance in results: - self.max_spec_length = max(self.max_spec_length, spec) - self.max_utterance_length = max(self.max_utterance_length, utterance) - - # write to file - with open(os.path.join(self.dataset_path, self.language, self.mls_split, "paddings.txt"), "w") as f: - f.write(f"{self.max_spec_length}\n") - f.write(f"{self.max_utterance_length}") - def __len__(self): """Returns the length of the dataset""" return len(self.dataset_lookup) def __getitem__(self, idx: int) -> Sample: """One sample""" + if self.tokenizer is None: + raise ValueError("No tokenizer set") # get the utterance utterance = self.dataset_lookup[idx]["utterance"] @@ -261,42 +211,62 @@ class MLSDataset(Dataset): waveform, sample_rate = torchaudio.load(audio_path) # type: ignore # resample if necessary - if sample_rate != 16000: - resampler = torchaudio.transforms.Resample(sample_rate, 16000) + if sample_rate != self.spectrogram_hparams["sample_rate"]: + resampler = torchaudio.transforms.Resample(sample_rate, self.spectrogram_hparams["sample_rate"]) waveform = resampler(waveform) - sample_rate = 16000 - spec = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128)(waveform).squeeze(0).transpose(0, 1) + spec = torchaudio.transforms.MelSpectrogram(**self.spectrogram_hparams)(waveform).squeeze(0).transpose(0, 1) input_length = spec.shape[0] // 2 + utterance_length = len(utterance) - self.tokenizer.enable_padding() - utterance = self.tokenizer.encode( - utterance, - ).ids + utterance = self.tokenizer.encode(utterance) - utterance = torch.Tensor(utterance) + utterance = torch.LongTensor(utterance.ids) return Sample( - # TODO: add flag to only return spectrogram or waveform or both waveform=waveform, spectrogram=spec, input_length=input_length, utterance=utterance, utterance_length=utterance_length, - sample_rate=sample_rate, + sample_rate=self.spectrogram_hparams["sample_rate"], speaker_id=self.dataset_lookup[idx]["speakerid"], book_id=self.dataset_lookup[idx]["bookid"], chapter_id=self.dataset_lookup[idx]["chapterid"], ) - def download(self, dataset_path: str, language: str): - """Download the dataset""" - os.makedirs(dataset_path) - url = f"https://dl.fbaipublicfiles.com/mls/{language}.tar.gz" - - torch.hub.download_url_to_file(url, dataset_path) +def collate_fn(samples: list[Sample]) -> dict: + """Collate function for the dataloader + + pads all tensors within a batch to the same dimensions + """ + waveforms = [] + spectrograms = [] + labels = [] + input_lengths = [] + label_lengths = [] + + for sample in samples: + waveforms.append(sample["waveform"].transpose(0, 1)) + spectrograms.append(sample["spectrogram"]) + labels.append(sample["utterance"]) + input_lengths.append(sample["spectrogram"].shape[0] // 2) + label_lengths.append(len(sample["utterance"])) + + waveforms = torch.nn.utils.rnn.pad_sequence(waveforms, batch_first=True) + spectrograms = torch.nn.utils.rnn.pad_sequence(spectrograms, batch_first=True).unsqueeze(1).transpose(2,3) + labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True) + + return { + "waveform": waveforms, + "spectrogram": spectrograms, + "input_length": input_lengths, + "utterance": labels, + "utterance_length": label_lengths, + } + if __name__ == "__main__": @@ -305,11 +275,7 @@ if __name__ == "__main__": split = Split.train download = False - dataset = MLSDataset(dataset_path, language, split, download) + dataset = MLSDataset(dataset_path, language, split, download, None) tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") dataset.set_tokenizer(tok) - dataset.calc_paddings() - - print(f"Spectrogram shape: {dataset[41]['spectrogram'].shape}") - print(f"Utterance shape: {dataset[41]['utterance'].shape}") |