blob: a55d0d59163f1fcbd4aeeae9c147eaf4346a4797 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
|
"""Utilities for visualizing the training process and results."""
import matplotlib.pyplot as plt
import torch
def plot(epochs, path):
"""Plots the losses over the epochs"""
losses = []
test_losses = []
cers = []
wers = []
for epoch in range(1, epochs + 1):
current_state = torch.load(path + str(epoch))
losses.append(current_state["loss"])
test_losses.append(current_state["test_loss"])
cers.append(current_state["avg_cer"])
wers.append(current_state["avg_wer"])
plt.plot(losses)
plt.plot(test_losses)
plt.savefig("losses.svg")
|