diff options
Diffstat (limited to 'main.py')
-rw-r--r-- | main.py | 99 |
1 files changed, 63 insertions, 36 deletions
@@ -12,17 +12,33 @@ import torch.nn.functional as F CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ' " EPOCHS = 25 -FEATURES = 13 # idk +FEATURES = 15 # idk + + +def usage(): + print(f"usage: {sys.argv[0]} <train|test|infer> [file]") + sys.exit(1) + + +if __name__ != "__main__": + sys.exit(1) # this isn't a library bro + +if len(sys.argv) < 2 or sys.argv[1] not in ["train", "test", "infer"]: + usage() + +MODE = sys.argv[1] print("getting datasets...") # download the datasets -train_dataset = torchaudio.datasets.LIBRISPEECH( - "./data", url="train-clean-100", download=True -) -test_dataset = torchaudio.datasets.LIBRISPEECH( - "./data", url="test-clean", download=True -) +if MODE == "train": + train_dataset = torchaudio.datasets.LIBRISPEECH( + "./data", url="train-clean-100", download=True + ) +if MODE == "test": + test_dataset = torchaudio.datasets.LIBRISPEECH( + "./data", url="test-clean", download=True + ) print("got datasets!") @@ -51,12 +67,14 @@ def preprocess(data): print("preprocessing datasets...") # load the datasets into batches -train_loader = data.DataLoader( - train_dataset, batch_size=10, shuffle=True, collate_fn=preprocess -) -test_loader = data.DataLoader( - test_dataset, batch_size=10, shuffle=True, collate_fn=preprocess -) +if MODE == "train": + train_loader = data.DataLoader( + train_dataset, batch_size=10, shuffle=True, collate_fn=preprocess + ) +if MODE == "test": + test_loader = data.DataLoader( + test_dataset, batch_size=10, shuffle=True, collate_fn=preprocess + ) print("datasets ready!") @@ -104,15 +122,10 @@ def train(): print("model trained!") -def ctc_decode(outputs, labels, label_lengths): +def ctc_decode(outputs): arg_maxes = torch.argmax(outputs, dim=2) inferred = "" - target = "" for i, args in enumerate(arg_maxes): - target += "".join( - [CHARS[int(c)] for c in labels[i][: label_lengths[i]].tolist()] - ) - decode = [] for j, ind in enumerate(args): if ind != len(CHARS): @@ -120,7 +133,16 @@ def ctc_decode(outputs, labels, label_lengths): continue decode.append(ind.item()) inferred += "".join([CHARS[c] for c in decode]) - return inferred, target + return inferred + + +def ctc_decode_target(labels, label_lengths): + target = "" + for i, args in enumerate(labels): + target += "".join( + [CHARS[int(c)] for c in labels[i][: label_lengths[i]].tolist()] + ) + return target def test(): @@ -132,25 +154,30 @@ def test(): test_loader ): outputs = model(inputs) - inferred, target = ctc_decode(outputs, targets, target_lengths) + inferred = ctc_decode(outputs) + target = ctc_decode_target(targets, target_lengths) print("\n=========================================\n") print("inferred: ", inferred) print("") print("target: ", target) -def usage(): - print(f"usage: {sys.argv[0]} <train|test>") - - -if __name__ == "__main__": - if len(sys.argv) != 2: - usage() - sys.exit(1) - - if sys.argv[1] == "train": - train() - elif sys.argv[1] == "test": - test() - else: - usage() +# TODO: Might be buggy? +def infer(file): + model.load_state_dict(torch.load("target/model-final.ckpt")) + audio, sr = torchaudio.load(file) + inputs = preprocess([(audio, sr, "")])[0] + with torch.no_grad(): + outputs = model(inputs) + inferred = ctc_decode(outputs) + print("inferred: ", inferred) + + +if MODE == "train": + train() +elif MODE == "test": + test() +elif MODE == "infer" and len(sys.argv) == 3: + infer(sys.argv[2]) +else: + usage() |