aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/utils/visualization.py
blob: 80f942aa369053b92af79cb6e3535fcb6e515d16 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
"""Utilities for visualizing the training process and results."""

import matplotlib.pyplot as plt
import torch


def plot(epochs, path):
    """Plots the losses over the epochs"""
    losses = list()
    test_losses = list()
    cers = list()
    wers = list()
    for epoch in range(1, epochs + 1):
        current_state = torch.load(path + str(epoch))
        losses.append(current_state["loss"])
        test_losses.append(current_state["test_loss"])
        cers.append(current_state["avg_cer"])
        wers.append(current_state["avg_wer"])

    plt.plot(losses)
    plt.plot(test_losses)
    plt.savefig("losses.svg")