aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/utils/visualization.py
diff options
context:
space:
mode:
authorPherkel2023-09-12 14:19:15 +0200
committerGitHub2023-09-12 14:19:15 +0200
commit7a9a6c783e69b5a537a3d3f5bfe8d5fdc656c807 (patch)
tree0725631b9b68aeb65b292420a15941dcfa3fc04f /swr2_asr/utils/visualization.py
parentf9846193289c81d89342b6a36e951605c2cfa189 (diff)
parent7b71dab87591e04d874cd636614450b0e65e3f2b (diff)
Merge pull request #37 from Algo-Boys/fix/ultimate
Fix/ultimate
Diffstat (limited to 'swr2_asr/utils/visualization.py')
-rw-r--r--swr2_asr/utils/visualization.py22
1 files changed, 22 insertions, 0 deletions
diff --git a/swr2_asr/utils/visualization.py b/swr2_asr/utils/visualization.py
new file mode 100644
index 0000000..a55d0d5
--- /dev/null
+++ b/swr2_asr/utils/visualization.py
@@ -0,0 +1,22 @@
+"""Utilities for visualizing the training process and results."""
+
+import matplotlib.pyplot as plt
+import torch
+
+
+def plot(epochs, path):
+ """Plots the losses over the epochs"""
+ losses = []
+ test_losses = []
+ cers = []
+ wers = []
+ for epoch in range(1, epochs + 1):
+ current_state = torch.load(path + str(epoch))
+ losses.append(current_state["loss"])
+ test_losses.append(current_state["test_loss"])
+ cers.append(current_state["avg_cer"])
+ wers.append(current_state["avg_wer"])
+
+ plt.plot(losses)
+ plt.plot(test_losses)
+ plt.savefig("losses.svg")