From d44cf7b1cab683a8aa3876619c82226f4e6d6f3b Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 18 Sep 2023 18:11:33 +0200 Subject: fix --- lm_decoder_hparams.ipynb | 245 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 245 insertions(+) create mode 100644 lm_decoder_hparams.ipynb (limited to 'lm_decoder_hparams.ipynb') diff --git a/lm_decoder_hparams.ipynb b/lm_decoder_hparams.ipynb new file mode 100644 index 0000000..5e56312 --- /dev/null +++ b/lm_decoder_hparams.ipynb @@ -0,0 +1,245 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "lm_weights = [0, 1.0, 2.5,]\n", + "word_score = [-1.5, 0.0, 1.5]\n", + "beam_sizes = [50, 500]\n", + "beam_thresholds = [50]\n", + "beam_size_token = [10, 38]" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/lm/1zmdkgm91k912l2vgq978z800000gn/T/ipykernel_80481/3805229751.py:1: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n", + " from tqdm.autonotebook import tqdm\n", + "/Users/philippmerkel/DEV/SWR2-cool-projekt/.venv/lib/python3.10/site-packages/torchaudio/models/decoder/_ctc_decoder.py:62: UserWarning: The built-in flashlight integration is deprecated, and will be removed in future release. Please install flashlight-text. https://pypi.org/project/flashlight-text/ For the detail of CTC decoder migration, please see https://github.com/pytorch/audio/issues/3088.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "from tqdm.autonotebook import tqdm\n", + "\n", + "import torch\n", + "from torch.utils.data import DataLoader\n", + "import torch.nn.functional as F\n", + "\n", + "from swr2_asr.utils.decoder import decoder_factory\n", + "from swr2_asr.utils.tokenizer import CharTokenizer\n", + "from swr2_asr.model_deep_speech import SpeechRecognitionModel\n", + "from swr2_asr.utils.data import MLSDataset, Split, DataProcessing\n", + "from swr2_asr.utils.loss_scores import cer, wer" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "34aafd9aca2541748dc41d8550334536", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/144 [00:00= len(decoded_targets):\n", + " break\n", + " pred = \" \".join(decoded_preds[j][0].words).strip()\n", + " target = decoded_targets[j]\n", + " \n", + " test_cer.append(cer(pred, target))\n", + " test_wer.append(wer(pred, target))\n", + "\n", + " avg_cer = sum(test_cer) / len(test_cer)\n", + " avg_wer = sum(test_wer) / len(test_wer)\n", + " \n", + " if avg_wer < best_wer:\n", + " best_wer = avg_wer\n", + " best_cer = avg_cer\n", + " best_config = config\n", + " print(\"New best WER: \", best_wer, \" CER: \", best_cer)\n", + " print(\"Config: \", best_config)\n", + " print(\"LM Weight: \", lm_weight, \n", + " \" Word Score: \", ws, \n", + " \" Beam Size: \", beam_size, \n", + " \" Beam Threshold: \", beam_threshold, \n", + " \" Beam Size Token: \", beam_size_t)\n", + " print(\"--------------------------------------------------------------\")\n", + " \n", + " pbar.update(1)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} -- cgit v1.2.3