aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/utils
diff options
context:
space:
mode:
authorJoJoBarthold22023-09-18 12:19:31 +0200
committerJoJoBarthold22023-09-18 12:19:31 +0200
commit0f14789f1c33d55dc270bcd154201cce2c4d516e (patch)
tree7d3e80fc80233eba2005399dc85afaa52452a854 /swr2_asr/utils
parentec8bfe9df205608282e5297635363fc8fc8fe55b (diff)
reset commit history
Diffstat (limited to 'swr2_asr/utils')
-rw-r--r--swr2_asr/utils/data.py6
-rw-r--r--swr2_asr/utils/decoder.py29
2 files changed, 31 insertions, 4 deletions
diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py
index f484bdd..19605f6 100644
--- a/swr2_asr/utils/data.py
+++ b/swr2_asr/utils/data.py
@@ -344,7 +344,7 @@ class MLSDataset(Dataset):
idx,
) # type: ignore
- def create_lexicon(vocab_counts_path, lexicon_path):
+def create_lexicon(vocab_counts_path, lexicon_path):
words_list = []
with open(vocab_counts_path, 'r') as file:
@@ -361,6 +361,8 @@ class MLSDataset(Dataset):
file.write(f"{word} ")
for char in word:
file.write(char + ' ')
- file.write("|")
+ file.write("<SPACE>")
+
+
\ No newline at end of file
diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py
index ef8de49..e6b5852 100644
--- a/swr2_asr/utils/decoder.py
+++ b/swr2_asr/utils/decoder.py
@@ -2,8 +2,11 @@
import torch
from swr2_asr.utils.tokenizer import CharTokenizer
-
-
+from swr2_asr.utils.data import create_lexicon
+import os
+from torchaudio.datasets.utils import _extract_tar
+from torchaudio.models.decoder import ctc_decoder
+LEXICON = "lexicon.txt"
# TODO: refactor to use torch CTC decoder class
def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True):
"""Greedily decode a sequence."""
@@ -25,3 +28,25 @@ def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, coll
# TODO: add beam search decoder
+def beam_search_decoder(output, tokenizer:CharTokenizer, tokenizer_txt_path,lang_model_path):
+ if not os.path.isdir(lang_model_path):
+ url = f"https://dl.fbaipublicfiles.com/mls/mls_lm_german.tar.gz"
+ torch.hub.download_url_to_file(
+ url, "data/mls_lm_german.tar.gz" )
+ _extract_tar("data/mls_lm_german.tar.gz", overwrite=True)
+ if not os.path.isfile(tokenizer_txt_path):
+ tokenizer.create_txt(tokenizer_txt_path)
+
+ lexicon_path= os.join(lang_model_path, LEXICON)
+ if not os.path.isfile(lexicon_path):
+ occurences_path = os.join(lang_model_path,"vocab_counts.txt")
+ create_lexicon(occurences_path, lexicon_path)
+ lm_path = os.join(lang_model_path,"3-gram_lm.apa")
+ decoder = ctc_decoder(lexicon = lexicon_path,
+ tokenizer = tokenizer_txt_path,
+ lm =lm_path,
+ blank_token = '_',
+ nbest =1,
+ sil_token= '<SPACE>',
+ unk_word = '<UNK>')
+ return decoder \ No newline at end of file