diff options
author | Pherkel | 2023-09-18 18:11:33 +0200 |
---|---|---|
committer | Pherkel | 2023-09-18 18:11:33 +0200 |
commit | d44cf7b1cab683a8aa3876619c82226f4e6d6f3b (patch) | |
tree | dc5a0567d5ff939320c9737e1d66e9e83f0f534c /plots.ipynb | |
parent | c09ff76ba6f4c5dd5de64a401efcd27449150aec (diff) |
fix
Diffstat (limited to 'plots.ipynb')
-rw-r--r-- | plots.ipynb | 131 |
1 files changed, 131 insertions, 0 deletions
diff --git a/plots.ipynb b/plots.ipynb new file mode 100644 index 0000000..716834a --- /dev/null +++ b/plots.ipynb @@ -0,0 +1,131 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import pandas as pd" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# load csv with colmns epoch, train_loss, test_loss, cer, wer\n", + "# test_loss, cer, wer should not be plotted if they are 0.0\n", + "# plot train_loss and test_loss in one plot\n", + "# plot cer and wer in one plot\n", + " \n", + "# save plots as png\n", + "\n", + "csv_path = \"metrics.csv\"\n", + "\n", + "# load csv\n", + "df = pd.read_csv(csv_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# plot train_loss and test_loss\n", + "# do not use colors, distinguis by line style. use solid for train_loss and dashed for test_loss\n", + "plt.plot(df['epoch'], df['train_loss'], label='train_loss', linestyle='solid', color='black')\n", + "\n", + "# create zip with epoch and test_loss for all epochs\n", + "# filter out all test_loss with value 0.0\n", + "# plot test_loss\n", + "epoch_loss = zip(df['epoch'], df['test_loss'])\n", + "epoch_loss = list(filter(lambda x: x[1] != 0.0, epoch_loss))\n", + "plt.plot([x[0] for x in epoch_loss], [x[1] for x in epoch_loss], label='test_loss', linestyle='dashed', color='black')\n", + "\n", + "# add markers for test_loss\n", + "for x, y in epoch_loss:\n", + " plt.plot(x, y, marker='o', markersize=3, color='black')\n", + "\n", + "plt.xlabel('epoch')\n", + "plt.ylabel('loss')\n", + "plt.legend()\n", + "\n", + "# add ticks every 5 epochs\n", + "plt.xticks(range(0, 70, 5))\n", + "\n", + "# set y limits to 0\n", + "plt.ylim(bottom=0)\n", + "# reduce margins\n", + "plt.tight_layout()\n", + "# increase resolution\n", + "plt.savefig('train_test_loss.png', dpi=300)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "epoch_cer = zip(df['epoch'], df['cer'])\n", + "epoch_cer = list(filter(lambda x: x[1] != 0.0, epoch_cer))\n", + "plt.plot([x[0] for x in epoch_cer], [x[1] for x in epoch_cer], label='cer', linestyle='solid', color='black')\n", + "\n", + "# add markers for cer\n", + "for x, y in epoch_cer:\n", + " plt.plot(x, y, marker='o', markersize=3, color='black')\n", + " \n", + "epoch_wer = zip(df['epoch'], df['wer'])\n", + "epoch_wer = list(filter(lambda x: x[1] != 0.0, epoch_wer))\n", + "plt.plot([x[0] for x in epoch_wer], [x[1] for x in epoch_wer], label='wer', linestyle='dashed', color='black')\n", + "\n", + "# add markers for wer\n", + "for x, y in epoch_wer:\n", + " plt.plot(x, y, marker='o', markersize=3, color='black')\n", + " \n", + "# set y limits to 0 and 1\n", + "plt.ylim(bottom=0, top=1)\n", + "plt.xlabel('epoch')\n", + "plt.ylabel('error rate')\n", + "plt.legend()\n", + "# reduce margins\n", + "plt.tight_layout()\n", + "\n", + "# add ticks every 5 epochs\n", + "plt.xticks(range(0, 70, 5))\n", + "\n", + "# add ticks every 0.1 \n", + "plt.yticks([x/10 for x in range(0, 11, 1)])\n", + "\n", + "# increase resolution\n", + "plt.savefig('cer_wer.png', dpi=300)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} |