aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/inference.py
diff options
context:
space:
mode:
authorPherkel2023-09-18 14:25:36 +0200
committerPherkel2023-09-18 14:25:36 +0200
commitd5e482b7dc3d8b6acc48a883ae9b53b354fa1715 (patch)
tree580f0ab45784664978d8f24c4831f3eec1bceb2e /swr2_asr/inference.py
parentd5689047fa7062b284d13271bda39013dcf6150f (diff)
decoder changes
Diffstat (limited to 'swr2_asr/inference.py')
-rw-r--r--swr2_asr/inference.py35
1 files changed, 15 insertions, 20 deletions
diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py
index 3c58af0..64a6eeb 100644
--- a/swr2_asr/inference.py
+++ b/swr2_asr/inference.py
@@ -6,25 +6,10 @@ import torchaudio
import yaml
from swr2_asr.model_deep_speech import SpeechRecognitionModel
+from swr2_asr.utils.decoder import decoder_factory
from swr2_asr.utils.tokenizer import CharTokenizer
-def greedy_decoder(output, tokenizer: CharTokenizer, collapse_repeated=True):
- """Greedily decode a sequence."""
- arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
- blank_label = tokenizer.get_blank_token()
- decodes = []
- for args in arg_maxes:
- decode = []
- for j, index in enumerate(args):
- if index != blank_label:
- if collapse_repeated and j != 0 and index == args[j - 1]:
- continue
- decode.append(index.item())
- decodes.append(tokenizer.decode(decode))
- return decodes
-
-
@click.command()
@click.option(
"--config_path",
@@ -46,11 +31,16 @@ def main(config_path: str, file_path: str) -> None:
model_config = config_dict.get("model", {})
tokenizer_config = config_dict.get("tokenizer", {})
inference_config = config_dict.get("inference", {})
+ decoder_config = config_dict.get("decoder", {})
- if inference_config["device"] == "cpu":
+ if inference_config.get("device", "") == "cpu":
device = "cpu"
- elif inference_config["device"] == "cuda":
+ elif inference_config.get("device", "") == "cuda":
device = "cuda" if torch.cuda.is_available() else "cpu"
+ elif inference_config.get("device", "") == "mps":
+ device = "mps"
+ else:
+ device = "cpu"
device = torch.device(device) # pylint: disable=no-member
tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"])
@@ -90,11 +80,16 @@ def main(config_path: str, file_path: str) -> None:
spec = spec.unsqueeze(0)
spec = spec.transpose(1, 2)
spec = spec.unsqueeze(0)
+ spec = spec.to(device)
output = model(spec) # pylint: disable=not-callable
output = F.log_softmax(output, dim=2) # (batch, time, n_class)
- decoded_preds = greedy_decoder(output, tokenizer)
- print(decoded_preds)
+ decoder = decoder_factory(decoder_config["type"])(tokenizer, decoder_config)
+
+ preds = decoder(output)
+ preds = " ".join(preds[0][0].words).strip()
+
+ print(preds)
if __name__ == "__main__":