aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r--swr2_asr/train.py90
1 files changed, 50 insertions, 40 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index f3efd69..95038c2 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -5,19 +5,19 @@ from typing import TypedDict
import click
import torch
import torch.nn.functional as F
-from tokenizers import Tokenizer
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm import tqdm
from swr2_asr.model_deep_speech import SpeechRecognitionModel
-from swr2_asr.tokenizer import train_bpe_tokenizer
+from swr2_asr.tokenizer import CharTokenizer, train_char_tokenizer
from swr2_asr.utils import MLSDataset, Split, collate_fn
from .loss_scores import cer, wer
# TODO: improve naming of functions
+
class HParams(TypedDict):
"""Type for the hyperparameters of the model."""
@@ -33,10 +33,11 @@ class HParams(TypedDict):
epochs: int
-# TODO: get blank label from tokenizer
-def greedy_decoder(output, tokenizer, labels, label_lengths, blank_label=28, collapse_repeated=True):
+def greedy_decoder(output, tokenizer, labels, label_lengths, collapse_repeated=True):
"""Greedily decode a sequence."""
+ print("output shape", output.shape)
arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
+ blank_label = tokenizer.encode(" ").ids[0]
decodes = []
targets = []
for i, args in enumerate(arg_maxes):
@@ -81,28 +82,27 @@ def train(
print(f"Epoch: {epoch}")
losses = []
for _data in tqdm(train_loader, desc="batches"):
- spectrograms, labels = _data['spectrogram'].to(device), _data['utterance'].to(device)
+ spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device)
optimizer.zero_grad()
output = model(spectrograms) # (batch, time, n_class)
output = F.log_softmax(output, dim=2)
output = output.transpose(0, 1) # (time, batch, n_class)
- loss = criterion(output, labels, _data['input_length'], _data["utterance_length"])
+ loss = criterion(output, labels, _data["input_length"], _data["utterance_length"])
loss.backward()
optimizer.step()
scheduler.step()
iter_meter.step()
-
+
losses.append(loss.item())
print(f"loss in epoch {epoch}: {sum(losses) / len(losses)}")
return sum(losses) / len(losses)
-# TODO: profile this function call
-# TODO: only calculate wer and cer at the end, or less often
-# TODO: change this to only be a sanity check and calculate measures after training
+
+
def test(model, device, test_loader, criterion, tokenizer):
"""Test"""
print("\nevaluating...")
@@ -111,21 +111,21 @@ def test(model, device, test_loader, criterion, tokenizer):
test_cer, test_wer = [], []
with torch.no_grad():
for _data in test_loader:
- spectrograms, labels = _data['spectrogram'].to(device), _data['utterance'].to(device)
+ spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device)
output = model(spectrograms) # (batch, time, n_class)
output = F.log_softmax(output, dim=2)
output = output.transpose(0, 1) # (time, batch, n_class)
- # TODO: get rid of this
- loss = criterion(output, labels, _data['input_length'], _data["utterance_length"])
+ loss = criterion(output, labels, _data["input_length"], _data["utterance_length"])
test_loss += loss.item() / len(test_loader)
decoded_preds, decoded_targets = greedy_decoder(
- output = output.transpose(0, 1),
- labels = labels,
- label_lengths= _data["utterance_length"],
- tokenizer=tokenizer)
+ output=output.transpose(0, 1),
+ labels=labels,
+ label_lengths=_data["utterance_length"],
+ tokenizer=tokenizer,
+ )
for j, pred in enumerate(decoded_preds):
test_cer.append(cer(decoded_targets[j], pred))
test_wer.append(wer(decoded_targets[j], pred))
@@ -135,9 +135,11 @@ def test(model, device, test_loader, criterion, tokenizer):
print(
f"Test set: Average loss:\
- {test_loss}, Average CER: {avg_cer} Average WER: {avg_wer}\n"
+ {test_loss}, Average CER: {None} Average WER: {None}\n"
)
+ return test_loss, avg_cer, avg_wer
+
def run(
learning_rate: float,
@@ -152,33 +154,34 @@ def run(
use_cuda = torch.cuda.is_available()
torch.manual_seed(42)
device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member
- device = torch.device("mps")
+ # device = torch.device("mps")
# load dataset
- # TODO: change this from dev split to train split again (was faster for development)
- train_dataset = MLSDataset(dataset_path, language, Split.dev, download=True, spectrogram_hparams=None)
- valid_dataset = MLSDataset(dataset_path, language, Split.dev, download=True, spectrogram_hparams=None)
- test_dataset = MLSDataset(dataset_path, language, Split.test, download=True, spectrogram_hparams=None)
+ train_dataset = MLSDataset(
+ dataset_path, language, Split.TRAIN, download=True, spectrogram_hparams=None
+ )
+ valid_dataset = MLSDataset(
+ dataset_path, language, Split.VALID, download=True, spectrogram_hparams=None
+ )
- # load tokenizer (bpe by default):
- if not os.path.isfile("data/tokenizers/bpe_tokenizer_german_3000.json"):
+ # load tokenizer (bpe by default):
+ if not os.path.isfile("data/tokenizers/char_tokenizer_german.json"):
print("There is no tokenizer available. Do you want to train it on the dataset?")
input("Press Enter to continue...")
- train_bpe_tokenizer(
+ train_char_tokenizer(
dataset_path=dataset_path,
language=language,
split="all",
download=False,
- out_path="data/tokenizers/bpe_tokenizer_german_3000.json",
+ out_path="data/tokenizers/char_tokenizer_german.json",
vocab_size=3000,
)
-
- tokenizer = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
-
- train_dataset.set_tokenizer(tokenizer)
- valid_dataset.set_tokenizer(tokenizer)
- test_dataset.set_tokenizer(tokenizer)
-
+
+ tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json")
+
+ train_dataset.set_tokenizer(tokenizer) # type: ignore
+ valid_dataset.set_tokenizer(tokenizer) # type: ignore
+
print(f"Waveform shape: {train_dataset[0]['waveform'].shape}")
hparams = HParams(
@@ -221,10 +224,10 @@ def run(
hparams["stride"],
hparams["dropout"],
).to(device)
-
+ print(tokenizer.encode(" "))
print("Num Model Parameters", sum((param.nelement() for param in model.parameters())))
optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"])
- criterion = nn.CTCLoss(blank=28).to(device)
+ criterion = nn.CTCLoss(tokenizer.encode(" ").ids[0]).to(device)
if load:
checkpoint = torch.load(path)
model.load_state_dict(checkpoint["model_state_dict"])
@@ -240,7 +243,7 @@ def run(
)
iter_meter = IterMeter()
- for epoch in range(1, epochs + 1):
+ for epoch in range(1, epochs + 1):
loss = train(
model,
device,
@@ -252,7 +255,13 @@ def run(
iter_meter,
)
- test(model=model, device=device, test_loader=valid_loader, criterion=criterion, tokenizer = tokenizer)
+ test(
+ model=model,
+ device=device,
+ test_loader=valid_loader,
+ criterion=criterion,
+ tokenizer=tokenizer,
+ )
print("saving epoch", str(epoch))
torch.save(
{"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss},
@@ -295,5 +304,6 @@ def run_cli(
language="mls_german_opus",
)
-if __name__ == "__main__":
- run(1e-3, 10, 1, False, "", "/Volumes/pherkel/SWR2-ASR", "mls_german_opus") \ No newline at end of file
+
+if __name__ == "__main__":
+ run(1e-3, 10, 1, False, "", "/Volumes/pherkel/SWR2-ASR", "mls_german_opus")