diff options
author | Pherkel | 2023-09-11 22:58:19 +0200 |
---|---|---|
committer | Pherkel | 2023-09-11 22:58:19 +0200 |
commit | 6f5513140f153206cfa91df3077e67ce58043d35 (patch) | |
tree | 71cee784719a1f9c912a1038824eb6bd26195408 | |
parent | 64dbb9d32a51b1bce6c9de67069dc8f5943a5399 (diff) |
model loading is broken :(
-rw-r--r-- | config.philipp.yaml | 9 | ||||
-rw-r--r-- | config.yaml (renamed from config.train.yaml) | 10 | ||||
-rw-r--r-- | swr2_asr/inference.py | 140 | ||||
-rw-r--r-- | swr2_asr/train.py | 2 |
4 files changed, 77 insertions, 84 deletions
diff --git a/config.philipp.yaml b/config.philipp.yaml index 6b905cd..4a723c6 100644 --- a/config.philipp.yaml +++ b/config.philipp.yaml @@ -12,6 +12,7 @@ training: epochs: 3 eval_every_n: 1 # evaluate every n epochs num_workers: 4 # number of workers for dataloader + device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically dataset: download: True @@ -25,5 +26,9 @@ tokenizer: tokenizer_path: "data/tokenizers/char_tokenizer_german.json" checkpoints: - model_load_path: ~ # path to load model from - model_save_path: ~ # path to save model to
\ No newline at end of file + model_load_path: "data/runs/epoch30" # path to load model from + model_save_path: ~ # path to save model to + +inference: + model_load_path: "data/runs/epoch30" # path to load model from + device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically
\ No newline at end of file diff --git a/config.train.yaml b/config.yaml index c82439d..e5ff43a 100644 --- a/config.train.yaml +++ b/config.yaml @@ -4,7 +4,7 @@ model: rnn_dim: 512 n_feats: 128 # number of mel features stride: 2 - dropout: 0.25 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets + dropout: 0.3 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets training: learning_rate: 5e-4 @@ -19,10 +19,16 @@ dataset: language_name: "mls_german_opus" limited_supervision: False # set to True if you want to use limited supervision dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%) + shuffle: True tokenizer: tokenizer_path: "data/tokenizers/char_tokenizer_german.yaml" checkpoints: model_load_path: "YOUR/PATH" # path to load model from - model_save_path: "YOUR/PATH" # path to save model to
\ No newline at end of file + model_save_path: "YOUR/PATH" # path to save model to + +inference: + model_load_path: "YOUR/PATH" # path to load model from + beam_width: 10 # beam width for beam search + device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically
\ No newline at end of file diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py index f8342f7..6495a9a 100644 --- a/swr2_asr/inference.py +++ b/swr2_asr/inference.py @@ -1,35 +1,20 @@ """Training script for the ASR model.""" -from typing import TypedDict - +import click import torch import torch.nn.functional as F import torchaudio +import yaml from swr2_asr.model_deep_speech import SpeechRecognitionModel from swr2_asr.utils.tokenizer import CharTokenizer -class HParams(TypedDict): - """Type for the hyperparameters of the model.""" - - n_cnn_layers: int - n_rnn_layers: int - rnn_dim: int - n_class: int - n_feats: int - stride: int - dropout: float - learning_rate: float - batch_size: int - epochs: int - - -def greedy_decoder(output, tokenizer, collapse_repeated=True): +def greedy_decoder(output, tokenizer: CharTokenizer, collapse_repeated=True): """Greedily decode a sequence.""" arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member - blank_label = tokenizer.encode(" ").ids[0] + blank_label = tokenizer.get_blank_token() decodes = [] - for _i, args in enumerate(arg_maxes): + for args in arg_maxes: decode = [] for j, index in enumerate(args): if index != blank_label: @@ -40,75 +25,72 @@ def greedy_decoder(output, tokenizer, collapse_repeated=True): return decodes -def main() -> None: +@click.command() +@click.option( + "--config_path", + default="config.yaml", + help="Path to yaml config file", + type=click.Path(exists=True), +) +@click.option( + "--file_path", + help="Path to audio file", + type=click.Path(exists=True), +) +def main(config_path: str, file_path: str) -> None: """inference function.""" - - device = "cuda" if torch.cuda.is_available() else "cpu" + with open(config_path, "r", encoding="utf-8") as yaml_file: + config_dict = yaml.safe_load(yaml_file) + + # Create separate dictionaries for each top-level key + model_config = config_dict.get("model", {}) + tokenizer_config = config_dict.get("tokenizer", {}) + inference_config = config_dict.get("inference", {}) + + if inference_config["device"] == "cpu": + device = "cpu" + elif inference_config["device"] == "cuda": + device = "cuda" if torch.cuda.is_available() else "cpu" device = torch.device(device) # pylint: disable=no-member - tokenizer = CharTokenizer.from_file("char_tokenizer_german.json") - - spectrogram_hparams = { - "sample_rate": 16000, - "n_fft": 400, - "win_length": 400, - "hop_length": 160, - "n_mels": 128, - "f_min": 0, - "f_max": 8000, - "power": 2.0, - } - - hparams = HParams( - n_cnn_layers=3, - n_rnn_layers=5, - rnn_dim=512, - n_class=tokenizer.get_vocab_size(), - n_feats=128, - stride=2, - dropout=0.1, - learning_rate=0.1, - batch_size=30, - epochs=100, - ) + tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"]) model = SpeechRecognitionModel( - hparams["n_cnn_layers"], - hparams["n_rnn_layers"], - hparams["rnn_dim"], - hparams["n_class"], - hparams["n_feats"], - hparams["stride"], - hparams["dropout"], + model_config["n_cnn_layers"], + model_config["n_rnn_layers"], + model_config["rnn_dim"], + tokenizer.get_vocab_size(), + model_config["n_feats"], + model_config["stride"], + model_config["dropout"], ).to(device) - checkpoint = torch.load("model8", map_location=device) - state_dict = { - k[len("module.") :] if k.startswith("module.") else k: v - for k, v in checkpoint["model_state_dict"].items() - } - model.load_state_dict(state_dict) - - # waveform, sample_rate = torchaudio.load("test.opus") - waveform, sample_rate = torchaudio.load("marvin_rede.flac") # pylint: disable=no-member - if sample_rate != spectrogram_hparams["sample_rate"]: - resampler = torchaudio.transforms.Resample(sample_rate, spectrogram_hparams["sample_rate"]) + checkpoint = torch.load(inference_config["model_load_path"], map_location=device) + print(checkpoint["model_state_dict"].keys()) + model.load_state_dict(checkpoint["model_state_dict"], strict=False) + model.eval() + waveform, sample_rate = torchaudio.load(file_path) # pylint: disable=no-member + if waveform.shape[0] != 1: + waveform = waveform[1] + waveform = waveform.unsqueeze(0) + if sample_rate != 16000: + resampler = torchaudio.transforms.Resample(sample_rate, 16000) waveform = resampler(waveform) + sample_rate = 16000 + + data_processing = torchaudio.transforms.MelSpectrogram(n_mels=model_config["n_feats"]) + + spec = data_processing(waveform).squeeze(0).transpose(0, 1) - spec = ( - torchaudio.transforms.MelSpectrogram(**spectrogram_hparams)(waveform) - .squeeze(0) - .transpose(0, 1) - ) - specs = [spec] - specs = torch.nn.utils.rnn.pad_sequence(specs, batch_first=True).unsqueeze(1).transpose(2, 3) + spec = spec.unsqueeze(0) + spec = spec.transpose(1, 2) + spec = spec.unsqueeze(0) + output = model(spec) # pylint: disable=not-callable + output = F.log_softmax(output, dim=2) # (batch, time, n_class) + decoded_preds = greedy_decoder(output, tokenizer) - output = model(specs) # pylint: disable=not-callable - output = F.log_softmax(output, dim=2) - output = output.transpose(0, 1) # (time, batch, n_class) - decodes = greedy_decoder(output, tokenizer) - print(decodes) + print(decoded_preds) if __name__ == "__main__": - main() + main() # pylint: disable=no-value-for-parameter diff --git a/swr2_asr/train.py b/swr2_asr/train.py index ca70d21..ec25918 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -263,7 +263,7 @@ def main(config_path: str): prev_epoch = 0 if checkpoints_config["model_load_path"] is not None: - checkpoint = torch.load(checkpoints_config["model_load_path"]) + checkpoint = torch.load(checkpoints_config["model_load_path"], map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) prev_epoch = checkpoint["epoch"] |