aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--config.philipp.yaml4
-rw-r--r--data/metrics.csv (renamed from metrics.csv)0
-rw-r--r--plots.ipynb42
3 files changed, 36 insertions, 10 deletions
diff --git a/config.philipp.yaml b/config.philipp.yaml
index cbeabfe..e329508 100644
--- a/config.philipp.yaml
+++ b/config.philipp.yaml
@@ -27,8 +27,8 @@ decoder:
beam_size: 500
beam_threshold: 150
n_best: 1
- lm_weight: 2
- word_score: -2
+ lm_weight: 1
+ word_score: 1
training:
learning_rate: 0.0005
diff --git a/metrics.csv b/data/metrics.csv
index 22b8cec..22b8cec 100644
--- a/metrics.csv
+++ b/data/metrics.csv
diff --git a/plots.ipynb b/plots.ipynb
index 716834a..252eb0d 100644
--- a/plots.ipynb
+++ b/plots.ipynb
@@ -6,8 +6,10 @@
"metadata": {},
"outputs": [],
"source": [
+ "import os\n",
"import matplotlib.pyplot as plt\n",
- "import pandas as pd"
+ "import pandas as pd\n",
+ "import torch"
]
},
{
@@ -16,14 +18,38 @@
"metadata": {},
"outputs": [],
"source": [
- "# load csv with colmns epoch, train_loss, test_loss, cer, wer\n",
- "# test_loss, cer, wer should not be plotted if they are 0.0\n",
- "# plot train_loss and test_loss in one plot\n",
- "# plot cer and wer in one plot\n",
- " \n",
- "# save plots as png\n",
+ "# load model checkpoints to get data\n",
+ "CHECKPOINTS_DIR = 'data'\n",
+ "CHECKPOINTS_PREFIX = 'epoch'\n",
"\n",
- "csv_path = \"metrics.csv\"\n",
+ "rows = []\n",
+ "\n",
+ "epoch = 67\n",
+ "while True:\n",
+ " ckp_path = os.path.join(CHECKPOINTS_DIR, CHECKPOINTS_PREFIX + str(epoch))\n",
+ " print(ckp_path)\n",
+ " try:\n",
+ " current_state = torch.load(ckp_path, map_location=torch.device(\"cpu\"))\n",
+ " except FileNotFoundError:\n",
+ " break\n",
+ " \n",
+ " rows.append({\"epoch\": epoch, \"train_loss\": current_state[\"train_loss\"].item(), \"test_loss\": current_state[\"test_loss\"], \"avg_cer\": current_state[\"avg_cer\"], \"avg_wer\": current_state[\"avg_wer\"]})\n",
+ " \n",
+ " epoch += 1\n",
+ "\n",
+ "\n",
+ "# create dataframe, then csv from dataframe\n",
+ "df = pd.DataFrame(rows)\n",
+ "df.to_csv(\"data/losses.csv\", index=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "csv_path = \"losses.csv\"\n",
"\n",
"# load csv\n",
"df = pd.read_csv(csv_path)"