diff options
author | Marvin Borner | 2023-07-02 17:31:01 +0200 |
---|---|---|
committer | Marvin Borner | 2023-07-02 17:31:01 +0200 |
commit | 09b10b5d7eb82781de7a6452e6227404db73b9c1 (patch) | |
tree | 71a82774757720c3666cf4f596cb8c8c8e56c6a2 /main.py | |
parent | eb3c44379bd531c9a9254867ff63e90a4e50d405 (diff) |
Added basic testing
Diffstat (limited to 'main.py')
-rw-r--r-- | main.py | 48 |
1 files changed, 34 insertions, 14 deletions
@@ -12,7 +12,7 @@ import torch.nn.functional as F CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ' " EPOCHS = 25 -SEGMENTS = 15 +FEATURES = 13 # idk print("getting datasets...") @@ -28,7 +28,7 @@ print("got datasets!") def preprocess(data): - transform = T.MFCC(sample_rate=16000, n_mfcc=SEGMENTS) + transform = T.MFCC(sample_rate=16000, n_mfcc=FEATURES) inputs = [] targets = [] @@ -75,12 +75,12 @@ class Model(nn.Module): return x +model = Model(input_size=FEATURES, hidden_size=128, output_size=len(CHARS) + 1) + + def train(): print("training model...") - model = Model( - input_size=SEGMENTS, hidden_size=128, output_size=len(CHARS) + 1 - ) criterion = nn.CTCLoss(blank=len(CHARS)) optimizer = optim.Adam(model.parameters()) @@ -104,19 +104,39 @@ def train(): print("model trained!") +def ctc_decode(outputs, labels, label_lengths): + arg_maxes = torch.argmax(outputs, dim=2) + inferred = "" + target = "" + for i, args in enumerate(arg_maxes): + target += "".join( + [CHARS[int(c)] for c in labels[i][: label_lengths[i]].tolist()] + ) + + decode = [] + for j, ind in enumerate(args): + if ind != len(CHARS): + if j != 0 and ind == args[j - 1]: + continue + decode.append(ind.item()) + inferred += "".join([CHARS[c] for c in decode]) + return inferred, target + + def test(): - model = torch.load("target/model-final.ckpt") - correct = 0 - total = 0 + model.load_state_dict(torch.load("target/model-final.ckpt")) + # TODO: Calculate accuracy using string difference functions with torch.no_grad(): - for i, (inputs, targets) in enumerate(test_loader): + for i, (inputs, targets, input_lengths, target_lengths) in enumerate( + test_loader + ): outputs = model(inputs) - decoded_preds = ctc_decode(outputs) - total += targets.size(0) - correct += (decoded_preds == targets).sum().item() - - print("Test Accuracy: %d %%" % (100 * correct / total)) + inferred, target = ctc_decode(outputs, targets, target_lengths) + print("\n=========================================\n") + print("inferred: ", inferred) + print("") + print("target: ", target) def usage(): |