diff options
-rw-r--r-- | main.py | 99 | ||||
-rw-r--r-- | readme.md | 20 |
2 files changed, 76 insertions, 43 deletions
@@ -12,17 +12,33 @@ import torch.nn.functional as F CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ' " EPOCHS = 25 -FEATURES = 13 # idk +FEATURES = 15 # idk + + +def usage(): + print(f"usage: {sys.argv[0]} <train|test|infer> [file]") + sys.exit(1) + + +if __name__ != "__main__": + sys.exit(1) # this isn't a library bro + +if len(sys.argv) < 2 or sys.argv[1] not in ["train", "test", "infer"]: + usage() + +MODE = sys.argv[1] print("getting datasets...") # download the datasets -train_dataset = torchaudio.datasets.LIBRISPEECH( - "./data", url="train-clean-100", download=True -) -test_dataset = torchaudio.datasets.LIBRISPEECH( - "./data", url="test-clean", download=True -) +if MODE == "train": + train_dataset = torchaudio.datasets.LIBRISPEECH( + "./data", url="train-clean-100", download=True + ) +if MODE == "test": + test_dataset = torchaudio.datasets.LIBRISPEECH( + "./data", url="test-clean", download=True + ) print("got datasets!") @@ -51,12 +67,14 @@ def preprocess(data): print("preprocessing datasets...") # load the datasets into batches -train_loader = data.DataLoader( - train_dataset, batch_size=10, shuffle=True, collate_fn=preprocess -) -test_loader = data.DataLoader( - test_dataset, batch_size=10, shuffle=True, collate_fn=preprocess -) +if MODE == "train": + train_loader = data.DataLoader( + train_dataset, batch_size=10, shuffle=True, collate_fn=preprocess + ) +if MODE == "test": + test_loader = data.DataLoader( + test_dataset, batch_size=10, shuffle=True, collate_fn=preprocess + ) print("datasets ready!") @@ -104,15 +122,10 @@ def train(): print("model trained!") -def ctc_decode(outputs, labels, label_lengths): +def ctc_decode(outputs): arg_maxes = torch.argmax(outputs, dim=2) inferred = "" - target = "" for i, args in enumerate(arg_maxes): - target += "".join( - [CHARS[int(c)] for c in labels[i][: label_lengths[i]].tolist()] - ) - decode = [] for j, ind in enumerate(args): if ind != len(CHARS): @@ -120,7 +133,16 @@ def ctc_decode(outputs, labels, label_lengths): continue decode.append(ind.item()) inferred += "".join([CHARS[c] for c in decode]) - return inferred, target + return inferred + + +def ctc_decode_target(labels, label_lengths): + target = "" + for i, args in enumerate(labels): + target += "".join( + [CHARS[int(c)] for c in labels[i][: label_lengths[i]].tolist()] + ) + return target def test(): @@ -132,25 +154,30 @@ def test(): test_loader ): outputs = model(inputs) - inferred, target = ctc_decode(outputs, targets, target_lengths) + inferred = ctc_decode(outputs) + target = ctc_decode_target(targets, target_lengths) print("\n=========================================\n") print("inferred: ", inferred) print("") print("target: ", target) -def usage(): - print(f"usage: {sys.argv[0]} <train|test>") - - -if __name__ == "__main__": - if len(sys.argv) != 2: - usage() - sys.exit(1) - - if sys.argv[1] == "train": - train() - elif sys.argv[1] == "test": - test() - else: - usage() +# TODO: Might be buggy? +def infer(file): + model.load_state_dict(torch.load("target/model-final.ckpt")) + audio, sr = torchaudio.load(file) + inputs = preprocess([(audio, sr, "")])[0] + with torch.no_grad(): + outputs = model(inputs) + inferred = ctc_decode(outputs) + print("inferred: ", inferred) + + +if MODE == "train": + train() +elif MODE == "test": + test() +elif MODE == "infer" and len(sys.argv) == 3: + infer(sys.argv[2]) +else: + usage() @@ -2,14 +2,20 @@ > spoken word recognition using CTC LSTMs -## Installation +## Instructions -- `python -m venv venv` -- `./venv/bin/pip install -r requirements.txt` -- `./venv/bin/python main.py train` -- `./venv/bin/python main.py test` +- Create a virtual environment: `python -m venv venv` +- Install the required packages: + `./venv/bin/pip install -r requirements.txt` +- Train the model: `./venv/bin/python main.py train` (takes a few hours + and needs around 20GB disk and 5GB memory) + - or download my pre-trained model (25 epochs, **not good**) from + [here](https://marvinborner.de/model-final.ckpt) and move it to + `target/model-final.ckpt` +- Test the final model: `./venv/bin/python main.py test` +- Infer text from flac: `./venv/bin/python main.py infer audio.flac` ## Note -- This is a proof-of-concept -- Does not use CUDA but should be easy to implement +- This is a proof-of-concept +- Does not use CUDA but should be easy to implement |