aboutsummaryrefslogtreecommitdiff
path: root/main.py
diff options
context:
space:
mode:
authorMarvin Borner2023-07-02 17:31:01 +0200
committerMarvin Borner2023-07-02 17:31:01 +0200
commit09b10b5d7eb82781de7a6452e6227404db73b9c1 (patch)
tree71a82774757720c3666cf4f596cb8c8c8e56c6a2 /main.py
parenteb3c44379bd531c9a9254867ff63e90a4e50d405 (diff)
Added basic testing
Diffstat (limited to 'main.py')
-rw-r--r--main.py48
1 files changed, 34 insertions, 14 deletions
diff --git a/main.py b/main.py
index f1caf26..1c87dd4 100644
--- a/main.py
+++ b/main.py
@@ -12,7 +12,7 @@ import torch.nn.functional as F
CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ' "
EPOCHS = 25
-SEGMENTS = 15
+FEATURES = 13 # idk
print("getting datasets...")
@@ -28,7 +28,7 @@ print("got datasets!")
def preprocess(data):
- transform = T.MFCC(sample_rate=16000, n_mfcc=SEGMENTS)
+ transform = T.MFCC(sample_rate=16000, n_mfcc=FEATURES)
inputs = []
targets = []
@@ -75,12 +75,12 @@ class Model(nn.Module):
return x
+model = Model(input_size=FEATURES, hidden_size=128, output_size=len(CHARS) + 1)
+
+
def train():
print("training model...")
- model = Model(
- input_size=SEGMENTS, hidden_size=128, output_size=len(CHARS) + 1
- )
criterion = nn.CTCLoss(blank=len(CHARS))
optimizer = optim.Adam(model.parameters())
@@ -104,19 +104,39 @@ def train():
print("model trained!")
+def ctc_decode(outputs, labels, label_lengths):
+ arg_maxes = torch.argmax(outputs, dim=2)
+ inferred = ""
+ target = ""
+ for i, args in enumerate(arg_maxes):
+ target += "".join(
+ [CHARS[int(c)] for c in labels[i][: label_lengths[i]].tolist()]
+ )
+
+ decode = []
+ for j, ind in enumerate(args):
+ if ind != len(CHARS):
+ if j != 0 and ind == args[j - 1]:
+ continue
+ decode.append(ind.item())
+ inferred += "".join([CHARS[c] for c in decode])
+ return inferred, target
+
+
def test():
- model = torch.load("target/model-final.ckpt")
- correct = 0
- total = 0
+ model.load_state_dict(torch.load("target/model-final.ckpt"))
+ # TODO: Calculate accuracy using string difference functions
with torch.no_grad():
- for i, (inputs, targets) in enumerate(test_loader):
+ for i, (inputs, targets, input_lengths, target_lengths) in enumerate(
+ test_loader
+ ):
outputs = model(inputs)
- decoded_preds = ctc_decode(outputs)
- total += targets.size(0)
- correct += (decoded_preds == targets).sum().item()
-
- print("Test Accuracy: %d %%" % (100 * correct / total))
+ inferred, target = ctc_decode(outputs, targets, target_lengths)
+ print("\n=========================================\n")
+ print("inferred: ", inferred)
+ print("")
+ print("target: ", target)
def usage():