aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--main.py136
-rw-r--r--readme.md8
-rw-r--r--train.py11
4 files changed, 144 insertions, 12 deletions
diff --git a/.gitignore b/.gitignore
index f45af34..6315a01 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,2 +1,3 @@
venv/
data/
+target/
diff --git a/main.py b/main.py
new file mode 100644
index 0000000..f1caf26
--- /dev/null
+++ b/main.py
@@ -0,0 +1,136 @@
+#!/bin/env python3
+
+import sys
+
+import torchaudio
+import torchaudio.transforms as T
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torch.utils.data as data
+import torch.nn.functional as F
+
+CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ' "
+EPOCHS = 25
+SEGMENTS = 15
+
+print("getting datasets...")
+
+# download the datasets
+train_dataset = torchaudio.datasets.LIBRISPEECH(
+ "./data", url="train-clean-100", download=True
+)
+test_dataset = torchaudio.datasets.LIBRISPEECH(
+ "./data", url="test-clean", download=True
+)
+
+print("got datasets!")
+
+
+def preprocess(data):
+ transform = T.MFCC(sample_rate=16000, n_mfcc=SEGMENTS)
+
+ inputs = []
+ targets = []
+ input_lengths = []
+ target_lengths = []
+
+ for audio in data:
+ input = transform(audio[0]).squeeze(0).transpose(0, 1)
+ inputs.append(input)
+ target = torch.Tensor([CHARS.index(c) for c in audio[2]])
+ targets.append(target)
+ input_lengths.append(input.shape[0])
+ target_lengths.append(len(target))
+
+ inputs = nn.utils.rnn.pad_sequence(inputs, batch_first=True)
+ targets = nn.utils.rnn.pad_sequence(targets, batch_first=True)
+ return (inputs, targets, input_lengths, target_lengths)
+
+
+print("preprocessing datasets...")
+
+# load the datasets into batches
+train_loader = data.DataLoader(
+ train_dataset, batch_size=10, shuffle=True, collate_fn=preprocess
+)
+test_loader = data.DataLoader(
+ test_dataset, batch_size=10, shuffle=True, collate_fn=preprocess
+)
+
+print("datasets ready!")
+
+
+class Model(nn.Module):
+ def __init__(self, input_size, hidden_size, output_size, num_layers=1):
+ super().__init__()
+ self.lstm = nn.LSTM(
+ input_size, hidden_size, num_layers, batch_first=True
+ )
+ self.fc = nn.Linear(hidden_size, output_size)
+
+ def forward(self, x):
+ x, _ = self.lstm(x)
+ x = self.fc(x)
+ return x
+
+
+def train():
+ print("training model...")
+
+ model = Model(
+ input_size=SEGMENTS, hidden_size=128, output_size=len(CHARS) + 1
+ )
+ criterion = nn.CTCLoss(blank=len(CHARS))
+ optimizer = optim.Adam(model.parameters())
+
+ for epoch in range(EPOCHS):
+ for i, (inputs, targets, input_lengths, target_lengths) in enumerate(
+ train_loader
+ ):
+ print(f"epoch {epoch}/{EPOCHS}, iteration {i}/{len(train_loader)}")
+ optimizer.zero_grad()
+
+ outputs = model(inputs)
+ outputs = F.log_softmax(outputs, dim=2)
+ outputs = outputs.transpose(0, 1)
+
+ loss = criterion(outputs, targets, input_lengths, target_lengths)
+ loss.backward()
+ optimizer.step()
+ torch.save(model.state_dict(), f"target/model-{epoch}.ckpt")
+
+ torch.save(model.state_dict(), "target/model-final.ckpt")
+ print("model trained!")
+
+
+def test():
+ model = torch.load("target/model-final.ckpt")
+ correct = 0
+ total = 0
+
+ with torch.no_grad():
+ for i, (inputs, targets) in enumerate(test_loader):
+ outputs = model(inputs)
+ decoded_preds = ctc_decode(outputs)
+ total += targets.size(0)
+ correct += (decoded_preds == targets).sum().item()
+
+ print("Test Accuracy: %d %%" % (100 * correct / total))
+
+
+def usage():
+ print(f"usage: {sys.argv[0]} <train|test>")
+
+
+if __name__ == "__main__":
+ if len(sys.argv) != 2:
+ usage()
+ sys.exit(1)
+
+ if sys.argv[1] == "train":
+ train()
+ elif sys.argv[1] == "test":
+ test()
+ else:
+ usage()
diff --git a/readme.md b/readme.md
index c648cf6..86ab694 100644
--- a/readme.md
+++ b/readme.md
@@ -6,4 +6,10 @@
- `python -m venv venv`
- `./venv/bin/pip install -r requirements.txt`
-- `./venv/bin/python train.py`
+- `./venv/bin/python main.py train`
+- `./venv/bin/python main.py test`
+
+## Note
+
+- This is a proof-of-concept
+- Does not use CUDA but should be easy to implement
diff --git a/train.py b/train.py
deleted file mode 100644
index ebeb4d4..0000000
--- a/train.py
+++ /dev/null
@@ -1,11 +0,0 @@
-#!/bin/env python3
-
-import torchaudio
-
-# download the datasets
-train_dataset = torchaudio.datasets.LIBRISPEECH(
- "./data", url="train-clean-100", download=True
-)
-test_dataset = torchaudio.datasets.LIBRISPEECH(
- "./data", url="test-clean", download=True
-)