1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
|
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"import torch"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# load model checkpoints to get data\n",
"CHECKPOINTS_DIR = 'data'\n",
"CHECKPOINTS_PREFIX = 'epoch'\n",
"\n",
"rows = []\n",
"\n",
"epoch = 67\n",
"while True:\n",
" ckp_path = os.path.join(CHECKPOINTS_DIR, CHECKPOINTS_PREFIX + str(epoch))\n",
" print(ckp_path)\n",
" try:\n",
" current_state = torch.load(ckp_path, map_location=torch.device(\"cpu\"))\n",
" except FileNotFoundError:\n",
" break\n",
" \n",
" rows.append({\"epoch\": epoch, \"train_loss\": current_state[\"train_loss\"].item(), \"test_loss\": current_state[\"test_loss\"], \"avg_cer\": current_state[\"avg_cer\"], \"avg_wer\": current_state[\"avg_wer\"]})\n",
" \n",
" epoch += 1\n",
"\n",
"\n",
"# create dataframe, then csv from dataframe\n",
"df = pd.DataFrame(rows)\n",
"df.to_csv(\"data/losses.csv\", index=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"csv_path = \"losses.csv\"\n",
"\n",
"# load csv\n",
"df = pd.read_csv(csv_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# plot train_loss and test_loss\n",
"# do not use colors, distinguis by line style. use solid for train_loss and dashed for test_loss\n",
"plt.plot(df['epoch'], df['train_loss'], label='train_loss', linestyle='solid', color='black')\n",
"\n",
"# create zip with epoch and test_loss for all epochs\n",
"# filter out all test_loss with value 0.0\n",
"# plot test_loss\n",
"epoch_loss = zip(df['epoch'], df['test_loss'])\n",
"epoch_loss = list(filter(lambda x: x[1] != 0.0, epoch_loss))\n",
"plt.plot([x[0] for x in epoch_loss], [x[1] for x in epoch_loss], label='test_loss', linestyle='dashed', color='black')\n",
"\n",
"# add markers for test_loss\n",
"for x, y in epoch_loss:\n",
" plt.plot(x, y, marker='o', markersize=3, color='black')\n",
"\n",
"plt.xlabel('epoch')\n",
"plt.ylabel('loss')\n",
"plt.legend()\n",
"\n",
"# add ticks every 5 epochs\n",
"plt.xticks(range(0, 70, 5))\n",
"\n",
"# set y limits to 0\n",
"plt.ylim(bottom=0)\n",
"# reduce margins\n",
"plt.tight_layout()\n",
"# increase resolution\n",
"plt.savefig('train_test_loss.png', dpi=300)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"epoch_cer = zip(df['epoch'], df['cer'])\n",
"epoch_cer = list(filter(lambda x: x[1] != 0.0, epoch_cer))\n",
"plt.plot([x[0] for x in epoch_cer], [x[1] for x in epoch_cer], label='cer', linestyle='solid', color='black')\n",
"\n",
"# add markers for cer\n",
"for x, y in epoch_cer:\n",
" plt.plot(x, y, marker='o', markersize=3, color='black')\n",
" \n",
"epoch_wer = zip(df['epoch'], df['wer'])\n",
"epoch_wer = list(filter(lambda x: x[1] != 0.0, epoch_wer))\n",
"plt.plot([x[0] for x in epoch_wer], [x[1] for x in epoch_wer], label='wer', linestyle='dashed', color='black')\n",
"\n",
"# add markers for wer\n",
"for x, y in epoch_wer:\n",
" plt.plot(x, y, marker='o', markersize=3, color='black')\n",
" \n",
"# set y limits to 0 and 1\n",
"plt.ylim(bottom=0, top=1)\n",
"plt.xlabel('epoch')\n",
"plt.ylabel('error rate')\n",
"plt.legend()\n",
"# reduce margins\n",
"plt.tight_layout()\n",
"\n",
"# add ticks every 5 epochs\n",
"plt.xticks(range(0, 70, 5))\n",
"\n",
"# add ticks every 0.1 \n",
"plt.yticks([x/10 for x in range(0, 11, 1)])\n",
"\n",
"# increase resolution\n",
"plt.savefig('cer_wer.png', dpi=300)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
|