From 335b8a32f8bba5d37c00af6b4ecd1b9fc520f964 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Wed, 30 Aug 2023 17:11:51 +0200 Subject: wörks now!°!! --- swr2_asr/train.py | 49 +++++++++++++++++++++++++++++-------------------- 1 file changed, 29 insertions(+), 20 deletions(-) (limited to 'swr2_asr/train.py') diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 53cdac1..bae8c7c 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -5,13 +5,14 @@ from typing import TypedDict import click import torch import torch.nn.functional as F +from AudioLoader.speech import MultilingualLibriSpeech from tokenizers import Tokenizer from torch import nn, optim from torch.utils.data import DataLoader from swr2_asr.model_deep_speech import SpeechRecognitionModel from swr2_asr.tokenizer import train_bpe_tokenizer -from swr2_asr.utils import MLSDataset, Split +from swr2_asr.utils import MLSDataset, Split, collate_fn from .loss_scores import cer, wer @@ -31,20 +32,20 @@ class HParams(TypedDict): epochs: int -def greedy_decoder(output, labels, label_lengths, blank_label=28, collapse_repeated=True): +def greedy_decoder(output, tokenizer, labels, label_lengths, blank_label=28, collapse_repeated=True): """Greedily decode a sequence.""" arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member decodes = [] targets = [] for i, args in enumerate(arg_maxes): decode = [] - targets.append(text_transform.decode([int(x) for x in labels[i][: label_lengths[i]].tolist()])) + targets.append(tokenizer.decode([int(x) for x in labels[i][: label_lengths[i]].tolist()])) for j, index in enumerate(args): if index != blank_label: if collapse_repeated and j != 0 and index == args[j - 1]: continue decode.append(index.item()) - decodes.append(text_transform.decode(decode)) + decodes.append(tokenizer.decode(decode)) return decodes, targets @@ -77,15 +78,14 @@ def train( model.train() data_len = len(train_loader.dataset) for batch_idx, _data in enumerate(train_loader): - _, spectrograms, input_lengths, labels, label_lengths, *_ = _data - spectrograms, labels = spectrograms.to(device), labels.to(device) + spectrograms, labels = _data['spectrogram'].to(device), _data['utterance'].to(device) optimizer.zero_grad() output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) - loss = criterion(output, labels, input_lengths, label_lengths) + loss = criterion(output, labels, _data['input_length'], _data["utterance_length"]) loss.backward() optimizer.step() @@ -102,7 +102,7 @@ def train( return loss.item() -def test(model, device, test_loader, criterion): +def test(model, device, test_loader, criterion, tokenizer): """Test""" print("\nevaluating...") model.eval() @@ -110,17 +110,20 @@ def test(model, device, test_loader, criterion): test_cer, test_wer = [], [] with torch.no_grad(): for _data in test_loader: - spectrograms, labels, input_lengths, label_lengths = _data - spectrograms, labels = spectrograms.to(device), labels.to(device) + spectrograms, labels = _data['spectrogram'].to(device), _data['utterance'].to(device) output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) - loss = criterion(output, labels, input_lengths, label_lengths) + loss = criterion(output, labels, _data['input_length'], _data["utterance_length"]) test_loss += loss.item() / len(test_loader) - decoded_preds, decoded_targets = greedy_decoder(output.transpose(0, 1), labels, label_lengths) + decoded_preds, decoded_targets = greedy_decoder( + output = output.transpose(0, 1), + labels = labels, + label_lengths= _data["utterance_length"], + tokenizer=tokenizer) for j, pred in enumerate(decoded_preds): test_cer.append(cer(decoded_targets[j], pred)) test_wer.append(wer(decoded_targets[j], pred)) @@ -150,12 +153,11 @@ def run( # device = torch.device("mps") # load dataset - train_dataset = MLSDataset(dataset_path, language, Split.train, download=True) - valid_dataset = MLSDataset(dataset_path, language, Split.valid, download=True) - test_dataset = MLSDataset(dataset_path, language, Split.test, download=True) + train_dataset = MLSDataset(dataset_path, language, Split.train, download=True, spectrogram_hparams=None) + valid_dataset = MLSDataset(dataset_path, language, Split.valid, download=True, spectrogram_hparams=None) + test_dataset = MLSDataset(dataset_path, language, Split.test, download=True, spectrogram_hparams=None) - # TODO: add flag to choose tokenizer - # load tokenizer (bpe by default): + # load tokenizer (bpe by default): if not os.path.isfile("data/tokenizers/bpe_tokenizer_german_3000.json"): print("There is no tokenizer available. Do you want to train it on the dataset?") input("Press Enter to continue...") @@ -167,12 +169,14 @@ def run( out_path="data/tokenizers/bpe_tokenizer_german_3000.json", vocab_size=3000, ) - + tokenizer = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") - + train_dataset.set_tokenizer(tokenizer) valid_dataset.set_tokenizer(tokenizer) test_dataset.set_tokenizer(tokenizer) + + print(f"Waveform shape: {train_dataset[0]['waveform'].shape}") hparams = HParams( n_cnn_layers=3, @@ -191,12 +195,14 @@ def run( train_dataset, batch_size=hparams["batch_size"], shuffle=True, + collate_fn=lambda x: collate_fn(x), ) valid_loader = DataLoader( valid_dataset, batch_size=hparams["batch_size"], shuffle=True, + collate_fn=lambda x: collate_fn(x), ) # enable flag to find the most compatible algorithms in advance @@ -243,7 +249,7 @@ def run( iter_meter, ) - test(model=model, device=device, test_loader=valid_loader, criterion=criterion) + test(model=model, device=device, test_loader=valid_loader, criterion=criterion, tokenizer = tokenizer) print("saving epoch", str(epoch)) torch.save( {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss}, @@ -285,3 +291,6 @@ def run_cli( dataset_path=dataset_path, language="mls_german_opus", ) + +if __name__ == "__main__": + run(1e-3, 10, 1, False, "", "/Volumes/pherkel/SWR2-ASR", "mls_german_opus") \ No newline at end of file -- cgit v1.2.3