diff options
-rw-r--r-- | main.py | 15 |
1 files changed, 9 insertions, 6 deletions
@@ -51,10 +51,13 @@ def preprocess(data): input_lengths = [] target_lengths = [] - for audio in data: - input = transform(audio[0]).squeeze(0).transpose(0, 1) + for wav, sr, label in data: + if sr != 16000: + resample = T.Resample(orig_freq=sr, new_freq=16000) + wav = resample(wav) + input = transform(wav).squeeze(0).transpose(0, 1) inputs.append(input) - target = torch.Tensor([CHARS.index(c) for c in audio[2]]) + target = torch.Tensor([CHARS.index(c) for c in label]) targets.append(target) input_lengths.append(input.shape[0]) target_lengths.append(len(target)) @@ -163,10 +166,10 @@ def test(): # TODO: Might be buggy? -def infer(file): +def infer(file_): model.load_state_dict(torch.load("target/model-final.ckpt")) - audio, sr = torchaudio.load(file) - inputs = preprocess([(audio, sr, "")])[0] + wav, sr = torchaudio.load(file_) + inputs = preprocess([(wav, sr, "")])[0] with torch.no_grad(): outputs = model(inputs) inferred = ctc_decode(outputs) |