From 9dc3bc07424908dd7cf3f052708f506fd58b6e2c Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 11 Sep 2023 14:49:28 +0200 Subject: refactor utilities (data, vis, tokenizer) --- swr2_asr/utils/decoder.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 swr2_asr/utils/decoder.py (limited to 'swr2_asr/utils/decoder.py') diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py new file mode 100644 index 0000000..fcddb79 --- /dev/null +++ b/swr2_asr/utils/decoder.py @@ -0,0 +1,26 @@ +"""Decoder for CTC-based ASR.""" "" +import torch + +from swr2_asr.utils.tokenizer import CharTokenizer + + +# TODO: refactor to use torch CTC decoder class +def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True): + """Greedily decode a sequence.""" + blank_label = tokenizer.get_blank_token() + arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member + decodes = [] + targets = [] + for i, args in enumerate(arg_maxes): + decode = [] + targets.append(tokenizer.decode(labels[i][: label_lengths[i]].tolist())) + for j, index in enumerate(args): + if index != blank_label: + if collapse_repeated and j != 0 and index == args[j - 1]: + continue + decode.append(index.item()) + decodes.append(tokenizer.decode(decode)) + return decodes, targets + + +# TODO: add beam search decoder -- cgit v1.2.3