{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "lm_weights = [0, 1.0, 2.5,]\n", "word_score = [-1.5, 0.0, 1.5]\n", "beam_sizes = [50, 500]\n", "beam_thresholds = [50]\n", "beam_size_token = [10, 38]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from tqdm.autonotebook import tqdm\n", "\n", "import torch\n", "from torch.utils.data import DataLoader\n", "import torch.nn.functional as F\n", "\n", "from swr2_asr.utils.decoder import decoder_factory\n", "from swr2_asr.utils.tokenizer import CharTokenizer\n", "from swr2_asr.model_deep_speech import SpeechRecognitionModel\n", "from swr2_asr.utils.data import MLSDataset, Split, DataProcessing\n", "from swr2_asr.utils.loss_scores import cer, wer" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "\n", "\n", "tokenizer = CharTokenizer.from_file(\"data/tokenizers/char_tokenizer_german.json\")\n", "\n", "# manually increment tqdm progress bar\n", "pbar = tqdm(total=len(lm_weights) * len(word_score) * len(beam_sizes) * len(beam_thresholds) * len(beam_size_token))\n", "\n", "base_config = {\n", " \"language\": \"german\",\n", " \"language_model_path\": \"data\", # path where model and supplementary files are stored\n", " \"n_gram\": 3, # n-gram size of ,the language model, 3 or 5\n", " \"beam_size\": 50 ,\n", " \"beam_threshold\": 50,\n", " \"n_best\": 1,\n", " \"lm_weight\": 2,\n", " \"word_score\": 0,\n", " }\n", "\n", "dataset_params = {\n", " \"dataset_path\": \"/Volumes/pherkel 2/SWR2-ASR\",\n", " \"language\": \"mls_german_opus\",\n", " \"split\": Split.DEV,\n", " \"limited\": True,\n", " \"download\": False,\n", " \"size\": 0.01,\n", "}\n", " \n", "\n", "model_params = {\n", " \"n_cnn_layers\": 3,\n", " \"n_rnn_layers\": 5,\n", " \"rnn_dim\": 512,\n", " \"n_class\": tokenizer.get_vocab_size(),\n", " \"n_feats\": 128,\n", " \"stride\": 2,\n", " \"dropout\": 0.1,\n", "}\n", "\n", "model = SpeechRecognitionModel(**model_params)\n", "\n", "checkpoint = torch.load(\"data/epoch67\", map_location=torch.device(\"cpu\"))\n", "\n", "state_dict = {\n", " k[len(\"module.\") :] if k.startswith(\"module.\") else k: v\n", " for k, v in checkpoint[\"model_state_dict\"].items()\n", "}\n", "model.load_state_dict(state_dict, strict=True)\n", "model.eval()\n", "\n", "\n", "dataset = MLSDataset(**dataset_params,)\n", "\n", "data_processing = DataProcessing(\"valid\", tokenizer, {\"n_feats\": model_params[\"n_feats\"]})\n", "\n", "dataloader = DataLoader(\n", " dataset=dataset,\n", " batch_size=16,\n", " shuffle = False,\n", " collate_fn=data_processing,\n", " num_workers=8,\n", " pin_memory=True,\n", ")\n", "\n", "best_wer = 1.0\n", "best_cer = 1.0\n", "best_config = None\n", "\n", "for lm_weight in lm_weights:\n", " for ws in word_score:\n", " for beam_size in beam_sizes:\n", " for beam_threshold in beam_thresholds:\n", " for beam_size_t in beam_size_token:\n", " config = base_config.copy()\n", " config[\"lm_weight\"] = lm_weight\n", " config[\"word_score\"] = ws\n", " config[\"beam_size\"] = beam_size\n", " config[\"beam_threshold\"] = beam_threshold\n", " config[\"beam_size_token\"] = beam_size_t\n", " \n", " decoder = decoder_factory(\"lm\")(tokenizer, {\"lm\": config})\n", " \n", " test_cer, test_wer = [], []\n", " with torch.no_grad():\n", " model.eval()\n", " for batch in dataloader:\n", " # perform inference, decode, compute WER and CER\n", " spectrograms, labels, input_lengths, label_lengths = batch\n", " \n", " output = model(spectrograms)\n", " output = F.log_softmax(output, dim=2)\n", " \n", " decoded_preds = decoder(output)\n", " decoded_targets = tokenizer.decode_batch(labels)\n", " \n", " for j, _ in enumerate(decoded_preds):\n", " if j >= len(decoded_targets):\n", " break\n", " pred = \" \".join(decoded_preds[j][0].words).strip()\n", " target = decoded_targets[j]\n", " \n", " test_cer.append(cer(pred, target))\n", " test_wer.append(wer(pred, target))\n", "\n", " avg_cer = sum(test_cer) / len(test_cer)\n", " avg_wer = sum(test_wer) / len(test_wer)\n", " \n", " if avg_wer < best_wer:\n", " best_wer = avg_wer\n", " best_cer = avg_cer\n", " best_config = config\n", " print(\"New best WER: \", best_wer, \" CER: \", best_cer)\n", " print(\"Config: \", best_config)\n", " print(\"LM Weight: \", lm_weight, \n", " \" Word Score: \", ws, \n", " \" Beam Size: \", beam_size, \n", " \" Beam Threshold: \", beam_threshold, \n", " \" Beam Size Token: \", beam_size_t)\n", " print(\"--------------------------------------------------------------\")\n", " \n", " pbar.update(1)" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }