aboutsummaryrefslogtreecommitdiff
path: root/plots.ipynb
diff options
context:
space:
mode:
Diffstat (limited to 'plots.ipynb')
-rw-r--r--plots.ipynb42
1 files changed, 34 insertions, 8 deletions
diff --git a/plots.ipynb b/plots.ipynb
index 716834a..252eb0d 100644
--- a/plots.ipynb
+++ b/plots.ipynb
@@ -6,8 +6,10 @@
"metadata": {},
"outputs": [],
"source": [
+ "import os\n",
"import matplotlib.pyplot as plt\n",
- "import pandas as pd"
+ "import pandas as pd\n",
+ "import torch"
]
},
{
@@ -16,14 +18,38 @@
"metadata": {},
"outputs": [],
"source": [
- "# load csv with colmns epoch, train_loss, test_loss, cer, wer\n",
- "# test_loss, cer, wer should not be plotted if they are 0.0\n",
- "# plot train_loss and test_loss in one plot\n",
- "# plot cer and wer in one plot\n",
- " \n",
- "# save plots as png\n",
+ "# load model checkpoints to get data\n",
+ "CHECKPOINTS_DIR = 'data'\n",
+ "CHECKPOINTS_PREFIX = 'epoch'\n",
"\n",
- "csv_path = \"metrics.csv\"\n",
+ "rows = []\n",
+ "\n",
+ "epoch = 67\n",
+ "while True:\n",
+ " ckp_path = os.path.join(CHECKPOINTS_DIR, CHECKPOINTS_PREFIX + str(epoch))\n",
+ " print(ckp_path)\n",
+ " try:\n",
+ " current_state = torch.load(ckp_path, map_location=torch.device(\"cpu\"))\n",
+ " except FileNotFoundError:\n",
+ " break\n",
+ " \n",
+ " rows.append({\"epoch\": epoch, \"train_loss\": current_state[\"train_loss\"].item(), \"test_loss\": current_state[\"test_loss\"], \"avg_cer\": current_state[\"avg_cer\"], \"avg_wer\": current_state[\"avg_wer\"]})\n",
+ " \n",
+ " epoch += 1\n",
+ "\n",
+ "\n",
+ "# create dataframe, then csv from dataframe\n",
+ "df = pd.DataFrame(rows)\n",
+ "df.to_csv(\"data/losses.csv\", index=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "csv_path = \"losses.csv\"\n",
"\n",
"# load csv\n",
"df = pd.read_csv(csv_path)"