aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/utils/visualization.py
blob: b288c5aa49838e2b999a2baea1d18cdf887eb7c3 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""Utilities for visualizing the training process and results."""

import matplotlib.pyplot as plt
import torch


def plot(path):
    """Plots the losses over the epochs"""
    train_losses = []
    test_losses = []
    cers = []
    wers = []

    epoch = 5
    while True:
        try:
            current_state = torch.load(path + str(epoch), map_location=torch.device("cpu"))
        except FileNotFoundError:
            break
        train_losses.append((epoch, current_state["train_loss"].item()))
        test_losses.append((epoch, current_state["test_loss"]))
        cers.append((epoch, current_state["avg_cer"]))
        wers.append((epoch, current_state["avg_wer"]))
        epoch += 5

    plt.plot(*zip(*train_losses), label="train_loss")
    plt.plot(*zip(*test_losses), label="test_loss")
    plt.plot(*zip(*cers), label="cer")
    plt.plot(*zip(*wers), label="wer")
    plt.xlabel("epoch")
    plt.ylabel("score")
    plt.title("Model performance for 5n epochs")
    plt.legend()
    plt.savefig("losses.svg")


if __name__ == "__main__":
    plot("data/runs/epoch")