aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--config.yaml5
-rw-r--r--swr2_asr/utils/data.py6
-rw-r--r--swr2_asr/utils/decoder.py29
3 files changed, 35 insertions, 5 deletions
diff --git a/config.yaml b/config.yaml
index e5ff43a..41b473c 100644
--- a/config.yaml
+++ b/config.yaml
@@ -31,4 +31,7 @@ checkpoints:
inference:
model_load_path: "YOUR/PATH" # path to load model from
beam_width: 10 # beam width for beam search
- device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically \ No newline at end of file
+ device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically
+
+lang_model:
+ path: "data/mls_lm_german" #path where model and supplementary files are stored
diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py
index f484bdd..19605f6 100644
--- a/swr2_asr/utils/data.py
+++ b/swr2_asr/utils/data.py
@@ -344,7 +344,7 @@ class MLSDataset(Dataset):
idx,
) # type: ignore
- def create_lexicon(vocab_counts_path, lexicon_path):
+def create_lexicon(vocab_counts_path, lexicon_path):
words_list = []
with open(vocab_counts_path, 'r') as file:
@@ -361,6 +361,8 @@ class MLSDataset(Dataset):
file.write(f"{word} ")
for char in word:
file.write(char + ' ')
- file.write("|")
+ file.write("<SPACE>")
+
+
\ No newline at end of file
diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py
index ef8de49..e6b5852 100644
--- a/swr2_asr/utils/decoder.py
+++ b/swr2_asr/utils/decoder.py
@@ -2,8 +2,11 @@
import torch
from swr2_asr.utils.tokenizer import CharTokenizer
-
-
+from swr2_asr.utils.data import create_lexicon
+import os
+from torchaudio.datasets.utils import _extract_tar
+from torchaudio.models.decoder import ctc_decoder
+LEXICON = "lexicon.txt"
# TODO: refactor to use torch CTC decoder class
def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True):
"""Greedily decode a sequence."""
@@ -25,3 +28,25 @@ def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, coll
# TODO: add beam search decoder
+def beam_search_decoder(output, tokenizer:CharTokenizer, tokenizer_txt_path,lang_model_path):
+ if not os.path.isdir(lang_model_path):
+ url = f"https://dl.fbaipublicfiles.com/mls/mls_lm_german.tar.gz"
+ torch.hub.download_url_to_file(
+ url, "data/mls_lm_german.tar.gz" )
+ _extract_tar("data/mls_lm_german.tar.gz", overwrite=True)
+ if not os.path.isfile(tokenizer_txt_path):
+ tokenizer.create_txt(tokenizer_txt_path)
+
+ lexicon_path= os.join(lang_model_path, LEXICON)
+ if not os.path.isfile(lexicon_path):
+ occurences_path = os.join(lang_model_path,"vocab_counts.txt")
+ create_lexicon(occurences_path, lexicon_path)
+ lm_path = os.join(lang_model_path,"3-gram_lm.apa")
+ decoder = ctc_decoder(lexicon = lexicon_path,
+ tokenizer = tokenizer_txt_path,
+ lm =lm_path,
+ blank_token = '_',
+ nbest =1,
+ sil_token= '<SPACE>',
+ unk_word = '<UNK>')
+ return decoder \ No newline at end of file