diff options
Diffstat (limited to 'swr2_asr/utils.py')
-rw-r--r-- | swr2_asr/utils.py | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py index 8a950ab..87d4f82 100644 --- a/swr2_asr/utils.py +++ b/swr2_asr/utils.py @@ -4,6 +4,7 @@ from enum import Enum from typing import TypedDict import numpy as np +import matplotlib.pyplot as plt import torch import torchaudio from tokenizers import Tokenizer @@ -314,3 +315,20 @@ if __name__ == "__main__": tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") dataset.set_tokenizer(tok) + + +def plot(epochs,path): + losses = list() + test_losses = list() + cers = list() + wers =list() + for epoch in range(1, epochs +1): + current_state = torch.load(path + str(epoch)) + losses.append(current_state["loss"]) + test_losses.append(current_state["test_loss"]) + cers.append(current_state["avg_cer"]) + wers.append(current_state["avg_wer"]) + + plt.plot(losses) + plt.plot(test_losses) + plt.savefig("losses.svg")
\ No newline at end of file |