aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/inference_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/inference_test.py')
-rw-r--r--swr2_asr/inference_test.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/swr2_asr/inference_test.py b/swr2_asr/inference_test.py
index 16bd54b..96277fd 100644
--- a/swr2_asr/inference_test.py
+++ b/swr2_asr/inference_test.py
@@ -1,11 +1,12 @@
"""Training script for the ASR model."""
-from AudioLoader.speech.mls import MultilingualLibriSpeech
import torch
import torchaudio
import torchaudio.functional as F
class GreedyCTCDecoder(torch.nn.Module):
+ """Greedy CTC decoder for the wav2vec2 model."""
+
def __init__(self, labels, blank=0) -> None:
super().__init__()
self.labels = labels
@@ -25,6 +26,8 @@ class GreedyCTCDecoder(torch.nn.Module):
return "".join([self.labels[i] for i in indices])
+'''
+Sorry marvin, Please fix this to use the new dataset
def main() -> None:
"""Main function."""
# choose between cuda, cpu and mps devices
@@ -66,7 +69,7 @@ def main() -> None:
transcript = decoder(emission[0])
print(transcript)
-
+'''
if __name__ == "__main__":
main()