diff options
author | Pherkel | 2023-09-11 22:58:19 +0200 |
---|---|---|
committer | Pherkel | 2023-09-11 22:58:19 +0200 |
commit | 6f5513140f153206cfa91df3077e67ce58043d35 (patch) | |
tree | 71cee784719a1f9c912a1038824eb6bd26195408 /swr2_asr/inference.py | |
parent | 64dbb9d32a51b1bce6c9de67069dc8f5943a5399 (diff) |
model loading is broken :(
Diffstat (limited to 'swr2_asr/inference.py')
-rw-r--r-- | swr2_asr/inference.py | 140 |
1 files changed, 61 insertions, 79 deletions
diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py index f8342f7..6495a9a 100644 --- a/swr2_asr/inference.py +++ b/swr2_asr/inference.py @@ -1,35 +1,20 @@ """Training script for the ASR model.""" -from typing import TypedDict - +import click import torch import torch.nn.functional as F import torchaudio +import yaml from swr2_asr.model_deep_speech import SpeechRecognitionModel from swr2_asr.utils.tokenizer import CharTokenizer -class HParams(TypedDict): - """Type for the hyperparameters of the model.""" - - n_cnn_layers: int - n_rnn_layers: int - rnn_dim: int - n_class: int - n_feats: int - stride: int - dropout: float - learning_rate: float - batch_size: int - epochs: int - - -def greedy_decoder(output, tokenizer, collapse_repeated=True): +def greedy_decoder(output, tokenizer: CharTokenizer, collapse_repeated=True): """Greedily decode a sequence.""" arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member - blank_label = tokenizer.encode(" ").ids[0] + blank_label = tokenizer.get_blank_token() decodes = [] - for _i, args in enumerate(arg_maxes): + for args in arg_maxes: decode = [] for j, index in enumerate(args): if index != blank_label: @@ -40,75 +25,72 @@ def greedy_decoder(output, tokenizer, collapse_repeated=True): return decodes -def main() -> None: +@click.command() +@click.option( + "--config_path", + default="config.yaml", + help="Path to yaml config file", + type=click.Path(exists=True), +) +@click.option( + "--file_path", + help="Path to audio file", + type=click.Path(exists=True), +) +def main(config_path: str, file_path: str) -> None: """inference function.""" - - device = "cuda" if torch.cuda.is_available() else "cpu" + with open(config_path, "r", encoding="utf-8") as yaml_file: + config_dict = yaml.safe_load(yaml_file) + + # Create separate dictionaries for each top-level key + model_config = config_dict.get("model", {}) + tokenizer_config = config_dict.get("tokenizer", {}) + inference_config = config_dict.get("inference", {}) + + if inference_config["device"] == "cpu": + device = "cpu" + elif inference_config["device"] == "cuda": + device = "cuda" if torch.cuda.is_available() else "cpu" device = torch.device(device) # pylint: disable=no-member - tokenizer = CharTokenizer.from_file("char_tokenizer_german.json") - - spectrogram_hparams = { - "sample_rate": 16000, - "n_fft": 400, - "win_length": 400, - "hop_length": 160, - "n_mels": 128, - "f_min": 0, - "f_max": 8000, - "power": 2.0, - } - - hparams = HParams( - n_cnn_layers=3, - n_rnn_layers=5, - rnn_dim=512, - n_class=tokenizer.get_vocab_size(), - n_feats=128, - stride=2, - dropout=0.1, - learning_rate=0.1, - batch_size=30, - epochs=100, - ) + tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"]) model = SpeechRecognitionModel( - hparams["n_cnn_layers"], - hparams["n_rnn_layers"], - hparams["rnn_dim"], - hparams["n_class"], - hparams["n_feats"], - hparams["stride"], - hparams["dropout"], + model_config["n_cnn_layers"], + model_config["n_rnn_layers"], + model_config["rnn_dim"], + tokenizer.get_vocab_size(), + model_config["n_feats"], + model_config["stride"], + model_config["dropout"], ).to(device) - checkpoint = torch.load("model8", map_location=device) - state_dict = { - k[len("module.") :] if k.startswith("module.") else k: v - for k, v in checkpoint["model_state_dict"].items() - } - model.load_state_dict(state_dict) - - # waveform, sample_rate = torchaudio.load("test.opus") - waveform, sample_rate = torchaudio.load("marvin_rede.flac") # pylint: disable=no-member - if sample_rate != spectrogram_hparams["sample_rate"]: - resampler = torchaudio.transforms.Resample(sample_rate, spectrogram_hparams["sample_rate"]) + checkpoint = torch.load(inference_config["model_load_path"], map_location=device) + print(checkpoint["model_state_dict"].keys()) + model.load_state_dict(checkpoint["model_state_dict"], strict=False) + model.eval() + waveform, sample_rate = torchaudio.load(file_path) # pylint: disable=no-member + if waveform.shape[0] != 1: + waveform = waveform[1] + waveform = waveform.unsqueeze(0) + if sample_rate != 16000: + resampler = torchaudio.transforms.Resample(sample_rate, 16000) waveform = resampler(waveform) + sample_rate = 16000 + + data_processing = torchaudio.transforms.MelSpectrogram(n_mels=model_config["n_feats"]) + + spec = data_processing(waveform).squeeze(0).transpose(0, 1) - spec = ( - torchaudio.transforms.MelSpectrogram(**spectrogram_hparams)(waveform) - .squeeze(0) - .transpose(0, 1) - ) - specs = [spec] - specs = torch.nn.utils.rnn.pad_sequence(specs, batch_first=True).unsqueeze(1).transpose(2, 3) + spec = spec.unsqueeze(0) + spec = spec.transpose(1, 2) + spec = spec.unsqueeze(0) + output = model(spec) # pylint: disable=not-callable + output = F.log_softmax(output, dim=2) # (batch, time, n_class) + decoded_preds = greedy_decoder(output, tokenizer) - output = model(specs) # pylint: disable=not-callable - output = F.log_softmax(output, dim=2) - output = output.transpose(0, 1) # (time, batch, n_class) - decodes = greedy_decoder(output, tokenizer) - print(decodes) + print(decoded_preds) if __name__ == "__main__": - main() + main() # pylint: disable=no-value-for-parameter |