diff options
author | Pherkel | 2023-09-04 22:55:27 +0200 |
---|---|---|
committer | GitHub | 2023-09-04 22:55:27 +0200 |
commit | 93e49e708fa59613406249069d31c2f6c8f2d2ab (patch) | |
tree | 9f413c226283db990116c9559ffffb9124b911d8 /swr2_asr/inference_test.py | |
parent | 14ceeb5ad36beea2f05214aa26260cdd1d86590b (diff) | |
parent | 0d70a19e1fea6eda3f7b16ad0084591613f2de72 (diff) |
Merge pull request #27 from Algo-Boys/refactor_modularize
Refactor modularize
Diffstat (limited to 'swr2_asr/inference_test.py')
-rw-r--r-- | swr2_asr/inference_test.py | 11 |
1 files changed, 6 insertions, 5 deletions
diff --git a/swr2_asr/inference_test.py b/swr2_asr/inference_test.py index a6b0010..96277fd 100644 --- a/swr2_asr/inference_test.py +++ b/swr2_asr/inference_test.py @@ -1,11 +1,12 @@ """Training script for the ASR model.""" -from AudioLoader.speech.mls import MultilingualLibriSpeech import torch import torchaudio import torchaudio.functional as F class GreedyCTCDecoder(torch.nn.Module): + """Greedy CTC decoder for the wav2vec2 model.""" + def __init__(self, labels, blank=0) -> None: super().__init__() self.labels = labels @@ -25,6 +26,8 @@ class GreedyCTCDecoder(torch.nn.Module): return "".join([self.labels[i] for i in indices]) +''' +Sorry marvin, Please fix this to use the new dataset def main() -> None: """Main function.""" # choose between cuda, cpu and mps devices @@ -44,9 +47,7 @@ def main() -> None: print(model.__class__) # only do all things for one single sample - dataset = MultilingualLibriSpeech( - "data", "mls_german_opus", split="train", download=True - ) + dataset = MultilingualLibriSpeech("data", "mls_german_opus", split="train", download=True) print(dataset[0]) @@ -68,7 +69,7 @@ def main() -> None: transcript = decoder(emission[0]) print(transcript) - +''' if __name__ == "__main__": main() |