aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/utils/decoder.py
diff options
context:
space:
mode:
authorPherkel2023-09-18 12:44:34 +0200
committerPherkel2023-09-18 12:44:34 +0200
commit9475900a1085b8277808b0a0b1555c59f7eb6d36 (patch)
treef9e17b2d15ed959bb8f405e648ce7103c1fe708e /swr2_asr/utils/decoder.py
parent0f14789f1c33d55dc270bcd154201cce2c4d516e (diff)
small fixes
Diffstat (limited to 'swr2_asr/utils/decoder.py')
-rw-r--r--swr2_asr/utils/decoder.py103
1 files changed, 76 insertions, 27 deletions
diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py
index e6b5852..098f6a4 100644
--- a/swr2_asr/utils/decoder.py
+++ b/swr2_asr/utils/decoder.py
@@ -1,14 +1,18 @@
"""Decoder for CTC-based ASR.""" ""
-import torch
-
-from swr2_asr.utils.tokenizer import CharTokenizer
-from swr2_asr.utils.data import create_lexicon
import os
+
+import torch
from torchaudio.datasets.utils import _extract_tar
from torchaudio.models.decoder import ctc_decoder
-LEXICON = "lexicon.txt"
+
+from swr2_asr.utils.data import create_lexicon
+from swr2_asr.utils.tokenizer import CharTokenizer
+
+
# TODO: refactor to use torch CTC decoder class
-def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True):
+def greedy_decoder(
+ output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True
+): # pylint: disable=redefined-outer-name
"""Greedily decode a sequence."""
blank_label = tokenizer.get_blank_token()
arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
@@ -26,27 +30,72 @@ def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, coll
return decodes, targets
-# TODO: add beam search decoder
+def beam_search_decoder(
+ tokenizer: CharTokenizer, # pylint: disable=redefined-outer-name
+ tokens_path: str,
+ lang_model_path: str,
+ language: str,
+ hparams: dict, # pylint: disable=redefined-outer-name
+):
+ """Beam search decoder."""
+
+ n_gram, beam_size, beam_threshold, n_best, lm_weight, word_score = (
+ hparams["n_gram"],
+ hparams["beam_size"],
+ hparams["beam_threshold"],
+ hparams["n_best"],
+ hparams["lm_weight"],
+ hparams["word_score"],
+ )
+
+ if not os.path.isdir(os.path.join(lang_model_path, f"mls_lm_{language}")):
+ url = f"https://dl.fbaipublicfiles.com/mls/mls_lm_{language}.tar.gz"
+ torch.hub.download_url_to_file(url, f"data/mls_lm_{language}.tar.gz")
+ _extract_tar("data/mls_lm_{language}.tar.gz", overwrite=True)
-def beam_search_decoder(output, tokenizer:CharTokenizer, tokenizer_txt_path,lang_model_path):
- if not os.path.isdir(lang_model_path):
- url = f"https://dl.fbaipublicfiles.com/mls/mls_lm_german.tar.gz"
- torch.hub.download_url_to_file(
- url, "data/mls_lm_german.tar.gz" )
- _extract_tar("data/mls_lm_german.tar.gz", overwrite=True)
- if not os.path.isfile(tokenizer_txt_path):
- tokenizer.create_txt(tokenizer_txt_path)
-
- lexicon_path= os.join(lang_model_path, LEXICON)
+ if not os.path.isfile(tokens_path):
+ tokenizer.create_tokens_txt(tokens_path)
+
+ lexicon_path = os.path.join(lang_model_path, f"mls_lm_{language}", "lexicon.txt")
if not os.path.isfile(lexicon_path):
- occurences_path = os.join(lang_model_path,"vocab_counts.txt")
+ occurences_path = os.path.join(lang_model_path, f"mls_lm_{language}", "vocab_counts.txt")
create_lexicon(occurences_path, lexicon_path)
- lm_path = os.join(lang_model_path,"3-gram_lm.apa")
- decoder = ctc_decoder(lexicon = lexicon_path,
- tokenizer = tokenizer_txt_path,
- lm =lm_path,
- blank_token = '_',
- nbest =1,
- sil_token= '<SPACE>',
- unk_word = '<UNK>')
- return decoder \ No newline at end of file
+
+ lm_path = os.path.join(lang_model_path, f"mls_lm_{language}", f"{n_gram}-gram_lm.arpa")
+
+ decoder = ctc_decoder(
+ lexicon=lexicon_path,
+ tokens=tokens_path,
+ lm=lm_path,
+ blank_token="_",
+ sil_token="<SPACE>",
+ unk_word="<UNK>",
+ nbest=n_best,
+ beam_size=beam_size,
+ beam_threshold=beam_threshold,
+ lm_weight=lm_weight,
+ word_score=word_score,
+ )
+ return decoder
+
+
+if __name__ == "__main__":
+ tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json")
+ tokenizer.create_tokens_txt("data/tokenizers/tokens_german.txt")
+
+ hparams = {
+ "n_gram": 3,
+ "beam_size": 100,
+ "beam_threshold": 100,
+ "n_best": 1,
+ "lm_weight": 0.5,
+ "word_score": 1.0,
+ }
+
+ beam_search_decoder(
+ tokenizer,
+ "data/tokenizers/tokens_german.txt",
+ "data",
+ "german",
+ hparams,
+ )