diff options
Diffstat (limited to 'plots.ipynb')
-rw-r--r-- | plots.ipynb | 42 |
1 files changed, 34 insertions, 8 deletions
diff --git a/plots.ipynb b/plots.ipynb index 716834a..252eb0d 100644 --- a/plots.ipynb +++ b/plots.ipynb @@ -6,8 +6,10 @@ "metadata": {}, "outputs": [], "source": [ + "import os\n", "import matplotlib.pyplot as plt\n", - "import pandas as pd" + "import pandas as pd\n", + "import torch" ] }, { @@ -16,14 +18,38 @@ "metadata": {}, "outputs": [], "source": [ - "# load csv with colmns epoch, train_loss, test_loss, cer, wer\n", - "# test_loss, cer, wer should not be plotted if they are 0.0\n", - "# plot train_loss and test_loss in one plot\n", - "# plot cer and wer in one plot\n", - " \n", - "# save plots as png\n", + "# load model checkpoints to get data\n", + "CHECKPOINTS_DIR = 'data'\n", + "CHECKPOINTS_PREFIX = 'epoch'\n", "\n", - "csv_path = \"metrics.csv\"\n", + "rows = []\n", + "\n", + "epoch = 67\n", + "while True:\n", + " ckp_path = os.path.join(CHECKPOINTS_DIR, CHECKPOINTS_PREFIX + str(epoch))\n", + " print(ckp_path)\n", + " try:\n", + " current_state = torch.load(ckp_path, map_location=torch.device(\"cpu\"))\n", + " except FileNotFoundError:\n", + " break\n", + " \n", + " rows.append({\"epoch\": epoch, \"train_loss\": current_state[\"train_loss\"].item(), \"test_loss\": current_state[\"test_loss\"], \"avg_cer\": current_state[\"avg_cer\"], \"avg_wer\": current_state[\"avg_wer\"]})\n", + " \n", + " epoch += 1\n", + "\n", + "\n", + "# create dataframe, then csv from dataframe\n", + "df = pd.DataFrame(rows)\n", + "df.to_csv(\"data/losses.csv\", index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "csv_path = \"losses.csv\"\n", "\n", "# load csv\n", "df = pd.read_csv(csv_path)" |