blob: b288c5aa49838e2b999a2baea1d18cdf887eb7c3 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
|
"""Utilities for visualizing the training process and results."""
import matplotlib.pyplot as plt
import torch
def plot(path):
"""Plots the losses over the epochs"""
train_losses = []
test_losses = []
cers = []
wers = []
epoch = 5
while True:
try:
current_state = torch.load(path + str(epoch), map_location=torch.device("cpu"))
except FileNotFoundError:
break
train_losses.append((epoch, current_state["train_loss"].item()))
test_losses.append((epoch, current_state["test_loss"]))
cers.append((epoch, current_state["avg_cer"]))
wers.append((epoch, current_state["avg_wer"]))
epoch += 5
plt.plot(*zip(*train_losses), label="train_loss")
plt.plot(*zip(*test_losses), label="test_loss")
plt.plot(*zip(*cers), label="cer")
plt.plot(*zip(*wers), label="wer")
plt.xlabel("epoch")
plt.ylabel("score")
plt.title("Model performance for 5n epochs")
plt.legend()
plt.savefig("losses.svg")
if __name__ == "__main__":
plot("data/runs/epoch")
|