aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/utils/decoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/utils/decoder.py')
-rw-r--r--swr2_asr/utils/decoder.py26
1 files changed, 26 insertions, 0 deletions
diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py
new file mode 100644
index 0000000..fcddb79
--- /dev/null
+++ b/swr2_asr/utils/decoder.py
@@ -0,0 +1,26 @@
+"""Decoder for CTC-based ASR.""" ""
+import torch
+
+from swr2_asr.utils.tokenizer import CharTokenizer
+
+
+# TODO: refactor to use torch CTC decoder class
+def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True):
+ """Greedily decode a sequence."""
+ blank_label = tokenizer.get_blank_token()
+ arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
+ decodes = []
+ targets = []
+ for i, args in enumerate(arg_maxes):
+ decode = []
+ targets.append(tokenizer.decode(labels[i][: label_lengths[i]].tolist()))
+ for j, index in enumerate(args):
+ if index != blank_label:
+ if collapse_repeated and j != 0 and index == args[j - 1]:
+ continue
+ decode.append(index.item())
+ decodes.append(tokenizer.decode(decode))
+ return decodes, targets
+
+
+# TODO: add beam search decoder