diff options
author | Pherkel | 2023-09-18 12:44:40 +0200 |
---|---|---|
committer | Pherkel | 2023-09-18 12:44:40 +0200 |
commit | d5689047fa7062b284d13271bda39013dcf6150f (patch) | |
tree | bd1a843abda1929b826d9441df3ddc3db5cbec29 /swr2_asr/utils | |
parent | 9475900a1085b8277808b0a0b1555c59f7eb6d36 (diff) | |
parent | e06227289ad9d2fa45c736c771d859e9911b9a11 (diff) |
Merge branch 'decoder' of github.com:Algo-Boys/SWR2-cool-projekt into decoder
Diffstat (limited to 'swr2_asr/utils')
-rw-r--r-- | swr2_asr/utils/visualization.py | 36 |
1 files changed, 26 insertions, 10 deletions
diff --git a/swr2_asr/utils/visualization.py b/swr2_asr/utils/visualization.py index a55d0d5..b288c5a 100644 --- a/swr2_asr/utils/visualization.py +++ b/swr2_asr/utils/visualization.py @@ -4,19 +4,35 @@ import matplotlib.pyplot as plt import torch -def plot(epochs, path): +def plot(path): """Plots the losses over the epochs""" - losses = [] + train_losses = [] test_losses = [] cers = [] wers = [] - for epoch in range(1, epochs + 1): - current_state = torch.load(path + str(epoch)) - losses.append(current_state["loss"]) - test_losses.append(current_state["test_loss"]) - cers.append(current_state["avg_cer"]) - wers.append(current_state["avg_wer"]) - plt.plot(losses) - plt.plot(test_losses) + epoch = 5 + while True: + try: + current_state = torch.load(path + str(epoch), map_location=torch.device("cpu")) + except FileNotFoundError: + break + train_losses.append((epoch, current_state["train_loss"].item())) + test_losses.append((epoch, current_state["test_loss"])) + cers.append((epoch, current_state["avg_cer"])) + wers.append((epoch, current_state["avg_wer"])) + epoch += 5 + + plt.plot(*zip(*train_losses), label="train_loss") + plt.plot(*zip(*test_losses), label="test_loss") + plt.plot(*zip(*cers), label="cer") + plt.plot(*zip(*wers), label="wer") + plt.xlabel("epoch") + plt.ylabel("score") + plt.title("Model performance for 5n epochs") + plt.legend() plt.savefig("losses.svg") + + +if __name__ == "__main__": + plot("data/runs/epoch") |