From 5ba152d0b3268ee5a0d5bdf182b1e5d34c570ecf Mon Sep 17 00:00:00 2001 From: Konstantin (Tino) Sering Date: Tue, 4 Jul 2023 13:40:40 +0200 Subject: resamples audio to 16000 Hz in the inference step; some code improvements --- main.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) (limited to 'main.py') diff --git a/main.py b/main.py index 0ef7fcd..940617f 100644 --- a/main.py +++ b/main.py @@ -51,10 +51,13 @@ def preprocess(data): input_lengths = [] target_lengths = [] - for audio in data: - input = transform(audio[0]).squeeze(0).transpose(0, 1) + for wav, sr, label in data: + if sr != 16000: + resample = T.Resample(orig_freq=sr, new_freq=16000) + wav = resample(wav) + input = transform(wav).squeeze(0).transpose(0, 1) inputs.append(input) - target = torch.Tensor([CHARS.index(c) for c in audio[2]]) + target = torch.Tensor([CHARS.index(c) for c in label]) targets.append(target) input_lengths.append(input.shape[0]) target_lengths.append(len(target)) @@ -163,10 +166,10 @@ def test(): # TODO: Might be buggy? -def infer(file): +def infer(file_): model.load_state_dict(torch.load("target/model-final.ckpt")) - audio, sr = torchaudio.load(file) - inputs = preprocess([(audio, sr, "")])[0] + wav, sr = torchaudio.load(file_) + inputs = preprocess([(wav, sr, "")])[0] with torch.no_grad(): outputs = model(inputs) inferred = ctc_decode(outputs) -- cgit v1.2.3