aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r--swr2_asr/train.py298
1 files changed, 65 insertions, 233 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 6af1e80..53cdac1 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -1,74 +1,44 @@
"""Training script for the ASR model."""
import os
+from typing import TypedDict
+
import click
import torch
import torch.nn.functional as F
-import torchaudio
-from AudioLoader.speech import MultilingualLibriSpeech
+from tokenizers import Tokenizer
from torch import nn, optim
from torch.utils.data import DataLoader
-from tokenizers import Tokenizer
-from .tokenizer import CharTokenizer
+
+from swr2_asr.model_deep_speech import SpeechRecognitionModel
+from swr2_asr.tokenizer import train_bpe_tokenizer
+from swr2_asr.utils import MLSDataset, Split
from .loss_scores import cer, wer
-train_audio_transforms = nn.Sequential(
- torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
- torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
- torchaudio.transforms.TimeMasking(time_mask_param=100),
-)
-valid_audio_transforms = torchaudio.transforms.MelSpectrogram()
-
-# text_transform = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
-text_transform = CharTokenizer()
-text_transform.from_file("data/tokenizers/char_tokenizer_german.json")
-
-
-def data_processing(data, data_type="train"):
- """Return the spectrograms, labels, and their lengths."""
- spectrograms = []
- labels = []
- input_lengths = []
- label_lengths = []
- for sample in data:
- if data_type == "train":
- spec = train_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1)
- elif data_type == "valid":
- spec = valid_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1)
- else:
- raise ValueError("data_type should be train or valid")
- spectrograms.append(spec)
- label = torch.Tensor(text_transform.encode(sample["utterance"]).ids)
- labels.append(label)
- input_lengths.append(spec.shape[0] // 2)
- label_lengths.append(len(label))
-
- spectrograms = (
- nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
- .unsqueeze(1)
- .transpose(2, 3)
- )
- labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)
+class HParams(TypedDict):
+ """Type for the hyperparameters of the model."""
- return spectrograms, labels, input_lengths, label_lengths
+ n_cnn_layers: int
+ n_rnn_layers: int
+ rnn_dim: int
+ n_class: int
+ n_feats: int
+ stride: int
+ dropout: float
+ learning_rate: float
+ batch_size: int
+ epochs: int
-def greedy_decoder(
- output, labels, label_lengths, blank_label=28, collapse_repeated=True
-):
- # TODO: adopt to support both tokenizers
+def greedy_decoder(output, labels, label_lengths, blank_label=28, collapse_repeated=True):
"""Greedily decode a sequence."""
arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
decodes = []
targets = []
for i, args in enumerate(arg_maxes):
decode = []
- targets.append(
- text_transform.decode(
- [int(x) for x in labels[i][: label_lengths[i]].tolist()]
- )
- )
+ targets.append(text_transform.decode([int(x) for x in labels[i][: label_lengths[i]].tolist()]))
for j, index in enumerate(args):
if index != blank_label:
if collapse_repeated and j != 0 and index == args[j - 1]:
@@ -78,155 +48,6 @@ def greedy_decoder(
return decodes, targets
-# TODO: restructure into own file / class
-class CNNLayerNorm(nn.Module):
- """Layer normalization built for cnns input"""
-
- def __init__(self, n_feats: int):
- super().__init__()
- self.layer_norm = nn.LayerNorm(n_feats)
-
- def forward(self, data):
- """x (batch, channel, feature, time)"""
- data = data.transpose(2, 3).contiguous() # (batch, channel, time, feature)
- data = self.layer_norm(data)
- return data.transpose(2, 3).contiguous() # (batch, channel, feature, time)
-
-
-class ResidualCNN(nn.Module):
- """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf"""
-
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel: int,
- stride: int,
- dropout: float,
- n_feats: int,
- ):
- super().__init__()
-
- self.cnn1 = nn.Conv2d(
- in_channels, out_channels, kernel, stride, padding=kernel // 2
- )
- self.cnn2 = nn.Conv2d(
- out_channels,
- out_channels,
- kernel,
- stride,
- padding=kernel // 2,
- )
- self.dropout1 = nn.Dropout(dropout)
- self.dropout2 = nn.Dropout(dropout)
- self.layer_norm1 = CNNLayerNorm(n_feats)
- self.layer_norm2 = CNNLayerNorm(n_feats)
-
- def forward(self, data):
- """x (batch, channel, feature, time)"""
- residual = data # (batch, channel, feature, time)
- data = self.layer_norm1(data)
- data = F.gelu(data)
- data = self.dropout1(data)
- data = self.cnn1(data)
- data = self.layer_norm2(data)
- data = F.gelu(data)
- data = self.dropout2(data)
- data = self.cnn2(data)
- data += residual
- return data # (batch, channel, feature, time)
-
-
-class BidirectionalGRU(nn.Module):
- """BIdirectional GRU with Layer Normalization and Dropout"""
-
- def __init__(
- self,
- rnn_dim: int,
- hidden_size: int,
- dropout: float,
- batch_first: bool,
- ):
- super().__init__()
-
- self.bi_gru = nn.GRU(
- input_size=rnn_dim,
- hidden_size=hidden_size,
- num_layers=1,
- batch_first=batch_first,
- bidirectional=True,
- )
- self.layer_norm = nn.LayerNorm(rnn_dim)
- self.dropout = nn.Dropout(dropout)
-
- def forward(self, data):
- """data (batch, time, feature)"""
- data = self.layer_norm(data)
- data = F.gelu(data)
- data = self.dropout(data)
- data, _ = self.bi_gru(data)
- return data
-
-
-class SpeechRecognitionModel(nn.Module):
- """Speech Recognition Model Inspired by DeepSpeech 2"""
-
- def __init__(
- self,
- n_cnn_layers: int,
- n_rnn_layers: int,
- rnn_dim: int,
- n_class: int,
- n_feats: int,
- stride: int = 2,
- dropout: float = 0.1,
- ):
- super().__init__()
- n_feats //= 2
- self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2)
- # n residual cnn layers with filter size of 32
- self.rescnn_layers = nn.Sequential(
- *[
- ResidualCNN(
- 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats
- )
- for _ in range(n_cnn_layers)
- ]
- )
- self.fully_connected = nn.Linear(n_feats * 32, rnn_dim)
- self.birnn_layers = nn.Sequential(
- *[
- BidirectionalGRU(
- rnn_dim=rnn_dim if i == 0 else rnn_dim * 2,
- hidden_size=rnn_dim,
- dropout=dropout,
- batch_first=i == 0,
- )
- for i in range(n_rnn_layers)
- ]
- )
- self.classifier = nn.Sequential(
- nn.Linear(rnn_dim * 2, rnn_dim), # birnn returns rnn_dim*2
- nn.GELU(),
- nn.Dropout(dropout),
- nn.Linear(rnn_dim, n_class),
- )
-
- def forward(self, data):
- """data (batch, channel, feature, time)"""
- data = self.cnn(data)
- data = self.rescnn_layers(data)
- sizes = data.size()
- data = data.view(
- sizes[0], sizes[1] * sizes[2], sizes[3]
- ) # (batch, feature, time)
- data = data.transpose(1, 2) # (batch, time, feature)
- data = self.fully_connected(data)
- data = self.birnn_layers(data)
- data = self.classifier(data)
- return data
-
-
class IterMeter:
"""keeps track of total iterations"""
@@ -256,9 +77,8 @@ def train(
model.train()
data_len = len(train_loader.dataset)
for batch_idx, _data in enumerate(train_loader):
- spectrograms, labels, input_lengths, label_lengths = _data
+ _, spectrograms, input_lengths, labels, label_lengths, *_ = _data
spectrograms, labels = spectrograms.to(device), labels.to(device)
-
optimizer.zero_grad()
output = model(spectrograms) # (batch, time, n_class)
@@ -282,7 +102,6 @@ def train(
return loss.item()
-# TODO: check how dataloader can be made more efficient
def test(model, device, test_loader, criterion):
"""Test"""
print("\nevaluating...")
@@ -301,9 +120,7 @@ def test(model, device, test_loader, criterion):
loss = criterion(output, labels, input_lengths, label_lengths)
test_loss += loss.item() / len(test_loader)
- decoded_preds, decoded_targets = greedy_decoder(
- output.transpose(0, 1), labels, label_lengths
- )
+ decoded_preds, decoded_targets = greedy_decoder(output.transpose(0, 1), labels, label_lengths)
for j, pred in enumerate(decoded_preds):
test_cer.append(cer(decoded_targets[j], pred))
test_wer.append(wer(decoded_targets[j], pred))
@@ -324,46 +141,62 @@ def run(
load: bool,
path: str,
dataset_path: str,
+ language: str,
) -> None:
"""Runs the training script."""
- hparams = {
- "n_cnn_layers": 3,
- "n_rnn_layers": 5,
- "rnn_dim": 512,
- "n_class": 36, # TODO: dynamically determine this from vocab size
- "n_feats": 128,
- "stride": 2,
- "dropout": 0.1,
- "learning_rate": learning_rate,
- "batch_size": batch_size,
- "epochs": epochs,
- }
-
use_cuda = torch.cuda.is_available()
torch.manual_seed(42)
device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member
# device = torch.device("mps")
- download_dataset = not os.path.isdir(path)
- train_dataset = MultilingualLibriSpeech(
- dataset_path, "mls_german_opus", split="dev", download=download_dataset
- )
- test_dataset = MultilingualLibriSpeech(
- dataset_path, "mls_german_opus", split="test", download=False
+ # load dataset
+ train_dataset = MLSDataset(dataset_path, language, Split.train, download=True)
+ valid_dataset = MLSDataset(dataset_path, language, Split.valid, download=True)
+ test_dataset = MLSDataset(dataset_path, language, Split.test, download=True)
+
+ # TODO: add flag to choose tokenizer
+ # load tokenizer (bpe by default):
+ if not os.path.isfile("data/tokenizers/bpe_tokenizer_german_3000.json"):
+ print("There is no tokenizer available. Do you want to train it on the dataset?")
+ input("Press Enter to continue...")
+ train_bpe_tokenizer(
+ dataset_path=dataset_path,
+ language=language,
+ split="all",
+ download=False,
+ out_path="data/tokenizers/bpe_tokenizer_german_3000.json",
+ vocab_size=3000,
+ )
+
+ tokenizer = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
+
+ train_dataset.set_tokenizer(tokenizer)
+ valid_dataset.set_tokenizer(tokenizer)
+ test_dataset.set_tokenizer(tokenizer)
+
+ hparams = HParams(
+ n_cnn_layers=3,
+ n_rnn_layers=5,
+ rnn_dim=512,
+ n_class=tokenizer.get_vocab_size(),
+ n_feats=128,
+ stride=2,
+ dropout=0.1,
+ learning_rate=learning_rate,
+ batch_size=batch_size,
+ epochs=epochs,
)
train_loader = DataLoader(
train_dataset,
batch_size=hparams["batch_size"],
shuffle=True,
- collate_fn=lambda x: data_processing(x, "train"),
)
- test_loader = DataLoader(
- test_dataset,
+ valid_loader = DataLoader(
+ valid_dataset,
batch_size=hparams["batch_size"],
shuffle=True,
- collate_fn=lambda x: data_processing(x, "train"),
)
# enable flag to find the most compatible algorithms in advance
@@ -380,9 +213,7 @@ def run(
hparams["dropout"],
).to(device)
- print(
- "Num Model Parameters", sum((param.nelement() for param in model.parameters()))
- )
+ print("Num Model Parameters", sum((param.nelement() for param in model.parameters())))
optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"])
criterion = nn.CTCLoss(blank=28).to(device)
if load:
@@ -412,7 +243,7 @@ def run(
iter_meter,
)
- test(model=model, device=device, test_loader=test_loader, criterion=criterion)
+ test(model=model, device=device, test_loader=valid_loader, criterion=criterion)
print("saving epoch", str(epoch))
torch.save(
{"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss},
@@ -452,4 +283,5 @@ def run_cli(
load=load,
path=path,
dataset_path=dataset_path,
+ language="mls_german_opus",
)