aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--pyproject.toml9
-rw-r--r--swr2_asr/model_deep_speech.py150
-rw-r--r--swr2_asr/tokenizer.py72
-rw-r--r--swr2_asr/train.py298
-rw-r--r--swr2_asr/utils.py207
5 files changed, 365 insertions, 371 deletions
diff --git a/pyproject.toml b/pyproject.toml
index fabe364..94f7553 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -25,6 +25,15 @@ pylint = "^2.17.5"
ruff = "^0.0.285"
types-tqdm = "^4.66.0.1"
+[tool.ruff]
+select = ["E", "F", "B", "I"]
+fixable = ["ALL"]
+line-length = 120
+target-version = "py310"
+
+[tool.black]
+line-length = 120
+
[tool.poetry.scripts]
train = "swr2_asr.train:run_cli"
train-bpe-tokenizer = "swr2_asr.tokenizer:train_bpe_tokenizer"
diff --git a/swr2_asr/model_deep_speech.py b/swr2_asr/model_deep_speech.py
new file mode 100644
index 0000000..ea0b667
--- /dev/null
+++ b/swr2_asr/model_deep_speech.py
@@ -0,0 +1,150 @@
+from torch import nn
+import torch.nn.functional as F
+
+
+class CNNLayerNorm(nn.Module):
+ """Layer normalization built for cnns input"""
+
+ def __init__(self, n_feats: int):
+ super().__init__()
+ self.layer_norm = nn.LayerNorm(n_feats)
+
+ def forward(self, data):
+ """x (batch, channel, feature, time)"""
+ data = data.transpose(2, 3).contiguous() # (batch, channel, time, feature)
+ data = self.layer_norm(data)
+ return data.transpose(2, 3).contiguous() # (batch, channel, feature, time)
+
+
+class ResidualCNN(nn.Module):
+ """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf"""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel: int,
+ stride: int,
+ dropout: float,
+ n_feats: int,
+ ):
+ super().__init__()
+
+ self.cnn1 = nn.Conv2d(
+ in_channels, out_channels, kernel, stride, padding=kernel // 2
+ )
+ self.cnn2 = nn.Conv2d(
+ out_channels,
+ out_channels,
+ kernel,
+ stride,
+ padding=kernel // 2,
+ )
+ self.dropout1 = nn.Dropout(dropout)
+ self.dropout2 = nn.Dropout(dropout)
+ self.layer_norm1 = CNNLayerNorm(n_feats)
+ self.layer_norm2 = CNNLayerNorm(n_feats)
+
+ def forward(self, data):
+ """x (batch, channel, feature, time)"""
+ residual = data # (batch, channel, feature, time)
+ data = self.layer_norm1(data)
+ data = F.gelu(data)
+ data = self.dropout1(data)
+ data = self.cnn1(data)
+ data = self.layer_norm2(data)
+ data = F.gelu(data)
+ data = self.dropout2(data)
+ data = self.cnn2(data)
+ data += residual
+ return data # (batch, channel, feature, time)
+
+
+class BidirectionalGRU(nn.Module):
+ """BIdirectional GRU with Layer Normalization and Dropout"""
+
+ def __init__(
+ self,
+ rnn_dim: int,
+ hidden_size: int,
+ dropout: float,
+ batch_first: bool,
+ ):
+ super().__init__()
+
+ self.bi_gru = nn.GRU(
+ input_size=rnn_dim,
+ hidden_size=hidden_size,
+ num_layers=1,
+ batch_first=batch_first,
+ bidirectional=True,
+ )
+ self.layer_norm = nn.LayerNorm(rnn_dim)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self, data):
+ """data (batch, time, feature)"""
+ data = self.layer_norm(data)
+ data = F.gelu(data)
+ data = self.dropout(data)
+ data, _ = self.bi_gru(data)
+ return data
+
+
+class SpeechRecognitionModel(nn.Module):
+ """Speech Recognition Model Inspired by DeepSpeech 2"""
+
+ def __init__(
+ self,
+ n_cnn_layers: int,
+ n_rnn_layers: int,
+ rnn_dim: int,
+ n_class: int,
+ n_feats: int,
+ stride: int = 2,
+ dropout: float = 0.1,
+ ):
+ super().__init__()
+ n_feats //= 2
+ self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2)
+ # n residual cnn layers with filter size of 32
+ self.rescnn_layers = nn.Sequential(
+ *[
+ ResidualCNN(
+ 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats
+ )
+ for _ in range(n_cnn_layers)
+ ]
+ )
+ self.fully_connected = nn.Linear(n_feats * 32, rnn_dim)
+ self.birnn_layers = nn.Sequential(
+ *[
+ BidirectionalGRU(
+ rnn_dim=rnn_dim if i == 0 else rnn_dim * 2,
+ hidden_size=rnn_dim,
+ dropout=dropout,
+ batch_first=i == 0,
+ )
+ for i in range(n_rnn_layers)
+ ]
+ )
+ self.classifier = nn.Sequential(
+ nn.Linear(rnn_dim * 2, rnn_dim), # birnn returns rnn_dim*2
+ nn.GELU(),
+ nn.Dropout(dropout),
+ nn.Linear(rnn_dim, n_class),
+ )
+
+ def forward(self, data):
+ """data (batch, channel, feature, time)"""
+ data = self.cnn(data)
+ data = self.rescnn_layers(data)
+ sizes = data.size()
+ data = data.view(
+ sizes[0], sizes[1] * sizes[2], sizes[3]
+ ) # (batch, feature, time)
+ data = data.transpose(1, 2) # (batch, time, feature)
+ data = self.fully_connected(data)
+ data = self.birnn_layers(data)
+ data = self.classifier(data)
+ return data
diff --git a/swr2_asr/tokenizer.py b/swr2_asr/tokenizer.py
index 4dbb386..5758da7 100644
--- a/swr2_asr/tokenizer.py
+++ b/swr2_asr/tokenizer.py
@@ -1,16 +1,50 @@
"""Tokenizer for use with Multilingual Librispeech"""
-from dataclasses import dataclass
import json
import os
-import click
-from tqdm import tqdm
+from dataclasses import dataclass
+from typing import Type
+import click
from AudioLoader.speech import MultilingualLibriSpeech
-
from tokenizers import Tokenizer, normalizers
from tokenizers.models import BPE
-from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
+from tokenizers.trainers import BpeTrainer
+from tqdm import tqdm
+
+
+class TokenizerType:
+ def encode(self, sequence: str) -> list[int]:
+ raise NotImplementedError
+
+ def decode(self, labels: list[int], remove_special_tokens: bool) -> str:
+ raise NotImplementedError
+
+ def decode_batch(self, labels: list[list[int]]) -> list[str]:
+ raise NotImplementedError
+
+ def get_vocab_size(self) -> int:
+ raise NotImplementedError
+
+ def enable_padding(
+ self,
+ length: int = -1,
+ direction: str = "right",
+ pad_id: int = 0,
+ pad_type_id: int = 0,
+ pad_token: str = "[PAD]",
+ ) -> None:
+ raise NotImplementedError
+
+ def save(self, path: str) -> None:
+ raise NotImplementedError
+
+ @staticmethod
+ def from_file(path: str) -> "TokenizerType":
+ raise NotImplementedError
+
+
+tokenizer_type = Type[TokenizerType]
@dataclass
@@ -20,7 +54,7 @@ class Encoding:
ids: list[int]
-class CharTokenizer:
+class CharTokenizer(TokenizerType):
"""Very simple tokenizer for use with Multilingual Librispeech
Simply checks what characters are in the dataset and uses them as tokens.
@@ -45,9 +79,7 @@ class CharTokenizer:
self.char_map[token] = len(self.char_map)
self.index_map[len(self.index_map)] = token
- def train(
- self, dataset_path: str, language: str, split: str, download: bool = True
- ):
+ def train(self, dataset_path: str, language: str, split: str, download: bool = True):
"""Train the tokenizer on the given dataset
Args:
@@ -65,9 +97,7 @@ class CharTokenizer:
chars: set = set()
for s_plit in splits:
- transcript_path = os.path.join(
- dataset_path, language, s_plit, "transcripts.txt"
- )
+ transcript_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt")
# check if dataset is downloaded, download if not
if download and not os.path.exists(transcript_path):
@@ -90,7 +120,7 @@ class CharTokenizer:
self.char_map[char] = i
self.index_map[i] = char
- def encode(self, text: str):
+ def encode(self, sequence: str):
"""Use a character map and convert text to an integer sequence
automatically maps spaces to <SPACE> and makes everything lowercase
@@ -98,8 +128,8 @@ class CharTokenizer:
"""
int_sequence = []
- text = text.lower()
- for char in text:
+ sequence = sequence.lower()
+ for char in sequence:
if char == " ":
mapped_char = self.char_map["<SPACE>"]
elif char not in self.char_map:
@@ -174,9 +204,7 @@ class CharTokenizer:
@click.option("--language", default="mls_german_opus", help="Language to use")
@click.option("--split", default="train", help="Split to use (including all)")
@click.option("--download", default=True, help="Whether to download the dataset")
-@click.option(
- "--out_path", default="tokenizer.json", help="Path to save the tokenizer to"
-)
+@click.option("--out_path", default="tokenizer.json", help="Path to save the tokenizer to")
@click.option("--vocab_size", default=2000, help="Size of the vocabulary")
def train_bpe_tokenizer(
dataset_path: str,
@@ -210,9 +238,7 @@ def train_bpe_tokenizer(
lines = []
for s_plit in splits:
- transcripts_path = os.path.join(
- dataset_path, language, s_plit, "transcripts.txt"
- )
+ transcripts_path = os.path.join(dataset_path, language, s_plit, "transcripts.txt")
if download and not os.path.exists(transcripts_path):
MultilingualLibriSpeech(dataset_path, language, s_plit, download=True)
@@ -296,9 +322,7 @@ def train_bpe_tokenizer(
@click.option("--dataset_path", default="data", help="Path to the MLS dataset")
@click.option("--language", default="mls_german_opus", help="Language to use")
@click.option("--split", default="train", help="Split to use")
-@click.option(
- "--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to"
-)
+@click.option("--out_path", default="tokenizer_chars.txt", help="Path to save the tokenizer to")
@click.option("--download", default=True, help="Whether to download the dataset")
def train_char_tokenizer(
dataset_path: str,
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 6af1e80..53cdac1 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -1,74 +1,44 @@
"""Training script for the ASR model."""
import os
+from typing import TypedDict
+
import click
import torch
import torch.nn.functional as F
-import torchaudio
-from AudioLoader.speech import MultilingualLibriSpeech
+from tokenizers import Tokenizer
from torch import nn, optim
from torch.utils.data import DataLoader
-from tokenizers import Tokenizer
-from .tokenizer import CharTokenizer
+
+from swr2_asr.model_deep_speech import SpeechRecognitionModel
+from swr2_asr.tokenizer import train_bpe_tokenizer
+from swr2_asr.utils import MLSDataset, Split
from .loss_scores import cer, wer
-train_audio_transforms = nn.Sequential(
- torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
- torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
- torchaudio.transforms.TimeMasking(time_mask_param=100),
-)
-valid_audio_transforms = torchaudio.transforms.MelSpectrogram()
-
-# text_transform = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
-text_transform = CharTokenizer()
-text_transform.from_file("data/tokenizers/char_tokenizer_german.json")
-
-
-def data_processing(data, data_type="train"):
- """Return the spectrograms, labels, and their lengths."""
- spectrograms = []
- labels = []
- input_lengths = []
- label_lengths = []
- for sample in data:
- if data_type == "train":
- spec = train_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1)
- elif data_type == "valid":
- spec = valid_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1)
- else:
- raise ValueError("data_type should be train or valid")
- spectrograms.append(spec)
- label = torch.Tensor(text_transform.encode(sample["utterance"]).ids)
- labels.append(label)
- input_lengths.append(spec.shape[0] // 2)
- label_lengths.append(len(label))
-
- spectrograms = (
- nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
- .unsqueeze(1)
- .transpose(2, 3)
- )
- labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)
+class HParams(TypedDict):
+ """Type for the hyperparameters of the model."""
- return spectrograms, labels, input_lengths, label_lengths
+ n_cnn_layers: int
+ n_rnn_layers: int
+ rnn_dim: int
+ n_class: int
+ n_feats: int
+ stride: int
+ dropout: float
+ learning_rate: float
+ batch_size: int
+ epochs: int
-def greedy_decoder(
- output, labels, label_lengths, blank_label=28, collapse_repeated=True
-):
- # TODO: adopt to support both tokenizers
+def greedy_decoder(output, labels, label_lengths, blank_label=28, collapse_repeated=True):
"""Greedily decode a sequence."""
arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
decodes = []
targets = []
for i, args in enumerate(arg_maxes):
decode = []
- targets.append(
- text_transform.decode(
- [int(x) for x in labels[i][: label_lengths[i]].tolist()]
- )
- )
+ targets.append(text_transform.decode([int(x) for x in labels[i][: label_lengths[i]].tolist()]))
for j, index in enumerate(args):
if index != blank_label:
if collapse_repeated and j != 0 and index == args[j - 1]:
@@ -78,155 +48,6 @@ def greedy_decoder(
return decodes, targets
-# TODO: restructure into own file / class
-class CNNLayerNorm(nn.Module):
- """Layer normalization built for cnns input"""
-
- def __init__(self, n_feats: int):
- super().__init__()
- self.layer_norm = nn.LayerNorm(n_feats)
-
- def forward(self, data):
- """x (batch, channel, feature, time)"""
- data = data.transpose(2, 3).contiguous() # (batch, channel, time, feature)
- data = self.layer_norm(data)
- return data.transpose(2, 3).contiguous() # (batch, channel, feature, time)
-
-
-class ResidualCNN(nn.Module):
- """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf"""
-
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel: int,
- stride: int,
- dropout: float,
- n_feats: int,
- ):
- super().__init__()
-
- self.cnn1 = nn.Conv2d(
- in_channels, out_channels, kernel, stride, padding=kernel // 2
- )
- self.cnn2 = nn.Conv2d(
- out_channels,
- out_channels,
- kernel,
- stride,
- padding=kernel // 2,
- )
- self.dropout1 = nn.Dropout(dropout)
- self.dropout2 = nn.Dropout(dropout)
- self.layer_norm1 = CNNLayerNorm(n_feats)
- self.layer_norm2 = CNNLayerNorm(n_feats)
-
- def forward(self, data):
- """x (batch, channel, feature, time)"""
- residual = data # (batch, channel, feature, time)
- data = self.layer_norm1(data)
- data = F.gelu(data)
- data = self.dropout1(data)
- data = self.cnn1(data)
- data = self.layer_norm2(data)
- data = F.gelu(data)
- data = self.dropout2(data)
- data = self.cnn2(data)
- data += residual
- return data # (batch, channel, feature, time)
-
-
-class BidirectionalGRU(nn.Module):
- """BIdirectional GRU with Layer Normalization and Dropout"""
-
- def __init__(
- self,
- rnn_dim: int,
- hidden_size: int,
- dropout: float,
- batch_first: bool,
- ):
- super().__init__()
-
- self.bi_gru = nn.GRU(
- input_size=rnn_dim,
- hidden_size=hidden_size,
- num_layers=1,
- batch_first=batch_first,
- bidirectional=True,
- )
- self.layer_norm = nn.LayerNorm(rnn_dim)
- self.dropout = nn.Dropout(dropout)
-
- def forward(self, data):
- """data (batch, time, feature)"""
- data = self.layer_norm(data)
- data = F.gelu(data)
- data = self.dropout(data)
- data, _ = self.bi_gru(data)
- return data
-
-
-class SpeechRecognitionModel(nn.Module):
- """Speech Recognition Model Inspired by DeepSpeech 2"""
-
- def __init__(
- self,
- n_cnn_layers: int,
- n_rnn_layers: int,
- rnn_dim: int,
- n_class: int,
- n_feats: int,
- stride: int = 2,
- dropout: float = 0.1,
- ):
- super().__init__()
- n_feats //= 2
- self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2)
- # n residual cnn layers with filter size of 32
- self.rescnn_layers = nn.Sequential(
- *[
- ResidualCNN(
- 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats
- )
- for _ in range(n_cnn_layers)
- ]
- )
- self.fully_connected = nn.Linear(n_feats * 32, rnn_dim)
- self.birnn_layers = nn.Sequential(
- *[
- BidirectionalGRU(
- rnn_dim=rnn_dim if i == 0 else rnn_dim * 2,
- hidden_size=rnn_dim,
- dropout=dropout,
- batch_first=i == 0,
- )
- for i in range(n_rnn_layers)
- ]
- )
- self.classifier = nn.Sequential(
- nn.Linear(rnn_dim * 2, rnn_dim), # birnn returns rnn_dim*2
- nn.GELU(),
- nn.Dropout(dropout),
- nn.Linear(rnn_dim, n_class),
- )
-
- def forward(self, data):
- """data (batch, channel, feature, time)"""
- data = self.cnn(data)
- data = self.rescnn_layers(data)
- sizes = data.size()
- data = data.view(
- sizes[0], sizes[1] * sizes[2], sizes[3]
- ) # (batch, feature, time)
- data = data.transpose(1, 2) # (batch, time, feature)
- data = self.fully_connected(data)
- data = self.birnn_layers(data)
- data = self.classifier(data)
- return data
-
-
class IterMeter:
"""keeps track of total iterations"""
@@ -256,9 +77,8 @@ def train(
model.train()
data_len = len(train_loader.dataset)
for batch_idx, _data in enumerate(train_loader):
- spectrograms, labels, input_lengths, label_lengths = _data
+ _, spectrograms, input_lengths, labels, label_lengths, *_ = _data
spectrograms, labels = spectrograms.to(device), labels.to(device)
-
optimizer.zero_grad()
output = model(spectrograms) # (batch, time, n_class)
@@ -282,7 +102,6 @@ def train(
return loss.item()
-# TODO: check how dataloader can be made more efficient
def test(model, device, test_loader, criterion):
"""Test"""
print("\nevaluating...")
@@ -301,9 +120,7 @@ def test(model, device, test_loader, criterion):
loss = criterion(output, labels, input_lengths, label_lengths)
test_loss += loss.item() / len(test_loader)
- decoded_preds, decoded_targets = greedy_decoder(
- output.transpose(0, 1), labels, label_lengths
- )
+ decoded_preds, decoded_targets = greedy_decoder(output.transpose(0, 1), labels, label_lengths)
for j, pred in enumerate(decoded_preds):
test_cer.append(cer(decoded_targets[j], pred))
test_wer.append(wer(decoded_targets[j], pred))
@@ -324,46 +141,62 @@ def run(
load: bool,
path: str,
dataset_path: str,
+ language: str,
) -> None:
"""Runs the training script."""
- hparams = {
- "n_cnn_layers": 3,
- "n_rnn_layers": 5,
- "rnn_dim": 512,
- "n_class": 36, # TODO: dynamically determine this from vocab size
- "n_feats": 128,
- "stride": 2,
- "dropout": 0.1,
- "learning_rate": learning_rate,
- "batch_size": batch_size,
- "epochs": epochs,
- }
-
use_cuda = torch.cuda.is_available()
torch.manual_seed(42)
device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member
# device = torch.device("mps")
- download_dataset = not os.path.isdir(path)
- train_dataset = MultilingualLibriSpeech(
- dataset_path, "mls_german_opus", split="dev", download=download_dataset
- )
- test_dataset = MultilingualLibriSpeech(
- dataset_path, "mls_german_opus", split="test", download=False
+ # load dataset
+ train_dataset = MLSDataset(dataset_path, language, Split.train, download=True)
+ valid_dataset = MLSDataset(dataset_path, language, Split.valid, download=True)
+ test_dataset = MLSDataset(dataset_path, language, Split.test, download=True)
+
+ # TODO: add flag to choose tokenizer
+ # load tokenizer (bpe by default):
+ if not os.path.isfile("data/tokenizers/bpe_tokenizer_german_3000.json"):
+ print("There is no tokenizer available. Do you want to train it on the dataset?")
+ input("Press Enter to continue...")
+ train_bpe_tokenizer(
+ dataset_path=dataset_path,
+ language=language,
+ split="all",
+ download=False,
+ out_path="data/tokenizers/bpe_tokenizer_german_3000.json",
+ vocab_size=3000,
+ )
+
+ tokenizer = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
+
+ train_dataset.set_tokenizer(tokenizer)
+ valid_dataset.set_tokenizer(tokenizer)
+ test_dataset.set_tokenizer(tokenizer)
+
+ hparams = HParams(
+ n_cnn_layers=3,
+ n_rnn_layers=5,
+ rnn_dim=512,
+ n_class=tokenizer.get_vocab_size(),
+ n_feats=128,
+ stride=2,
+ dropout=0.1,
+ learning_rate=learning_rate,
+ batch_size=batch_size,
+ epochs=epochs,
)
train_loader = DataLoader(
train_dataset,
batch_size=hparams["batch_size"],
shuffle=True,
- collate_fn=lambda x: data_processing(x, "train"),
)
- test_loader = DataLoader(
- test_dataset,
+ valid_loader = DataLoader(
+ valid_dataset,
batch_size=hparams["batch_size"],
shuffle=True,
- collate_fn=lambda x: data_processing(x, "train"),
)
# enable flag to find the most compatible algorithms in advance
@@ -380,9 +213,7 @@ def run(
hparams["dropout"],
).to(device)
- print(
- "Num Model Parameters", sum((param.nelement() for param in model.parameters()))
- )
+ print("Num Model Parameters", sum((param.nelement() for param in model.parameters())))
optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"])
criterion = nn.CTCLoss(blank=28).to(device)
if load:
@@ -412,7 +243,7 @@ def run(
iter_meter,
)
- test(model=model, device=device, test_loader=test_loader, criterion=criterion)
+ test(model=model, device=device, test_loader=valid_loader, criterion=criterion)
print("saving epoch", str(epoch))
torch.save(
{"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss},
@@ -452,4 +283,5 @@ def run_cli(
load=load,
path=path,
dataset_path=dataset_path,
+ language="mls_german_opus",
)
diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py
index c4aeb0b..786fbcf 100644
--- a/swr2_asr/utils.py
+++ b/swr2_asr/utils.py
@@ -1,16 +1,21 @@
"""Class containing utils for the ASR system."""
-from dataclasses import dataclass
import os
-from AudioLoader.speech import MultilingualLibriSpeech
+from enum import Enum
+from typing import TypedDict
+
import numpy as np
import torch
import torchaudio
-from torch import nn
-from torch.utils.data import Dataset, DataLoader
-from enum import Enum
-
from tokenizers import Tokenizer
-from swr2_asr.tokenizer import CharTokenizer
+from torch.utils.data import Dataset
+
+from swr2_asr.tokenizer import CharTokenizer, TokenizerType
+
+train_audio_transforms = torch.nn.Sequential(
+ torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
+ torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
+ torchaudio.transforms.TimeMasking(time_mask_param=100),
+)
# create enum specifiying dataset splits
@@ -40,34 +45,20 @@ def split_to_mls_split(split: Split) -> MLSSplit:
return split # type: ignore
-@dataclass
-class Sample:
- """Dataclass for a sample in the dataset"""
+class Sample(TypedDict):
+ """Type for a sample in the dataset"""
waveform: torch.Tensor
spectrogram: torch.Tensor
- utterance: str
+ input_length: int
+ utterance: torch.Tensor
+ utterance_length: int
sample_rate: int
speaker_id: str
book_id: str
chapter_id: str
-def tokenizer_factory(tokenizer_path: str, tokenizer_type: str = "BPE"):
- """Factory for Tokenizer class
-
- Args:
- tokenizer_type (str, optional): Type of tokenizer to use. Defaults to "BPE".
-
- Returns:
- nn.Module: Tokenizer class
- """
- if tokenizer_type == "BPE":
- return Tokenizer.from_file(tokenizer_path)
- elif tokenizer_type == "char":
- return CharTokenizer.from_file(tokenizer_path)
-
-
class MLSDataset(Dataset):
"""Custom Dataset for reading Multilingual LibriSpeech
@@ -105,23 +96,33 @@ class MLSDataset(Dataset):
self.mls_split: MLSSplit = split_to_mls_split(split) # split path on disk
self.split: Split = split # split used internally
self.dataset_lookup = []
+ self.tokenizer: type[TokenizerType]
self._handle_download_dataset(download)
self._validate_local_directory()
- transcripts_path = os.path.join(
- dataset_path, language, self.mls_split, "transcripts.txt"
- )
+ transcripts_path = os.path.join(dataset_path, language, self.mls_split, "transcripts.txt")
with open(transcripts_path, "r", encoding="utf-8") as script_file:
# read all lines in transcripts.txt
transcripts = script_file.readlines()
# split each line into (<speakerid>_<bookid>_<chapterid>, <utterance>)
- transcripts = [line.strip().split("\t", 1) for line in transcripts]
- utterances = [utterance.strip() for _, utterance in transcripts]
- identifier = [identifier.strip() for identifier, _ in transcripts]
+ transcripts = [line.strip().split("\t", 1) for line in transcripts] # type: ignore
+ utterances = [utterance.strip() for _, utterance in transcripts] # type: ignore
+ identifier = [identifier.strip() for identifier, _ in transcripts] # type: ignore
identifier = [path.split("_") for path in identifier]
+ if self.split == Split.valid:
+ np.random.seed(42)
+ indices = np.random.choice(len(utterances), int(len(utterances) * 0.2))
+ utterances = [utterances[i] for i in indices]
+ identifier = [identifier[i] for i in indices]
+ elif self.split == Split.train:
+ np.random.seed(42)
+ indices = np.random.choice(len(utterances), int(len(utterances) * 0.8))
+ utterances = [utterances[i] for i in indices]
+ identifier = [identifier[i] for i in indices]
+
self.dataset_lookup = [
{
"speakerid": path[0],
@@ -129,27 +130,23 @@ class MLSDataset(Dataset):
"chapterid": path[2],
"utterance": utterance,
}
- for path, utterance in zip(identifier, utterances)
+ for path, utterance in zip(identifier, utterances, strict=False)
]
- # save dataset_lookup as list of dicts, where each dict contains
- # the speakerid, bookid and chapterid, as well as the utterance
- # we can then use this to map the utterance to the audio file
+ def set_tokenizer(self, tokenizer: type[TokenizerType]):
+ """Sets the tokenizer"""
+ self.tokenizer = tokenizer
+
+ self.calc_paddings()
def _handle_download_dataset(self, download: bool):
"""Download the dataset"""
- if (
- not os.path.exists(os.path.join(self.dataset_path, self.language))
- and download
- ):
+ if not os.path.exists(os.path.join(self.dataset_path, self.language)) and download:
os.makedirs(self.dataset_path)
url = f"https://dl.fbaipublicfiles.com/mls/{self.language}.tar.gz"
torch.hub.download_url_to_file(url, self.dataset_path)
- elif (
- not os.path.exists(os.path.join(self.dataset_path, self.language))
- and not download
- ):
+ elif not os.path.exists(os.path.join(self.dataset_path, self.language)) and not download:
raise ValueError("Dataset not found. Set download to True to download it")
def _validate_local_directory(self):
@@ -158,18 +155,32 @@ class MLSDataset(Dataset):
raise ValueError("Dataset path does not exist")
if not os.path.exists(os.path.join(self.dataset_path, self.language)):
raise ValueError("Language not found in dataset")
- if not os.path.exists(
- os.path.join(self.dataset_path, self.language, self.mls_split)
- ):
+ if not os.path.exists(os.path.join(self.dataset_path, self.language, self.mls_split)):
raise ValueError("Split not found in dataset")
- # checks if the transcripts.txt file exists
- if not os.path.exists(
- os.path.join(dataset_path, language, split, "transcripts.txt")
- ):
- raise ValueError("transcripts.txt not found in dataset")
-
- def __get_len__(self):
+ def calc_paddings(self):
+ """Sets the maximum length of the spectrogram"""
+ # check if dataset has been loaded and tokenizer has been set
+ if not self.dataset_lookup:
+ raise ValueError("Dataset not loaded")
+ if not self.tokenizer:
+ raise ValueError("Tokenizer not set")
+
+ max_spec_length = 0
+ max_uterance_length = 0
+ for sample in self.dataset_lookup:
+ spec_length = sample["spectrogram"].shape[0]
+ if spec_length > max_spec_length:
+ max_spec_length = spec_length
+
+ utterance_length = sample["utterance"].shape[0]
+ if utterance_length > max_uterance_length:
+ max_uterance_length = utterance_length
+
+ self.max_spec_length = max_spec_length
+ self.max_utterance_length = max_uterance_length
+
+ def __len__(self):
"""Returns the length of the dataset"""
return len(self.dataset_lookup)
@@ -197,13 +208,32 @@ class MLSDataset(Dataset):
)
waveform, sample_rate = torchaudio.load(audio_path) # type: ignore
+ # TODO: figure out if we have to resample or not
+ # TODO: pad correctly (manually)
+ spec = torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128)(waveform).squeeze(0).transpose(0, 1)
+ print(f"spec.shape: {spec.shape}")
+ input_length = spec.shape[0] // 2
+ spec = (
+ torch.nn.functional.pad(spec, pad=(0, self.max_spec_length), mode="constant", value=0)
+ .unsqueeze(1)
+ .transpose(2, 3)
+ )
+
+ utterance_length = len(utterance)
+ self.tokenizer.enable_padding()
+ utterance = self.tokenizer.encode(
+ utterance,
+ ).ids
+
+ utterance = torch.Tensor(utterance)
return Sample(
+ # TODO: add flag to only return spectrogram or waveform or both
waveform=waveform,
- spectrogram=torchaudio.transforms.MelSpectrogram(
- sample_rate=16000, n_mels=128
- )(waveform),
+ spectrogram=spec,
+ input_length=input_length,
utterance=utterance,
+ utterance_length=utterance_length,
sample_rate=sample_rate,
speaker_id=self.dataset_lookup[idx]["speakerid"],
book_id=self.dataset_lookup[idx]["bookid"],
@@ -218,62 +248,6 @@ class MLSDataset(Dataset):
torch.hub.download_url_to_file(url, dataset_path)
-class DataProcessor:
- """Factory for DataProcessingclass
-
- Transforms the dataset into spectrograms and labels, as well as a tokenizer
- """
-
- def __init__(
- self,
- dataset: MultilingualLibriSpeech,
- tokenizer_path: str,
- data_type: str = "train",
- tokenizer_type: str = "BPE",
- ):
- self.dataset = dataset
- self.data_type = data_type
-
- self.train_audio_transforms = nn.Sequential(
- torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
- torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
- torchaudio.transforms.TimeMasking(time_mask_param=100),
- )
-
- self.valid_audio_transforms = torchaudio.transforms.MelSpectrogram()
- self.tokenizer = tokenizer_factory(
- tokenizer_path=tokenizer_path, tokenizer_type=tokenizer_type
- )
-
- def __call__(self) -> tuple[np.ndarray, np.ndarray, int, int]:
- """Returns spectrograms, labels and their lenghts"""
- for sample in self.dataset:
- if self.data_type == "train":
- spec = (
- self.train_audio_transforms(sample["waveform"])
- .squeeze(0)
- .transpose(0, 1)
- )
- elif self.data_type == "valid":
- spec = (
- self.valid_audio_transforms(sample["waveform"])
- .squeeze(0)
- .transpose(0, 1)
- )
- else:
- raise ValueError("data_type should be train or valid")
- label = torch.Tensor(text_transform.encode(sample["utterance"]).ids)
-
- spectrograms = (
- nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
- .unsqueeze(1)
- .transpose(2, 3)
- )
- labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)
-
- yield spec, label, spec.shape[0] // 2, len(labels)
-
-
if __name__ == "__main__":
dataset_path = "/Volumes/pherkel/SWR2-ASR"
language = "mls_german_opus"
@@ -281,4 +255,9 @@ if __name__ == "__main__":
download = False
dataset = MLSDataset(dataset_path, language, split, download)
- print(dataset[0])
+
+ tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
+ dataset.set_tokenizer(tok)
+ dataset.calc_paddings()
+
+ print(dataset[41]["spectrogram"].shape)