aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/inference.py
blob: d3543a959a5dc57b92218dd14c6c5c742e55a23f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
"""Training script for the ASR model."""
from typing import Union

import click
import torch
import torch.nn.functional as F
import torchaudio
import yaml

from swr2_asr.model_deep_speech import SpeechRecognitionModel
from swr2_asr.utils.decoder import decoder_factory
from swr2_asr.utils.loss_scores import wer
from swr2_asr.utils.tokenizer import CharTokenizer


@click.command()
@click.option(
    "--config_path",
    default="config.yaml",
    help="Path to yaml config file",
    type=click.Path(exists=True),
)
@click.option(
    "--file_path",
    help="Path to audio file",
    type=click.Path(exists=True),
)
# optional arguments
@click.option(
    "--target_path",
    help="Path to target text file",
)
def main(config_path: str, file_path: str, target_path: Union[str, None] = None) -> None:
    """inference function."""
    with open(config_path, "r", encoding="utf-8") as yaml_file:
        config_dict = yaml.safe_load(yaml_file)

    # Create separate dictionaries for each top-level key
    model_config = config_dict.get("model", {})
    tokenizer_config = config_dict.get("tokenizer", {})
    inference_config = config_dict.get("inference", {})
    decoder_config = config_dict.get("decoder", {})

    if inference_config.get("device", "") == "cpu":
        device = "cpu"
    elif inference_config.get("device", "") == "cuda":
        device = "cuda" if torch.cuda.is_available() else "cpu"
    elif inference_config.get("device", "") == "mps":
        device = "mps"
    else:
        device = "cpu"
    device = torch.device(device)  # pylint: disable=no-member

    tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"])

    model = SpeechRecognitionModel(
        model_config["n_cnn_layers"],
        model_config["n_rnn_layers"],
        model_config["rnn_dim"],
        tokenizer.get_vocab_size(),
        model_config["n_feats"],
        model_config["stride"],
        model_config["dropout"],
    ).to(device)

    checkpoint = torch.load(inference_config["model_load_path"], map_location=device)

    state_dict = {
        k[len("module.") :] if k.startswith("module.") else k: v
        for k, v in checkpoint["model_state_dict"].items()
    }
    model.load_state_dict(state_dict, strict=True)
    model.eval()

    waveform, sample_rate = torchaudio.load(file_path)  # pylint: disable=no-member
    if waveform.shape[0] != 1:
        waveform = waveform[1]
        waveform = waveform.unsqueeze(0)
    if sample_rate != 16000:
        resampler = torchaudio.transforms.Resample(sample_rate, 16000)
        waveform = resampler(waveform)
        sample_rate = 16000

    data_processing = torchaudio.transforms.MelSpectrogram(n_mels=model_config["n_feats"])

    spec = data_processing(waveform).squeeze(0).transpose(0, 1)

    spec = spec.unsqueeze(0)
    spec = spec.transpose(1, 2)
    spec = spec.unsqueeze(0)
    spec = spec.to(device)
    output = model(spec)  # pylint: disable=not-callable
    output = F.log_softmax(output, dim=2)  # (batch, time, n_class)

    decoder = decoder_factory(decoder_config["type"])(tokenizer, decoder_config)

    preds = decoder(output)
    preds = " ".join(preds[0][0].words).strip()

    if target_path is not None:
        with open(target_path, "r", encoding="utf-8") as target_file:
            target = target_file.read()
            target = target.lower()
            target = target.replace("«", "")
            target = target.replace("»", "")
            target = target.replace(",", "")
            target = target.replace(".", "")
            target = target.replace("?", "")
            target = target.replace("!", "")

        print("---------")
        print(f"Prediction:\n{preds}")
        print("---------")
        print(f"Target:\n{target}")
        print("---------")
        print(f"WER: {wer(preds, target)}")

    else:
        print(f"Prediction:\n{preds}")


if __name__ == "__main__":
    main()  # pylint: disable=no-value-for-parameter