diff options
author | Pherkel | 2023-09-18 19:28:16 +0200 |
---|---|---|
committer | Pherkel | 2023-09-18 19:28:16 +0200 |
commit | 623649ed45b67e6984051b24414b0eec691d7a23 (patch) | |
tree | f91c17fc9bfcf19b5c39ee7221c4214392c24cac /swr2_asr/inference.py | |
parent | 56594331d07bc5426ed28d6496962c99f73ea675 (diff) |
added wer to inference, transcripts
Diffstat (limited to 'swr2_asr/inference.py')
-rw-r--r-- | swr2_asr/inference.py | 31 |
1 files changed, 29 insertions, 2 deletions
diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py index 64a6eeb..511aef1 100644 --- a/swr2_asr/inference.py +++ b/swr2_asr/inference.py @@ -1,4 +1,6 @@ """Training script for the ASR model.""" +from typing import Union + import click import torch import torch.nn.functional as F @@ -7,6 +9,7 @@ import yaml from swr2_asr.model_deep_speech import SpeechRecognitionModel from swr2_asr.utils.decoder import decoder_factory +from swr2_asr.utils.loss_scores import wer from swr2_asr.utils.tokenizer import CharTokenizer @@ -22,7 +25,12 @@ from swr2_asr.utils.tokenizer import CharTokenizer help="Path to audio file", type=click.Path(exists=True), ) -def main(config_path: str, file_path: str) -> None: +# optional arguments +@click.option( + "--target_path", + help="Path to target text file", +) +def main(config_path: str, file_path: str, target_path: Union[str, None] = None) -> None: """inference function.""" with open(config_path, "r", encoding="utf-8") as yaml_file: config_dict = yaml.safe_load(yaml_file) @@ -89,7 +97,26 @@ def main(config_path: str, file_path: str) -> None: preds = decoder(output) preds = " ".join(preds[0][0].words).strip() - print(preds) + if target_path is not None: + with open(target_path, "r", encoding="utf-8") as target_file: + target = target_file.read() + target = target.lower() + target = target.replace("«", "") + target = target.replace("»", "") + target = target.replace(",", "") + target = target.replace(".", "") + target = target.replace("?", "") + target = target.replace("!", "") + + print("---------") + print(f"Prediction:\n\{preds}") + print("---------") + print(f"Target:\n{target}") + print("---------") + print(f"WER: {wer(preds, target)}") + + else: + print(f"Prediction:\n{preds}") if __name__ == "__main__": |