aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r--swr2_asr/train.py35
1 files changed, 19 insertions, 16 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index bae8c7c..f3efd69 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -5,10 +5,10 @@ from typing import TypedDict
import click
import torch
import torch.nn.functional as F
-from AudioLoader.speech import MultilingualLibriSpeech
from tokenizers import Tokenizer
from torch import nn, optim
from torch.utils.data import DataLoader
+from tqdm import tqdm
from swr2_asr.model_deep_speech import SpeechRecognitionModel
from swr2_asr.tokenizer import train_bpe_tokenizer
@@ -16,6 +16,7 @@ from swr2_asr.utils import MLSDataset, Split, collate_fn
from .loss_scores import cer, wer
+# TODO: improve naming of functions
class HParams(TypedDict):
"""Type for the hyperparameters of the model."""
@@ -32,6 +33,7 @@ class HParams(TypedDict):
epochs: int
+# TODO: get blank label from tokenizer
def greedy_decoder(output, tokenizer, labels, label_lengths, blank_label=28, collapse_repeated=True):
"""Greedily decode a sequence."""
arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
@@ -76,8 +78,9 @@ def train(
):
"""Train"""
model.train()
- data_len = len(train_loader.dataset)
- for batch_idx, _data in enumerate(train_loader):
+ print(f"Epoch: {epoch}")
+ losses = []
+ for _data in tqdm(train_loader, desc="batches"):
spectrograms, labels = _data['spectrogram'].to(device), _data['utterance'].to(device)
optimizer.zero_grad()
@@ -91,17 +94,15 @@ def train(
optimizer.step()
scheduler.step()
iter_meter.step()
- if batch_idx % 100 == 0 or batch_idx == data_len:
- print(
- f"Train Epoch: \
- {epoch} \
- [{batch_idx * len(spectrograms)}/{data_len} \
- ({100.0 * batch_idx / len(train_loader)}%)]\t \
- Loss: {loss.item()}"
- )
- return loss.item()
+
+ losses.append(loss.item())
+ print(f"loss in epoch {epoch}: {sum(losses) / len(losses)}")
+ return sum(losses) / len(losses)
+# TODO: profile this function call
+# TODO: only calculate wer and cer at the end, or less often
+# TODO: change this to only be a sanity check and calculate measures after training
def test(model, device, test_loader, criterion, tokenizer):
"""Test"""
print("\nevaluating...")
@@ -116,6 +117,7 @@ def test(model, device, test_loader, criterion, tokenizer):
output = F.log_softmax(output, dim=2)
output = output.transpose(0, 1) # (time, batch, n_class)
+ # TODO: get rid of this
loss = criterion(output, labels, _data['input_length'], _data["utterance_length"])
test_loss += loss.item() / len(test_loader)
@@ -150,11 +152,12 @@ def run(
use_cuda = torch.cuda.is_available()
torch.manual_seed(42)
device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member
- # device = torch.device("mps")
+ device = torch.device("mps")
# load dataset
- train_dataset = MLSDataset(dataset_path, language, Split.train, download=True, spectrogram_hparams=None)
- valid_dataset = MLSDataset(dataset_path, language, Split.valid, download=True, spectrogram_hparams=None)
+ # TODO: change this from dev split to train split again (was faster for development)
+ train_dataset = MLSDataset(dataset_path, language, Split.dev, download=True, spectrogram_hparams=None)
+ valid_dataset = MLSDataset(dataset_path, language, Split.dev, download=True, spectrogram_hparams=None)
test_dataset = MLSDataset(dataset_path, language, Split.test, download=True, spectrogram_hparams=None)
# load tokenizer (bpe by default):
@@ -237,7 +240,7 @@ def run(
)
iter_meter = IterMeter()
- for epoch in range(1, epochs + 1):
+ for epoch in range(1, epochs + 1):
loss = train(
model,
device,