From f3d2ea9a16944434a08e662c5ecfd6ba50e5ea89 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Fri, 1 Sep 2023 22:40:29 +0200 Subject: many todos --- swr2_asr/train.py | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) (limited to 'swr2_asr/train.py') diff --git a/swr2_asr/train.py b/swr2_asr/train.py index bae8c7c..f3efd69 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -5,10 +5,10 @@ from typing import TypedDict import click import torch import torch.nn.functional as F -from AudioLoader.speech import MultilingualLibriSpeech from tokenizers import Tokenizer from torch import nn, optim from torch.utils.data import DataLoader +from tqdm import tqdm from swr2_asr.model_deep_speech import SpeechRecognitionModel from swr2_asr.tokenizer import train_bpe_tokenizer @@ -16,6 +16,7 @@ from swr2_asr.utils import MLSDataset, Split, collate_fn from .loss_scores import cer, wer +# TODO: improve naming of functions class HParams(TypedDict): """Type for the hyperparameters of the model.""" @@ -32,6 +33,7 @@ class HParams(TypedDict): epochs: int +# TODO: get blank label from tokenizer def greedy_decoder(output, tokenizer, labels, label_lengths, blank_label=28, collapse_repeated=True): """Greedily decode a sequence.""" arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member @@ -76,8 +78,9 @@ def train( ): """Train""" model.train() - data_len = len(train_loader.dataset) - for batch_idx, _data in enumerate(train_loader): + print(f"Epoch: {epoch}") + losses = [] + for _data in tqdm(train_loader, desc="batches"): spectrograms, labels = _data['spectrogram'].to(device), _data['utterance'].to(device) optimizer.zero_grad() @@ -91,17 +94,15 @@ def train( optimizer.step() scheduler.step() iter_meter.step() - if batch_idx % 100 == 0 or batch_idx == data_len: - print( - f"Train Epoch: \ - {epoch} \ - [{batch_idx * len(spectrograms)}/{data_len} \ - ({100.0 * batch_idx / len(train_loader)}%)]\t \ - Loss: {loss.item()}" - ) - return loss.item() + + losses.append(loss.item()) + print(f"loss in epoch {epoch}: {sum(losses) / len(losses)}") + return sum(losses) / len(losses) +# TODO: profile this function call +# TODO: only calculate wer and cer at the end, or less often +# TODO: change this to only be a sanity check and calculate measures after training def test(model, device, test_loader, criterion, tokenizer): """Test""" print("\nevaluating...") @@ -116,6 +117,7 @@ def test(model, device, test_loader, criterion, tokenizer): output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) + # TODO: get rid of this loss = criterion(output, labels, _data['input_length'], _data["utterance_length"]) test_loss += loss.item() / len(test_loader) @@ -150,11 +152,12 @@ def run( use_cuda = torch.cuda.is_available() torch.manual_seed(42) device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member - # device = torch.device("mps") + device = torch.device("mps") # load dataset - train_dataset = MLSDataset(dataset_path, language, Split.train, download=True, spectrogram_hparams=None) - valid_dataset = MLSDataset(dataset_path, language, Split.valid, download=True, spectrogram_hparams=None) + # TODO: change this from dev split to train split again (was faster for development) + train_dataset = MLSDataset(dataset_path, language, Split.dev, download=True, spectrogram_hparams=None) + valid_dataset = MLSDataset(dataset_path, language, Split.dev, download=True, spectrogram_hparams=None) test_dataset = MLSDataset(dataset_path, language, Split.test, download=True, spectrogram_hparams=None) # load tokenizer (bpe by default): @@ -237,7 +240,7 @@ def run( ) iter_meter = IterMeter() - for epoch in range(1, epochs + 1): + for epoch in range(1, epochs + 1): loss = train( model, device, -- cgit v1.2.3