diff options
author | Pherkel | 2023-09-11 14:49:28 +0200 |
---|---|---|
committer | Pherkel | 2023-09-11 14:49:28 +0200 |
commit | 9dc3bc07424908dd7cf3f052708f506fd58b6e2c (patch) | |
tree | cd45dc9b70977530669c271c09025246ebbb9fef /swr2_asr/utils/visualization.py | |
parent | 01fae2b5e395e84db6a7e9819b6f98777c46e845 (diff) |
refactor utilities (data, vis, tokenizer)
Diffstat (limited to 'swr2_asr/utils/visualization.py')
-rw-r--r-- | swr2_asr/utils/visualization.py | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/swr2_asr/utils/visualization.py b/swr2_asr/utils/visualization.py new file mode 100644 index 0000000..80f942a --- /dev/null +++ b/swr2_asr/utils/visualization.py @@ -0,0 +1,22 @@ +"""Utilities for visualizing the training process and results.""" + +import matplotlib.pyplot as plt +import torch + + +def plot(epochs, path): + """Plots the losses over the epochs""" + losses = list() + test_losses = list() + cers = list() + wers = list() + for epoch in range(1, epochs + 1): + current_state = torch.load(path + str(epoch)) + losses.append(current_state["loss"]) + test_losses.append(current_state["test_loss"]) + cers.append(current_state["avg_cer"]) + wers.append(current_state["avg_wer"]) + + plt.plot(losses) + plt.plot(test_losses) + plt.savefig("losses.svg") |