diff options
author | Marvin Borner | 2023-09-16 15:55:41 +0200 |
---|---|---|
committer | JoJoBarthold2 | 2023-09-18 12:29:46 +0200 |
commit | 0d38af95f0058875d42dd261a287856ba84d3ce6 (patch) | |
tree | 60c7c0affddfb417ce6d66e11d7b4324a4c80e3d /swr2_asr/utils | |
parent | 0f14789f1c33d55dc270bcd154201cce2c4d516e (diff) |
Added better visualization
Diffstat (limited to 'swr2_asr/utils')
-rw-r--r-- | swr2_asr/utils/visualization.py | 36 |
1 files changed, 26 insertions, 10 deletions
diff --git a/swr2_asr/utils/visualization.py b/swr2_asr/utils/visualization.py index a55d0d5..b288c5a 100644 --- a/swr2_asr/utils/visualization.py +++ b/swr2_asr/utils/visualization.py @@ -4,19 +4,35 @@ import matplotlib.pyplot as plt import torch -def plot(epochs, path): +def plot(path): """Plots the losses over the epochs""" - losses = [] + train_losses = [] test_losses = [] cers = [] wers = [] - for epoch in range(1, epochs + 1): - current_state = torch.load(path + str(epoch)) - losses.append(current_state["loss"]) - test_losses.append(current_state["test_loss"]) - cers.append(current_state["avg_cer"]) - wers.append(current_state["avg_wer"]) - plt.plot(losses) - plt.plot(test_losses) + epoch = 5 + while True: + try: + current_state = torch.load(path + str(epoch), map_location=torch.device("cpu")) + except FileNotFoundError: + break + train_losses.append((epoch, current_state["train_loss"].item())) + test_losses.append((epoch, current_state["test_loss"])) + cers.append((epoch, current_state["avg_cer"])) + wers.append((epoch, current_state["avg_wer"])) + epoch += 5 + + plt.plot(*zip(*train_losses), label="train_loss") + plt.plot(*zip(*test_losses), label="test_loss") + plt.plot(*zip(*cers), label="cer") + plt.plot(*zip(*wers), label="wer") + plt.xlabel("epoch") + plt.ylabel("score") + plt.title("Model performance for 5n epochs") + plt.legend() plt.savefig("losses.svg") + + +if __name__ == "__main__": + plot("data/runs/epoch") |