aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--main.py99
-rw-r--r--readme.md20
2 files changed, 76 insertions, 43 deletions
diff --git a/main.py b/main.py
index 1c87dd4..0ef7fcd 100644
--- a/main.py
+++ b/main.py
@@ -12,17 +12,33 @@ import torch.nn.functional as F
CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ' "
EPOCHS = 25
-FEATURES = 13 # idk
+FEATURES = 15 # idk
+
+
+def usage():
+ print(f"usage: {sys.argv[0]} <train|test|infer> [file]")
+ sys.exit(1)
+
+
+if __name__ != "__main__":
+ sys.exit(1) # this isn't a library bro
+
+if len(sys.argv) < 2 or sys.argv[1] not in ["train", "test", "infer"]:
+ usage()
+
+MODE = sys.argv[1]
print("getting datasets...")
# download the datasets
-train_dataset = torchaudio.datasets.LIBRISPEECH(
- "./data", url="train-clean-100", download=True
-)
-test_dataset = torchaudio.datasets.LIBRISPEECH(
- "./data", url="test-clean", download=True
-)
+if MODE == "train":
+ train_dataset = torchaudio.datasets.LIBRISPEECH(
+ "./data", url="train-clean-100", download=True
+ )
+if MODE == "test":
+ test_dataset = torchaudio.datasets.LIBRISPEECH(
+ "./data", url="test-clean", download=True
+ )
print("got datasets!")
@@ -51,12 +67,14 @@ def preprocess(data):
print("preprocessing datasets...")
# load the datasets into batches
-train_loader = data.DataLoader(
- train_dataset, batch_size=10, shuffle=True, collate_fn=preprocess
-)
-test_loader = data.DataLoader(
- test_dataset, batch_size=10, shuffle=True, collate_fn=preprocess
-)
+if MODE == "train":
+ train_loader = data.DataLoader(
+ train_dataset, batch_size=10, shuffle=True, collate_fn=preprocess
+ )
+if MODE == "test":
+ test_loader = data.DataLoader(
+ test_dataset, batch_size=10, shuffle=True, collate_fn=preprocess
+ )
print("datasets ready!")
@@ -104,15 +122,10 @@ def train():
print("model trained!")
-def ctc_decode(outputs, labels, label_lengths):
+def ctc_decode(outputs):
arg_maxes = torch.argmax(outputs, dim=2)
inferred = ""
- target = ""
for i, args in enumerate(arg_maxes):
- target += "".join(
- [CHARS[int(c)] for c in labels[i][: label_lengths[i]].tolist()]
- )
-
decode = []
for j, ind in enumerate(args):
if ind != len(CHARS):
@@ -120,7 +133,16 @@ def ctc_decode(outputs, labels, label_lengths):
continue
decode.append(ind.item())
inferred += "".join([CHARS[c] for c in decode])
- return inferred, target
+ return inferred
+
+
+def ctc_decode_target(labels, label_lengths):
+ target = ""
+ for i, args in enumerate(labels):
+ target += "".join(
+ [CHARS[int(c)] for c in labels[i][: label_lengths[i]].tolist()]
+ )
+ return target
def test():
@@ -132,25 +154,30 @@ def test():
test_loader
):
outputs = model(inputs)
- inferred, target = ctc_decode(outputs, targets, target_lengths)
+ inferred = ctc_decode(outputs)
+ target = ctc_decode_target(targets, target_lengths)
print("\n=========================================\n")
print("inferred: ", inferred)
print("")
print("target: ", target)
-def usage():
- print(f"usage: {sys.argv[0]} <train|test>")
-
-
-if __name__ == "__main__":
- if len(sys.argv) != 2:
- usage()
- sys.exit(1)
-
- if sys.argv[1] == "train":
- train()
- elif sys.argv[1] == "test":
- test()
- else:
- usage()
+# TODO: Might be buggy?
+def infer(file):
+ model.load_state_dict(torch.load("target/model-final.ckpt"))
+ audio, sr = torchaudio.load(file)
+ inputs = preprocess([(audio, sr, "")])[0]
+ with torch.no_grad():
+ outputs = model(inputs)
+ inferred = ctc_decode(outputs)
+ print("inferred: ", inferred)
+
+
+if MODE == "train":
+ train()
+elif MODE == "test":
+ test()
+elif MODE == "infer" and len(sys.argv) == 3:
+ infer(sys.argv[2])
+else:
+ usage()
diff --git a/readme.md b/readme.md
index 86ab694..1bf726f 100644
--- a/readme.md
+++ b/readme.md
@@ -2,14 +2,20 @@
> spoken word recognition using CTC LSTMs
-## Installation
+## Instructions
-- `python -m venv venv`
-- `./venv/bin/pip install -r requirements.txt`
-- `./venv/bin/python main.py train`
-- `./venv/bin/python main.py test`
+- Create a virtual environment: `python -m venv venv`
+- Install the required packages:
+ `./venv/bin/pip install -r requirements.txt`
+- Train the model: `./venv/bin/python main.py train` (takes a few hours
+ and needs around 20GB disk and 5GB memory)
+ - or download my pre-trained model (25 epochs, **not good**) from
+ [here](https://marvinborner.de/model-final.ckpt) and move it to
+ `target/model-final.ckpt`
+- Test the final model: `./venv/bin/python main.py test`
+- Infer text from flac: `./venv/bin/python main.py infer audio.flac`
## Note
-- This is a proof-of-concept
-- Does not use CUDA but should be easy to implement
+- This is a proof-of-concept
+- Does not use CUDA but should be easy to implement