diff options
author | Pherkel | 2023-09-18 14:25:36 +0200 |
---|---|---|
committer | Pherkel | 2023-09-18 14:25:36 +0200 |
commit | d5e482b7dc3d8b6acc48a883ae9b53b354fa1715 (patch) | |
tree | 580f0ab45784664978d8f24c4831f3eec1bceb2e /swr2_asr/inference.py | |
parent | d5689047fa7062b284d13271bda39013dcf6150f (diff) |
decoder changes
Diffstat (limited to 'swr2_asr/inference.py')
-rw-r--r-- | swr2_asr/inference.py | 35 |
1 files changed, 15 insertions, 20 deletions
diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py index 3c58af0..64a6eeb 100644 --- a/swr2_asr/inference.py +++ b/swr2_asr/inference.py @@ -6,25 +6,10 @@ import torchaudio import yaml from swr2_asr.model_deep_speech import SpeechRecognitionModel +from swr2_asr.utils.decoder import decoder_factory from swr2_asr.utils.tokenizer import CharTokenizer -def greedy_decoder(output, tokenizer: CharTokenizer, collapse_repeated=True): - """Greedily decode a sequence.""" - arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member - blank_label = tokenizer.get_blank_token() - decodes = [] - for args in arg_maxes: - decode = [] - for j, index in enumerate(args): - if index != blank_label: - if collapse_repeated and j != 0 and index == args[j - 1]: - continue - decode.append(index.item()) - decodes.append(tokenizer.decode(decode)) - return decodes - - @click.command() @click.option( "--config_path", @@ -46,11 +31,16 @@ def main(config_path: str, file_path: str) -> None: model_config = config_dict.get("model", {}) tokenizer_config = config_dict.get("tokenizer", {}) inference_config = config_dict.get("inference", {}) + decoder_config = config_dict.get("decoder", {}) - if inference_config["device"] == "cpu": + if inference_config.get("device", "") == "cpu": device = "cpu" - elif inference_config["device"] == "cuda": + elif inference_config.get("device", "") == "cuda": device = "cuda" if torch.cuda.is_available() else "cpu" + elif inference_config.get("device", "") == "mps": + device = "mps" + else: + device = "cpu" device = torch.device(device) # pylint: disable=no-member tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"]) @@ -90,11 +80,16 @@ def main(config_path: str, file_path: str) -> None: spec = spec.unsqueeze(0) spec = spec.transpose(1, 2) spec = spec.unsqueeze(0) + spec = spec.to(device) output = model(spec) # pylint: disable=not-callable output = F.log_softmax(output, dim=2) # (batch, time, n_class) - decoded_preds = greedy_decoder(output, tokenizer) - print(decoded_preds) + decoder = decoder_factory(decoder_config["type"])(tokenizer, decoder_config) + + preds = decoder(output) + preds = " ".join(preds[0][0].words).strip() + + print(preds) if __name__ == "__main__": |