diff options
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r-- | swr2_asr/train.py | 90 |
1 files changed, 50 insertions, 40 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py index f3efd69..95038c2 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -5,19 +5,19 @@ from typing import TypedDict import click import torch import torch.nn.functional as F -from tokenizers import Tokenizer from torch import nn, optim from torch.utils.data import DataLoader from tqdm import tqdm from swr2_asr.model_deep_speech import SpeechRecognitionModel -from swr2_asr.tokenizer import train_bpe_tokenizer +from swr2_asr.tokenizer import CharTokenizer, train_char_tokenizer from swr2_asr.utils import MLSDataset, Split, collate_fn from .loss_scores import cer, wer # TODO: improve naming of functions + class HParams(TypedDict): """Type for the hyperparameters of the model.""" @@ -33,10 +33,11 @@ class HParams(TypedDict): epochs: int -# TODO: get blank label from tokenizer -def greedy_decoder(output, tokenizer, labels, label_lengths, blank_label=28, collapse_repeated=True): +def greedy_decoder(output, tokenizer, labels, label_lengths, collapse_repeated=True): """Greedily decode a sequence.""" + print("output shape", output.shape) arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member + blank_label = tokenizer.encode(" ").ids[0] decodes = [] targets = [] for i, args in enumerate(arg_maxes): @@ -81,28 +82,27 @@ def train( print(f"Epoch: {epoch}") losses = [] for _data in tqdm(train_loader, desc="batches"): - spectrograms, labels = _data['spectrogram'].to(device), _data['utterance'].to(device) + spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device) optimizer.zero_grad() output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) - loss = criterion(output, labels, _data['input_length'], _data["utterance_length"]) + loss = criterion(output, labels, _data["input_length"], _data["utterance_length"]) loss.backward() optimizer.step() scheduler.step() iter_meter.step() - + losses.append(loss.item()) print(f"loss in epoch {epoch}: {sum(losses) / len(losses)}") return sum(losses) / len(losses) -# TODO: profile this function call -# TODO: only calculate wer and cer at the end, or less often -# TODO: change this to only be a sanity check and calculate measures after training + + def test(model, device, test_loader, criterion, tokenizer): """Test""" print("\nevaluating...") @@ -111,21 +111,21 @@ def test(model, device, test_loader, criterion, tokenizer): test_cer, test_wer = [], [] with torch.no_grad(): for _data in test_loader: - spectrograms, labels = _data['spectrogram'].to(device), _data['utterance'].to(device) + spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device) output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) - # TODO: get rid of this - loss = criterion(output, labels, _data['input_length'], _data["utterance_length"]) + loss = criterion(output, labels, _data["input_length"], _data["utterance_length"]) test_loss += loss.item() / len(test_loader) decoded_preds, decoded_targets = greedy_decoder( - output = output.transpose(0, 1), - labels = labels, - label_lengths= _data["utterance_length"], - tokenizer=tokenizer) + output=output.transpose(0, 1), + labels=labels, + label_lengths=_data["utterance_length"], + tokenizer=tokenizer, + ) for j, pred in enumerate(decoded_preds): test_cer.append(cer(decoded_targets[j], pred)) test_wer.append(wer(decoded_targets[j], pred)) @@ -135,9 +135,11 @@ def test(model, device, test_loader, criterion, tokenizer): print( f"Test set: Average loss:\ - {test_loss}, Average CER: {avg_cer} Average WER: {avg_wer}\n" + {test_loss}, Average CER: {None} Average WER: {None}\n" ) + return test_loss, avg_cer, avg_wer + def run( learning_rate: float, @@ -152,33 +154,34 @@ def run( use_cuda = torch.cuda.is_available() torch.manual_seed(42) device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member - device = torch.device("mps") + # device = torch.device("mps") # load dataset - # TODO: change this from dev split to train split again (was faster for development) - train_dataset = MLSDataset(dataset_path, language, Split.dev, download=True, spectrogram_hparams=None) - valid_dataset = MLSDataset(dataset_path, language, Split.dev, download=True, spectrogram_hparams=None) - test_dataset = MLSDataset(dataset_path, language, Split.test, download=True, spectrogram_hparams=None) + train_dataset = MLSDataset( + dataset_path, language, Split.TRAIN, download=True, spectrogram_hparams=None + ) + valid_dataset = MLSDataset( + dataset_path, language, Split.VALID, download=True, spectrogram_hparams=None + ) - # load tokenizer (bpe by default): - if not os.path.isfile("data/tokenizers/bpe_tokenizer_german_3000.json"): + # load tokenizer (bpe by default): + if not os.path.isfile("data/tokenizers/char_tokenizer_german.json"): print("There is no tokenizer available. Do you want to train it on the dataset?") input("Press Enter to continue...") - train_bpe_tokenizer( + train_char_tokenizer( dataset_path=dataset_path, language=language, split="all", download=False, - out_path="data/tokenizers/bpe_tokenizer_german_3000.json", + out_path="data/tokenizers/char_tokenizer_german.json", vocab_size=3000, ) - - tokenizer = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") - - train_dataset.set_tokenizer(tokenizer) - valid_dataset.set_tokenizer(tokenizer) - test_dataset.set_tokenizer(tokenizer) - + + tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") + + train_dataset.set_tokenizer(tokenizer) # type: ignore + valid_dataset.set_tokenizer(tokenizer) # type: ignore + print(f"Waveform shape: {train_dataset[0]['waveform'].shape}") hparams = HParams( @@ -221,10 +224,10 @@ def run( hparams["stride"], hparams["dropout"], ).to(device) - + print(tokenizer.encode(" ")) print("Num Model Parameters", sum((param.nelement() for param in model.parameters()))) optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) - criterion = nn.CTCLoss(blank=28).to(device) + criterion = nn.CTCLoss(tokenizer.encode(" ").ids[0]).to(device) if load: checkpoint = torch.load(path) model.load_state_dict(checkpoint["model_state_dict"]) @@ -240,7 +243,7 @@ def run( ) iter_meter = IterMeter() - for epoch in range(1, epochs + 1): + for epoch in range(1, epochs + 1): loss = train( model, device, @@ -252,7 +255,13 @@ def run( iter_meter, ) - test(model=model, device=device, test_loader=valid_loader, criterion=criterion, tokenizer = tokenizer) + test( + model=model, + device=device, + test_loader=valid_loader, + criterion=criterion, + tokenizer=tokenizer, + ) print("saving epoch", str(epoch)) torch.save( {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss}, @@ -295,5 +304,6 @@ def run_cli( language="mls_german_opus", ) -if __name__ == "__main__": - run(1e-3, 10, 1, False, "", "/Volumes/pherkel/SWR2-ASR", "mls_german_opus")
\ No newline at end of file + +if __name__ == "__main__": + run(1e-3, 10, 1, False, "", "/Volumes/pherkel/SWR2-ASR", "mls_german_opus") |