aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/inference_test.py
diff options
context:
space:
mode:
authorPherkel2023-09-03 19:44:27 +0200
committerPherkel2023-09-03 19:44:27 +0200
commit6a73a1ed117d9ace8546aa5114daada91e69b069 (patch)
treec217e74e59f229672044f9f96c04cbdfdc2dd51f /swr2_asr/inference_test.py
parent33f09080aee10bddb4797a557d676ee1f7b8de31 (diff)
reformat, remove deps
Diffstat (limited to 'swr2_asr/inference_test.py')
-rw-r--r--swr2_asr/inference_test.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/swr2_asr/inference_test.py b/swr2_asr/inference_test.py
index 16bd54b..96277fd 100644
--- a/swr2_asr/inference_test.py
+++ b/swr2_asr/inference_test.py
@@ -1,11 +1,12 @@
"""Training script for the ASR model."""
-from AudioLoader.speech.mls import MultilingualLibriSpeech
import torch
import torchaudio
import torchaudio.functional as F
class GreedyCTCDecoder(torch.nn.Module):
+ """Greedy CTC decoder for the wav2vec2 model."""
+
def __init__(self, labels, blank=0) -> None:
super().__init__()
self.labels = labels
@@ -25,6 +26,8 @@ class GreedyCTCDecoder(torch.nn.Module):
return "".join([self.labels[i] for i in indices])
+'''
+Sorry marvin, Please fix this to use the new dataset
def main() -> None:
"""Main function."""
# choose between cuda, cpu and mps devices
@@ -66,7 +69,7 @@ def main() -> None:
transcript = decoder(emission[0])
print(transcript)
-
+'''
if __name__ == "__main__":
main()