aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/utils/decoder.py
blob: fcddb790755def28cdc18a6e3fab5bdeaa0125f1 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
"""Decoder for CTC-based ASR.""" ""
import torch

from swr2_asr.utils.tokenizer import CharTokenizer


# TODO: refactor to use torch CTC decoder class
def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True):
    """Greedily decode a sequence."""
    blank_label = tokenizer.get_blank_token()
    arg_maxes = torch.argmax(output, dim=2)  # pylint: disable=no-member
    decodes = []
    targets = []
    for i, args in enumerate(arg_maxes):
        decode = []
        targets.append(tokenizer.decode(labels[i][: label_lengths[i]].tolist()))
        for j, index in enumerate(args):
            if index != blank_label:
                if collapse_repeated and j != 0 and index == args[j - 1]:
                    continue
                decode.append(index.item())
        decodes.append(tokenizer.decode(decode))
    return decodes, targets


# TODO: add beam search decoder