diff options
Diffstat (limited to 'swr2_asr')
-rw-r--r-- | swr2_asr/inference.py | 7 | ||||
-rw-r--r-- | swr2_asr/utils/visualization.py | 36 |
2 files changed, 32 insertions, 11 deletions
diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py index 3f6a44e..3c58af0 100644 --- a/swr2_asr/inference.py +++ b/swr2_asr/inference.py @@ -66,7 +66,12 @@ def main(config_path: str, file_path: str) -> None: ).to(device) checkpoint = torch.load(inference_config["model_load_path"], map_location=device) - model.load_state_dict(checkpoint["model_state_dict"], strict=True) + + state_dict = { + k[len("module.") :] if k.startswith("module.") else k: v + for k, v in checkpoint["model_state_dict"].items() + } + model.load_state_dict(state_dict, strict=True) model.eval() waveform, sample_rate = torchaudio.load(file_path) # pylint: disable=no-member diff --git a/swr2_asr/utils/visualization.py b/swr2_asr/utils/visualization.py index a55d0d5..b288c5a 100644 --- a/swr2_asr/utils/visualization.py +++ b/swr2_asr/utils/visualization.py @@ -4,19 +4,35 @@ import matplotlib.pyplot as plt import torch -def plot(epochs, path): +def plot(path): """Plots the losses over the epochs""" - losses = [] + train_losses = [] test_losses = [] cers = [] wers = [] - for epoch in range(1, epochs + 1): - current_state = torch.load(path + str(epoch)) - losses.append(current_state["loss"]) - test_losses.append(current_state["test_loss"]) - cers.append(current_state["avg_cer"]) - wers.append(current_state["avg_wer"]) - plt.plot(losses) - plt.plot(test_losses) + epoch = 5 + while True: + try: + current_state = torch.load(path + str(epoch), map_location=torch.device("cpu")) + except FileNotFoundError: + break + train_losses.append((epoch, current_state["train_loss"].item())) + test_losses.append((epoch, current_state["test_loss"])) + cers.append((epoch, current_state["avg_cer"])) + wers.append((epoch, current_state["avg_wer"])) + epoch += 5 + + plt.plot(*zip(*train_losses), label="train_loss") + plt.plot(*zip(*test_losses), label="test_loss") + plt.plot(*zip(*cers), label="cer") + plt.plot(*zip(*wers), label="wer") + plt.xlabel("epoch") + plt.ylabel("score") + plt.title("Model performance for 5n epochs") + plt.legend() plt.savefig("losses.svg") + + +if __name__ == "__main__": + plot("data/runs/epoch") |