1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
|
"""Training script for the ASR model."""
import click
import torch
import torch.nn.functional as F
import torchaudio
import yaml
from swr2_asr.model_deep_speech import SpeechRecognitionModel
from swr2_asr.utils.tokenizer import CharTokenizer
def greedy_decoder(output, tokenizer: CharTokenizer, collapse_repeated=True):
"""Greedily decode a sequence."""
arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
blank_label = tokenizer.get_blank_token()
decodes = []
for args in arg_maxes:
decode = []
for j, index in enumerate(args):
if index != blank_label:
if collapse_repeated and j != 0 and index == args[j - 1]:
continue
decode.append(index.item())
decodes.append(tokenizer.decode(decode))
return decodes
@click.command()
@click.option(
"--config_path",
default="config.yaml",
help="Path to yaml config file",
type=click.Path(exists=True),
)
@click.option(
"--file_path",
help="Path to audio file",
type=click.Path(exists=True),
)
def main(config_path: str, file_path: str) -> None:
"""inference function."""
with open(config_path, "r", encoding="utf-8") as yaml_file:
config_dict = yaml.safe_load(yaml_file)
# Create separate dictionaries for each top-level key
model_config = config_dict.get("model", {})
tokenizer_config = config_dict.get("tokenizer", {})
inference_config = config_dict.get("inference", {})
if inference_config["device"] == "cpu":
device = "cpu"
elif inference_config["device"] == "cuda":
device = "cuda" if torch.cuda.is_available() else "cpu"
device = torch.device(device) # pylint: disable=no-member
tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"])
model = SpeechRecognitionModel(
model_config["n_cnn_layers"],
model_config["n_rnn_layers"],
model_config["rnn_dim"],
tokenizer.get_vocab_size(),
model_config["n_feats"],
model_config["stride"],
model_config["dropout"],
).to(device)
checkpoint = torch.load(inference_config["model_load_path"], map_location=device)
model.load_state_dict(checkpoint["model_state_dict"], strict=True)
model.eval()
waveform, sample_rate = torchaudio.load(file_path) # pylint: disable=no-member
if waveform.shape[0] != 1:
waveform = waveform[1]
waveform = waveform.unsqueeze(0)
if sample_rate != 16000:
resampler = torchaudio.transforms.Resample(sample_rate, 16000)
waveform = resampler(waveform)
sample_rate = 16000
data_processing = torchaudio.transforms.MelSpectrogram(n_mels=model_config["n_feats"])
spec = data_processing(waveform).squeeze(0).transpose(0, 1)
spec = spec.unsqueeze(0)
spec = spec.transpose(1, 2)
spec = spec.unsqueeze(0)
output = model(spec) # pylint: disable=not-callable
output = F.log_softmax(output, dim=2) # (batch, time, n_class)
decoded_preds = greedy_decoder(output, tokenizer)
print(decoded_preds)
if __name__ == "__main__":
main() # pylint: disable=no-value-for-parameter
|