aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/utils
diff options
context:
space:
mode:
authorPherkel2023-09-18 12:44:34 +0200
committerPherkel2023-09-18 12:44:34 +0200
commit9475900a1085b8277808b0a0b1555c59f7eb6d36 (patch)
treef9e17b2d15ed959bb8f405e648ce7103c1fe708e /swr2_asr/utils
parent0f14789f1c33d55dc270bcd154201cce2c4d516e (diff)
small fixes
Diffstat (limited to 'swr2_asr/utils')
-rw-r--r--swr2_asr/utils/data.py41
-rw-r--r--swr2_asr/utils/decoder.py103
-rw-r--r--swr2_asr/utils/tokenizer.py14
3 files changed, 100 insertions, 58 deletions
diff --git a/swr2_asr/utils/data.py b/swr2_asr/utils/data.py
index 19605f6..74cd572 100644
--- a/swr2_asr/utils/data.py
+++ b/swr2_asr/utils/data.py
@@ -6,7 +6,7 @@ import numpy as np
import torch
import torchaudio
from torch import Tensor, nn
-from torch.utils.data import DataLoader, Dataset
+from torch.utils.data import Dataset
from torchaudio.datasets.utils import _extract_tar
from swr2_asr.utils.tokenizer import CharTokenizer
@@ -343,26 +343,21 @@ class MLSDataset(Dataset):
dataset_lookup_entry["chapterid"],
idx,
) # type: ignore
-
+
+
def create_lexicon(vocab_counts_path, lexicon_path):
-
- words_list = []
- with open(vocab_counts_path, 'r') as file:
- for line in file:
-
- words = line.split()
- if len(words) >= 1:
-
- word = words[0]
- words_list.append(word)
-
- with open(lexicon_path, 'w') as file:
- for word in words_list:
- file.write(f"{word} ")
- for char in word:
- file.write(char + ' ')
- file.write("<SPACE>")
-
-
-
- \ No newline at end of file
+ """Create a lexicon from the vocab_counts.txt file"""
+ words_list = []
+ with open(vocab_counts_path, "r", encoding="utf-8") as file:
+ for line in file:
+ words = line.split()
+ if len(words) >= 1:
+ word = words[0]
+ words_list.append(word)
+
+ with open(lexicon_path, "w", encoding="utf-8") as file:
+ for word in words_list:
+ file.write(f"{word} ")
+ for char in word:
+ file.write(char + " ")
+ file.write("<SPACE>")
diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py
index e6b5852..098f6a4 100644
--- a/swr2_asr/utils/decoder.py
+++ b/swr2_asr/utils/decoder.py
@@ -1,14 +1,18 @@
"""Decoder for CTC-based ASR.""" ""
-import torch
-
-from swr2_asr.utils.tokenizer import CharTokenizer
-from swr2_asr.utils.data import create_lexicon
import os
+
+import torch
from torchaudio.datasets.utils import _extract_tar
from torchaudio.models.decoder import ctc_decoder
-LEXICON = "lexicon.txt"
+
+from swr2_asr.utils.data import create_lexicon
+from swr2_asr.utils.tokenizer import CharTokenizer
+
+
# TODO: refactor to use torch CTC decoder class
-def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True):
+def greedy_decoder(
+ output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True
+): # pylint: disable=redefined-outer-name
"""Greedily decode a sequence."""
blank_label = tokenizer.get_blank_token()
arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
@@ -26,27 +30,72 @@ def greedy_decoder(output, labels, label_lengths, tokenizer: CharTokenizer, coll
return decodes, targets
-# TODO: add beam search decoder
+def beam_search_decoder(
+ tokenizer: CharTokenizer, # pylint: disable=redefined-outer-name
+ tokens_path: str,
+ lang_model_path: str,
+ language: str,
+ hparams: dict, # pylint: disable=redefined-outer-name
+):
+ """Beam search decoder."""
+
+ n_gram, beam_size, beam_threshold, n_best, lm_weight, word_score = (
+ hparams["n_gram"],
+ hparams["beam_size"],
+ hparams["beam_threshold"],
+ hparams["n_best"],
+ hparams["lm_weight"],
+ hparams["word_score"],
+ )
+
+ if not os.path.isdir(os.path.join(lang_model_path, f"mls_lm_{language}")):
+ url = f"https://dl.fbaipublicfiles.com/mls/mls_lm_{language}.tar.gz"
+ torch.hub.download_url_to_file(url, f"data/mls_lm_{language}.tar.gz")
+ _extract_tar("data/mls_lm_{language}.tar.gz", overwrite=True)
-def beam_search_decoder(output, tokenizer:CharTokenizer, tokenizer_txt_path,lang_model_path):
- if not os.path.isdir(lang_model_path):
- url = f"https://dl.fbaipublicfiles.com/mls/mls_lm_german.tar.gz"
- torch.hub.download_url_to_file(
- url, "data/mls_lm_german.tar.gz" )
- _extract_tar("data/mls_lm_german.tar.gz", overwrite=True)
- if not os.path.isfile(tokenizer_txt_path):
- tokenizer.create_txt(tokenizer_txt_path)
-
- lexicon_path= os.join(lang_model_path, LEXICON)
+ if not os.path.isfile(tokens_path):
+ tokenizer.create_tokens_txt(tokens_path)
+
+ lexicon_path = os.path.join(lang_model_path, f"mls_lm_{language}", "lexicon.txt")
if not os.path.isfile(lexicon_path):
- occurences_path = os.join(lang_model_path,"vocab_counts.txt")
+ occurences_path = os.path.join(lang_model_path, f"mls_lm_{language}", "vocab_counts.txt")
create_lexicon(occurences_path, lexicon_path)
- lm_path = os.join(lang_model_path,"3-gram_lm.apa")
- decoder = ctc_decoder(lexicon = lexicon_path,
- tokenizer = tokenizer_txt_path,
- lm =lm_path,
- blank_token = '_',
- nbest =1,
- sil_token= '<SPACE>',
- unk_word = '<UNK>')
- return decoder \ No newline at end of file
+
+ lm_path = os.path.join(lang_model_path, f"mls_lm_{language}", f"{n_gram}-gram_lm.arpa")
+
+ decoder = ctc_decoder(
+ lexicon=lexicon_path,
+ tokens=tokens_path,
+ lm=lm_path,
+ blank_token="_",
+ sil_token="<SPACE>",
+ unk_word="<UNK>",
+ nbest=n_best,
+ beam_size=beam_size,
+ beam_threshold=beam_threshold,
+ lm_weight=lm_weight,
+ word_score=word_score,
+ )
+ return decoder
+
+
+if __name__ == "__main__":
+ tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json")
+ tokenizer.create_tokens_txt("data/tokenizers/tokens_german.txt")
+
+ hparams = {
+ "n_gram": 3,
+ "beam_size": 100,
+ "beam_threshold": 100,
+ "n_best": 1,
+ "lm_weight": 0.5,
+ "word_score": 1.0,
+ }
+
+ beam_search_decoder(
+ tokenizer,
+ "data/tokenizers/tokens_german.txt",
+ "data",
+ "german",
+ hparams,
+ )
diff --git a/swr2_asr/utils/tokenizer.py b/swr2_asr/utils/tokenizer.py
index 9abf57d..b1de83e 100644
--- a/swr2_asr/utils/tokenizer.py
+++ b/swr2_asr/utils/tokenizer.py
@@ -120,11 +120,9 @@ class CharTokenizer:
load_tokenizer.char_map[char] = int(index)
load_tokenizer.index_map[int(index)] = char
return load_tokenizer
-
-
- #TO DO check about the weird unknown tokens etc.
- def create_txt(self,path:str):
- with open(path, 'w',encoding="utf-8") as file:
- for key,value in self.char_map():
- file.write(f"{key}\n")
- \ No newline at end of file
+
+ def create_tokens_txt(self, path: str):
+ """Create a txt file with all the characters"""
+ with open(path, "w", encoding="utf-8") as file:
+ for char, _ in self.char_map.items():
+ file.write(f"{char}\n")