aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--config.philipp.yaml55
-rw-r--r--config.yaml43
-rw-r--r--swr2_asr/inference.py35
-rw-r--r--swr2_asr/utils/decoder.py121
4 files changed, 161 insertions, 93 deletions
diff --git a/config.philipp.yaml b/config.philipp.yaml
index f72ce2e..38a68f8 100644
--- a/config.philipp.yaml
+++ b/config.philipp.yaml
@@ -1,34 +1,45 @@
+dataset:
+ download: True
+ dataset_root_path: "/Volumes/pherkel 2/SWR2-ASR" # files will be downloaded into this dir
+ language_name: "mls_german_opus"
+ limited_supervision: True # set to True if you want to use limited supervision
+ dataset_percentage: 0.15 # percentage of dataset to use (1.0 = 100%)
+ shuffle: True
+
model:
n_cnn_layers: 3
n_rnn_layers: 5
rnn_dim: 512
n_feats: 128 # number of mel features
stride: 2
- dropout: 0.2 # recommended to be around 0.4-0.6 for smaller datasets, 0.1 for really large datasets
-
-training:
- learning_rate: 0.0005
- batch_size: 32 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU)
- epochs: 150
- eval_every_n: 5 # evaluate every n epochs
- num_workers: 4 # number of workers for dataloader
- device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically
-
-dataset:
- download: true
- dataset_root_path: "data" # files will be downloaded into this dir
- language_name: "mls_german_opus"
- limited_supervision: false # set to True if you want to use limited supervision
- dataset_percentage: 1 # percentage of dataset to use (1.0 = 100%)
- shuffle: true
+ dropout: 0.6 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets
tokenizer:
tokenizer_path: "data/tokenizers/char_tokenizer_german.json"
-checkpoints:
- model_load_path: "data/runs/epoch31" # path to load model from
- model_save_path: "data/runs/epoch" # path to save model to
+decoder:
+ type: "lm" # greedy, or lm (beam search)
+
+ lm: # config for lm decoder
+ language_model_path: "data" # path where model and supplementary files are stored
+ language: "german"
+ n_gram: 3 # n-gram size of the language model, 3 or 5
+ beam_size: 50
+ beam_threshold: 50
+ n_best: 1
+ lm_weight: 2
+ word_score: 0
+
+training:
+ learning_rate: 0.0005
+ batch_size: 8 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU)
+ epochs: 3
+ eval_every_n: 3 # evaluate every n epochs
+ num_workers: 8 # number of workers for dataloader
+
+checkpoints: # use "~" to disable saving/loading
+ model_load_path: "YOUR/PATH" # path to load model from
+ model_save_path: "YOUR/PATH" # path to save model to
inference:
- model_load_path: "data/runs/epoch30" # path to load model from
- device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically \ No newline at end of file
+ model_load_path: "data/epoch67" # path to load model from \ No newline at end of file
diff --git a/config.yaml b/config.yaml
index 41b473c..d248d43 100644
--- a/config.yaml
+++ b/config.yaml
@@ -1,3 +1,11 @@
+dataset:
+ download: True
+ dataset_root_path: "YOUR/PATH" # files will be downloaded into this dir
+ language_name: "mls_german_opus"
+ limited_supervision: False # set to True if you want to use limited supervision
+ dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%)
+ shuffle: True
+
model:
n_cnn_layers: 3
n_rnn_layers: 5
@@ -6,32 +14,33 @@ model:
stride: 2
dropout: 0.3 # recommended to be around 0.4 for smaller datasets, 0.1 for really large datasets
+tokenizer:
+ tokenizer_path: "data/tokenizers/char_tokenizer_german.json"
+
+decoder:
+ type: "greedy" # greedy, or lm (beam search)
+
+ lm: # config for lm decoder
+ language_model_path: "data" # path where model and supplementary files are stored
+ language: "german"
+ n_gram: 3 # n-gram size of the language model, 3 or 5
+ beam_size: 50
+ beam_threshold: 50
+ n_best: 1
+ lm_weight: 2,
+ word_score: 0,
+
training:
- learning_rate: 5e-4
+ learning_rate: 0.0005
batch_size: 8 # recommended to maximum number that fits on the GPU (batch size of 32 fits on a 12GB GPU)
epochs: 3
eval_every_n: 3 # evaluate every n epochs
num_workers: 8 # number of workers for dataloader
-dataset:
- download: True
- dataset_root_path: "YOUR/PATH" # files will be downloaded into this dir
- language_name: "mls_german_opus"
- limited_supervision: False # set to True if you want to use limited supervision
- dataset_percentage: 1.0 # percentage of dataset to use (1.0 = 100%)
- shuffle: True
-
-tokenizer:
- tokenizer_path: "data/tokenizers/char_tokenizer_german.yaml"
-
-checkpoints:
+checkpoints: # use "~" to disable saving/loading
model_load_path: "YOUR/PATH" # path to load model from
model_save_path: "YOUR/PATH" # path to save model to
inference:
model_load_path: "YOUR/PATH" # path to load model from
- beam_width: 10 # beam width for beam search
- device: "cuda" # device to run inference on if gpu is available, else "cpu" will be set automatically
-lang_model:
- path: "data/mls_lm_german" #path where model and supplementary files are stored
diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py
index 3c58af0..64a6eeb 100644
--- a/swr2_asr/inference.py
+++ b/swr2_asr/inference.py
@@ -6,25 +6,10 @@ import torchaudio
import yaml
from swr2_asr.model_deep_speech import SpeechRecognitionModel
+from swr2_asr.utils.decoder import decoder_factory
from swr2_asr.utils.tokenizer import CharTokenizer
-def greedy_decoder(output, tokenizer: CharTokenizer, collapse_repeated=True):
- """Greedily decode a sequence."""
- arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
- blank_label = tokenizer.get_blank_token()
- decodes = []
- for args in arg_maxes:
- decode = []
- for j, index in enumerate(args):
- if index != blank_label:
- if collapse_repeated and j != 0 and index == args[j - 1]:
- continue
- decode.append(index.item())
- decodes.append(tokenizer.decode(decode))
- return decodes
-
-
@click.command()
@click.option(
"--config_path",
@@ -46,11 +31,16 @@ def main(config_path: str, file_path: str) -> None:
model_config = config_dict.get("model", {})
tokenizer_config = config_dict.get("tokenizer", {})
inference_config = config_dict.get("inference", {})
+ decoder_config = config_dict.get("decoder", {})
- if inference_config["device"] == "cpu":
+ if inference_config.get("device", "") == "cpu":
device = "cpu"
- elif inference_config["device"] == "cuda":
+ elif inference_config.get("device", "") == "cuda":
device = "cuda" if torch.cuda.is_available() else "cpu"
+ elif inference_config.get("device", "") == "mps":
+ device = "mps"
+ else:
+ device = "cpu"
device = torch.device(device) # pylint: disable=no-member
tokenizer = CharTokenizer.from_file(tokenizer_config["tokenizer_path"])
@@ -90,11 +80,16 @@ def main(config_path: str, file_path: str) -> None:
spec = spec.unsqueeze(0)
spec = spec.transpose(1, 2)
spec = spec.unsqueeze(0)
+ spec = spec.to(device)
output = model(spec) # pylint: disable=not-callable
output = F.log_softmax(output, dim=2) # (batch, time, n_class)
- decoded_preds = greedy_decoder(output, tokenizer)
- print(decoded_preds)
+ decoder = decoder_factory(decoder_config["type"])(tokenizer, decoder_config)
+
+ preds = decoder(output)
+ preds = " ".join(preds[0][0].words).strip()
+
+ print(preds)
if __name__ == "__main__":
diff --git a/swr2_asr/utils/decoder.py b/swr2_asr/utils/decoder.py
index 098f6a4..2b6d29b 100644
--- a/swr2_asr/utils/decoder.py
+++ b/swr2_asr/utils/decoder.py
@@ -1,4 +1,5 @@
"""Decoder for CTC-based ASR.""" ""
+from dataclasses import dataclass
import os
import torch
@@ -9,37 +10,39 @@ from swr2_asr.utils.data import create_lexicon
from swr2_asr.utils.tokenizer import CharTokenizer
-# TODO: refactor to use torch CTC decoder class
-def greedy_decoder(
- output, labels, label_lengths, tokenizer: CharTokenizer, collapse_repeated=True
-): # pylint: disable=redefined-outer-name
- """Greedily decode a sequence."""
- blank_label = tokenizer.get_blank_token()
- arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
- decodes = []
- targets = []
- for i, args in enumerate(arg_maxes):
- decode = []
- targets.append(tokenizer.decode(labels[i][: label_lengths[i]].tolist()))
- for j, index in enumerate(args):
- if index != blank_label:
- if collapse_repeated and j != 0 and index == args[j - 1]:
- continue
- decode.append(index.item())
- decodes.append(tokenizer.decode(decode))
- return decodes, targets
-
-
-def beam_search_decoder(
+@dataclass
+class DecoderOutput:
+ """Decoder output."""
+
+ words: list[str]
+
+
+def decoder_factory(decoder_type: str = "greedy") -> callable:
+ """Decoder factory."""
+ if decoder_type == "greedy":
+ return get_greedy_decoder
+ if decoder_type == "lm":
+ return get_beam_search_decoder
+ raise NotImplementedError
+
+
+def get_greedy_decoder(
+ tokenizer: CharTokenizer, # pylint: disable=redefined-outer-name
+ *_,
+):
+ """Greedy decoder."""
+ return GreedyDecoder(tokenizer)
+
+
+def get_beam_search_decoder(
tokenizer: CharTokenizer, # pylint: disable=redefined-outer-name
- tokens_path: str,
- lang_model_path: str,
- language: str,
hparams: dict, # pylint: disable=redefined-outer-name
):
"""Beam search decoder."""
-
- n_gram, beam_size, beam_threshold, n_best, lm_weight, word_score = (
+ hparams = hparams.get("lm", {})
+ language, lang_model_path, n_gram, beam_size, beam_threshold, n_best, lm_weight, word_score = (
+ hparams["language"],
+ hparams["language_model_path"],
hparams["n_gram"],
hparams["beam_size"],
hparams["beam_threshold"],
@@ -53,6 +56,7 @@ def beam_search_decoder(
torch.hub.download_url_to_file(url, f"data/mls_lm_{language}.tar.gz")
_extract_tar("data/mls_lm_{language}.tar.gz", overwrite=True)
+ tokens_path = os.path.join(lang_model_path, f"mls_lm_{language}", "tokens.txt")
if not os.path.isfile(tokens_path):
tokenizer.create_tokens_txt(tokens_path)
@@ -79,11 +83,66 @@ def beam_search_decoder(
return decoder
+class GreedyDecoder:
+ """Greedy decoder."""
+
+ def __init__(self, tokenizer: CharTokenizer): # pylint: disable=redefined-outer-name
+ self.tokenizer = tokenizer
+
+ def __call__(
+ self, output, greedy_type: str = "inference", labels=None, label_lengths=None
+ ): # pylint: disable=redefined-outer-name
+ """Greedily decode a sequence."""
+ if greedy_type == "train":
+ res = self.train(output, labels, label_lengths)
+ if greedy_type == "inference":
+ res = self.inference(output)
+
+ res = [[DecoderOutput(words=res)]]
+ return res
+
+ def train(self, output, labels, label_lengths):
+ """Greedily decode a sequence with known labels."""
+ blank_label = tokenizer.get_blank_token()
+ arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
+ decodes = []
+ targets = []
+ for i, args in enumerate(arg_maxes):
+ decode = []
+ targets.append(self.tokenizer.decode(labels[i][: label_lengths[i]].tolist()))
+ for j, index in enumerate(args):
+ if index != blank_label:
+ if j != 0 and index == args[j - 1]:
+ continue
+ decode.append(index.item())
+ decodes.append(self.tokenizer.decode(decode))
+ return decodes, targets
+
+ def inference(self, output):
+ """Greedily decode a sequence."""
+ collapse_repeated = True
+ arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
+ blank_label = self.tokenizer.get_blank_token()
+ decodes = []
+ for args in arg_maxes:
+ decode = []
+ for j, index in enumerate(args):
+ if index != blank_label:
+ if collapse_repeated and j != 0 and index == args[j - 1]:
+ continue
+ decode.append(index.item())
+ decodes.append(self.tokenizer.decode(decode))
+
+ return decodes
+
+
if __name__ == "__main__":
tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json")
tokenizer.create_tokens_txt("data/tokenizers/tokens_german.txt")
hparams = {
+ "language": "german",
+ "lang_model_path": "data",
"n_gram": 3,
"beam_size": 100,
"beam_threshold": 100,
@@ -92,10 +151,4 @@ if __name__ == "__main__":
"word_score": 1.0,
}
- beam_search_decoder(
- tokenizer,
- "data/tokenizers/tokens_german.txt",
- "data",
- "german",
- hparams,
- )
+ get_beam_search_decoder(tokenizer, hparams)