aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/inference.py
diff options
context:
space:
mode:
authorPherkel2023-09-18 19:28:16 +0200
committerPherkel2023-09-18 19:28:16 +0200
commit623649ed45b67e6984051b24414b0eec691d7a23 (patch)
treef91c17fc9bfcf19b5c39ee7221c4214392c24cac /swr2_asr/inference.py
parent56594331d07bc5426ed28d6496962c99f73ea675 (diff)
added wer to inference, transcripts
Diffstat (limited to 'swr2_asr/inference.py')
-rw-r--r--swr2_asr/inference.py31
1 files changed, 29 insertions, 2 deletions
diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py
index 64a6eeb..511aef1 100644
--- a/swr2_asr/inference.py
+++ b/swr2_asr/inference.py
@@ -1,4 +1,6 @@
"""Training script for the ASR model."""
+from typing import Union
+
import click
import torch
import torch.nn.functional as F
@@ -7,6 +9,7 @@ import yaml
from swr2_asr.model_deep_speech import SpeechRecognitionModel
from swr2_asr.utils.decoder import decoder_factory
+from swr2_asr.utils.loss_scores import wer
from swr2_asr.utils.tokenizer import CharTokenizer
@@ -22,7 +25,12 @@ from swr2_asr.utils.tokenizer import CharTokenizer
help="Path to audio file",
type=click.Path(exists=True),
)
-def main(config_path: str, file_path: str) -> None:
+# optional arguments
+@click.option(
+ "--target_path",
+ help="Path to target text file",
+)
+def main(config_path: str, file_path: str, target_path: Union[str, None] = None) -> None:
"""inference function."""
with open(config_path, "r", encoding="utf-8") as yaml_file:
config_dict = yaml.safe_load(yaml_file)
@@ -89,7 +97,26 @@ def main(config_path: str, file_path: str) -> None:
preds = decoder(output)
preds = " ".join(preds[0][0].words).strip()
- print(preds)
+ if target_path is not None:
+ with open(target_path, "r", encoding="utf-8") as target_file:
+ target = target_file.read()
+ target = target.lower()
+ target = target.replace("«", "")
+ target = target.replace("»", "")
+ target = target.replace(",", "")
+ target = target.replace(".", "")
+ target = target.replace("?", "")
+ target = target.replace("!", "")
+
+ print("---------")
+ print(f"Prediction:\n\{preds}")
+ print("---------")
+ print(f"Target:\n{target}")
+ print("---------")
+ print(f"WER: {wer(preds, target)}")
+
+ else:
+ print(f"Prediction:\n{preds}")
if __name__ == "__main__":