aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/utils/visualization.py
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/utils/visualization.py')
-rw-r--r--swr2_asr/utils/visualization.py8
1 files changed, 4 insertions, 4 deletions
diff --git a/swr2_asr/utils/visualization.py b/swr2_asr/utils/visualization.py
index 80f942a..a55d0d5 100644
--- a/swr2_asr/utils/visualization.py
+++ b/swr2_asr/utils/visualization.py
@@ -6,10 +6,10 @@ import torch
def plot(epochs, path):
"""Plots the losses over the epochs"""
- losses = list()
- test_losses = list()
- cers = list()
- wers = list()
+ losses = []
+ test_losses = []
+ cers = []
+ wers = []
for epoch in range(1, epochs + 1):
current_state = torch.load(path + str(epoch))
losses.append(current_state["loss"])