diff options
author | Pherkel | 2023-09-11 14:49:28 +0200 |
---|---|---|
committer | Pherkel | 2023-09-11 14:49:28 +0200 |
commit | 9dc3bc07424908dd7cf3f052708f506fd58b6e2c (patch) | |
tree | cd45dc9b70977530669c271c09025246ebbb9fef /swr2_asr/utils/decoder.py | |
parent | 01fae2b5e395e84db6a7e9819b6f98777c46e845 (diff) |
refactor utilities (data, vis, tokenizer)
Diffstat (limited to 'swr2_asr/utils/decoder.py')
-rw-r--r-- | swr2_asr/utils/decoder.py | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py new file mode 100644 index 0000000..fcddb79 --- /dev/null +++ b/swr2_asr/utils/decoder.py @@ -0,0 +1,26 @@ +"""Decoder for CTC-based ASR.""" "" +import torch + +from swr2_asr.utils.tokenizer import CharTokenizer + + +# TODO: refactor to use torch CTC decoder class +def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True): + """Greedily decode a sequence.""" + blank_label = tokenizer.get_blank_token() + arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member + decodes = [] + targets = [] + for i, args in enumerate(arg_maxes): + decode = [] + targets.append(tokenizer.decode(labels[i][: label_lengths[i]].tolist())) + for j, index in enumerate(args): + if index != blank_label: + if collapse_repeated and j != 0 and index == args[j - 1]: + continue + decode.append(index.item()) + decodes.append(tokenizer.decode(decode)) + return decodes, targets + + +# TODO: add beam search decoder |