diff options
author | Pherkel | 2023-09-18 15:05:21 +0200 |
---|---|---|
committer | Pherkel | 2023-09-18 15:05:21 +0200 |
commit | c09ff76ba6f4c5dd5de64a401efcd27449150aec (patch) | |
tree | 5538b4b5cab1b2306860e04b7d052ed1435e4318 /swr2_asr/utils | |
parent | d5e482b7dc3d8b6acc48a883ae9b53b354fa1715 (diff) |
added support for lm decoder during training
Diffstat (limited to 'swr2_asr/utils')
-rw-r--r-- | swr2_asr/utils/decoder.py | 2 | ||||
-rw-r--r-- | swr2_asr/utils/tokenizer.py | 8 |
2 files changed, 9 insertions, 1 deletions
diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py index 2b6d29b..6ffdef2 100644 --- a/swr2_asr/utils/decoder.py +++ b/swr2_asr/utils/decoder.py @@ -1,6 +1,6 @@ """Decoder for CTC-based ASR.""" "" -from dataclasses import dataclass import os +from dataclasses import dataclass import torch from torchaudio.datasets.utils import _extract_tar diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py index b1de83e..4e3fddd 100644 --- a/swr2_asr/utils/tokenizer.py +++ b/swr2_asr/utils/tokenizer.py @@ -29,9 +29,17 @@ class CharTokenizer: """Use a character map and convert integer labels to an text sequence""" string = [] for i in labels: + i = int(i) string.append(self.index_map[i]) return "".join(string).replace("<SPACE>", " ") + def decode_batch(self, labels: list[list[int]]) -> list[str]: + """Use a character map and convert integer labels to an text sequence""" + string = [] + for label in labels: + string.append(self.decode(label)) + return string + def get_vocab_size(self) -> int: """Get the number of unique characters in the dataset""" return len(self.char_map) |