diff options
-rw-r--r-- | config.philipp.yaml | 4 | ||||
-rw-r--r-- | data/metrics.csv (renamed from metrics.csv) | 0 | ||||
-rw-r--r-- | plots.ipynb | 42 |
3 files changed, 36 insertions, 10 deletions
diff --git a/config.philipp.yaml b/config.philipp.yaml index cbeabfe..e329508 100644 --- a/config.philipp.yaml +++ b/config.philipp.yaml @@ -27,8 +27,8 @@ decoder: beam_size: 500 beam_threshold: 150 n_best: 1 - lm_weight: 2 - word_score: -2 + lm_weight: 1 + word_score: 1 training: learning_rate: 0.0005 diff --git a/metrics.csv b/data/metrics.csv index 22b8cec..22b8cec 100644 --- a/metrics.csv +++ b/data/metrics.csv diff --git a/plots.ipynb b/plots.ipynb index 716834a..252eb0d 100644 --- a/plots.ipynb +++ b/plots.ipynb @@ -6,8 +6,10 @@ "metadata": {}, "outputs": [], "source": [ + "import os\n", "import matplotlib.pyplot as plt\n", - "import pandas as pd" + "import pandas as pd\n", + "import torch" ] }, { @@ -16,14 +18,38 @@ "metadata": {}, "outputs": [], "source": [ - "# load csv with colmns epoch, train_loss, test_loss, cer, wer\n", - "# test_loss, cer, wer should not be plotted if they are 0.0\n", - "# plot train_loss and test_loss in one plot\n", - "# plot cer and wer in one plot\n", - " \n", - "# save plots as png\n", + "# load model checkpoints to get data\n", + "CHECKPOINTS_DIR = 'data'\n", + "CHECKPOINTS_PREFIX = 'epoch'\n", "\n", - "csv_path = \"metrics.csv\"\n", + "rows = []\n", + "\n", + "epoch = 67\n", + "while True:\n", + " ckp_path = os.path.join(CHECKPOINTS_DIR, CHECKPOINTS_PREFIX + str(epoch))\n", + " print(ckp_path)\n", + " try:\n", + " current_state = torch.load(ckp_path, map_location=torch.device(\"cpu\"))\n", + " except FileNotFoundError:\n", + " break\n", + " \n", + " rows.append({\"epoch\": epoch, \"train_loss\": current_state[\"train_loss\"].item(), \"test_loss\": current_state[\"test_loss\"], \"avg_cer\": current_state[\"avg_cer\"], \"avg_wer\": current_state[\"avg_wer\"]})\n", + " \n", + " epoch += 1\n", + "\n", + "\n", + "# create dataframe, then csv from dataframe\n", + "df = pd.DataFrame(rows)\n", + "df.to_csv(\"data/losses.csv\", index=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "csv_path = \"losses.csv\"\n", "\n", "# load csv\n", "df = pd.read_csv(csv_path)" |