From e9721154ed94dbedb3443bac83a8e1e3f9d4b734 Mon Sep 17 00:00:00 2001 From: JoJoBarthold2 Date: Tue, 5 Sep 2023 10:25:28 +0200 Subject: model now saves validation_loss and cer and wer --- swr2_asr/train.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) (limited to 'swr2_asr') diff --git a/swr2_asr/train.py b/swr2_asr/train.py index 63deb72..6f3bc6c 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -253,7 +253,7 @@ def run( iter_meter, ) - test( + test_loss,avg_cer,avg_wer = test( model=model, device=device, test_loader=valid_loader, @@ -262,7 +262,12 @@ def run( ) print("saving epoch", str(epoch)) torch.save( - {"epoch": epoch, "model_state_dict": model.state_dict(), "loss": loss}, + {"epoch": epoch, + "model_state_dict": model.state_dict(), + "loss": loss, + "test_loss": test_loss, + "avg_cer": avg_cer, + "avg_wer": avg_wer}, path + str(epoch), ) -- cgit v1.2.3 From 3dd8dfcf3dd8d1e127e7f8fd68224c164d9f6fb1 Mon Sep 17 00:00:00 2001 From: JoJoBarthold2 Date: Tue, 5 Sep 2023 11:11:31 +0200 Subject: please add a good path for this --- swr2_asr/utils.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) (limited to 'swr2_asr') diff --git a/swr2_asr/utils.py b/swr2_asr/utils.py index 8a950ab..87d4f82 100644 --- a/swr2_asr/utils.py +++ b/swr2_asr/utils.py @@ -4,6 +4,7 @@ from enum import Enum from typing import TypedDict import numpy as np +import matplotlib.pyplot as plt import torch import torchaudio from tokenizers import Tokenizer @@ -314,3 +315,20 @@ if __name__ == "__main__": tok = Tokenizer.from_file("data/tokenizers/bpe_tokenizer_german_3000.json") dataset.set_tokenizer(tok) + + +def plot(epochs,path): + losses = list() + test_losses = list() + cers = list() + wers =list() + for epoch in range(1, epochs +1): + current_state = torch.load(path + str(epoch)) + losses.append(current_state["loss"]) + test_losses.append(current_state["test_loss"]) + cers.append(current_state["avg_cer"]) + wers.append(current_state["avg_wer"]) + + plt.plot(losses) + plt.plot(test_losses) + plt.savefig("losses.svg") \ No newline at end of file -- cgit v1.2.3