diff options
Diffstat (limited to 'swr2_asr/inference_test.py')
-rw-r--r-- | swr2_asr/inference_test.py | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/swr2_asr/inference_test.py b/swr2_asr/inference_test.py index 16bd54b..96277fd 100644 --- a/swr2_asr/inference_test.py +++ b/swr2_asr/inference_test.py @@ -1,11 +1,12 @@ """Training script for the ASR model.""" -from AudioLoader.speech.mls import MultilingualLibriSpeech import torch import torchaudio import torchaudio.functional as F class GreedyCTCDecoder(torch.nn.Module): + """Greedy CTC decoder for the wav2vec2 model.""" + def __init__(self, labels, blank=0) -> None: super().__init__() self.labels = labels @@ -25,6 +26,8 @@ class GreedyCTCDecoder(torch.nn.Module): return "".join([self.labels[i] for i in indices]) +''' +Sorry marvin, Please fix this to use the new dataset def main() -> None: """Main function.""" # choose between cuda, cpu and mps devices @@ -66,7 +69,7 @@ def main() -> None: transcript = decoder(emission[0]) print(transcript) - +''' if __name__ == "__main__": main() |