{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import os\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "import torch" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# load model checkpoints to get data\n", "CHECKPOINTS_DIR = 'data'\n", "CHECKPOINTS_PREFIX = 'epoch'\n", "\n", "rows = []\n", "\n", "epoch = 67\n", "while True:\n", " ckp_path = os.path.join(CHECKPOINTS_DIR, CHECKPOINTS_PREFIX + str(epoch))\n", " print(ckp_path)\n", " try:\n", " current_state = torch.load(ckp_path, map_location=torch.device(\"cpu\"))\n", " except FileNotFoundError:\n", " break\n", " \n", " rows.append({\"epoch\": epoch, \"train_loss\": current_state[\"train_loss\"].item(), \"test_loss\": current_state[\"test_loss\"], \"avg_cer\": current_state[\"avg_cer\"], \"avg_wer\": current_state[\"avg_wer\"]})\n", " \n", " epoch += 1\n", "\n", "\n", "# create dataframe, then csv from dataframe\n", "df = pd.DataFrame(rows)\n", "df.to_csv(\"data/losses.csv\", index=False)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "csv_path = \"losses.csv\"\n", "\n", "# load csv\n", "df = pd.read_csv(csv_path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# plot train_loss and test_loss\n", "# do not use colors, distinguis by line style. use solid for train_loss and dashed for test_loss\n", "plt.plot(df['epoch'], df['train_loss'], label='train_loss', linestyle='solid', color='black')\n", "\n", "# create zip with epoch and test_loss for all epochs\n", "# filter out all test_loss with value 0.0\n", "# plot test_loss\n", "epoch_loss = zip(df['epoch'], df['test_loss'])\n", "epoch_loss = list(filter(lambda x: x[1] != 0.0, epoch_loss))\n", "plt.plot([x[0] for x in epoch_loss], [x[1] for x in epoch_loss], label='test_loss', linestyle='dashed', color='black')\n", "\n", "# add markers for test_loss\n", "for x, y in epoch_loss:\n", " plt.plot(x, y, marker='o', markersize=3, color='black')\n", "\n", "plt.xlabel('epoch')\n", "plt.ylabel('loss')\n", "plt.legend()\n", "\n", "# add ticks every 5 epochs\n", "plt.xticks(range(0, 70, 5))\n", "\n", "# set y limits to 0\n", "plt.ylim(bottom=0)\n", "# reduce margins\n", "plt.tight_layout()\n", "# increase resolution\n", "plt.savefig('train_test_loss.png', dpi=300)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "epoch_cer = zip(df['epoch'], df['cer'])\n", "epoch_cer = list(filter(lambda x: x[1] != 0.0, epoch_cer))\n", "plt.plot([x[0] for x in epoch_cer], [x[1] for x in epoch_cer], label='cer', linestyle='solid', color='black')\n", "\n", "# add markers for cer\n", "for x, y in epoch_cer:\n", " plt.plot(x, y, marker='o', markersize=3, color='black')\n", " \n", "epoch_wer = zip(df['epoch'], df['wer'])\n", "epoch_wer = list(filter(lambda x: x[1] != 0.0, epoch_wer))\n", "plt.plot([x[0] for x in epoch_wer], [x[1] for x in epoch_wer], label='wer', linestyle='dashed', color='black')\n", "\n", "# add markers for wer\n", "for x, y in epoch_wer:\n", " plt.plot(x, y, marker='o', markersize=3, color='black')\n", " \n", "# set y limits to 0 and 1\n", "plt.ylim(bottom=0, top=1)\n", "plt.xlabel('epoch')\n", "plt.ylabel('error rate')\n", "plt.legend()\n", "# reduce margins\n", "plt.tight_layout()\n", "\n", "# add ticks every 5 epochs\n", "plt.xticks(range(0, 70, 5))\n", "\n", "# add ticks every 0.1 \n", "plt.yticks([x/10 for x in range(0, 11, 1)])\n", "\n", "# increase resolution\n", "plt.savefig('cer_wer.png', dpi=300)" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }