aboutsummaryrefslogtreecommitdiff
path: root/lm_decoder_hparams.ipynb
diff options
context:
space:
mode:
authorPherkel2023-09-18 18:11:33 +0200
committerPherkel2023-09-18 18:11:33 +0200
commitd44cf7b1cab683a8aa3876619c82226f4e6d6f3b (patch)
treedc5a0567d5ff939320c9737e1d66e9e83f0f534c /lm_decoder_hparams.ipynb
parentc09ff76ba6f4c5dd5de64a401efcd27449150aec (diff)
fix
Diffstat (limited to 'lm_decoder_hparams.ipynb')
-rw-r--r--lm_decoder_hparams.ipynb245
1 files changed, 245 insertions, 0 deletions
diff --git a/lm_decoder_hparams.ipynb b/lm_decoder_hparams.ipynb
new file mode 100644
index 0000000..5e56312
--- /dev/null
+++ b/lm_decoder_hparams.ipynb
@@ -0,0 +1,245 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "lm_weights = [0, 1.0, 2.5,]\n",
+ "word_score = [-1.5, 0.0, 1.5]\n",
+ "beam_sizes = [50, 500]\n",
+ "beam_thresholds = [50]\n",
+ "beam_size_token = [10, 38]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/var/folders/lm/1zmdkgm91k912l2vgq978z800000gn/T/ipykernel_80481/3805229751.py:1: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
+ " from tqdm.autonotebook import tqdm\n",
+ "/Users/philippmerkel/DEV/SWR2-cool-projekt/.venv/lib/python3.10/site-packages/torchaudio/models/decoder/_ctc_decoder.py:62: UserWarning: The built-in flashlight integration is deprecated, and will be removed in future release. Please install flashlight-text. https://pypi.org/project/flashlight-text/ For the detail of CTC decoder migration, please see https://github.com/pytorch/audio/issues/3088.\n",
+ " warnings.warn(\n"
+ ]
+ }
+ ],
+ "source": [
+ "from tqdm.autonotebook import tqdm\n",
+ "\n",
+ "import torch\n",
+ "from torch.utils.data import DataLoader\n",
+ "import torch.nn.functional as F\n",
+ "\n",
+ "from swr2_asr.utils.decoder import decoder_factory\n",
+ "from swr2_asr.utils.tokenizer import CharTokenizer\n",
+ "from swr2_asr.model_deep_speech import SpeechRecognitionModel\n",
+ "from swr2_asr.utils.data import MLSDataset, Split, DataProcessing\n",
+ "from swr2_asr.utils.loss_scores import cer, wer"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.jupyter.widget-view+json": {
+ "model_id": "34aafd9aca2541748dc41d8550334536",
+ "version_major": 2,
+ "version_minor": 0
+ },
+ "text/plain": [
+ " 0%| | 0/144 [00:00<?, ?it/s]"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Download flag not set, skipping download\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "/Users/philippmerkel/DEV/SWR2-cool-projekt/.venv/lib/python3.10/site-packages/torchaudio/functional/functional.py:576: UserWarning: At least one mel filterbank has all zero values. The value for `n_mels` (128) may be set too high. Or, the value for `n_freqs` (201) may be set too low.\n",
+ " warnings.warn(\n"
+ ]
+ },
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "New best WER: 0.8266228565397248 CER: 0.6048691547202959\n",
+ "Config: {'language': 'german', 'language_model_path': 'data', 'n_gram': 3, 'beam_size': 25, 'beam_threshold': 10, 'n_best': 1, 'lm_weight': 0, 'word_score': -1.5, 'beam_size_token': 10}\n",
+ "LM Weight: 0 Word Score: -1.5 Beam Size: 25 Beam Threshold: 10 Beam Size Token: 10\n",
+ "--------------------------------------------------------------\n",
+ "New best WER: 0.7900706123452581 CER: 0.49197597466135945\n",
+ "Config: {'language': 'german', 'language_model_path': 'data', 'n_gram': 3, 'beam_size': 25, 'beam_threshold': 50, 'n_best': 1, 'lm_weight': 0, 'word_score': -1.5, 'beam_size_token': 10}\n",
+ "LM Weight: 0 Word Score: -1.5 Beam Size: 25 Beam Threshold: 50 Beam Size Token: 10\n",
+ "--------------------------------------------------------------\n",
+ "New best WER: 0.7877685082828738 CER: 0.48660732878914315\n",
+ "Config: {'language': 'german', 'language_model_path': 'data', 'n_gram': 3, 'beam_size': 100, 'beam_threshold': 50, 'n_best': 1, 'lm_weight': 0, 'word_score': -1.5, 'beam_size_token': 10}\n",
+ "LM Weight: 0 Word Score: -1.5 Beam Size: 100 Beam Threshold: 50 Beam Size Token: 10\n",
+ "--------------------------------------------------------------\n"
+ ]
+ }
+ ],
+ "source": [
+ "\n",
+ "\n",
+ "tokenizer = CharTokenizer.from_file(\"data/tokenizers/char_tokenizer_german.json\")\n",
+ "\n",
+ "# manually increment tqdm progress bar\n",
+ "pbar = tqdm(total=len(lm_weights) * len(word_score) * len(beam_sizes) * len(beam_thresholds) * len(beam_size_token))\n",
+ "\n",
+ "base_config = {\n",
+ " \"language\": \"german\",\n",
+ " \"language_model_path\": \"data\", # path where model and supplementary files are stored\n",
+ " \"n_gram\": 3, # n-gram size of ,the language model, 3 or 5\n",
+ " \"beam_size\": 50 ,\n",
+ " \"beam_threshold\": 50,\n",
+ " \"n_best\": 1,\n",
+ " \"lm_weight\": 2,\n",
+ " \"word_score\": 0,\n",
+ " }\n",
+ "\n",
+ "dataset_params = {\n",
+ " \"dataset_path\": \"/Volumes/pherkel 2/SWR2-ASR\",\n",
+ " \"language\": \"mls_german_opus\",\n",
+ " \"split\": Split.DEV,\n",
+ " \"limited\": True,\n",
+ " \"download\": False,\n",
+ " \"size\": 0.01,\n",
+ "}\n",
+ " \n",
+ "\n",
+ "model_params = {\n",
+ " \"n_cnn_layers\": 3,\n",
+ " \"n_rnn_layers\": 5,\n",
+ " \"rnn_dim\": 512,\n",
+ " \"n_class\": tokenizer.get_vocab_size(),\n",
+ " \"n_feats\": 128,\n",
+ " \"stride\": 2,\n",
+ " \"dropout\": 0.1,\n",
+ "}\n",
+ "\n",
+ "model = SpeechRecognitionModel(**model_params)\n",
+ "\n",
+ "checkpoint = torch.load(\"data/epoch67\", map_location=torch.device(\"cpu\"))\n",
+ "\n",
+ "state_dict = {\n",
+ " k[len(\"module.\") :] if k.startswith(\"module.\") else k: v\n",
+ " for k, v in checkpoint[\"model_state_dict\"].items()\n",
+ "}\n",
+ "model.load_state_dict(state_dict, strict=True)\n",
+ "model.eval()\n",
+ "\n",
+ "\n",
+ "dataset = MLSDataset(**dataset_params,)\n",
+ "\n",
+ "data_processing = DataProcessing(\"valid\", tokenizer, {\"n_feats\": model_params[\"n_feats\"]})\n",
+ "\n",
+ "dataloader = DataLoader(\n",
+ " dataset=dataset,\n",
+ " batch_size=16,\n",
+ " shuffle = False,\n",
+ " collate_fn=data_processing,\n",
+ " num_workers=8,\n",
+ " pin_memory=True,\n",
+ ")\n",
+ "\n",
+ "best_wer = 1.0\n",
+ "best_cer = 1.0\n",
+ "best_config = None\n",
+ "\n",
+ "for lm_weight in lm_weights:\n",
+ " for ws in word_score:\n",
+ " for beam_size in beam_sizes:\n",
+ " for beam_threshold in beam_thresholds:\n",
+ " for beam_size_t in beam_size_token:\n",
+ " config = base_config.copy()\n",
+ " config[\"lm_weight\"] = lm_weight\n",
+ " config[\"word_score\"] = ws\n",
+ " config[\"beam_size\"] = beam_size\n",
+ " config[\"beam_threshold\"] = beam_threshold\n",
+ " config[\"beam_size_token\"] = beam_size_t\n",
+ " \n",
+ " decoder = decoder_factory(\"lm\")(tokenizer, {\"lm\": config})\n",
+ " \n",
+ " test_cer, test_wer = [], []\n",
+ " with torch.no_grad():\n",
+ " model.eval()\n",
+ " for batch in dataloader:\n",
+ " # perform inference, decode, compute WER and CER\n",
+ " spectrograms, labels, input_lengths, label_lengths = batch\n",
+ " \n",
+ " output = model(spectrograms)\n",
+ " output = F.log_softmax(output, dim=2)\n",
+ " \n",
+ " decoded_preds = decoder(output)\n",
+ " decoded_targets = tokenizer.decode_batch(labels)\n",
+ " \n",
+ " for j, _ in enumerate(decoded_preds):\n",
+ " if j >= len(decoded_targets):\n",
+ " break\n",
+ " pred = \" \".join(decoded_preds[j][0].words).strip()\n",
+ " target = decoded_targets[j]\n",
+ " \n",
+ " test_cer.append(cer(pred, target))\n",
+ " test_wer.append(wer(pred, target))\n",
+ "\n",
+ " avg_cer = sum(test_cer) / len(test_cer)\n",
+ " avg_wer = sum(test_wer) / len(test_wer)\n",
+ " \n",
+ " if avg_wer < best_wer:\n",
+ " best_wer = avg_wer\n",
+ " best_cer = avg_cer\n",
+ " best_config = config\n",
+ " print(\"New best WER: \", best_wer, \" CER: \", best_cer)\n",
+ " print(\"Config: \", best_config)\n",
+ " print(\"LM Weight: \", lm_weight, \n",
+ " \" Word Score: \", ws, \n",
+ " \" Beam Size: \", beam_size, \n",
+ " \" Beam Threshold: \", beam_threshold, \n",
+ " \" Beam Size Token: \", beam_size_t)\n",
+ " print(\"--------------------------------------------------------------\")\n",
+ " \n",
+ " pbar.update(1)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}