diff options
author | Pherkel | 2023-09-12 14:19:15 +0200 |
---|---|---|
committer | GitHub | 2023-09-12 14:19:15 +0200 |
commit | 7a9a6c783e69b5a537a3d3f5bfe8d5fdc656c807 (patch) | |
tree | 0725631b9b68aeb65b292420a15941dcfa3fc04f /swr2_asr/utils/decoder.py | |
parent | f9846193289c81d89342b6a36e951605c2cfa189 (diff) | |
parent | 7b71dab87591e04d874cd636614450b0e65e3f2b (diff) |
Merge pull request #37 from Algo-Boys/fix/ultimate
Fix/ultimate
Diffstat (limited to 'swr2_asr/utils/decoder.py')
-rw-r--r-- | swr2_asr/utils/decoder.py | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py new file mode 100644 index 0000000..fcddb79 --- /dev/null +++ b/swr2_asr/utils/decoder.py @@ -0,0 +1,26 @@ +"""Decoder for CTC-based ASR.""" "" +import torch + +from swr2_asr.utils.tokenizer import CharTokenizer + + +# TODO: refactor to use torch CTC decoder class +def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True): + """Greedily decode a sequence.""" + blank_label = tokenizer.get_blank_token() + arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member + decodes = [] + targets = [] + for i, args in enumerate(arg_maxes): + decode = [] + targets.append(tokenizer.decode(labels[i][: label_lengths[i]].tolist())) + for j, index in enumerate(args): + if index != blank_label: + if collapse_repeated and j != 0 and index == args[j - 1]: + continue + decode.append(index.item()) + decodes.append(tokenizer.decode(decode)) + return decodes, targets + + +# TODO: add beam search decoder |