aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/utils
diff options
context:
space:
mode:
authorPherkel2023-09-18 12:44:40 +0200
committerPherkel2023-09-18 12:44:40 +0200
commitd5689047fa7062b284d13271bda39013dcf6150f (patch)
treebd1a843abda1929b826d9441df3ddc3db5cbec29 /swr2_asr/utils
parent9475900a1085b8277808b0a0b1555c59f7eb6d36 (diff)
parente06227289ad9d2fa45c736c771d859e9911b9a11 (diff)
Merge branch 'decoder' of github.com:Algo-Boys/SWR2-cool-projekt into decoder
Diffstat (limited to 'swr2_asr/utils')
-rw-r--r--swr2_asr/utils/visualization.py36
1 files changed, 26 insertions, 10 deletions
diff --git a/swr2_asr/utils/visualization.py b/swr2_asr/utils/visualization.py
index a55d0d5..b288c5a 100644
--- a/swr2_asr/utils/visualization.py
+++ b/swr2_asr/utils/visualization.py
@@ -4,19 +4,35 @@ import matplotlib.pyplot as plt
import torch
-def plot(epochs, path):
+def plot(path):
"""Plots the losses over the epochs"""
- losses = []
+ train_losses = []
test_losses = []
cers = []
wers = []
- for epoch in range(1, epochs + 1):
- current_state = torch.load(path + str(epoch))
- losses.append(current_state["loss"])
- test_losses.append(current_state["test_loss"])
- cers.append(current_state["avg_cer"])
- wers.append(current_state["avg_wer"])
- plt.plot(losses)
- plt.plot(test_losses)
+ epoch = 5
+ while True:
+ try:
+ current_state = torch.load(path + str(epoch), map_location=torch.device("cpu"))
+ except FileNotFoundError:
+ break
+ train_losses.append((epoch, current_state["train_loss"].item()))
+ test_losses.append((epoch, current_state["test_loss"]))
+ cers.append((epoch, current_state["avg_cer"]))
+ wers.append((epoch, current_state["avg_wer"]))
+ epoch += 5
+
+ plt.plot(*zip(*train_losses), label="train_loss")
+ plt.plot(*zip(*test_losses), label="test_loss")
+ plt.plot(*zip(*cers), label="cer")
+ plt.plot(*zip(*wers), label="wer")
+ plt.xlabel("epoch")
+ plt.ylabel("score")
+ plt.title("Model performance for 5n epochs")
+ plt.legend()
plt.savefig("losses.svg")
+
+
+if __name__ == "__main__":
+ plot("data/runs/epoch")