From 58b30927bd870604a4077a8af9ec3cad7b0be21c Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 11 Sep 2023 21:52:42 +0200 Subject: changed config to yaml! --- swr2_asr/train.py | 192 ++++++++++++++++++++++++++---------------------------- 1 file changed, 92 insertions(+), 100 deletions(-) (limited to 'swr2_asr/train.py') diff --git a/swr2_asr/train.py b/swr2_asr/train.py index ac7666b..eb79ee2 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -5,11 +5,12 @@ from typing import TypedDict import click import torch import torch.nn.functional as F +import yaml from torch import nn, optim from torch.utils.data import DataLoader from tqdm.autonotebook import tqdm -from swr2_asr.model_deep_speech import HParams, SpeechRecognitionModel +from swr2_asr.model_deep_speech import SpeechRecognitionModel from swr2_asr.utils.data import DataProcessing, MLSDataset, Split from swr2_asr.utils.decoder import greedy_decoder from swr2_asr.utils.tokenizer import CharTokenizer @@ -17,7 +18,7 @@ from swr2_asr.utils.tokenizer import CharTokenizer from .utils.loss_scores import cer, wer -class IterMeter(object): +class IterMeter: """keeps track of total iterations""" def __init__(self): @@ -116,6 +117,7 @@ class TestArgs(TypedDict): def test(test_args: TestArgs) -> tuple[float, float, float]: + """Test""" print("\nevaluating...") # get values from test_args: @@ -128,7 +130,7 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: test_loss = 0 test_cer, test_wer = [], [] with torch.no_grad(): - for i, _data in enumerate(tqdm(test_loader, desc="Validation Batches")): + for _data in tqdm(test_loader, desc="Validation Batches"): spectrograms, labels, input_lengths, label_lengths = _data spectrograms, labels = spectrograms.to(device), labels.to(device) @@ -142,8 +144,6 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: decoded_preds, decoded_targets = greedy_decoder( output.transpose(0, 1), labels, label_lengths, tokenizer ) - if i == 1: - print(f"decoding first sample: {decoded_preds}") for j, _ in enumerate(decoded_preds): test_cer.append(cer(decoded_targets[j], decoded_preds[j])) test_wer.append(wer(decoded_targets[j], decoded_preds[j])) @@ -161,157 +161,149 @@ def test(test_args: TestArgs) -> tuple[float, float, float]: return test_loss, avg_cer, avg_wer -def main( - learning_rate: float, - batch_size: int, - epochs: int, - dataset_path: str, - language: str, - limited_supervision: bool, - model_load_path: str, - model_save_path: str, - dataset_percentage: float, - eval_every: int, - num_workers: int, -): +@click.command() +@click.option( + "--config_path", + default="config.yaml", + help="Path to yaml config file", + type=click.Path(exists=True), +) +def main(config_path: str): """Main function for training the model. - Args: - learning_rate: learning rate for the optimizer - batch_size: batch size - epochs: number of epochs to train - dataset_path: path for the dataset - language: language of the dataset - limited_supervision: whether to use only limited supervision - model_load_path: path to load a model from - model_save_path: path to save the model to - dataset_percentage: percentage of the dataset to use - eval_every: evaluate every n epochs - num_workers: number of workers for the dataloader + Gets all configuration arguments from yaml config file. """ use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member torch.manual_seed(7) - if not os.path.isdir(dataset_path): - os.makedirs(dataset_path) + with open(config_path, "r", encoding="utf-8") as yaml_file: + config_dict = yaml.safe_load(yaml_file) + + # Create separate dictionaries for each top-level key + model_config = config_dict.get("model", {}) + training_config = config_dict.get("training", {}) + dataset_config = config_dict.get("dataset", {}) + tokenizer_config = config_dict.get("tokenizer", {}) + checkpoints_config = config_dict.get("checkpoints", {}) + + print(training_config["learning_rate"]) + + if not os.path.isdir(dataset_config["dataset_root_path"]): + os.makedirs(dataset_config["dataset_root_path"]) train_dataset = MLSDataset( - dataset_path, - language, + dataset_config["dataset_root_path"], + dataset_config["language_name"], Split.TEST, - download=True, - limited=limited_supervision, - size=dataset_percentage, + download=dataset_config["download"], + limited=dataset_config["limited_supervision"], + size=dataset_config["dataset_percentage"], ) valid_dataset = MLSDataset( - dataset_path, - language, + dataset_config["dataset_root_path"], + dataset_config["language_name"], Split.TRAIN, - download=False, - limited=Falimited_supervisionlse, - size=dataset_percentage, + download=dataset_config["download"], + limited=dataset_config["limited_supervision"], + size=dataset_config["dataset_percentage"], ) - # TODO: initialize and possibly train tokenizer if none found - - kwargs = {"num_workers": num_workers, "pin_memory": True} if use_cuda else {} - - hparams = HParams( - n_cnn_layers=3, - n_rnn_layers=5, - rnn_dim=512, - n_class=tokenizer.get_vocab_size(), - n_feats=128, - stride=2, - dropout=0.1, - learning_rate=learning_rate, - batch_size=batch_size, - epochs=epochs, - ) + kwargs = {"num_workers": training_config["num_workers"], "pin_memory": True} if use_cuda else {} + + if tokenizer_config["tokenizer_path"] is None: + print("Tokenizer not found!") + if click.confirm("Do you want to train a new tokenizer?", default=True): + pass + else: + return + tokenizer = CharTokenizer.train( + dataset_config["dataset_root_path"], dataset_config["language_name"] + ) + tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"]) train_data_processing = DataProcessing("train", tokenizer) valid_data_processing = DataProcessing("valid", tokenizer) train_loader = DataLoader( dataset=train_dataset, - batch_size=hparams["batch_size"], - shuffle=True, + batch_size=training_config["batch_size"], + shuffle=dataset_config["shuffle"], collate_fn=train_data_processing, **kwargs, ) valid_loader = DataLoader( dataset=valid_dataset, - batch_size=hparams["batch_size"], - shuffle=False, + batch_size=training_config["batch_size"], + shuffle=dataset_config["shuffle"], collate_fn=valid_data_processing, **kwargs, ) model = SpeechRecognitionModel( - hparams["n_cnn_layers"], - hparams["n_rnn_layers"], - hparams["rnn_dim"], - hparams["n_class"], - hparams["n_feats"], - hparams["stride"], - hparams["dropout"], + model_config["n_cnn_layers"], + model_config["n_rnn_layers"], + model_config["rnn_dim"], + tokenizer.get_vocab_size(), + model_config["n_feats"], + model_config["stride"], + model_config["dropout"], ).to(device) - optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) + optimizer = optim.AdamW(model.parameters(), training_config["learning_rate"]) criterion = nn.CTCLoss(tokenizer.get_blank_token()).to(device) scheduler = optim.lr_scheduler.OneCycleLR( optimizer, - max_lr=hparams["learning_rate"], + max_lr=training_config["learning_rate"], steps_per_epoch=int(len(train_loader)), - epochs=hparams["epochs"], + epochs=training_config["epochs"], anneal_strategy="linear", ) prev_epoch = 0 - if model_load_path is not None: - checkpoint = torch.load(model_load_path) + if checkpoints_config["model_load_path"] is not None: + checkpoint = torch.load(checkpoints_config["model_load_path"]) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) prev_epoch = checkpoint["epoch"] iter_meter = IterMeter() - if not os.path.isdir(os.path.dirname(model_save_path)): - os.makedirs(os.path.dirname(model_save_path)) - for epoch in range(prev_epoch + 1, epochs + 1): - train_args: TrainArgs = dict( - model=model, - device=device, - train_loader=train_loader, - criterion=criterion, - optimizer=optimizer, - scheduler=scheduler, - epoch=epoch, - iter_meter=iter_meter, - ) + + for epoch in range(prev_epoch + 1, training_config["epochs"] + 1): + train_args: TrainArgs = { + "model": model, + "device": device, + "train_loader": train_loader, + "criterion": criterion, + "optimizer": optimizer, + "scheduler": scheduler, + "epoch": epoch, + "iter_meter": iter_meter, + } train_loss = train(train_args) test_loss, test_cer, test_wer = 0, 0, 0 - test_args: TestArgs = dict( - model=model, - device=device, - test_loader=valid_loader, - criterion=criterion, - tokenizer=tokenizer, - decoder="greedy", - ) + test_args: TestArgs = { + "model": model, + "device": device, + "test_loader": valid_loader, + "criterion": criterion, + "tokenizer": tokenizer, + "decoder": "greedy", + } - if epoch % eval_every == 0: + if training_config["eval_every_n"] != 0 and epoch % training_config["eval_every_n"] == 0: test_loss, test_cer, test_wer = test(test_args) - if model_save_path is None: + if checkpoints_config["model_save_path"] is None: continue - if not os.path.isdir(os.path.dirname(model_save_path)): - os.makedirs(os.path.dirname(model_save_path)) + if not os.path.isdir(os.path.dirname(checkpoints_config["model_save_path"])): + os.makedirs(os.path.dirname(checkpoints_config["model_save_path"])) + torch.save( { "epoch": epoch, @@ -322,7 +314,7 @@ def main( "avg_cer": test_cer, "avg_wer": test_wer, }, - model_save_path + str(epoch), + checkpoints_config["model_save_path"] + str(epoch), ) -- cgit v1.2.3