diff options
author | Pherkel | 2023-09-12 14:19:15 +0200 |
---|---|---|
committer | GitHub | 2023-09-12 14:19:15 +0200 |
commit | 7a9a6c783e69b5a537a3d3f5bfe8d5fdc656c807 (patch) | |
tree | 0725631b9b68aeb65b292420a15941dcfa3fc04f /swr2_asr/train.py | |
parent | f9846193289c81d89342b6a36e951605c2cfa189 (diff) | |
parent | 7b71dab87591e04d874cd636614450b0e65e3f2b (diff) |
Merge pull request #37 from Algo-Boys/fix/ultimate
Fix/ultimate
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r-- | swr2_asr/train.py | 424 |
1 files changed, 216 insertions, 208 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 9f12bcb..ffdae73 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -5,49 +5,17 @@ from typing import TypedDict import click import torch import torch.nn.functional as F +import yaml from torch import nn, optim from torch.utils.data import DataLoader -from tqdm import tqdm +from tqdm.autonotebook import tqdm from swr2_asr.model_deep_speech import SpeechRecognitionModel -from swr2_asr.tokenizer import CharTokenizer, train_char_tokenizer -from swr2_asr.utils import MLSDataset, Split, collate_fn,plot - -from .loss_scores import cer, wer - - -class HParams(TypedDict): - """Type for the hyperparameters of the model.""" - - n_cnn_layers: int - n_rnn_layers: int - rnn_dim: int - n_class: int - n_feats: int - stride: int - dropout: float - learning_rate: float - batch_size: int - epochs: int - - -def greedy_decoder(output, tokenizer, labels, label_lengths, collapse_repeated=True): - """Greedily decode a sequence.""" - print("output shape", output.shape) - arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member - blank_label = tokenizer.encode(" ").ids[0] - decodes = [] - targets = [] - for i, args in enumerate(arg_maxes): - decode = [] - targets.append(tokenizer.decode([int(x) for x in labels[i][: label_lengths[i]].tolist()])) - for j, index in enumerate(args): - if index != blank_label: - if collapse_repeated and j != 0 and index == args[j - 1]: - continue - decode.append(index.item()) - decodes.append(tokenizer.decode(decode)) - return decodes, targets +from swr2_asr.utils.data import DataProcessing, MLSDataset, Split +from swr2_asr.utils.decoder import greedy_decoder +from swr2_asr.utils.tokenizer import CharTokenizer + +from .utils.loss_scores import cer, wer class IterMeter: @@ -61,252 +29,292 @@ class IterMeter: self.val += 1 def get(self): - """get""" + """get steps""" return self.val -def train( - model, - device, - train_loader, - criterion, - optimizer, - scheduler, - epoch, - iter_meter, -): - """Train""" +class TrainArgs(TypedDict): + """Type for the arguments of the training function.""" + + model: SpeechRecognitionModel + device: torch.device # pylint: disable=no-member + train_loader: DataLoader + criterion: nn.CTCLoss + optimizer: optim.AdamW + scheduler: optim.lr_scheduler.OneCycleLR + epoch: int + iter_meter: IterMeter + + +def train(train_args) -> float: + """Train + Args: + model: model + device: device type + train_loader: train dataloader + criterion: loss function + optimizer: optimizer + scheduler: learning rate scheduler + epoch: epoch number + iter_meter: iteration meter + + Returns: + avg_train_loss: avg_train_loss for the epoch + + Information: + spectrograms: (batch, time, feature) + labels: (batch, label_length) + + model output: (batch,time, n_class) + + """ + # get values from train_args: + ( + model, + device, + train_loader, + criterion, + optimizer, + scheduler, + epoch, + iter_meter, + ) = train_args.values() + model.train() - print(f"Epoch: {epoch}") - losses = [] - for _data in tqdm(train_loader, desc="batches"): - spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device) + print(f"training batch {epoch}") + train_losses = [] + for _data in tqdm(train_loader, desc="Training batches"): + spectrograms, labels, input_lengths, label_lengths = _data + spectrograms, labels = spectrograms.to(device), labels.to(device) + optimizer.zero_grad() output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) - loss = criterion(output, labels, _data["input_length"], _data["utterance_length"]) + loss = criterion(output, labels, input_lengths, label_lengths) + train_losses.append(loss) loss.backward() optimizer.step() scheduler.step() iter_meter.step() + avg_train_loss = sum(train_losses) / len(train_losses) + print(f"Train set: Average loss: {avg_train_loss:.2f}") + return avg_train_loss + - losses.append(loss.item()) +class TestArgs(TypedDict): + """Type for the arguments of the test function.""" - print(f"loss in epoch {epoch}: {sum(losses) / len(losses)}") - return sum(losses) / len(losses) + model: SpeechRecognitionModel + device: torch.device # pylint: disable=no-member + test_loader: DataLoader + criterion: nn.CTCLoss + tokenizer: CharTokenizer + decoder: str -def test(model, device, test_loader, criterion, tokenizer): +def test(test_args: TestArgs) -> tuple[float, float, float]: """Test""" print("\nevaluating...") + + # get values from test_args: + model, device, test_loader, criterion, tokenizer, decoder = test_args.values() + + if decoder == "greedy": + decoder = greedy_decoder + model.eval() test_loss = 0 test_cer, test_wer = [], [] with torch.no_grad(): - for _data in test_loader: - spectrograms, labels = _data["spectrogram"].to(device), _data["utterance"].to(device) + for _data in tqdm(test_loader, desc="Validation Batches"): + spectrograms, labels, input_lengths, label_lengths = _data + spectrograms, labels = spectrograms.to(device), labels.to(device) output = model(spectrograms) # (batch, time, n_class) output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) - loss = criterion(output, labels, _data["input_length"], _data["utterance_length"]) + loss = criterion(output, labels, input_lengths, label_lengths) test_loss += loss.item() / len(test_loader) decoded_preds, decoded_targets = greedy_decoder( - output=output.transpose(0, 1), - labels=labels, - label_lengths=_data["utterance_length"], - tokenizer=tokenizer, + output.transpose(0, 1), labels, label_lengths, tokenizer ) - for j, pred in enumerate(decoded_preds): - test_cer.append(cer(decoded_targets[j], pred)) - test_wer.append(wer(decoded_targets[j], pred)) + for j, _ in enumerate(decoded_preds): + test_cer.append(cer(decoded_targets[j], decoded_preds[j])) + test_wer.append(wer(decoded_targets[j], decoded_preds[j])) avg_cer = sum(test_cer) / len(test_cer) avg_wer = sum(test_wer) / len(test_wer) print( - f"Test set: Average loss:\ - {test_loss}, Average CER: {None} Average WER: {None}\n" + f"Test set: \ + Average loss: {test_loss:.4f}, \ + Average CER: {avg_cer:4f} \ + Average WER: {avg_wer:.4f}\n" ) return test_loss, avg_cer, avg_wer -def run( - learning_rate: float, - batch_size: int, - epochs: int, - load: bool, - path: str, - dataset_path: str, - language: str, -) -> None: - """Runs the training script.""" +@click.command() +@click.option( + "--config_path", + default="config.yaml", + help="Path to yaml config file", + type=click.Path(exists=True), +) +def main(config_path: str): + """Main function for training the model. + + Gets all configuration arguments from yaml config file. + """ use_cuda = torch.cuda.is_available() - torch.manual_seed(42) device = torch.device("cuda" if use_cuda else "cpu") # pylint: disable=no-member - # device = torch.device("mps") - # load dataset + torch.manual_seed(7) + + with open(config_path, "r", encoding="utf-8") as yaml_file: + config_dict = yaml.safe_load(yaml_file) + + # Create separate dictionaries for each top-level key + model_config = config_dict.get("model", {}) + training_config = config_dict.get("training", {}) + dataset_config = config_dict.get("dataset", {}) + tokenizer_config = config_dict.get("tokenizer", {}) + checkpoints_config = config_dict.get("checkpoints", {}) + + if not os.path.isdir(dataset_config["dataset_root_path"]): + os.makedirs(dataset_config["dataset_root_path"]) + train_dataset = MLSDataset( - dataset_path, language, Split.TRAIN, download=True, spectrogram_hparams=None, limited=True + dataset_config["dataset_root_path"], + dataset_config["language_name"], + Split.TRAIN, + download=dataset_config["download"], + limited=dataset_config["limited_supervision"], + size=dataset_config["dataset_percentage"], ) valid_dataset = MLSDataset( - dataset_path, language, Split.VALID, download=True, spectrogram_hparams=None, limited=True + dataset_config["dataset_root_path"], + dataset_config["language_name"], + Split.TEST, + download=dataset_config["download"], + limited=dataset_config["limited_supervision"], + size=dataset_config["dataset_percentage"], ) - # load tokenizer (bpe by default): - if not os.path.isfile("data/tokenizers/char_tokenizer_german.json"): - print("There is no tokenizer available. Do you want to train it on the dataset?") - input("Press Enter to continue...") - train_char_tokenizer( - dataset_path=dataset_path, - language=language, - split="all", - out_path="data/tokenizers/char_tokenizer_german.json", - ) - - tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json") + kwargs = {"num_workers": training_config["num_workers"], "pin_memory": True} if use_cuda else {} - train_dataset.set_tokenizer(tokenizer) # type: ignore - valid_dataset.set_tokenizer(tokenizer) # type: ignore - - print(f"Waveform shape: {train_dataset[0]['waveform'].shape}") + if tokenizer_config["tokenizer_path"] is None: + print("Tokenizer not found!") + if click.confirm("Do you want to train a new tokenizer?", default=True): + pass + else: + return + tokenizer = CharTokenizer.train( + dataset_config["dataset_root_path"], dataset_config["language_name"] + ) + tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"]) - hparams = HParams( - n_cnn_layers=3, - n_rnn_layers=5, - rnn_dim=512, - n_class=tokenizer.get_vocab_size(), - n_feats=128, - stride=2, - dropout=0.1, - learning_rate=learning_rate, - batch_size=batch_size, - epochs=epochs, - ) + train_data_processing = DataProcessing("train", tokenizer, {"n_feats": model_config["n_feats"]}) + valid_data_processing = DataProcessing("valid", tokenizer, {"n_feats": model_config["n_feats"]}) train_loader = DataLoader( - train_dataset, - batch_size=hparams["batch_size"], - shuffle=True, - collate_fn=lambda x: collate_fn(x), + dataset=train_dataset, + batch_size=training_config["batch_size"], + shuffle=dataset_config["shuffle"], + collate_fn=train_data_processing, + **kwargs, ) - valid_loader = DataLoader( - valid_dataset, - batch_size=hparams["batch_size"], - shuffle=True, - collate_fn=lambda x: collate_fn(x), + dataset=valid_dataset, + batch_size=training_config["batch_size"], + shuffle=dataset_config["shuffle"], + collate_fn=valid_data_processing, + **kwargs, ) - # enable flag to find the most compatible algorithms in advance - if use_cuda: - torch.backends.cudnn.benchmark = True # pylance: disable=no-member - model = SpeechRecognitionModel( - hparams["n_cnn_layers"], - hparams["n_rnn_layers"], - hparams["rnn_dim"], - hparams["n_class"], - hparams["n_feats"], - hparams["stride"], - hparams["dropout"], + model_config["n_cnn_layers"], + model_config["n_rnn_layers"], + model_config["rnn_dim"], + tokenizer.get_vocab_size(), + model_config["n_feats"], + model_config["stride"], + model_config["dropout"], ).to(device) - print(tokenizer.encode(" ")) - print("Num Model Parameters", sum((param.nelement() for param in model.parameters()))) - optimizer = optim.AdamW(model.parameters(), hparams["learning_rate"]) - criterion = nn.CTCLoss(tokenizer.encode(" ").ids[0]).to(device) - if load: - checkpoint = torch.load(path) - model.load_state_dict(checkpoint["model_state_dict"]) - optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) - epoch = checkpoint["epoch"] - loss = checkpoint["loss"] + + optimizer = optim.AdamW(model.parameters(), training_config["learning_rate"]) + criterion = nn.CTCLoss(tokenizer.get_blank_token()).to(device) scheduler = optim.lr_scheduler.OneCycleLR( optimizer, - max_lr=hparams["learning_rate"], + max_lr=training_config["learning_rate"], steps_per_epoch=int(len(train_loader)), - epochs=hparams["epochs"], + epochs=training_config["epochs"], anneal_strategy="linear", ) + prev_epoch = 0 + + if checkpoints_config["model_load_path"] is not None: + checkpoint = torch.load(checkpoints_config["model_load_path"], map_location=device) + model.load_state_dict(checkpoint["model_state_dict"]) + optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + prev_epoch = checkpoint["epoch"] iter_meter = IterMeter() - for epoch in range(1, epochs + 1): - loss = train( - model, - device, - train_loader, - criterion, - optimizer, - scheduler, - epoch, - iter_meter, - ) - test_loss, avg_cer, avg_wer = test( - model=model, - device=device, - test_loader=valid_loader, - criterion=criterion, - tokenizer=tokenizer, - ) - print("saving epoch", str(epoch)) + for epoch in range(prev_epoch + 1, training_config["epochs"] + 1): + train_args: TrainArgs = { + "model": model, + "device": device, + "train_loader": train_loader, + "criterion": criterion, + "optimizer": optimizer, + "scheduler": scheduler, + "epoch": epoch, + "iter_meter": iter_meter, + } + + train_loss = train(train_args) + + test_loss, test_cer, test_wer = 0, 0, 0 + + test_args: TestArgs = { + "model": model, + "device": device, + "test_loader": valid_loader, + "criterion": criterion, + "tokenizer": tokenizer, + "decoder": "greedy", + } + + if training_config["eval_every_n"] != 0 and epoch % training_config["eval_every_n"] == 0: + test_loss, test_cer, test_wer = test(test_args) + + if checkpoints_config["model_save_path"] is None: + continue + + if not os.path.isdir(os.path.dirname(checkpoints_config["model_save_path"])): + os.makedirs(os.path.dirname(checkpoints_config["model_save_path"])) + torch.save( { "epoch": epoch, "model_state_dict": model.state_dict(), - "loss": loss, + "optimizer_state_dict": optimizer.state_dict(), + "train_loss": train_loss, "test_loss": test_loss, - "avg_cer": avg_cer, - "avg_wer": avg_wer, + "avg_cer": test_cer, + "avg_wer": test_wer, }, - path + str(epoch), - plot(epochs,path) + checkpoints_config["model_save_path"] + str(epoch), ) -@click.command() -@click.option("--learning_rate", default=1e-3, help="Learning rate") -@click.option("--batch_size", default=10, help="Batch size") -@click.option("--epochs", default=1, help="Number of epochs") -@click.option("--load", default=False, help="Do you want to load a model?") -@click.option( - "--path", - default="model", - help="Path where the model will be saved to/loaded from", -) -@click.option( - "--dataset_path", - default="data/", - help="Path for the dataset directory", -) -def run_cli( - learning_rate: float, - batch_size: int, - epochs: int, - load: bool, - path: str, - dataset_path: str, -) -> None: - """Runs the training script.""" - - run( - learning_rate=learning_rate, - batch_size=batch_size, - epochs=epochs, - load=load, - path=path, - dataset_path=dataset_path, - language="mls_german_opus", - ) - - if __name__ == "__main__": - run_cli() # pylint: disable=no-value-for-parameter + main() # pylint: disable=no-value-for-parameter |