aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r--swr2_asr/train.py49
1 files changed, 29 insertions, 20 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 53cdac1..bae8c7c 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -5,13 +5,14 @@ from typing import TypedDict
import click
import torch
import torch.nn.functional as F
+from AudioLoader.speech import MultilingualLibriSpeech
from tokenizers import Tokenizer
from torch import nn, optim
from torch.utils.data import DataLoader
from swr2_asr.model_deep_speech import SpeechRecognitionModel
from swr2_asr.tokenizer import train_bpe_tokenizer
-from swr2_asr.utils import MLSDataset, Split
+from swr2_asr.utils import MLSDataset, Split, collate_fn
from .loss_scores import cer, wer
@@ -31,20 +32,20 @@ class HParams(TypedDict):
epochs: int
-def greedy_decoder(output, labels, label_lengths, blank_label=28, collapse_repeated=True):
+def greedy_decoder(output, tokenizer, labels, label_lengths, blank_label=28, collapse_repeated=True):
"""Greedily decode a sequence."""
arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
decodes = []
targets = []
for i, args in enumerate(arg_maxes):
decode = []
- targets.append(text_transform.decode([int(x) for x in labels[i][: label_lengths[i]].tolist()]))
+ targets.append(tokenizer.decode([int(x) for x in labels[i][: label_lengths[i]].tolist()]))
for j, index in enumerate(args):
if index != blank_label:
if collapse_repeated and j != 0 and index == args[j - 1]:
continue
decode.append(index.item())
- decodes.append(text_transform.decode(decode))
+ decodes.append(tokenizer.decode(decode))
return decodes, targets
@@ -77,15 +78,14 @@ def train(
model.train()
data_len = len(train_loader.dataset)
for batch_idx, _data in enumerate(train_loader):
- _, spectrograms, input_lengths, labels, label_lengths, *_ = _data
- spectrograms, labels = spectrograms.to(device), labels.to(device)
+ spectrograms, labels = _data['spectrogram'].to(device), _data['utterance'].to(device)
optimizer.zero_grad()
output = model(spectrograms) # (batch, time, n_class)
output = F.log_softmax(output, dim=2)
output = output.transpose(0, 1) # (time, batch, n_class)
- loss = criterion(output, labels, input_lengths, label_lengths)
+ loss = criterion(output, labels, _data['input_length'], _data["utterance_length"])
loss.backward()
optimizer.step()
@@ -102,7 +102,7 @@ def train(
return loss.item()
-def test(model, device, test_loader, criterion):
+def test(model, device, test_loader, criterion, tokenizer):
"""Test"""
print("\nevaluating...")
model.eval()
@@ -110,17 +110,20 @@ def test(model, device, test_loader, criterion):
test_cer, test_wer = [], []
with torch.no_grad():
for _data in test_loader:
- spectrograms, labels, input_lengths, label_lengths = _data
- spectrograms, labels = spectrograms.to(device), labels.to(device)
+ spectrograms, labels = _data['spectrogram'].to(device), _data['utterance'].to(device)
output = model(spectrograms) # (batch, time, n_class)
output = F.log_softmax(output, dim=2)
output = output.transpose(0, 1) # (time, batch, n_class)
- loss = criterion(output, labels, input_lengths, label_lengths)
+ loss = criterion(output, labels, _data['input_length'], _data["utterance_length"])
test_loss += loss.item() / len(test_loader)
- decoded_preds, decoded_targets = greedy_decoder(output.transpose(0, 1), labels, label_lengths)
+ decoded_preds, decoded_targets = greedy_decoder(
+ output = output.transpose(0, 1),
+ labels = labels,
+ label_lengths= _data["utterance_length"],
+ tokenizer=tokenizer)
for j, pred in enumerate(decoded_preds):
test_cer.append(cer(decoded_targets[j], pred))
test_wer.append(wer(decoded_targets[j], pred))
@@ -150,12 +153,11 @@ def run(
# device = torch.device("mps")
# load dataset
- train_dataset = MLSDataset(dataset_path, language, Split.train, download=True)
- valid_dataset = MLSDataset(dataset_path, language, Split.valid, download=True)
- test_dataset = MLSDataset(dataset_path, language, Split.test, download=True)
+ train_dataset = MLSDataset(dataset_path, language, Split.train, download=True, spectrogram_hparams=None)
+ valid_dataset = MLSDataset(dataset_path, language, Split.valid, download=True, spectrogram_hparams=None)
+ test_dataset = MLSDataset(dataset_path, language, Split.test, download=True, spectrogram_hparams=None)
- # TODO: add flag to choose tokenizer
- # load tokenizer (bpe by default):
+ # load tokenizer (bpe by default):
if not os.path.isfile("data/tokenizers/bpe_tokenizer_german_3000.json"):
print("There is no tokenizer available. Do you want to train it on the dataset?")
input("Press Enter to continue...")
@@ -167,12 +169,14 @@ def run(
out_path="data/tokenizers/bpe_tokenizer_german_3000.json",
vocab_size=3000,
)
-
+
tokenizer = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
-
+
train_dataset.set_tokenizer(tokenizer)
valid_dataset.set_tokenizer(tokenizer)
test_dataset.set_tokenizer(tokenizer)
+
+ print(f"Waveform shape: {train_dataset[0]['waveform'].shape}")
hparams = HParams(
n_cnn_layers=3,
@@ -191,12 +195,14 @@ def run(
train_dataset,
batch_size=hparams["batch_size"],
shuffle=True,
+ collate_fn=lambda x: collate_fn(x),
)
valid_loader = DataLoader(
valid_dataset,
batch_size=hparams["batch_size"],
shuffle=True,
+ collate_fn=lambda x: collate_fn(x),
)
# enable flag to find the most compatible algorithms in advance
@@ -243,7 +249,7 @@ def run(
iter_meter,
)
- test(model=model, device=device, test_loader=valid_loader, criterion=criterion)
+ test(model=model, device=device, test_loader=valid_loader, criterion=criterion, tokenizer = tokenizer)
print("saving epoch", str(epoch))
torch.save(
{"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss},
@@ -285,3 +291,6 @@ def run_cli(
dataset_path=dataset_path,
language="mls_german_opus",
)
+
+if __name__ == "__main__":
+ run(1e-3, 10, 1, False, "", "/Volumes/pherkel/SWR2-ASR", "mls_german_opus") \ No newline at end of file