In [None]:
import os
import matplotlib.pyplot as plt
import pandas as pd
import torch

In [None]:
# load model checkpoints to get data
CHECKPOINTS_DIR = 'data'
CHECKPOINTS_PREFIX = 'epoch'

rows = []

epoch = 67
while True:
    ckp_path = os.path.join(CHECKPOINTS_DIR, CHECKPOINTS_PREFIX + str(epoch))
    print(ckp_path)
    try:
        current_state = torch.load(ckp_path, map_location=torch.device("cpu"))
    except FileNotFoundError:
        break
    
    rows.append({"epoch": epoch, "train_loss": current_state["train_loss"].item(), "test_loss": current_state["test_loss"], "avg_cer": current_state["avg_cer"], "avg_wer": current_state["avg_wer"]})
    
    epoch += 1


# create dataframe, then csv from dataframe
df = pd.DataFrame(rows)
df.to_csv("data/losses.csv", index=False)

In [None]:
csv_path = "losses.csv"

# load csv
df = pd.read_csv(csv_path)

In [None]:
# plot train_loss and test_loss
# do not use colors, distinguis by line style. use solid for train_loss and dashed for test_loss
plt.plot(df['epoch'], df['train_loss'], label='train_loss', linestyle='solid', color='black')

# create zip with epoch and test_loss for all epochs
# filter out all test_loss with value 0.0
# plot test_loss
epoch_loss = zip(df['epoch'], df['test_loss'])
epoch_loss = list(filter(lambda x: x[1] != 0.0, epoch_loss))
plt.plot([x[0] for x in epoch_loss], [x[1] for x in epoch_loss], label='test_loss', linestyle='dashed', color='black')

# add markers for test_loss
for x, y in epoch_loss:
    plt.plot(x, y, marker='o', markersize=3, color='black')

plt.xlabel('epoch')
plt.ylabel('loss')
plt.legend()

# add ticks every 5 epochs
plt.xticks(range(0, 70, 5))

# set y limits to 0
plt.ylim(bottom=0)
# reduce margins
plt.tight_layout()
# increase resolution
plt.savefig('train_test_loss.png', dpi=300)

In [None]:
epoch_cer = zip(df['epoch'], df['cer'])
epoch_cer = list(filter(lambda x: x[1] != 0.0, epoch_cer))
plt.plot([x[0] for x in epoch_cer], [x[1] for x in epoch_cer], label='cer', linestyle='solid', color='black')

# add markers for cer
for x, y in epoch_cer:
    plt.plot(x, y, marker='o', markersize=3, color='black')
    
epoch_wer = zip(df['epoch'], df['wer'])
epoch_wer = list(filter(lambda x: x[1] != 0.0, epoch_wer))
plt.plot([x[0] for x in epoch_wer], [x[1] for x in epoch_wer], label='wer', linestyle='dashed', color='black')

# add markers for wer
for x, y in epoch_wer:
    plt.plot(x, y, marker='o', markersize=3, color='black')
    
# set y limits to 0 and 1
plt.ylim(bottom=0, top=1)
plt.xlabel('epoch')
plt.ylabel('error rate')
plt.legend()
# reduce margins
plt.tight_layout()

# add ticks every 5 epochs
plt.xticks(range(0, 70, 5))

# add ticks every 0.1   
plt.yticks([x/10 for x in range(0, 11, 1)])

# increase resolution
plt.savefig('cer_wer.png', dpi=300)