aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/utils
diff options
context:
space:
mode:
authorPherkel2023-09-18 15:05:21 +0200
committerPherkel2023-09-18 15:05:21 +0200
commitc09ff76ba6f4c5dd5de64a401efcd27449150aec (patch)
tree5538b4b5cab1b2306860e04b7d052ed1435e4318 /swr2_asr/utils
parentd5e482b7dc3d8b6acc48a883ae9b53b354fa1715 (diff)
added support for lm decoder during training
Diffstat (limited to 'swr2_asr/utils')
-rw-r--r--swr2_asr/utils/decoder.py2
-rw-r--r--swr2_asr/utils/tokenizer.py8
2 files changed, 9 insertions, 1 deletions
diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py
index 2b6d29b..6ffdef2 100644
--- a/swr2_asr/utils/decoder.py
+++ b/swr2_asr/utils/decoder.py
@@ -1,6 +1,6 @@
"""Decoder for CTC-based ASR.""" ""
-from dataclasses import dataclass
import os
+from dataclasses import dataclass
import torch
from torchaudio.datasets.utils import _extract_tar
diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py
index b1de83e..4e3fddd 100644
--- a/swr2_asr/utils/tokenizer.py
+++ b/swr2_asr/utils/tokenizer.py
@@ -29,9 +29,17 @@ class CharTokenizer:
"""Use a character map and convert integer labels to an text sequence"""
string = []
for i in labels:
+ i = int(i)
string.append(self.index_map[i])
return "".join(string).replace("<SPACE>", " ")
+ def decode_batch(self, labels: list[list[int]]) -> list[str]:
+ """Use a character map and convert integer labels to an text sequence"""
+ string = []
+ for label in labels:
+ string.append(self.decode(label))
+ return string
+
def get_vocab_size(self) -> int:
"""Get the number of unique characters in the dataset"""
return len(self.char_map)