diff options
-rw-r--r-- | pyproject.toml | 9 | ||||
-rw-r--r-- | swr2_asr/model_deep_speech.py | 150 | ||||
-rw-r--r-- | swr2_asr/tokenizer.py | 72 | ||||
-rw-r--r-- | swr2_asr/train.py | 298 | ||||
-rw-r--r-- | swr2_asr/utils.py | 207 |
5 files changed, 365 insertions, 371 deletions
diff --git a/pyproject.toml b/pyproject.toml index fabe364..94f7553 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,15 @@ pylint = "^2.17.5" ruff = "^0.0.285" types-tqdm = "^4.66.0.1" +[tool.ruff] +select = ["E", "F", "B", "I"] +fixable = ["ALL"] +line-length = 120 +target-version = "py310" + +[tool.black] +line-length = 120 + [tool.poetry.scripts] train = "swr2_asr.train:run_cli" train-bpe-tokenizer = "swr2_asr.tokenizer:train_bpe_tokenizer" diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py new file mode 100644 index 0000000..ea0b667 --- /dev/null +++ b/swr2_asr/model_deep_speech.py @@ -0,0 +1,150 @@ +from torch import nn +import torch.nn.functional as F + + +class CNNLayerNorm(nn.Module): + """Layer normalization built for cnns input""" + + def __init__(self, n_feats: int): + super().__init__() + self.layer_norm = nn.LayerNorm(n_feats) + + def forward(self, data): + """x (batch, channel, feature, time)""" + data = data.transpose(2, 3).contiguous() # (batch, channel, time, feature) + data = self.layer_norm(data) + return data.transpose(2, 3).contiguous() # (batch, channel, feature, time) + + +class ResidualCNN(nn.Module): + """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel: int, + stride: int, + dropout: float, + n_feats: int, + ): + super().__init__() + + self.cnn1 = nn.Conv2d( + in_channels, out_channels, kernel, stride, padding=kernel // 2 + ) + self.cnn2 = nn.Conv2d( + out_channels, + out_channels, + kernel, + stride, + padding=kernel // 2, + ) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.layer_norm1 = CNNLayerNorm(n_feats) + self.layer_norm2 = CNNLayerNorm(n_feats) + + def forward(self, data): + """x (batch, channel, feature, time)""" + residual = data # (batch, channel, feature, time) + data = self.layer_norm1(data) + data = F.gelu(data) + data = self.dropout1(data) + data = self.cnn1(data) + data = self.layer_norm2(data) + data = F.gelu(data) + data = self.dropout2(data) + data = self.cnn2(data) + data += residual + return data # (batch, channel, feature, time) + + +class BidirectionalGRU(nn.Module): + """BIdirectional GRU with Layer Normalization and Dropout""" + + def __init__( + self, + rnn_dim: int, + hidden_size: int, + dropout: float, + batch_first: bool, + ): + super().__init__() + + self.bi_gru = nn.GRU( + input_size=rnn_dim, + hidden_size=hidden_size, + num_layers=1, + batch_first=batch_first, + bidirectional=True, + ) + self.layer_norm = nn.LayerNorm(rnn_dim) + self.dropout = nn.Dropout(dropout) + + def forward(self, data): + """data (batch, time, feature)""" + data = self.layer_norm(data) + data = F.gelu(data) + data = self.dropout(data) + data, _ = self.bi_gru(data) + return data + + +class SpeechRecognitionModel(nn.Module): + """Speech Recognition Model Inspired by DeepSpeech 2""" + + def __init__( + self, + n_cnn_layers: int, + n_rnn_layers: int, + rnn_dim: int, + n_class: int, + n_feats: int, + stride: int = 2, + dropout: float = 0.1, + ): + super().__init__() + n_feats //= 2 + self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2) + # n residual cnn layers with filter size of 32 + self.rescnn_layers = nn.Sequential( + *[ + ResidualCNN( + 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats + ) + for _ in range(n_cnn_layers) + ] + ) + self.fully_connected = nn.Linear(n_feats * 32, rnn_dim) + self.birnn_layers = nn.Sequential( + *[ + BidirectionalGRU( + rnn_dim=rnn_dim if i == 0 else rnn_dim * 2, + hidden_size=rnn_dim, + dropout=dropout, + batch_first=i == 0, + ) + for i in range(n_rnn_layers) + ] + ) + self.classifier = nn.Sequential( + nn.Linear(rnn_dim * 2, rnn_dim), # birnn returns rnn_dim*2 + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(rnn_dim, n_class), + ) + + def forward(self, data): + """data (batch, channel, feature, time)""" + data = self.cnn(data) + data = self.rescnn_layers(data) + sizes = data.size() + data = data.view( + sizes[0], sizes[1] * sizes[2], sizes[3] + ) # (batch, feature, time) + data = data.transpose(1, 2) # (batch, time, feature) + data = self.fully_connected(data) + data = self.birnn_layers(data) + data = self.classifier(data) + return data diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py index 4dbb386..5758da7 100644 --- a/swr2_asr/tokenizer.py +++ b/swr2_asr/tokenizer.py @@ -1,16 +1,50 @@ """Tokenizer for use with Multilingual Librispeech""" -from dataclasses import dataclass import json import os -import click -from tqdm import tqdm +from dataclasses import dataclass +from typing import Type +import click from AudioLoader.speech import MultilingualLibriSpeech - from tokenizers import Tokenizer, normalizers from tokenizers.models import BPE -from tokenizers.trainers import BpeTrainer from tokenizers.pre_tokenizers import Whitespace +from tokenizers.trainers import BpeTrainer +from tqdm import tqdm + + +class TokenizerType: + def encode(self, sequence: str) -> list[int]: + raise NotImplementedError + + def decode(self, labels: list[int], remove_special_tokens: bool) -> str: + raise NotImplementedError + + def decode_batch(self, labels: list[list[int]]) -> list[str]: + raise NotImplementedError + + def get_vocab_size(self) -> int: + raise NotImplementedError + + def enable_padding( + self, + length: int = -1, + direction: str = "right", + pad_id: int = 0, + pad_type_id: int = 0, + pad_token: str = "[PAD]", + ) -> None: + raise NotImplementedError + + def save(self, path: str) -> None: + raise NotImplementedError + + @staticmethod + def from_file(path: str) -> "TokenizerType": + raise NotImplementedError + + +tokenizer_type = Type[TokenizerType] @dataclass @@ -20,7 +54,7 @@ class Encoding: ids: list[int] -class CharTokenizer: +class CharTokenizer(TokenizerType): """Very simple tokenizer for use with Multilingual Librispeech Simply checks what characters are in the dataset and uses them as tokens. @@ -45,9 +79,7 @@ class CharTokenizer: self.char_map[token] = len(self.char_map) self.index_map[len(self.index_map)] = token - def train( - self, dataset_path: str, language: str, split: str, download: bool = True - ): + def train(self, dataset_path: str, language: str, split: str, download: bool = True): """Train the tokenizer on the given dataset Args: @@ -65,9 +97,7 @@ class CharTokenizer: chars: set = set() for s_plit in splits: - transcript_path = os.path.join( - dataset_path, language, s_plit, "transcripts.txt" - ) + transcript_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt") # check if dataset is downloaded, download if not if download and not os.path.exists(transcript_path): @@ -90,7 +120,7 @@ class CharTokenizer: self.char_map[char] = i self.index_map[i] = char - def encode(self, text: str): + def encode(self, sequence: str): """Use a character map and convert text to an integer sequence automatically maps spaces to <SPACE> and makes everything lowercase @@ -98,8 +128,8 @@ class CharTokenizer: """ int_sequence = [] - text = text.lower() - for char in text: + sequence = sequence.lower() + for char in sequence: if char == " ": mapped_char = self.char_map["<SPACE>"] elif char not in self.char_map: @@ -174,9 +204,7 @@ class CharTokenizer: @click.option("--language", default="mls_german_opus", help="Language to use") @click.option("--split", default="train", help="Split to use (including all)") @click.option("--download", default=True, help="Whether to download the dataset") -@click.option( - "--out_path", default="tokenizer.json", help="Path to save the tokenizer to" -) +@click.option("--out_path", default="tokenizer.json", help="Path to save the tokenizer to") @click.option("--vocab_size", default=2000, help="Size of the vocabulary") def train_bpe_tokenizer( dataset_path: str, @@ -210,9 +238,7 @@ def train_bpe_tokenizer( lines = [] for s_plit in splits: - transcripts_path = os.path.join( - dataset_path, language, s_plit, "transcripts.txt" - ) + transcripts_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt") if download and not os.path.exists(transcripts_path): MultilingualLibriSpeech(dataset_path, language, s_plit, download=True) @@ -296,9 +322,7 @@ def train_bpe_tokenizer( @click.option("--dataset_path", default="data", help="Path to the MLS dataset") @click.option("--language", default="mls_german_opus", help="Language to use") @click.option("--split", default="train", help="Split to use") -@click.option( - "--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to" -) +@click.option("--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to") @click.option("--download", default=True, help="Whether to download the dataset") def train_char_tokenizer( dataset_path: str, diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 6af1e80..53cdac1 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -1,74 +1,44 @@ """Training script for the ASR model.""" import os +from typing import TypedDict + import click import torch import torch.nn.functional as F -import torchaudio -from AudioLoader.speech import MultilingualLibriSpeech +from tokenizers import Tokenizer from torch import nn, optim from torch.utils.data import DataLoader -from tokenizers import Tokenizer -from .tokenizer import CharTokenizer + +from swr2_asr.model_deep_speech import SpeechRecognitionModel +from swr2_asr.tokenizer import train_bpe_tokenizer +from swr2_asr.utils import MLSDataset, Split from .loss_scores import cer, wer -train_audio_transforms = nn.Sequential( - torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), - torchaudio.transforms.FrequencyMasking(freq_mask_param=30), - torchaudio.transforms.TimeMasking(time_mask_param=100), -) -valid_audio_transforms = torchaudio.transforms.MelSpectrogram() - -# text_transform = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") -text_transform = CharTokenizer() -text_transform.from_file("data/tokenizers/char_tokenizer_german.json") - - -def data_processing(data, data_type="train"): - """Return the spectrograms, labels, and their lengths.""" - spectrograms = [] - labels = [] - input_lengths = [] - label_lengths = [] - for sample in data: - if data_type == "train": - spec = train_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1) - elif data_type == "valid": - spec = valid_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1) - else: - raise ValueError("data_type should be train or valid") - spectrograms.append(spec) - label = torch.Tensor(text_transform.encode(sample["utterance"]).ids) - labels.append(label) - input_lengths.append(spec.shape[0] // 2) - label_lengths.append(len(label)) - - spectrograms = ( - nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) - .unsqueeze(1) - .transpose(2, 3) - ) - labels = nn.utils.rnn.pad_sequence(labels, batch_first=True) +class HParams(TypedDict): + """Type for the hyperparameters of the model.""" - return spectrograms, labels, input_lengths, label_lengths + n_cnn_layers: int + n_rnn_layers: int + rnn_dim: int + n_class: int + n_feats: int + stride: int + dropout: float + learning_rate: float + batch_size: int + epochs: int -def greedy_decoder( - output, labels, label_lengths, blank_label=28, collapse_repeated=True -): - # TODO: adopt to support both tokenizers +def greedy_decoder(output, labels, label_lengths, blank_label=28, collapse_repeated=True): """Greedily decode a sequence.""" arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member decodes = [] targets = [] for i, args in enumerate(arg_maxes): decode = [] - targets.append( - text_transform.decode( - [int(x) for x in labels[i][: label_lengths[i]].tolist()] - ) - ) + targets.append(text_transform.decode([int(x) for x in labels[i][: label_lengths[i]].tolist()])) for j, index in enumerate(args): if index != blank_label: if collapse_repeated and j != 0 and index == args[j - 1]: @@ -78,155 +48,6 @@ def greedy_decoder( return decodes, targets -# TODO: restructure into own file / class -class CNNLayerNorm(nn.Module): - """Layer normalization built for cnns input""" - - def __init__(self, n_feats: int): - super().__init__() - self.layer_norm = nn.LayerNorm(n_feats) - - def forward(self, data): - """x (batch, channel, feature, time)""" - data = data.transpose(2, 3).contiguous() # (batch, channel, time, feature) - data = self.layer_norm(data) - return data.transpose(2, 3).contiguous() # (batch, channel, feature, time) - - -class ResidualCNN(nn.Module): - """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf""" - - def __init__( - self, - in_channels: int, - out_channels: int, - kernel: int, - stride: int, - dropout: float, - n_feats: int, - ): - super().__init__() - - self.cnn1 = nn.Conv2d( - in_channels, out_channels, kernel, stride, padding=kernel // 2 - ) - self.cnn2 = nn.Conv2d( - out_channels, - out_channels, - kernel, - stride, - padding=kernel // 2, - ) - self.dropout1 = nn.Dropout(dropout) - self.dropout2 = nn.Dropout(dropout) - self.layer_norm1 = CNNLayerNorm(n_feats) - self.layer_norm2 = CNNLayerNorm(n_feats) - - def forward(self, data): - """x (batch, channel, feature, time)""" - residual = data # (batch, channel, feature, time) - data = self.layer_norm1(data) - data = F.gelu(data) - data = self.dropout1(data) - data = self.cnn1(data) - data = self.layer_norm2(data) - data = F.gelu(data) - data = self.dropout2(data) - data = self.cnn2(data) - data += residual - return data # (batch, channel, feature, time) - - -class BidirectionalGRU(nn.Module): - """BIdirectional GRU with Layer Normalization and Dropout""" - - def __init__( - self, - rnn_dim: int, - hidden_size: int, - dropout: float, - batch_first: bool, - ): - super().__init__() - - self.bi_gru = nn.GRU( - input_size=rnn_dim, - hidden_size=hidden_size, - num_layers=1, - batch_first=batch_first, - bidirectional=True, - ) - self.layer_norm = nn.LayerNorm(rnn_dim) - self.dropout = nn.Dropout(dropout) - - def forward(self, data): - """data (batch, time, feature)""" - data = self.layer_norm(data) - data = F.gelu(data) - data = self.dropout(data) - data, _ = self.bi_gru(data) - return data - - -class SpeechRecognitionModel(nn.Module): - """Speech Recognition Model Inspired by DeepSpeech 2""" - - def __init__( - self, - n_cnn_layers: int, - n_rnn_layers: int, - rnn_dim: int, - n_class: int, - n_feats: int, - stride: int = 2, - dropout: float = 0.1, - ): - super().__init__() - n_feats //= 2 - self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2) - # n residual cnn layers with filter size of 32 - self.rescnn_layers = nn.Sequential( - *[ - ResidualCNN( - 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats - ) - for _ in range(n_cnn_layers) - ] - ) - self.fully_connected = nn.Linear(n_feats * 32, rnn_dim) - self.birnn_layers = nn.Sequential( - *[ - BidirectionalGRU( - rnn_dim=rnn_dim if i == 0 else rnn_dim * 2, - hidden_size=rnn_dim, - dropout=dropout, - batch_first=i == 0, - ) - for i in range(n_rnn_layers) - ] - ) - self.classifier = nn.Sequential( - nn.Linear(rnn_dim * 2, rnn_dim), # birnn returns rnn_dim*2 - nn.GELU(), - nn.Dropout(dropout), - nn.Linear(rnn_dim, n_class), - ) - - def forward(self, data): - """data (batch, channel, feature, time)""" - data = self.cnn(data) - data = self.rescnn_layers(data) - sizes = data.size() - data = data.view( - sizes[0], sizes[1] * sizes[2], sizes[3] - ) # (batch, feature, time) - data = data.transpose(1, 2) # (batch, time, feature) - data = self.fully_connected(data) - data = self.birnn_layers(data) - data = self.classifier(data) - return data - - class IterMeter: """keeps track of total iterations""" @@ -256,9 +77,8 @@ def train( model.train() data_len = len(train_loader.dataset) for batch_idx, _data in enumerate(train_loader): - spectrograms, labels, input_lengths, label_lengths = _data + _, spectrograms, input_lengths, labels, label_lengths, *_ = _data spectrograms, labels = spectrograms.to(device), labels.to(device) - optimizer.zero_grad() output = model(spectrograms) # (batch, time, n_class) @@ -282,7 +102,6 @@ def train( return loss.item() -# TODO: check how dataloader can be made more efficient def test(model, device, test_loader, criterion): """Test""" print("\nevaluating...") @@ -301,9 +120,7 @@ def test(model, device, test_loader, criterion): loss = criterion(output, labels, input_lengths, label_lengths) test_loss += loss.item() / len(test_loader) - decoded_preds, decoded_targets = greedy_decoder( - output.transpose(0, 1), labels, label_lengths - ) + decoded_preds, decoded_targets = greedy_decoder(output.transpose(0, 1), labels, label_lengths) for j, pred in enumerate(decoded_preds): test_cer.append(cer(decoded_targets[j], pred)) test_wer.append(wer(decoded_targets[j], pred)) @@ -324,46 +141,62 @@ def run( load: bool, path: str, dataset_path: str, + language: str, ) -> None: """Runs the training script.""" - hparams = { - "n_cnn_layers": 3, - "n_rnn_layers": 5, - "rnn_dim": 512, - "n_class": 36, # TODO: dynamically determine this from vocab size - "n_feats": 128, - "stride": 2, - "dropout": 0.1, - "learning_rate": learning_rate, - "batch_size": batch_size, - "epochs": epochs, - } - use_cuda = torch.cuda.is_available() torch.manual_seed(42) device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member # device = torch.device("mps") - download_dataset = not os.path.isdir(path) - train_dataset = MultilingualLibriSpeech( - dataset_path, "mls_german_opus", split="dev", download=download_dataset - ) - test_dataset = MultilingualLibriSpeech( - dataset_path, "mls_german_opus", split="test", download=False + # load dataset + train_dataset = MLSDataset(dataset_path, language, Split.train, download=True) + valid_dataset = MLSDataset(dataset_path, language, Split.valid, download=True) + test_dataset = MLSDataset(dataset_path, language, Split.test, download=True) + + # TODO: add flag to choose tokenizer + # load tokenizer (bpe by default): + if not os.path.isfile("data/tokenizers/bpe_tokenizer_german_3000.json"): + print("There is no tokenizer available. Do you want to train it on the dataset?") + input("Press Enter to continue...") + train_bpe_tokenizer( + dataset_path=dataset_path, + language=language, + split="all", + download=False, + out_path="data/tokenizers/bpe_tokenizer_german_3000.json", + vocab_size=3000, + ) + + tokenizer = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") + + train_dataset.set_tokenizer(tokenizer) + valid_dataset.set_tokenizer(tokenizer) + test_dataset.set_tokenizer(tokenizer) + + hparams = HParams( + n_cnn_layers=3, + n_rnn_layers=5, + rnn_dim=512, + n_class=tokenizer.get_vocab_size(), + n_feats=128, + stride=2, + dropout=0.1, + learning_rate=learning_rate, + batch_size=batch_size, + epochs=epochs, ) train_loader = DataLoader( train_dataset, batch_size=hparams["batch_size"], shuffle=True, - collate_fn=lambda x: data_processing(x, "train"), ) - test_loader = DataLoader( - test_dataset, + valid_loader = DataLoader( + valid_dataset, batch_size=hparams["batch_size"], shuffle=True, - collate_fn=lambda x: data_processing(x, "train"), ) # enable flag to find the most compatible algorithms in advance @@ -380,9 +213,7 @@ def run( hparams["dropout"], ).to(device) - print( - "Num Model Parameters", sum((param.nelement() for param in model.parameters())) - ) + print("Num Model Parameters", sum((param.nelement() for param in model.parameters()))) optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) criterion = nn.CTCLoss(blank=28).to(device) if load: @@ -412,7 +243,7 @@ def run( iter_meter, ) - test(model=model, device=device, test_loader=test_loader, criterion=criterion) + test(model=model, device=device, test_loader=valid_loader, criterion=criterion) print("saving epoch", str(epoch)) torch.save( {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss}, @@ -452,4 +283,5 @@ def run_cli( load=load, path=path, dataset_path=dataset_path, + language="mls_german_opus", ) diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py index c4aeb0b..786fbcf 100644 --- a/swr2_asr/utils.py +++ b/swr2_asr/utils.py @@ -1,16 +1,21 @@ """Class containing utils for the ASR system.""" -from dataclasses import dataclass import os -from AudioLoader.speech import MultilingualLibriSpeech +from enum import Enum +from typing import TypedDict + import numpy as np import torch import torchaudio -from torch import nn -from torch.utils.data import Dataset, DataLoader -from enum import Enum - from tokenizers import Tokenizer -from swr2_asr.tokenizer import CharTokenizer +from torch.utils.data import Dataset + +from swr2_asr.tokenizer import CharTokenizer, TokenizerType + +train_audio_transforms = torch.nn.Sequential( + torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), + torchaudio.transforms.FrequencyMasking(freq_mask_param=30), + torchaudio.transforms.TimeMasking(time_mask_param=100), +) # create enum specifiying dataset splits @@ -40,34 +45,20 @@ def split_to_mls_split(split: Split) -> MLSSplit: return split # type: ignore -@dataclass -class Sample: - """Dataclass for a sample in the dataset""" +class Sample(TypedDict): + """Type for a sample in the dataset""" waveform: torch.Tensor spectrogram: torch.Tensor - utterance: str + input_length: int + utterance: torch.Tensor + utterance_length: int sample_rate: int speaker_id: str book_id: str chapter_id: str -def tokenizer_factory(tokenizer_path: str, tokenizer_type: str = "BPE"): - """Factory for Tokenizer class - - Args: - tokenizer_type (str, optional): Type of tokenizer to use. Defaults to "BPE". - - Returns: - nn.Module: Tokenizer class - """ - if tokenizer_type == "BPE": - return Tokenizer.from_file(tokenizer_path) - elif tokenizer_type == "char": - return CharTokenizer.from_file(tokenizer_path) - - class MLSDataset(Dataset): """Custom Dataset for reading Multilingual LibriSpeech @@ -105,23 +96,33 @@ class MLSDataset(Dataset): self.mls_split: MLSSplit = split_to_mls_split(split) # split path on disk self.split: Split = split # split used internally self.dataset_lookup = [] + self.tokenizer: type[TokenizerType] self._handle_download_dataset(download) self._validate_local_directory() - transcripts_path = os.path.join( - dataset_path, language, self.mls_split, "transcripts.txt" - ) + transcripts_path = os.path.join(dataset_path, language, self.mls_split, "transcripts.txt") with open(transcripts_path, "r", encoding="utf-8") as script_file: # read all lines in transcripts.txt transcripts = script_file.readlines() # split each line into (<speakerid>_<bookid>_<chapterid>, <utterance>) - transcripts = [line.strip().split("\t", 1) for line in transcripts] - utterances = [utterance.strip() for _, utterance in transcripts] - identifier = [identifier.strip() for identifier, _ in transcripts] + transcripts = [line.strip().split("\t", 1) for line in transcripts] # type: ignore + utterances = [utterance.strip() for _, utterance in transcripts] # type: ignore + identifier = [identifier.strip() for identifier, _ in transcripts] # type: ignore identifier = [path.split("_") for path in identifier] + if self.split == Split.valid: + np.random.seed(42) + indices = np.random.choice(len(utterances), int(len(utterances) * 0.2)) + utterances = [utterances[i] for i in indices] + identifier = [identifier[i] for i in indices] + elif self.split == Split.train: + np.random.seed(42) + indices = np.random.choice(len(utterances), int(len(utterances) * 0.8)) + utterances = [utterances[i] for i in indices] + identifier = [identifier[i] for i in indices] + self.dataset_lookup = [ { "speakerid": path[0], @@ -129,27 +130,23 @@ class MLSDataset(Dataset): "chapterid": path[2], "utterance": utterance, } - for path, utterance in zip(identifier, utterances) + for path, utterance in zip(identifier, utterances, strict=False) ] - # save dataset_lookup as list of dicts, where each dict contains - # the speakerid, bookid and chapterid, as well as the utterance - # we can then use this to map the utterance to the audio file + def set_tokenizer(self, tokenizer: type[TokenizerType]): + """Sets the tokenizer""" + self.tokenizer = tokenizer + + self.calc_paddings() def _handle_download_dataset(self, download: bool): """Download the dataset""" - if ( - not os.path.exists(os.path.join(self.dataset_path, self.language)) - and download - ): + if not os.path.exists(os.path.join(self.dataset_path, self.language)) and download: os.makedirs(self.dataset_path) url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz" torch.hub.download_url_to_file(url, self.dataset_path) - elif ( - not os.path.exists(os.path.join(self.dataset_path, self.language)) - and not download - ): + elif not os.path.exists(os.path.join(self.dataset_path, self.language)) and not download: raise ValueError("Dataset not found. Set download to True to download it") def _validate_local_directory(self): @@ -158,18 +155,32 @@ class MLSDataset(Dataset): raise ValueError("Dataset path does not exist") if not os.path.exists(os.path.join(self.dataset_path, self.language)): raise ValueError("Language not found in dataset") - if not os.path.exists( - os.path.join(self.dataset_path, self.language, self.mls_split) - ): + if not os.path.exists(os.path.join(self.dataset_path, self.language, self.mls_split)): raise ValueError("Split not found in dataset") - # checks if the transcripts.txt file exists - if not os.path.exists( - os.path.join(dataset_path, language, split, "transcripts.txt") - ): - raise ValueError("transcripts.txt not found in dataset") - - def __get_len__(self): + def calc_paddings(self): + """Sets the maximum length of the spectrogram""" + # check if dataset has been loaded and tokenizer has been set + if not self.dataset_lookup: + raise ValueError("Dataset not loaded") + if not self.tokenizer: + raise ValueError("Tokenizer not set") + + max_spec_length = 0 + max_uterance_length = 0 + for sample in self.dataset_lookup: + spec_length = sample["spectrogram"].shape[0] + if spec_length > max_spec_length: + max_spec_length = spec_length + + utterance_length = sample["utterance"].shape[0] + if utterance_length > max_uterance_length: + max_uterance_length = utterance_length + + self.max_spec_length = max_spec_length + self.max_utterance_length = max_uterance_length + + def __len__(self): """Returns the length of the dataset""" return len(self.dataset_lookup) @@ -197,13 +208,32 @@ class MLSDataset(Dataset): ) waveform, sample_rate = torchaudio.load(audio_path) # type: ignore + # TODO: figure out if we have to resample or not + # TODO: pad correctly (manually) + spec = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128)(waveform).squeeze(0).transpose(0, 1) + print(f"spec.shape: {spec.shape}") + input_length = spec.shape[0] // 2 + spec = ( + torch.nn.functional.pad(spec, pad=(0, self.max_spec_length), mode="constant", value=0) + .unsqueeze(1) + .transpose(2, 3) + ) + + utterance_length = len(utterance) + self.tokenizer.enable_padding() + utterance = self.tokenizer.encode( + utterance, + ).ids + + utterance = torch.Tensor(utterance) return Sample( + # TODO: add flag to only return spectrogram or waveform or both waveform=waveform, - spectrogram=torchaudio.transforms.MelSpectrogram( - sample_rate=16000, n_mels=128 - )(waveform), + spectrogram=spec, + input_length=input_length, utterance=utterance, + utterance_length=utterance_length, sample_rate=sample_rate, speaker_id=self.dataset_lookup[idx]["speakerid"], book_id=self.dataset_lookup[idx]["bookid"], @@ -218,62 +248,6 @@ class MLSDataset(Dataset): torch.hub.download_url_to_file(url, dataset_path) -class DataProcessor: - """Factory for DataProcessingclass - - Transforms the dataset into spectrograms and labels, as well as a tokenizer - """ - - def __init__( - self, - dataset: MultilingualLibriSpeech, - tokenizer_path: str, - data_type: str = "train", - tokenizer_type: str = "BPE", - ): - self.dataset = dataset - self.data_type = data_type - - self.train_audio_transforms = nn.Sequential( - torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128), - torchaudio.transforms.FrequencyMasking(freq_mask_param=30), - torchaudio.transforms.TimeMasking(time_mask_param=100), - ) - - self.valid_audio_transforms = torchaudio.transforms.MelSpectrogram() - self.tokenizer = tokenizer_factory( - tokenizer_path=tokenizer_path, tokenizer_type=tokenizer_type - ) - - def __call__(self) -> tuple[np.ndarray, np.ndarray, int, int]: - """Returns spectrograms, labels and their lenghts""" - for sample in self.dataset: - if self.data_type == "train": - spec = ( - self.train_audio_transforms(sample["waveform"]) - .squeeze(0) - .transpose(0, 1) - ) - elif self.data_type == "valid": - spec = ( - self.valid_audio_transforms(sample["waveform"]) - .squeeze(0) - .transpose(0, 1) - ) - else: - raise ValueError("data_type should be train or valid") - label = torch.Tensor(text_transform.encode(sample["utterance"]).ids) - - spectrograms = ( - nn.utils.rnn.pad_sequence(spectrograms, batch_first=True) - .unsqueeze(1) - .transpose(2, 3) - ) - labels = nn.utils.rnn.pad_sequence(labels, batch_first=True) - - yield spec, label, spec.shape[0] // 2, len(labels) - - if __name__ == "__main__": dataset_path = "/Volumes/pherkel/SWR2-ASR" language = "mls_german_opus" @@ -281,4 +255,9 @@ if __name__ == "__main__": download = False dataset = MLSDataset(dataset_path, language, split, download) - print(dataset[0]) + + tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") + dataset.set_tokenizer(tok) + dataset.calc_paddings() + + print(dataset[41]["spectrogram"].shape) |