{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import pandas as pd" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# load csv with colmns epoch, train_loss, test_loss, cer, wer\n", "# test_loss, cer, wer should not be plotted if they are 0.0\n", "# plot train_loss and test_loss in one plot\n", "# plot cer and wer in one plot\n", " \n", "# save plots as png\n", "\n", "csv_path = \"metrics.csv\"\n", "\n", "# load csv\n", "df = pd.read_csv(csv_path)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# plot train_loss and test_loss\n", "# do not use colors, distinguis by line style. use solid for train_loss and dashed for test_loss\n", "plt.plot(df['epoch'], df['train_loss'], label='train_loss', linestyle='solid', color='black')\n", "\n", "# create zip with epoch and test_loss for all epochs\n", "# filter out all test_loss with value 0.0\n", "# plot test_loss\n", "epoch_loss = zip(df['epoch'], df['test_loss'])\n", "epoch_loss = list(filter(lambda x: x[1] != 0.0, epoch_loss))\n", "plt.plot([x[0] for x in epoch_loss], [x[1] for x in epoch_loss], label='test_loss', linestyle='dashed', color='black')\n", "\n", "# add markers for test_loss\n", "for x, y in epoch_loss:\n", " plt.plot(x, y, marker='o', markersize=3, color='black')\n", "\n", "plt.xlabel('epoch')\n", "plt.ylabel('loss')\n", "plt.legend()\n", "\n", "# add ticks every 5 epochs\n", "plt.xticks(range(0, 70, 5))\n", "\n", "# set y limits to 0\n", "plt.ylim(bottom=0)\n", "# reduce margins\n", "plt.tight_layout()\n", "# increase resolution\n", "plt.savefig('train_test_loss.png', dpi=300)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "epoch_cer = zip(df['epoch'], df['cer'])\n", "epoch_cer = list(filter(lambda x: x[1] != 0.0, epoch_cer))\n", "plt.plot([x[0] for x in epoch_cer], [x[1] for x in epoch_cer], label='cer', linestyle='solid', color='black')\n", "\n", "# add markers for cer\n", "for x, y in epoch_cer:\n", " plt.plot(x, y, marker='o', markersize=3, color='black')\n", " \n", "epoch_wer = zip(df['epoch'], df['wer'])\n", "epoch_wer = list(filter(lambda x: x[1] != 0.0, epoch_wer))\n", "plt.plot([x[0] for x in epoch_wer], [x[1] for x in epoch_wer], label='wer', linestyle='dashed', color='black')\n", "\n", "# add markers for wer\n", "for x, y in epoch_wer:\n", " plt.plot(x, y, marker='o', markersize=3, color='black')\n", " \n", "# set y limits to 0 and 1\n", "plt.ylim(bottom=0, top=1)\n", "plt.xlabel('epoch')\n", "plt.ylabel('error rate')\n", "plt.legend()\n", "# reduce margins\n", "plt.tight_layout()\n", "\n", "# add ticks every 5 epochs\n", "plt.xticks(range(0, 70, 5))\n", "\n", "# add ticks every 0.1 \n", "plt.yticks([x/10 for x in range(0, 11, 1)])\n", "\n", "# increase resolution\n", "plt.savefig('cer_wer.png', dpi=300)" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }