aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/inference.py
diff options
context:
space:
mode:
authorPherkel2023-09-11 21:52:42 +0200
committerPherkel2023-09-11 21:52:42 +0200
commit58b30927bd870604a4077a8af9ec3cad7b0be21c (patch)
tree7dd492fa8f14ff61c88545448972022ead324c31 /swr2_asr/inference.py
parent9ca17d8a83369257f4cc42c963e25baf35a28f8f (diff)
changed config to yaml!
Diffstat (limited to 'swr2_asr/inference.py')
-rw-r--r--swr2_asr/inference.py16
1 files changed, 8 insertions, 8 deletions
diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py
index c3eec42..f8342f7 100644
--- a/swr2_asr/inference.py
+++ b/swr2_asr/inference.py
@@ -1,11 +1,12 @@
"""Training script for the ASR model."""
+from typing import TypedDict
+
import torch
-import torchaudio
import torch.nn.functional as F
-from typing import TypedDict
+import torchaudio
-from swr2_asr.tokenizer import CharTokenizer
from swr2_asr.model_deep_speech import SpeechRecognitionModel
+from swr2_asr.utils.tokenizer import CharTokenizer
class HParams(TypedDict):
@@ -28,8 +29,7 @@ def greedy_decoder(output, tokenizer, collapse_repeated=True):
arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
blank_label = tokenizer.encode(" ").ids[0]
decodes = []
- targets = []
- for i, args in enumerate(arg_maxes):
+ for _i, args in enumerate(arg_maxes):
decode = []
for j, index in enumerate(args):
if index != blank_label:
@@ -44,7 +44,7 @@ def main() -> None:
"""inference function."""
device = "cuda" if torch.cuda.is_available() else "cpu"
- device = torch.device(device)
+ device = torch.device(device) # pylint: disable=no-member
tokenizer = CharTokenizer.from_file("char_tokenizer_german.json")
@@ -90,7 +90,7 @@ def main() -> None:
model.load_state_dict(state_dict)
# waveform, sample_rate = torchaudio.load("test.opus")
- waveform, sample_rate = torchaudio.load("marvin_rede.flac")
+ waveform, sample_rate = torchaudio.load("marvin_rede.flac") # pylint: disable=no-member
if sample_rate != spectrogram_hparams["sample_rate"]:
resampler = torchaudio.transforms.Resample(sample_rate, spectrogram_hparams["sample_rate"])
waveform = resampler(waveform)
@@ -103,7 +103,7 @@ def main() -> None:
specs = [spec]
specs = torch.nn.utils.rnn.pad_sequence(specs, batch_first=True).unsqueeze(1).transpose(2, 3)
- output = model(specs)
+ output = model(specs) # pylint: disable=not-callable
output = F.log_softmax(output, dim=2)
output = output.transpose(0, 1) # (time, batch, n_class)
decodes = greedy_decoder(output, tokenizer)