aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/utils/visualization.py
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/utils/visualization.py')
-rw-r--r--swr2_asr/utils/visualization.py22
1 files changed, 22 insertions, 0 deletions
diff --git a/swr2_asr/utils/visualization.py b/swr2_asr/utils/visualization.py
new file mode 100644
index 0000000..80f942a
--- /dev/null
+++ b/swr2_asr/utils/visualization.py
@@ -0,0 +1,22 @@
+"""Utilities for visualizing the training process and results."""
+
+import matplotlib.pyplot as plt
+import torch
+
+
+def plot(epochs, path):
+ """Plots the losses over the epochs"""
+ losses = list()
+ test_losses = list()
+ cers = list()
+ wers = list()
+ for epoch in range(1, epochs + 1):
+ current_state = torch.load(path + str(epoch))
+ losses.append(current_state["loss"])
+ test_losses.append(current_state["test_loss"])
+ cers.append(current_state["avg_cer"])
+ wers.append(current_state["avg_wer"])
+
+ plt.plot(losses)
+ plt.plot(test_losses)
+ plt.savefig("losses.svg")