diff options
author | Pherkel | 2023-09-18 18:11:33 +0200 |
---|---|---|
committer | Pherkel | 2023-09-18 18:11:33 +0200 |
commit | d44cf7b1cab683a8aa3876619c82226f4e6d6f3b (patch) | |
tree | dc5a0567d5ff939320c9737e1d66e9e83f0f534c /lm_decoder_hparams.ipynb | |
parent | c09ff76ba6f4c5dd5de64a401efcd27449150aec (diff) |
fix
Diffstat (limited to 'lm_decoder_hparams.ipynb')
-rw-r--r-- | lm_decoder_hparams.ipynb | 245 |
1 files changed, 245 insertions, 0 deletions
diff --git a/lm_decoder_hparams.ipynb b/lm_decoder_hparams.ipynb new file mode 100644 index 0000000..5e56312 --- /dev/null +++ b/lm_decoder_hparams.ipynb @@ -0,0 +1,245 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "lm_weights = [0, 1.0, 2.5,]\n", + "word_score = [-1.5, 0.0, 1.5]\n", + "beam_sizes = [50, 500]\n", + "beam_thresholds = [50]\n", + "beam_size_token = [10, 38]" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/lm/1zmdkgm91k912l2vgq978z800000gn/T/ipykernel_80481/3805229751.py:1: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n", + " from tqdm.autonotebook import tqdm\n", + "/Users/philippmerkel/DEV/SWR2-cool-projekt/.venv/lib/python3.10/site-packages/torchaudio/models/decoder/_ctc_decoder.py:62: UserWarning: The built-in flashlight integration is deprecated, and will be removed in future release. Please install flashlight-text. https://pypi.org/project/flashlight-text/ For the detail of CTC decoder migration, please see https://github.com/pytorch/audio/issues/3088.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "from tqdm.autonotebook import tqdm\n", + "\n", + "import torch\n", + "from torch.utils.data import DataLoader\n", + "import torch.nn.functional as F\n", + "\n", + "from swr2_asr.utils.decoder import decoder_factory\n", + "from swr2_asr.utils.tokenizer import CharTokenizer\n", + "from swr2_asr.model_deep_speech import SpeechRecognitionModel\n", + "from swr2_asr.utils.data import MLSDataset, Split, DataProcessing\n", + "from swr2_asr.utils.loss_scores import cer, wer" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "34aafd9aca2541748dc41d8550334536", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/144 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Download flag not set, skipping download\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/philippmerkel/DEV/SWR2-cool-projekt/.venv/lib/python3.10/site-packages/torchaudio/functional/functional.py:576: UserWarning: At least one mel filterbank has all zero values. The value for `n_mels` (128) may be set too high. Or, the value for `n_freqs` (201) may be set too low.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "New best WER: 0.8266228565397248 CER: 0.6048691547202959\n", + "Config: {'language': 'german', 'language_model_path': 'data', 'n_gram': 3, 'beam_size': 25, 'beam_threshold': 10, 'n_best': 1, 'lm_weight': 0, 'word_score': -1.5, 'beam_size_token': 10}\n", + "LM Weight: 0 Word Score: -1.5 Beam Size: 25 Beam Threshold: 10 Beam Size Token: 10\n", + "--------------------------------------------------------------\n", + "New best WER: 0.7900706123452581 CER: 0.49197597466135945\n", + "Config: {'language': 'german', 'language_model_path': 'data', 'n_gram': 3, 'beam_size': 25, 'beam_threshold': 50, 'n_best': 1, 'lm_weight': 0, 'word_score': -1.5, 'beam_size_token': 10}\n", + "LM Weight: 0 Word Score: -1.5 Beam Size: 25 Beam Threshold: 50 Beam Size Token: 10\n", + "--------------------------------------------------------------\n", + "New best WER: 0.7877685082828738 CER: 0.48660732878914315\n", + "Config: {'language': 'german', 'language_model_path': 'data', 'n_gram': 3, 'beam_size': 100, 'beam_threshold': 50, 'n_best': 1, 'lm_weight': 0, 'word_score': -1.5, 'beam_size_token': 10}\n", + "LM Weight: 0 Word Score: -1.5 Beam Size: 100 Beam Threshold: 50 Beam Size Token: 10\n", + "--------------------------------------------------------------\n" + ] + } + ], + "source": [ + "\n", + "\n", + "tokenizer = CharTokenizer.from_file(\"data/tokenizers/char_tokenizer_german.json\")\n", + "\n", + "# manually increment tqdm progress bar\n", + "pbar = tqdm(total=len(lm_weights) * len(word_score) * len(beam_sizes) * len(beam_thresholds) * len(beam_size_token))\n", + "\n", + "base_config = {\n", + " \"language\": \"german\",\n", + " \"language_model_path\": \"data\", # path where model and supplementary files are stored\n", + " \"n_gram\": 3, # n-gram size of ,the language model, 3 or 5\n", + " \"beam_size\": 50 ,\n", + " \"beam_threshold\": 50,\n", + " \"n_best\": 1,\n", + " \"lm_weight\": 2,\n", + " \"word_score\": 0,\n", + " }\n", + "\n", + "dataset_params = {\n", + " \"dataset_path\": \"/Volumes/pherkel 2/SWR2-ASR\",\n", + " \"language\": \"mls_german_opus\",\n", + " \"split\": Split.DEV,\n", + " \"limited\": True,\n", + " \"download\": False,\n", + " \"size\": 0.01,\n", + "}\n", + " \n", + "\n", + "model_params = {\n", + " \"n_cnn_layers\": 3,\n", + " \"n_rnn_layers\": 5,\n", + " \"rnn_dim\": 512,\n", + " \"n_class\": tokenizer.get_vocab_size(),\n", + " \"n_feats\": 128,\n", + " \"stride\": 2,\n", + " \"dropout\": 0.1,\n", + "}\n", + "\n", + "model = SpeechRecognitionModel(**model_params)\n", + "\n", + "checkpoint = torch.load(\"data/epoch67\", map_location=torch.device(\"cpu\"))\n", + "\n", + "state_dict = {\n", + " k[len(\"module.\") :] if k.startswith(\"module.\") else k: v\n", + " for k, v in checkpoint[\"model_state_dict\"].items()\n", + "}\n", + "model.load_state_dict(state_dict, strict=True)\n", + "model.eval()\n", + "\n", + "\n", + "dataset = MLSDataset(**dataset_params,)\n", + "\n", + "data_processing = DataProcessing(\"valid\", tokenizer, {\"n_feats\": model_params[\"n_feats\"]})\n", + "\n", + "dataloader = DataLoader(\n", + " dataset=dataset,\n", + " batch_size=16,\n", + " shuffle = False,\n", + " collate_fn=data_processing,\n", + " num_workers=8,\n", + " pin_memory=True,\n", + ")\n", + "\n", + "best_wer = 1.0\n", + "best_cer = 1.0\n", + "best_config = None\n", + "\n", + "for lm_weight in lm_weights:\n", + " for ws in word_score:\n", + " for beam_size in beam_sizes:\n", + " for beam_threshold in beam_thresholds:\n", + " for beam_size_t in beam_size_token:\n", + " config = base_config.copy()\n", + " config[\"lm_weight\"] = lm_weight\n", + " config[\"word_score\"] = ws\n", + " config[\"beam_size\"] = beam_size\n", + " config[\"beam_threshold\"] = beam_threshold\n", + " config[\"beam_size_token\"] = beam_size_t\n", + " \n", + " decoder = decoder_factory(\"lm\")(tokenizer, {\"lm\": config})\n", + " \n", + " test_cer, test_wer = [], []\n", + " with torch.no_grad():\n", + " model.eval()\n", + " for batch in dataloader:\n", + " # perform inference, decode, compute WER and CER\n", + " spectrograms, labels, input_lengths, label_lengths = batch\n", + " \n", + " output = model(spectrograms)\n", + " output = F.log_softmax(output, dim=2)\n", + " \n", + " decoded_preds = decoder(output)\n", + " decoded_targets = tokenizer.decode_batch(labels)\n", + " \n", + " for j, _ in enumerate(decoded_preds):\n", + " if j >= len(decoded_targets):\n", + " break\n", + " pred = \" \".join(decoded_preds[j][0].words).strip()\n", + " target = decoded_targets[j]\n", + " \n", + " test_cer.append(cer(pred, target))\n", + " test_wer.append(wer(pred, target))\n", + "\n", + " avg_cer = sum(test_cer) / len(test_cer)\n", + " avg_wer = sum(test_wer) / len(test_wer)\n", + " \n", + " if avg_wer < best_wer:\n", + " best_wer = avg_wer\n", + " best_cer = avg_cer\n", + " best_config = config\n", + " print(\"New best WER: \", best_wer, \" CER: \", best_cer)\n", + " print(\"Config: \", best_config)\n", + " print(\"LM Weight: \", lm_weight, \n", + " \" Word Score: \", ws, \n", + " \" Beam Size: \", beam_size, \n", + " \" Beam Threshold: \", beam_threshold, \n", + " \" Beam Size Token: \", beam_size_t)\n", + " print(\"--------------------------------------------------------------\")\n", + " \n", + " pbar.update(1)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} |