diff options
author | Pherkel | 2023-09-11 21:52:42 +0200 |
---|---|---|
committer | Pherkel | 2023-09-11 21:52:42 +0200 |
commit | 58b30927bd870604a4077a8af9ec3cad7b0be21c (patch) | |
tree | 7dd492fa8f14ff61c88545448972022ead324c31 /swr2_asr/inference.py | |
parent | 9ca17d8a83369257f4cc42c963e25baf35a28f8f (diff) |
changed config to yaml!
Diffstat (limited to 'swr2_asr/inference.py')
-rw-r--r-- | swr2_asr/inference.py | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py index c3eec42..f8342f7 100644 --- a/swr2_asr/inference.py +++ b/swr2_asr/inference.py @@ -1,11 +1,12 @@ """Training script for the ASR model.""" +from typing import TypedDict + import torch -import torchaudio import torch.nn.functional as F -from typing import TypedDict +import torchaudio -from swr2_asr.tokenizer import CharTokenizer from swr2_asr.model_deep_speech import SpeechRecognitionModel +from swr2_asr.utils.tokenizer import CharTokenizer class HParams(TypedDict): @@ -28,8 +29,7 @@ def greedy_decoder(output, tokenizer, collapse_repeated=True): arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member blank_label = tokenizer.encode(" ").ids[0] decodes = [] - targets = [] - for i, args in enumerate(arg_maxes): + for _i, args in enumerate(arg_maxes): decode = [] for j, index in enumerate(args): if index != blank_label: @@ -44,7 +44,7 @@ def main() -> None: """inference function.""" device = "cuda" if torch.cuda.is_available() else "cpu" - device = torch.device(device) + device = torch.device(device) # pylint: disable=no-member tokenizer = CharTokenizer.from_file("char_tokenizer_german.json") @@ -90,7 +90,7 @@ def main() -> None: model.load_state_dict(state_dict) # waveform, sample_rate = torchaudio.load("test.opus") - waveform, sample_rate = torchaudio.load("marvin_rede.flac") + waveform, sample_rate = torchaudio.load("marvin_rede.flac") # pylint: disable=no-member if sample_rate != spectrogram_hparams["sample_rate"]: resampler = torchaudio.transforms.Resample(sample_rate, spectrogram_hparams["sample_rate"]) waveform = resampler(waveform) @@ -103,7 +103,7 @@ def main() -> None: specs = [spec] specs = torch.nn.utils.rnn.pad_sequence(specs, batch_first=True).unsqueeze(1).transpose(2, 3) - output = model(specs) + output = model(specs) # pylint: disable=not-callable output = F.log_softmax(output, dim=2) output = output.transpose(0, 1) # (time, batch, n_class) decodes = greedy_decoder(output, tokenizer) |