aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r--swr2_asr/train.py192
1 files changed, 92 insertions, 100 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index ac7666b..eb79ee2 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -5,11 +5,12 @@ from typing import TypedDict
import click
import torch
import torch.nn.functional as F
+import yaml
from torch import nn, optim
from torch.utils.data import DataLoader
from tqdm.autonotebook import tqdm
-from swr2_asr.model_deep_speech import HParams, SpeechRecognitionModel
+from swr2_asr.model_deep_speech import SpeechRecognitionModel
from swr2_asr.utils.data import DataProcessing, MLSDataset, Split
from swr2_asr.utils.decoder import greedy_decoder
from swr2_asr.utils.tokenizer import CharTokenizer
@@ -17,7 +18,7 @@ from swr2_asr.utils.tokenizer import CharTokenizer
from .utils.loss_scores import cer, wer
-class IterMeter(object):
+class IterMeter:
"""keeps track of total iterations"""
def __init__(self):
@@ -116,6 +117,7 @@ class TestArgs(TypedDict):
def test(test_args: TestArgs) -> tuple[float, float, float]:
+ """Test"""
print("\nevaluating...")
# get values from test_args:
@@ -128,7 +130,7 @@ def test(test_args: TestArgs) -> tuple[float, float, float]:
test_loss = 0
test_cer, test_wer = [], []
with torch.no_grad():
- for i, _data in enumerate(tqdm(test_loader, desc="Validation Batches")):
+ for _data in tqdm(test_loader, desc="Validation Batches"):
spectrograms, labels, input_lengths, label_lengths = _data
spectrograms, labels = spectrograms.to(device), labels.to(device)
@@ -142,8 +144,6 @@ def test(test_args: TestArgs) -> tuple[float, float, float]:
decoded_preds, decoded_targets = greedy_decoder(
output.transpose(0, 1), labels, label_lengths, tokenizer
)
- if i == 1:
- print(f"decoding first sample: {decoded_preds}")
for j, _ in enumerate(decoded_preds):
test_cer.append(cer(decoded_targets[j], decoded_preds[j]))
test_wer.append(wer(decoded_targets[j], decoded_preds[j]))
@@ -161,157 +161,149 @@ def test(test_args: TestArgs) -> tuple[float, float, float]:
return test_loss, avg_cer, avg_wer
-def main(
- learning_rate: float,
- batch_size: int,
- epochs: int,
- dataset_path: str,
- language: str,
- limited_supervision: bool,
- model_load_path: str,
- model_save_path: str,
- dataset_percentage: float,
- eval_every: int,
- num_workers: int,
-):
+@click.command()
+@click.option(
+ "--config_path",
+ default="config.yaml",
+ help="Path to yaml config file",
+ type=click.Path(exists=True),
+)
+def main(config_path: str):
"""Main function for training the model.
- Args:
- learning_rate: learning rate for the optimizer
- batch_size: batch size
- epochs: number of epochs to train
- dataset_path: path for the dataset
- language: language of the dataset
- limited_supervision: whether to use only limited supervision
- model_load_path: path to load a model from
- model_save_path: path to save the model to
- dataset_percentage: percentage of the dataset to use
- eval_every: evaluate every n epochs
- num_workers: number of workers for the dataloader
+ Gets all configuration arguments from yaml config file.
"""
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member
torch.manual_seed(7)
- if not os.path.isdir(dataset_path):
- os.makedirs(dataset_path)
+ with open(config_path, "r", encoding="utf-8") as yaml_file:
+ config_dict = yaml.safe_load(yaml_file)
+
+ # Create separate dictionaries for each top-level key
+ model_config = config_dict.get("model", {})
+ training_config = config_dict.get("training", {})
+ dataset_config = config_dict.get("dataset", {})
+ tokenizer_config = config_dict.get("tokenizer", {})
+ checkpoints_config = config_dict.get("checkpoints", {})
+
+ print(training_config["learning_rate"])
+
+ if not os.path.isdir(dataset_config["dataset_root_path"]):
+ os.makedirs(dataset_config["dataset_root_path"])
train_dataset = MLSDataset(
- dataset_path,
- language,
+ dataset_config["dataset_root_path"],
+ dataset_config["language_name"],
Split.TEST,
- download=True,
- limited=limited_supervision,
- size=dataset_percentage,
+ download=dataset_config["download"],
+ limited=dataset_config["limited_supervision"],
+ size=dataset_config["dataset_percentage"],
)
valid_dataset = MLSDataset(
- dataset_path,
- language,
+ dataset_config["dataset_root_path"],
+ dataset_config["language_name"],
Split.TRAIN,
- download=False,
- limited=Falimited_supervisionlse,
- size=dataset_percentage,
+ download=dataset_config["download"],
+ limited=dataset_config["limited_supervision"],
+ size=dataset_config["dataset_percentage"],
)
- # TODO: initialize and possibly train tokenizer if none found
-
- kwargs = {"num_workers": num_workers, "pin_memory": True} if use_cuda else {}
-
- hparams = HParams(
- n_cnn_layers=3,
- n_rnn_layers=5,
- rnn_dim=512,
- n_class=tokenizer.get_vocab_size(),
- n_feats=128,
- stride=2,
- dropout=0.1,
- learning_rate=learning_rate,
- batch_size=batch_size,
- epochs=epochs,
- )
+ kwargs = {"num_workers": training_config["num_workers"], "pin_memory": True} if use_cuda else {}
+
+ if tokenizer_config["tokenizer_path"] is None:
+ print("Tokenizer not found!")
+ if click.confirm("Do you want to train a new tokenizer?", default=True):
+ pass
+ else:
+ return
+ tokenizer = CharTokenizer.train(
+ dataset_config["dataset_root_path"], dataset_config["language_name"]
+ )
+ tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"])
train_data_processing = DataProcessing("train", tokenizer)
valid_data_processing = DataProcessing("valid", tokenizer)
train_loader = DataLoader(
dataset=train_dataset,
- batch_size=hparams["batch_size"],
- shuffle=True,
+ batch_size=training_config["batch_size"],
+ shuffle=dataset_config["shuffle"],
collate_fn=train_data_processing,
**kwargs,
)
valid_loader = DataLoader(
dataset=valid_dataset,
- batch_size=hparams["batch_size"],
- shuffle=False,
+ batch_size=training_config["batch_size"],
+ shuffle=dataset_config["shuffle"],
collate_fn=valid_data_processing,
**kwargs,
)
model = SpeechRecognitionModel(
- hparams["n_cnn_layers"],
- hparams["n_rnn_layers"],
- hparams["rnn_dim"],
- hparams["n_class"],
- hparams["n_feats"],
- hparams["stride"],
- hparams["dropout"],
+ model_config["n_cnn_layers"],
+ model_config["n_rnn_layers"],
+ model_config["rnn_dim"],
+ tokenizer.get_vocab_size(),
+ model_config["n_feats"],
+ model_config["stride"],
+ model_config["dropout"],
).to(device)
- optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"])
+ optimizer = optim.AdamW(model.parameters(), training_config["learning_rate"])
criterion = nn.CTCLoss(tokenizer.get_blank_token()).to(device)
scheduler = optim.lr_scheduler.OneCycleLR(
optimizer,
- max_lr=hparams["learning_rate"],
+ max_lr=training_config["learning_rate"],
steps_per_epoch=int(len(train_loader)),
- epochs=hparams["epochs"],
+ epochs=training_config["epochs"],
anneal_strategy="linear",
)
prev_epoch = 0
- if model_load_path is not None:
- checkpoint = torch.load(model_load_path)
+ if checkpoints_config["model_load_path"] is not None:
+ checkpoint = torch.load(checkpoints_config["model_load_path"])
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
prev_epoch = checkpoint["epoch"]
iter_meter = IterMeter()
- if not os.path.isdir(os.path.dirname(model_save_path)):
- os.makedirs(os.path.dirname(model_save_path))
- for epoch in range(prev_epoch + 1, epochs + 1):
- train_args: TrainArgs = dict(
- model=model,
- device=device,
- train_loader=train_loader,
- criterion=criterion,
- optimizer=optimizer,
- scheduler=scheduler,
- epoch=epoch,
- iter_meter=iter_meter,
- )
+
+ for epoch in range(prev_epoch + 1, training_config["epochs"] + 1):
+ train_args: TrainArgs = {
+ "model": model,
+ "device": device,
+ "train_loader": train_loader,
+ "criterion": criterion,
+ "optimizer": optimizer,
+ "scheduler": scheduler,
+ "epoch": epoch,
+ "iter_meter": iter_meter,
+ }
train_loss = train(train_args)
test_loss, test_cer, test_wer = 0, 0, 0
- test_args: TestArgs = dict(
- model=model,
- device=device,
- test_loader=valid_loader,
- criterion=criterion,
- tokenizer=tokenizer,
- decoder="greedy",
- )
+ test_args: TestArgs = {
+ "model": model,
+ "device": device,
+ "test_loader": valid_loader,
+ "criterion": criterion,
+ "tokenizer": tokenizer,
+ "decoder": "greedy",
+ }
- if epoch % eval_every == 0:
+ if training_config["eval_every_n"] != 0 and epoch % training_config["eval_every_n"] == 0:
test_loss, test_cer, test_wer = test(test_args)
- if model_save_path is None:
+ if checkpoints_config["model_save_path"] is None:
continue
- if not os.path.isdir(os.path.dirname(model_save_path)):
- os.makedirs(os.path.dirname(model_save_path))
+ if not os.path.isdir(os.path.dirname(checkpoints_config["model_save_path"])):
+ os.makedirs(os.path.dirname(checkpoints_config["model_save_path"]))
+
torch.save(
{
"epoch": epoch,
@@ -322,7 +314,7 @@ def main(
"avg_cer": test_cer,
"avg_wer": test_wer,
},
- model_save_path + str(epoch),
+ checkpoints_config["model_save_path"] + str(epoch),
)