aboutsummaryrefslogtreecommitdiff
path: root/main.py
diff options
context:
space:
mode:
Diffstat (limited to 'main.py')
-rw-r--r--main.py15
1 files changed, 9 insertions, 6 deletions
diff --git a/main.py b/main.py
index 0ef7fcd..940617f 100644
--- a/main.py
+++ b/main.py
@@ -51,10 +51,13 @@ def preprocess(data):
input_lengths = []
target_lengths = []
- for audio in data:
- input = transform(audio[0]).squeeze(0).transpose(0, 1)
+ for wav, sr, label in data:
+ if sr != 16000:
+ resample = T.Resample(orig_freq=sr, new_freq=16000)
+ wav = resample(wav)
+ input = transform(wav).squeeze(0).transpose(0, 1)
inputs.append(input)
- target = torch.Tensor([CHARS.index(c) for c in audio[2]])
+ target = torch.Tensor([CHARS.index(c) for c in label])
targets.append(target)
input_lengths.append(input.shape[0])
target_lengths.append(len(target))
@@ -163,10 +166,10 @@ def test():
# TODO: Might be buggy?
-def infer(file):
+def infer(file_):
model.load_state_dict(torch.load("target/model-final.ckpt"))
- audio, sr = torchaudio.load(file)
- inputs = preprocess([(audio, sr, "")])[0]
+ wav, sr = torchaudio.load(file_)
+ inputs = preprocess([(wav, sr, "")])[0]
with torch.no_grad():
outputs = model(inputs)
inferred = ctc_decode(outputs)