aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr
diff options
context:
space:
mode:
authorMarvin2023-09-05 13:33:13 +0200
committerGitHub2023-09-05 13:33:13 +0200
commit46b23fd90f5ef9c3126ee66473b012fa715da008 (patch)
tree158fbd0db3541d9c0c9497fe0b1bf6f7221cc7ff /swr2_asr
parent93e49e708fa59613406249069d31c2f6c8f2d2ab (diff)
parent3dd8dfcf3dd8d1e127e7f8fd68224c164d9f6fb1 (diff)
Merge pull request #29 from Algo-Boys/plot
Plot
Diffstat (limited to 'swr2_asr')
-rw-r--r--swr2_asr/train.py9
-rw-r--r--swr2_asr/utils.py18
2 files changed, 25 insertions, 2 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index 63deb72..6f3bc6c 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -253,7 +253,7 @@ def run(
iter_meter,
)
- test(
+ test_loss,avg_cer,avg_wer = test(
model=model,
device=device,
test_loader=valid_loader,
@@ -262,7 +262,12 @@ def run(
)
print("saving epoch", str(epoch))
torch.save(
- {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss},
+ {"epoch": epoch,
+ "model_state_dict": model.state_dict(),
+ "loss": loss,
+ "test_loss": test_loss,
+ "avg_cer": avg_cer,
+ "avg_wer": avg_wer},
path + str(epoch),
)
diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py
index 8a950ab..87d4f82 100644
--- a/swr2_asr/utils.py
+++ b/swr2_asr/utils.py
@@ -4,6 +4,7 @@ from enum import Enum
from typing import TypedDict
import numpy as np
+import matplotlib.pyplot as plt
import torch
import torchaudio
from tokenizers import Tokenizer
@@ -314,3 +315,20 @@ if __name__ == "__main__":
tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json")
dataset.set_tokenizer(tok)
+
+
+def plot(epochs,path):
+ losses = list()
+ test_losses = list()
+ cers = list()
+ wers =list()
+ for epoch in range(1, epochs +1):
+ current_state = torch.load(path + str(epoch))
+ losses.append(current_state["loss"])
+ test_losses.append(current_state["test_loss"])
+ cers.append(current_state["avg_cer"])
+ wers.append(current_state["avg_wer"])
+
+ plt.plot(losses)
+ plt.plot(test_losses)
+ plt.savefig("losses.svg") \ No newline at end of file