aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/inference.py
blob: f8342f721664ef9f33186287b5488c6d0371e291 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""Training script for the ASR model."""
from typing import TypedDict

import torch
import torch.nn.functional as F
import torchaudio

from swr2_asr.model_deep_speech import SpeechRecognitionModel
from swr2_asr.utils.tokenizer import CharTokenizer


class HParams(TypedDict):
    """Type for the hyperparameters of the model."""

    n_cnn_layers: int
    n_rnn_layers: int
    rnn_dim: int
    n_class: int
    n_feats: int
    stride: int
    dropout: float
    learning_rate: float
    batch_size: int
    epochs: int


def greedy_decoder(output, tokenizer, collapse_repeated=True):
    """Greedily decode a sequence."""
    arg_maxes = torch.argmax(output, dim=2)  # pylint: disable=no-member
    blank_label = tokenizer.encode(" ").ids[0]
    decodes = []
    for _i, args in enumerate(arg_maxes):
        decode = []
        for j, index in enumerate(args):
            if index != blank_label:
                if collapse_repeated and j != 0 and index == args[j - 1]:
                    continue
                decode.append(index.item())
        decodes.append(tokenizer.decode(decode))
    return decodes


def main() -> None:
    """inference function."""

    device = "cuda" if torch.cuda.is_available() else "cpu"
    device = torch.device(device)  # pylint: disable=no-member

    tokenizer = CharTokenizer.from_file("char_tokenizer_german.json")

    spectrogram_hparams = {
        "sample_rate": 16000,
        "n_fft": 400,
        "win_length": 400,
        "hop_length": 160,
        "n_mels": 128,
        "f_min": 0,
        "f_max": 8000,
        "power": 2.0,
    }

    hparams = HParams(
        n_cnn_layers=3,
        n_rnn_layers=5,
        rnn_dim=512,
        n_class=tokenizer.get_vocab_size(),
        n_feats=128,
        stride=2,
        dropout=0.1,
        learning_rate=0.1,
        batch_size=30,
        epochs=100,
    )

    model = SpeechRecognitionModel(
        hparams["n_cnn_layers"],
        hparams["n_rnn_layers"],
        hparams["rnn_dim"],
        hparams["n_class"],
        hparams["n_feats"],
        hparams["stride"],
        hparams["dropout"],
    ).to(device)

    checkpoint = torch.load("model8", map_location=device)
    state_dict = {
        k[len("module.") :] if k.startswith("module.") else k: v
        for k, v in checkpoint["model_state_dict"].items()
    }
    model.load_state_dict(state_dict)

    # waveform, sample_rate = torchaudio.load("test.opus")
    waveform, sample_rate = torchaudio.load("marvin_rede.flac")  # pylint: disable=no-member
    if sample_rate != spectrogram_hparams["sample_rate"]:
        resampler = torchaudio.transforms.Resample(sample_rate, spectrogram_hparams["sample_rate"])
        waveform = resampler(waveform)

    spec = (
        torchaudio.transforms.MelSpectrogram(**spectrogram_hparams)(waveform)
        .squeeze(0)
        .transpose(0, 1)
    )
    specs = [spec]
    specs = torch.nn.utils.rnn.pad_sequence(specs, batch_first=True).unsqueeze(1).transpose(2, 3)

    output = model(specs)  # pylint: disable=not-callable
    output = F.log_softmax(output, dim=2)
    output = output.transpose(0, 1)  # (time, batch, n_class)
    decodes = greedy_decoder(output, tokenizer)
    print(decodes)


if __name__ == "__main__":
    main()