aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/utils/tokenizer.py
diff options
context:
space:
mode:
authorPherkel2023-09-18 18:13:46 +0200
committerGitHub2023-09-18 18:13:46 +0200
commitf94506764bde3e4d41dc593e9d11aa7330c00e30 (patch)
tree6fc438536a72e195805c1aea97926f4c9bbd4f85 /swr2_asr/utils/tokenizer.py
parent8b3a0b47813733ef67befa6959a4d24f8518b5b7 (diff)
parent21a3b1d7cc8544fa0031b8934283382bdfd1d8f1 (diff)
Merge pull request #38 from Algo-Boys/decoder
Decoder
Diffstat (limited to 'swr2_asr/utils/tokenizer.py')
-rw-r--r--swr2_asr/utils/tokenizer.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py
index 1cc7b84..4e3fddd 100644
--- a/swr2_asr/utils/tokenizer.py
+++ b/swr2_asr/utils/tokenizer.py
@@ -29,9 +29,17 @@ class CharTokenizer:
"""Use a character map and convert integer labels to an text sequence"""
string = []
for i in labels:
+ i = int(i)
string.append(self.index_map[i])
return "".join(string).replace("<SPACE>", " ")
+ def decode_batch(self, labels: list[list[int]]) -> list[str]:
+ """Use a character map and convert integer labels to an text sequence"""
+ string = []
+ for label in labels:
+ string.append(self.decode(label))
+ return string
+
def get_vocab_size(self) -> int:
"""Get the number of unique characters in the dataset"""
return len(self.char_map)
@@ -120,3 +128,9 @@ class CharTokenizer:
load_tokenizer.char_map[char] = int(index)
load_tokenizer.index_map[int(index)] = char
return load_tokenizer
+
+ def create_tokens_txt(self, path: str):
+ """Create a txt file with all the characters"""
+ with open(path, "w", encoding="utf-8") as file:
+ for char, _ in self.char_map.items():
+ file.write(f"{char}\n")