In [None]:
lm_weights = [0, 1.0, 2.5,]
word_score = [-1.5, 0.0, 1.5]
beam_sizes = [50, 500]
beam_thresholds = [50]
beam_size_token = [10, 38]

In [None]:
from tqdm.autonotebook import tqdm

import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F

from swr2_asr.utils.decoder import decoder_factory
from swr2_asr.utils.tokenizer import CharTokenizer
from swr2_asr.model_deep_speech import SpeechRecognitionModel
from swr2_asr.utils.data import MLSDataset, Split, DataProcessing
from swr2_asr.utils.loss_scores import cer, wer

In [None]:


tokenizer = CharTokenizer.from_file("data/tokenizers/char_tokenizer_german.json")

# manually increment tqdm progress bar
pbar = tqdm(total=len(lm_weights) * len(word_score) * len(beam_sizes) * len(beam_thresholds) * len(beam_size_token))

base_config = {
 "language": "german",
 "language_model_path": "data", # path where model and supplementary files are stored
 "n_gram": 3, # n-gram size of ,the language model, 3 or 5
 "beam_size": 50 ,
 "beam_threshold": 50,
 "n_best": 1,
 "lm_weight": 2,
 "word_score": 0,
 }

dataset_params = {
 "dataset_path": "/Volumes/pherkel 2/SWR2-ASR",
 "language": "mls_german_opus",
 "split": Split.DEV,
 "limited": True,
 "download": False,
 "size": 0.01,
}
 

model_params = {
 "n_cnn_layers": 3,
 "n_rnn_layers": 5,
 "rnn_dim": 512,
 "n_class": tokenizer.get_vocab_size(),
 "n_feats": 128,
 "stride": 2,
 "dropout": 0.1,
}

model = SpeechRecognitionModel(**model_params)

checkpoint = torch.load("data/epoch67", map_location=torch.device("cpu"))

state_dict = {
 k[len("module.") :] if k.startswith("module.") else k: v
 for k, v in checkpoint["model_state_dict"].items()
}
model.load_state_dict(state_dict, strict=True)
model.eval()


dataset = MLSDataset(**dataset_params,)

data_processing = DataProcessing("valid", tokenizer, {"n_feats": model_params["n_feats"]})

dataloader = DataLoader(
 dataset=dataset,
 batch_size=16,
 shuffle = False,
 collate_fn=data_processing,
 num_workers=8,
 pin_memory=True,
)

best_wer = 1.0
best_cer = 1.0
best_config = None

for lm_weight in lm_weights:
 for ws in word_score:
 for beam_size in beam_sizes:
 for beam_threshold in beam_thresholds:
 for beam_size_t in beam_size_token:
 config = base_config.copy()
 config["lm_weight"] = lm_weight
 config["word_score"] = ws
 config["beam_size"] = beam_size
 config["beam_threshold"] = beam_threshold
 config["beam_size_token"] = beam_size_t
 
 decoder = decoder_factory("lm")(tokenizer, {"lm": config})
 
 test_cer, test_wer = [], []
 with torch.no_grad():
 model.eval()
 for batch in dataloader:
 # perform inference, decode, compute WER and CER
 spectrograms, labels, input_lengths, label_lengths = batch
 
 output = model(spectrograms)
 output = F.log_softmax(output, dim=2)
 
 decoded_preds = decoder(output)
 decoded_targets = tokenizer.decode_batch(labels)
 
 for j, _ in enumerate(decoded_preds):
 if j >= len(decoded_targets):
 break
 pred = " ".join(decoded_preds[j][0].words).strip()
 target = decoded_targets[j]
 
 test_cer.append(cer(pred, target))
 test_wer.append(wer(pred, target))

 avg_cer = sum(test_cer) / len(test_cer)
 avg_wer = sum(test_wer) / len(test_wer)
 
 if avg_wer < best_wer:
 best_wer = avg_wer
 best_cer = avg_cer
 best_config = config
 print("New best WER: ", best_wer, " CER: ", best_cer)
 print("Config: ", best_config)
 print("LM Weight: ", lm_weight, 
 " Word Score: ", ws, 
 " Beam Size: ", beam_size, 
 " Beam Threshold: ", beam_threshold, 
 " Beam Size Token: ", beam_size_t)
 print("--------------------------------------------------------------")
 
 pbar.update(1)