blob: 80f942aa369053b92af79cb6e3535fcb6e515d16 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
|
"""Utilities for visualizing the training process and results."""
import matplotlib.pyplot as plt
import torch
def plot(epochs, path):
"""Plots the losses over the epochs"""
losses = list()
test_losses = list()
cers = list()
wers = list()
for epoch in range(1, epochs + 1):
current_state = torch.load(path + str(epoch))
losses.append(current_state["loss"])
test_losses.append(current_state["test_loss"])
cers.append(current_state["avg_cer"])
wers.append(current_state["avg_wer"])
plt.plot(losses)
plt.plot(test_losses)
plt.savefig("losses.svg")
|