aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/inference.py
diff options
context:
space:
mode:
authorMarvin Borner2023-09-06 17:00:58 +0200
committerMarvin Borner2023-09-06 17:01:14 +0200
commita451ada2efba6113a6ba562c399b423804875897 (patch)
tree25c5641ff8b12fa741e01a16b70991b81f0d271f /swr2_asr/inference.py
parentf0910dfad95ff29e1bdf0a657558e17ff2230a14 (diff)
Added inference
Diffstat (limited to 'swr2_asr/inference.py')
-rw-r--r--swr2_asr/inference.py114
1 files changed, 114 insertions, 0 deletions
diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py
new file mode 100644
index 0000000..c3eec42
--- /dev/null
+++ b/swr2_asr/inference.py
@@ -0,0 +1,114 @@
+"""Training script for the ASR model."""
+import torch
+import torchaudio
+import torch.nn.functional as F
+from typing import TypedDict
+
+from swr2_asr.tokenizer import CharTokenizer
+from swr2_asr.model_deep_speech import SpeechRecognitionModel
+
+
+class HParams(TypedDict):
+ """Type for the hyperparameters of the model."""
+
+ n_cnn_layers: int
+ n_rnn_layers: int
+ rnn_dim: int
+ n_class: int
+ n_feats: int
+ stride: int
+ dropout: float
+ learning_rate: float
+ batch_size: int
+ epochs: int
+
+
+def greedy_decoder(output, tokenizer, collapse_repeated=True):
+ """Greedily decode a sequence."""
+ arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
+ blank_label = tokenizer.encode(" ").ids[0]
+ decodes = []
+ targets = []
+ for i, args in enumerate(arg_maxes):
+ decode = []
+ for j, index in enumerate(args):
+ if index != blank_label:
+ if collapse_repeated and j != 0 and index == args[j - 1]:
+ continue
+ decode.append(index.item())
+ decodes.append(tokenizer.decode(decode))
+ return decodes
+
+
+def main() -> None:
+ """inference function."""
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ device = torch.device(device)
+
+ tokenizer = CharTokenizer.from_file("char_tokenizer_german.json")
+
+ spectrogram_hparams = {
+ "sample_rate": 16000,
+ "n_fft": 400,
+ "win_length": 400,
+ "hop_length": 160,
+ "n_mels": 128,
+ "f_min": 0,
+ "f_max": 8000,
+ "power": 2.0,
+ }
+
+ hparams = HParams(
+ n_cnn_layers=3,
+ n_rnn_layers=5,
+ rnn_dim=512,
+ n_class=tokenizer.get_vocab_size(),
+ n_feats=128,
+ stride=2,
+ dropout=0.1,
+ learning_rate=0.1,
+ batch_size=30,
+ epochs=100,
+ )
+
+ model = SpeechRecognitionModel(
+ hparams["n_cnn_layers"],
+ hparams["n_rnn_layers"],
+ hparams["rnn_dim"],
+ hparams["n_class"],
+ hparams["n_feats"],
+ hparams["stride"],
+ hparams["dropout"],
+ ).to(device)
+
+ checkpoint = torch.load("model8", map_location=device)
+ state_dict = {
+ k[len("module.") :] if k.startswith("module.") else k: v
+ for k, v in checkpoint["model_state_dict"].items()
+ }
+ model.load_state_dict(state_dict)
+
+ # waveform, sample_rate = torchaudio.load("test.opus")
+ waveform, sample_rate = torchaudio.load("marvin_rede.flac")
+ if sample_rate != spectrogram_hparams["sample_rate"]:
+ resampler = torchaudio.transforms.Resample(sample_rate, spectrogram_hparams["sample_rate"])
+ waveform = resampler(waveform)
+
+ spec = (
+ torchaudio.transforms.MelSpectrogram(**spectrogram_hparams)(waveform)
+ .squeeze(0)
+ .transpose(0, 1)
+ )
+ specs = [spec]
+ specs = torch.nn.utils.rnn.pad_sequence(specs, batch_first=True).unsqueeze(1).transpose(2, 3)
+
+ output = model(specs)
+ output = F.log_softmax(output, dim=2)
+ output = output.transpose(0, 1) # (time, batch, n_class)
+ decodes = greedy_decoder(output, tokenizer)
+ print(decodes)
+
+
+if __name__ == "__main__":
+ main()