aboutsummaryrefslogtreecommitdiff
path: root/plots.ipynb
diff options
context:
space:
mode:
authorPherkel2023-09-18 18:11:33 +0200
committerPherkel2023-09-18 18:11:33 +0200
commitd44cf7b1cab683a8aa3876619c82226f4e6d6f3b (patch)
treedc5a0567d5ff939320c9737e1d66e9e83f0f534c /plots.ipynb
parentc09ff76ba6f4c5dd5de64a401efcd27449150aec (diff)
fix
Diffstat (limited to 'plots.ipynb')
-rw-r--r--plots.ipynb131
1 files changed, 131 insertions, 0 deletions
diff --git a/plots.ipynb b/plots.ipynb
new file mode 100644
index 0000000..716834a
--- /dev/null
+++ b/plots.ipynb
@@ -0,0 +1,131 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import matplotlib.pyplot as plt\n",
+ "import pandas as pd"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# load csv with colmns epoch, train_loss, test_loss, cer, wer\n",
+ "# test_loss, cer, wer should not be plotted if they are 0.0\n",
+ "# plot train_loss and test_loss in one plot\n",
+ "# plot cer and wer in one plot\n",
+ " \n",
+ "# save plots as png\n",
+ "\n",
+ "csv_path = \"metrics.csv\"\n",
+ "\n",
+ "# load csv\n",
+ "df = pd.read_csv(csv_path)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# plot train_loss and test_loss\n",
+ "# do not use colors, distinguis by line style. use solid for train_loss and dashed for test_loss\n",
+ "plt.plot(df['epoch'], df['train_loss'], label='train_loss', linestyle='solid', color='black')\n",
+ "\n",
+ "# create zip with epoch and test_loss for all epochs\n",
+ "# filter out all test_loss with value 0.0\n",
+ "# plot test_loss\n",
+ "epoch_loss = zip(df['epoch'], df['test_loss'])\n",
+ "epoch_loss = list(filter(lambda x: x[1] != 0.0, epoch_loss))\n",
+ "plt.plot([x[0] for x in epoch_loss], [x[1] for x in epoch_loss], label='test_loss', linestyle='dashed', color='black')\n",
+ "\n",
+ "# add markers for test_loss\n",
+ "for x, y in epoch_loss:\n",
+ " plt.plot(x, y, marker='o', markersize=3, color='black')\n",
+ "\n",
+ "plt.xlabel('epoch')\n",
+ "plt.ylabel('loss')\n",
+ "plt.legend()\n",
+ "\n",
+ "# add ticks every 5 epochs\n",
+ "plt.xticks(range(0, 70, 5))\n",
+ "\n",
+ "# set y limits to 0\n",
+ "plt.ylim(bottom=0)\n",
+ "# reduce margins\n",
+ "plt.tight_layout()\n",
+ "# increase resolution\n",
+ "plt.savefig('train_test_loss.png', dpi=300)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "epoch_cer = zip(df['epoch'], df['cer'])\n",
+ "epoch_cer = list(filter(lambda x: x[1] != 0.0, epoch_cer))\n",
+ "plt.plot([x[0] for x in epoch_cer], [x[1] for x in epoch_cer], label='cer', linestyle='solid', color='black')\n",
+ "\n",
+ "# add markers for cer\n",
+ "for x, y in epoch_cer:\n",
+ " plt.plot(x, y, marker='o', markersize=3, color='black')\n",
+ " \n",
+ "epoch_wer = zip(df['epoch'], df['wer'])\n",
+ "epoch_wer = list(filter(lambda x: x[1] != 0.0, epoch_wer))\n",
+ "plt.plot([x[0] for x in epoch_wer], [x[1] for x in epoch_wer], label='wer', linestyle='dashed', color='black')\n",
+ "\n",
+ "# add markers for wer\n",
+ "for x, y in epoch_wer:\n",
+ " plt.plot(x, y, marker='o', markersize=3, color='black')\n",
+ " \n",
+ "# set y limits to 0 and 1\n",
+ "plt.ylim(bottom=0, top=1)\n",
+ "plt.xlabel('epoch')\n",
+ "plt.ylabel('error rate')\n",
+ "plt.legend()\n",
+ "# reduce margins\n",
+ "plt.tight_layout()\n",
+ "\n",
+ "# add ticks every 5 epochs\n",
+ "plt.xticks(range(0, 70, 5))\n",
+ "\n",
+ "# add ticks every 0.1 \n",
+ "plt.yticks([x/10 for x in range(0, 11, 1)])\n",
+ "\n",
+ "# increase resolution\n",
+ "plt.savefig('cer_wer.png', dpi=300)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": ".venv",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.13"
+ },
+ "orig_nbformat": 4
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}