aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPherkel2023-09-11 22:58:19 +0200
committerPherkel2023-09-11 22:58:19 +0200
commit6f5513140f153206cfa91df3077e67ce58043d35 (patch)
tree71cee784719a1f9c912a1038824eb6bd26195408
parent64dbb9d32a51b1bce6c9de67069dc8f5943a5399 (diff)
model loading is broken :(
-rw-r--r--config.philipp.yaml9
-rw-r--r--config.yaml (renamed from config.train.yaml)10
-rw-r--r--swr2_asr/inference.py140
-rw-r--r--swr2_asr/train.py2
4 files changed, 77 insertions, 84 deletions
diff --git a/config.philipp.yaml b/config.philipp.yaml
index 6b905cd..4a723c6 100644
--- a/config.philipp.yaml
+++ b/config.philipp.yaml
@@ -12,6 +12,7 @@ training:
epochs: 3
eval_every_n: 1 # evaluate every n epochs
num_workers: 4 # number of workers for dataloader
+ device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically
dataset:
download: True
@@ -25,5 +26,9 @@ tokenizer:
tokenizer_path: "data/tokenizers/char_tokenizer_german.json"
checkpoints:
- model_load_path: ~ # path to load model from
- model_save_path: ~ # path to save model to \ No newline at end of file
+ model_load_path: "data/runs/epoch30" # path to load model from
+ model_save_path: ~ # path to save model to
+
+inference:
+ model_load_path: "data/runs/epoch30" # path to load model from
+ device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically \ No newline at end of file
diff --git a/config.train.yaml b/config.yaml
index c82439d..e5ff43a 100644
--- a/config.train.yaml
+++ b/config.yaml
@@ -4,7 +4,7 @@ model:
rnn_dim: 512
n_feats: 128 # number of mel features
stride: 2
- dropout: 0.25 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets
+ dropout: 0.3 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets
training:
learning_rate: 5e-4
@@ -19,10 +19,16 @@ dataset:
language_name: "mls_german_opus"
limited_supervision: False # set to True if you want to use limited supervision
dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%)
+ shuffle: True
tokenizer:
tokenizer_path: "data/tokenizers/char_tokenizer_german.yaml"
checkpoints:
model_load_path: "YOUR/PATH" # path to load model from
- model_save_path: "YOUR/PATH" # path to save model to \ No newline at end of file
+ model_save_path: "YOUR/PATH" # path to save model to
+
+inference:
+ model_load_path: "YOUR/PATH" # path to load model from
+ beam_width: 10 # beam width for beam search
+ device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically \ No newline at end of file
diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py
index f8342f7..6495a9a 100644
--- a/swr2_asr/inference.py
+++ b/swr2_asr/inference.py
@@ -1,35 +1,20 @@
"""Training script for the ASR model."""
-from typing import TypedDict
-
+import click
import torch
import torch.nn.functional as F
import torchaudio
+import yaml
from swr2_asr.model_deep_speech import SpeechRecognitionModel
from swr2_asr.utils.tokenizer import CharTokenizer
-class HParams(TypedDict):
- """Type for the hyperparameters of the model."""
-
- n_cnn_layers: int
- n_rnn_layers: int
- rnn_dim: int
- n_class: int
- n_feats: int
- stride: int
- dropout: float
- learning_rate: float
- batch_size: int
- epochs: int
-
-
-def greedy_decoder(output, tokenizer, collapse_repeated=True):
+def greedy_decoder(output, tokenizer: CharTokenizer, collapse_repeated=True):
"""Greedily decode a sequence."""
arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
- blank_label = tokenizer.encode(" ").ids[0]
+ blank_label = tokenizer.get_blank_token()
decodes = []
- for _i, args in enumerate(arg_maxes):
+ for args in arg_maxes:
decode = []
for j, index in enumerate(args):
if index != blank_label:
@@ -40,75 +25,72 @@ def greedy_decoder(output, tokenizer, collapse_repeated=True):
return decodes
-def main() -> None:
+@click.command()
+@click.option(
+ "--config_path",
+ default="config.yaml",
+ help="Path to yaml config file",
+ type=click.Path(exists=True),
+)
+@click.option(
+ "--file_path",
+ help="Path to audio file",
+ type=click.Path(exists=True),
+)
+def main(config_path: str, file_path: str) -> None:
"""inference function."""
-
- device = "cuda" if torch.cuda.is_available() else "cpu"
+ with open(config_path, "r", encoding="utf-8") as yaml_file:
+ config_dict = yaml.safe_load(yaml_file)
+
+ # Create separate dictionaries for each top-level key
+ model_config = config_dict.get("model", {})
+ tokenizer_config = config_dict.get("tokenizer", {})
+ inference_config = config_dict.get("inference", {})
+
+ if inference_config["device"] == "cpu":
+ device = "cpu"
+ elif inference_config["device"] == "cuda":
+ device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device) # pylint: disable=no-member
- tokenizer = CharTokenizer.from_file("char_tokenizer_german.json")
-
- spectrogram_hparams = {
- "sample_rate": 16000,
- "n_fft": 400,
- "win_length": 400,
- "hop_length": 160,
- "n_mels": 128,
- "f_min": 0,
- "f_max": 8000,
- "power": 2.0,
- }
-
- hparams = HParams(
- n_cnn_layers=3,
- n_rnn_layers=5,
- rnn_dim=512,
- n_class=tokenizer.get_vocab_size(),
- n_feats=128,
- stride=2,
- dropout=0.1,
- learning_rate=0.1,
- batch_size=30,
- epochs=100,
- )
+ tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"])
model = SpeechRecognitionModel(
- hparams["n_cnn_layers"],
- hparams["n_rnn_layers"],
- hparams["rnn_dim"],
- hparams["n_class"],
- hparams["n_feats"],
- hparams["stride"],
- hparams["dropout"],
+ model_config["n_cnn_layers"],
+ model_config["n_rnn_layers"],
+ model_config["rnn_dim"],
+ tokenizer.get_vocab_size(),
+ model_config["n_feats"],
+ model_config["stride"],
+ model_config["dropout"],
).to(device)
- checkpoint = torch.load("model8", map_location=device)
- state_dict = {
- k[len("module.") :] if k.startswith("module.") else k: v
- for k, v in checkpoint["model_state_dict"].items()
- }
- model.load_state_dict(state_dict)
-
- # waveform, sample_rate = torchaudio.load("test.opus")
- waveform, sample_rate = torchaudio.load("marvin_rede.flac") # pylint: disable=no-member
- if sample_rate != spectrogram_hparams["sample_rate"]:
- resampler = torchaudio.transforms.Resample(sample_rate, spectrogram_hparams["sample_rate"])
+ checkpoint = torch.load(inference_config["model_load_path"], map_location=device)
+ print(checkpoint["model_state_dict"].keys())
+ model.load_state_dict(checkpoint["model_state_dict"], strict=False)
+ model.eval()
+ waveform, sample_rate = torchaudio.load(file_path) # pylint: disable=no-member
+ if waveform.shape[0] != 1:
+ waveform = waveform[1]
+ waveform = waveform.unsqueeze(0)
+ if sample_rate != 16000:
+ resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(waveform)
+ sample_rate = 16000
+
+ data_processing = torchaudio.transforms.MelSpectrogram(n_mels=model_config["n_feats"])
+
+ spec = data_processing(waveform).squeeze(0).transpose(0, 1)
- spec = (
- torchaudio.transforms.MelSpectrogram(**spectrogram_hparams)(waveform)
- .squeeze(0)
- .transpose(0, 1)
- )
- specs = [spec]
- specs = torch.nn.utils.rnn.pad_sequence(specs, batch_first=True).unsqueeze(1).transpose(2, 3)
+ spec = spec.unsqueeze(0)
+ spec = spec.transpose(1, 2)
+ spec = spec.unsqueeze(0)
+ output = model(spec) # pylint: disable=not-callable
+ output = F.log_softmax(output, dim=2) # (batch, time, n_class)
+ decoded_preds = greedy_decoder(output, tokenizer)
- output = model(specs) # pylint: disable=not-callable
- output = F.log_softmax(output, dim=2)
- output = output.transpose(0, 1) # (time, batch, n_class)
- decodes = greedy_decoder(output, tokenizer)
- print(decodes)
+ print(decoded_preds)
if __name__ == "__main__":
- main()
+ main() # pylint: disable=no-value-for-parameter
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index ca70d21..ec25918 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -263,7 +263,7 @@ def main(config_path: str):
prev_epoch = 0
if checkpoints_config["model_load_path"] is not None:
- checkpoint = torch.load(checkpoints_config["model_load_path"])
+ checkpoint = torch.load(checkpoints_config["model_load_path"], map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
prev_epoch = checkpoint["epoch"]