aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/train.py
diff options
context:
space:
mode:
authorPherkel2023-09-04 22:55:27 +0200
committerGitHub2023-09-04 22:55:27 +0200
commit93e49e708fa59613406249069d31c2f6c8f2d2ab (patch)
tree9f413c226283db990116c9559ffffb9124b911d8 /swr2_asr/train.py
parent14ceeb5ad36beea2f05214aa26260cdd1d86590b (diff)
parent0d70a19e1fea6eda3f7b16ad0084591613f2de72 (diff)
Merge pull request #27 from Algo-Boys/refactor_modularize
Refactor modularize
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r--swr2_asr/train.py354
1 files changed, 103 insertions, 251 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 6af1e80..63deb72 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -1,232 +1,57 @@
"""Training script for the ASR model."""
import os
+from typing import TypedDict
+
import click
import torch
import torch.nn.functional as F
-import torchaudio
-from AudioLoader.speech import MultilingualLibriSpeech
from torch import nn, optim
from torch.utils.data import DataLoader
-from tokenizers import Tokenizer
-from .tokenizer import CharTokenizer
+from tqdm import tqdm
+
+from swr2_asr.model_deep_speech import SpeechRecognitionModel
+from swr2_asr.tokenizer import CharTokenizer, train_char_tokenizer
+from swr2_asr.utils import MLSDataset, Split, collate_fn
from .loss_scores import cer, wer
-train_audio_transforms = nn.Sequential(
- torchaudio.transforms.MelSpectrogram(sample_rate=16000, n_mels=128),
- torchaudio.transforms.FrequencyMasking(freq_mask_param=30),
- torchaudio.transforms.TimeMasking(time_mask_param=100),
-)
+# TODO: improve naming of functions
-valid_audio_transforms = torchaudio.transforms.MelSpectrogram()
-
-# text_transform = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
-text_transform = CharTokenizer()
-text_transform.from_file("data/tokenizers/char_tokenizer_german.json")
-
-
-def data_processing(data, data_type="train"):
- """Return the spectrograms, labels, and their lengths."""
- spectrograms = []
- labels = []
- input_lengths = []
- label_lengths = []
- for sample in data:
- if data_type == "train":
- spec = train_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1)
- elif data_type == "valid":
- spec = valid_audio_transforms(sample["waveform"]).squeeze(0).transpose(0, 1)
- else:
- raise ValueError("data_type should be train or valid")
- spectrograms.append(spec)
- label = torch.Tensor(text_transform.encode(sample["utterance"]).ids)
- labels.append(label)
- input_lengths.append(spec.shape[0] // 2)
- label_lengths.append(len(label))
-
- spectrograms = (
- nn.utils.rnn.pad_sequence(spectrograms, batch_first=True)
- .unsqueeze(1)
- .transpose(2, 3)
- )
- labels = nn.utils.rnn.pad_sequence(labels, batch_first=True)
- return spectrograms, labels, input_lengths, label_lengths
+class HParams(TypedDict):
+ """Type for the hyperparameters of the model."""
+ n_cnn_layers: int
+ n_rnn_layers: int
+ rnn_dim: int
+ n_class: int
+ n_feats: int
+ stride: int
+ dropout: float
+ learning_rate: float
+ batch_size: int
+ epochs: int
-def greedy_decoder(
- output, labels, label_lengths, blank_label=28, collapse_repeated=True
-):
- # TODO: adopt to support both tokenizers
+
+def greedy_decoder(output, tokenizer, labels, label_lengths, collapse_repeated=True):
"""Greedily decode a sequence."""
+ print("output shape", output.shape)
arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
+ blank_label = tokenizer.encode(" ").ids[0]
decodes = []
targets = []
for i, args in enumerate(arg_maxes):
decode = []
- targets.append(
- text_transform.decode(
- [int(x) for x in labels[i][: label_lengths[i]].tolist()]
- )
- )
+ targets.append(tokenizer.decode([int(x) for x in labels[i][: label_lengths[i]].tolist()]))
for j, index in enumerate(args):
if index != blank_label:
if collapse_repeated and j != 0 and index == args[j - 1]:
continue
decode.append(index.item())
- decodes.append(text_transform.decode(decode))
+ decodes.append(tokenizer.decode(decode))
return decodes, targets
-# TODO: restructure into own file / class
-class CNNLayerNorm(nn.Module):
- """Layer normalization built for cnns input"""
-
- def __init__(self, n_feats: int):
- super().__init__()
- self.layer_norm = nn.LayerNorm(n_feats)
-
- def forward(self, data):
- """x (batch, channel, feature, time)"""
- data = data.transpose(2, 3).contiguous() # (batch, channel, time, feature)
- data = self.layer_norm(data)
- return data.transpose(2, 3).contiguous() # (batch, channel, feature, time)
-
-
-class ResidualCNN(nn.Module):
- """Residual CNN inspired by https://arxiv.org/pdf/1603.05027.pdf"""
-
- def __init__(
- self,
- in_channels: int,
- out_channels: int,
- kernel: int,
- stride: int,
- dropout: float,
- n_feats: int,
- ):
- super().__init__()
-
- self.cnn1 = nn.Conv2d(
- in_channels, out_channels, kernel, stride, padding=kernel // 2
- )
- self.cnn2 = nn.Conv2d(
- out_channels,
- out_channels,
- kernel,
- stride,
- padding=kernel // 2,
- )
- self.dropout1 = nn.Dropout(dropout)
- self.dropout2 = nn.Dropout(dropout)
- self.layer_norm1 = CNNLayerNorm(n_feats)
- self.layer_norm2 = CNNLayerNorm(n_feats)
-
- def forward(self, data):
- """x (batch, channel, feature, time)"""
- residual = data # (batch, channel, feature, time)
- data = self.layer_norm1(data)
- data = F.gelu(data)
- data = self.dropout1(data)
- data = self.cnn1(data)
- data = self.layer_norm2(data)
- data = F.gelu(data)
- data = self.dropout2(data)
- data = self.cnn2(data)
- data += residual
- return data # (batch, channel, feature, time)
-
-
-class BidirectionalGRU(nn.Module):
- """BIdirectional GRU with Layer Normalization and Dropout"""
-
- def __init__(
- self,
- rnn_dim: int,
- hidden_size: int,
- dropout: float,
- batch_first: bool,
- ):
- super().__init__()
-
- self.bi_gru = nn.GRU(
- input_size=rnn_dim,
- hidden_size=hidden_size,
- num_layers=1,
- batch_first=batch_first,
- bidirectional=True,
- )
- self.layer_norm = nn.LayerNorm(rnn_dim)
- self.dropout = nn.Dropout(dropout)
-
- def forward(self, data):
- """data (batch, time, feature)"""
- data = self.layer_norm(data)
- data = F.gelu(data)
- data = self.dropout(data)
- data, _ = self.bi_gru(data)
- return data
-
-
-class SpeechRecognitionModel(nn.Module):
- """Speech Recognition Model Inspired by DeepSpeech 2"""
-
- def __init__(
- self,
- n_cnn_layers: int,
- n_rnn_layers: int,
- rnn_dim: int,
- n_class: int,
- n_feats: int,
- stride: int = 2,
- dropout: float = 0.1,
- ):
- super().__init__()
- n_feats //= 2
- self.cnn = nn.Conv2d(1, 32, 3, stride=stride, padding=3 // 2)
- # n residual cnn layers with filter size of 32
- self.rescnn_layers = nn.Sequential(
- *[
- ResidualCNN(
- 32, 32, kernel=3, stride=1, dropout=dropout, n_feats=n_feats
- )
- for _ in range(n_cnn_layers)
- ]
- )
- self.fully_connected = nn.Linear(n_feats * 32, rnn_dim)
- self.birnn_layers = nn.Sequential(
- *[
- BidirectionalGRU(
- rnn_dim=rnn_dim if i == 0 else rnn_dim * 2,
- hidden_size=rnn_dim,
- dropout=dropout,
- batch_first=i == 0,
- )
- for i in range(n_rnn_layers)
- ]
- )
- self.classifier = nn.Sequential(
- nn.Linear(rnn_dim * 2, rnn_dim), # birnn returns rnn_dim*2
- nn.GELU(),
- nn.Dropout(dropout),
- nn.Linear(rnn_dim, n_class),
- )
-
- def forward(self, data):
- """data (batch, channel, feature, time)"""
- data = self.cnn(data)
- data = self.rescnn_layers(data)
- sizes = data.size()
- data = data.view(
- sizes[0], sizes[1] * sizes[2], sizes[3]
- ) # (batch, feature, time)
- data = data.transpose(1, 2) # (batch, time, feature)
- data = self.fully_connected(data)
- data = self.birnn_layers(data)
- data = self.classifier(data)
- return data
-
-
class IterMeter:
"""keeps track of total iterations"""
@@ -254,36 +79,30 @@ def train(
):
"""Train"""
model.train()
- data_len = len(train_loader.dataset)
- for batch_idx, _data in enumerate(train_loader):
- spectrograms, labels, input_lengths, label_lengths = _data
- spectrograms, labels = spectrograms.to(device), labels.to(device)
-
+ print(f"Epoch: {epoch}")
+ losses = []
+ for _data in tqdm(train_loader, desc="batches"):
+ spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device)
optimizer.zero_grad()
output = model(spectrograms) # (batch, time, n_class)
output = F.log_softmax(output, dim=2)
output = output.transpose(0, 1) # (time, batch, n_class)
- loss = criterion(output, labels, input_lengths, label_lengths)
+ loss = criterion(output, labels, _data["input_length"], _data["utterance_length"])
loss.backward()
optimizer.step()
scheduler.step()
iter_meter.step()
- if batch_idx % 100 == 0 or batch_idx == data_len:
- print(
- f"Train Epoch: \
- {epoch} \
- [{batch_idx * len(spectrograms)}/{data_len} \
- ({100.0 * batch_idx / len(train_loader)}%)]\t \
- Loss: {loss.item()}"
- )
- return loss.item()
+ losses.append(loss.item())
+
+ print(f"loss in epoch {epoch}: {sum(losses) / len(losses)}")
+ return sum(losses) / len(losses)
-# TODO: check how dataloader can be made more efficient
-def test(model, device, test_loader, criterion):
+
+def test(model, device, test_loader, criterion, tokenizer):
"""Test"""
print("\nevaluating...")
model.eval()
@@ -291,18 +110,20 @@ def test(model, device, test_loader, criterion):
test_cer, test_wer = [], []
with torch.no_grad():
for _data in test_loader:
- spectrograms, labels, input_lengths, label_lengths = _data
- spectrograms, labels = spectrograms.to(device), labels.to(device)
+ spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device)
output = model(spectrograms) # (batch, time, n_class)
output = F.log_softmax(output, dim=2)
output = output.transpose(0, 1) # (time, batch, n_class)
- loss = criterion(output, labels, input_lengths, label_lengths)
+ loss = criterion(output, labels, _data["input_length"], _data["utterance_length"])
test_loss += loss.item() / len(test_loader)
decoded_preds, decoded_targets = greedy_decoder(
- output.transpose(0, 1), labels, label_lengths
+ output=output.transpose(0, 1),
+ labels=labels,
+ label_lengths=_data["utterance_length"],
+ tokenizer=tokenizer,
)
for j, pred in enumerate(decoded_preds):
test_cer.append(cer(decoded_targets[j], pred))
@@ -313,9 +134,11 @@ def test(model, device, test_loader, criterion):
print(
f"Test set: Average loss:\
- {test_loss}, Average CER: {avg_cer} Average WER: {avg_wer}\n"
+ {test_loss}, Average CER: {None} Average WER: {None}\n"
)
+ return test_loss, avg_cer, avg_wer
+
def run(
learning_rate: float,
@@ -324,46 +147,66 @@ def run(
load: bool,
path: str,
dataset_path: str,
+ language: str,
) -> None:
"""Runs the training script."""
- hparams = {
- "n_cnn_layers": 3,
- "n_rnn_layers": 5,
- "rnn_dim": 512,
- "n_class": 36, # TODO: dynamically determine this from vocab size
- "n_feats": 128,
- "stride": 2,
- "dropout": 0.1,
- "learning_rate": learning_rate,
- "batch_size": batch_size,
- "epochs": epochs,
- }
-
use_cuda = torch.cuda.is_available()
torch.manual_seed(42)
device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member
# device = torch.device("mps")
- download_dataset = not os.path.isdir(path)
- train_dataset = MultilingualLibriSpeech(
- dataset_path, "mls_german_opus", split="dev", download=download_dataset
+ # load dataset
+ train_dataset = MLSDataset(
+ dataset_path, language, Split.TRAIN, download=True, spectrogram_hparams=None
)
- test_dataset = MultilingualLibriSpeech(
- dataset_path, "mls_german_opus", split="test", download=False
+ valid_dataset = MLSDataset(
+ dataset_path, language, Split.VALID, download=True, spectrogram_hparams=None
+ )
+
+ # load tokenizer (bpe by default):
+ if not os.path.isfile("data/tokenizers/char_tokenizer_german.json"):
+ print("There is no tokenizer available. Do you want to train it on the dataset?")
+ input("Press Enter to continue...")
+ train_char_tokenizer(
+ dataset_path=dataset_path,
+ language=language,
+ split="all",
+ download=False,
+ out_path="data/tokenizers/char_tokenizer_german.json",
+ )
+
+ tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json")
+
+ train_dataset.set_tokenizer(tokenizer) # type: ignore
+ valid_dataset.set_tokenizer(tokenizer) # type: ignore
+
+ print(f"Waveform shape: {train_dataset[0]['waveform'].shape}")
+
+ hparams = HParams(
+ n_cnn_layers=3,
+ n_rnn_layers=5,
+ rnn_dim=512,
+ n_class=tokenizer.get_vocab_size(),
+ n_feats=128,
+ stride=2,
+ dropout=0.1,
+ learning_rate=learning_rate,
+ batch_size=batch_size,
+ epochs=epochs,
)
train_loader = DataLoader(
train_dataset,
batch_size=hparams["batch_size"],
shuffle=True,
- collate_fn=lambda x: data_processing(x, "train"),
+ collate_fn=lambda x: collate_fn(x),
)
- test_loader = DataLoader(
- test_dataset,
+ valid_loader = DataLoader(
+ valid_dataset,
batch_size=hparams["batch_size"],
shuffle=True,
- collate_fn=lambda x: data_processing(x, "train"),
+ collate_fn=lambda x: collate_fn(x),
)
# enable flag to find the most compatible algorithms in advance
@@ -379,12 +222,10 @@ def run(
hparams["stride"],
hparams["dropout"],
).to(device)
-
- print(
- "Num Model Parameters", sum((param.nelement() for param in model.parameters()))
- )
+ print(tokenizer.encode(" "))
+ print("Num Model Parameters", sum((param.nelement() for param in model.parameters())))
optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"])
- criterion = nn.CTCLoss(blank=28).to(device)
+ criterion = nn.CTCLoss(tokenizer.encode(" ").ids[0]).to(device)
if load:
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
@@ -412,7 +253,13 @@ def run(
iter_meter,
)
- test(model=model, device=device, test_loader=test_loader, criterion=criterion)
+ test(
+ model=model,
+ device=device,
+ test_loader=valid_loader,
+ criterion=criterion,
+ tokenizer=tokenizer,
+ )
print("saving epoch", str(epoch))
torch.save(
{"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss},
@@ -452,4 +299,9 @@ def run_cli(
load=load,
path=path,
dataset_path=dataset_path,
+ language="mls_german_opus",
)
+
+
+if __name__ == "__main__":
+ run_cli() # pylint: disable=no-value-for-parameter