aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/inference_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/inference_test.py')
-rw-r--r--swr2_asr/inference_test.py11
1 files changed, 6 insertions, 5 deletions
diff --git a/swr2_asr/inference_test.py b/swr2_asr/inference_test.py
index a6b0010..96277fd 100644
--- a/swr2_asr/inference_test.py
+++ b/swr2_asr/inference_test.py
@@ -1,11 +1,12 @@
"""Training script for the ASR model."""
-from AudioLoader.speech.mls import MultilingualLibriSpeech
import torch
import torchaudio
import torchaudio.functional as F
class GreedyCTCDecoder(torch.nn.Module):
+ """Greedy CTC decoder for the wav2vec2 model."""
+
def __init__(self, labels, blank=0) -> None:
super().__init__()
self.labels = labels
@@ -25,6 +26,8 @@ class GreedyCTCDecoder(torch.nn.Module):
return "".join([self.labels[i] for i in indices])
+'''
+Sorry marvin, Please fix this to use the new dataset
def main() -> None:
"""Main function."""
# choose between cuda, cpu and mps devices
@@ -44,9 +47,7 @@ def main() -> None:
print(model.__class__)
# only do all things for one single sample
- dataset = MultilingualLibriSpeech(
- "data", "mls_german_opus", split="train", download=True
- )
+ dataset = MultilingualLibriSpeech("data", "mls_german_opus", split="train", download=True)
print(dataset[0])
@@ -68,7 +69,7 @@ def main() -> None:
transcript = decoder(emission[0])
print(transcript)
-
+'''
if __name__ == "__main__":
main()