{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "lm_weights = [0, 1.0, 2.5,]\n", "word_score = [-1.5, 0.0, 1.5]\n", "beam_sizes = [50, 500]\n", "beam_thresholds = [50]\n", "beam_size_token = [10, 38]" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/var/folders/lm/1zmdkgm91k912l2vgq978z800000gn/T/ipykernel_80481/3805229751.py:1: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n", " from tqdm.autonotebook import tqdm\n", "/Users/philippmerkel/DEV/SWR2-cool-projekt/.venv/lib/python3.10/site-packages/torchaudio/models/decoder/_ctc_decoder.py:62: UserWarning: The built-in flashlight integration is deprecated, and will be removed in future release. Please install flashlight-text. https://pypi.org/project/flashlight-text/ For the detail of CTC decoder migration, please see https://github.com/pytorch/audio/issues/3088.\n", " warnings.warn(\n" ] } ], "source": [ "from tqdm.autonotebook import tqdm\n", "\n", "import torch\n", "from torch.utils.data import DataLoader\n", "import torch.nn.functional as F\n", "\n", "from swr2_asr.utils.decoder import decoder_factory\n", "from swr2_asr.utils.tokenizer import CharTokenizer\n", "from swr2_asr.model_deep_speech import SpeechRecognitionModel\n", "from swr2_asr.utils.data import MLSDataset, Split, DataProcessing\n", "from swr2_asr.utils.loss_scores import cer, wer" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "34aafd9aca2541748dc41d8550334536", "version_major": 2, "version_minor": 0 }, "text/plain": [ " 0%| | 0/144 [00:00= len(decoded_targets):\n", " break\n", " pred = \" \".join(decoded_preds[j][0].words).strip()\n", " target = decoded_targets[j]\n", " \n", " test_cer.append(cer(pred, target))\n", " test_wer.append(wer(pred, target))\n", "\n", " avg_cer = sum(test_cer) / len(test_cer)\n", " avg_wer = sum(test_wer) / len(test_wer)\n", " \n", " if avg_wer < best_wer:\n", " best_wer = avg_wer\n", " best_cer = avg_cer\n", " best_config = config\n", " print(\"New best WER: \", best_wer, \" CER: \", best_cer)\n", " print(\"Config: \", best_config)\n", " print(\"LM Weight: \", lm_weight, \n", " \" Word Score: \", ws, \n", " \" Beam Size: \", beam_size, \n", " \" Beam Threshold: \", beam_threshold, \n", " \" Beam Size Token: \", beam_size_t)\n", " print(\"--------------------------------------------------------------\")\n", " \n", " pbar.update(1)" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }