aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/inference_test.py
diff options
context:
space:
mode:
authorPherkel2023-09-04 22:55:27 +0200
committerGitHub2023-09-04 22:55:27 +0200
commit93e49e708fa59613406249069d31c2f6c8f2d2ab (patch)
tree9f413c226283db990116c9559ffffb9124b911d8 /swr2_asr/inference_test.py
parent14ceeb5ad36beea2f05214aa26260cdd1d86590b (diff)
parent0d70a19e1fea6eda3f7b16ad0084591613f2de72 (diff)
Merge pull request #27 from Algo-Boys/refactor_modularize
Refactor modularize
Diffstat (limited to 'swr2_asr/inference_test.py')
-rw-r--r--swr2_asr/inference_test.py11
1 files changed, 6 insertions, 5 deletions
diff --git a/swr2_asr/inference_test.py b/swr2_asr/inference_test.py
index a6b0010..96277fd 100644
--- a/swr2_asr/inference_test.py
+++ b/swr2_asr/inference_test.py
@@ -1,11 +1,12 @@
"""Training script for the ASR model."""
-from AudioLoader.speech.mls import MultilingualLibriSpeech
import torch
import torchaudio
import torchaudio.functional as F
class GreedyCTCDecoder(torch.nn.Module):
+ """Greedy CTC decoder for the wav2vec2 model."""
+
def __init__(self, labels, blank=0) -> None:
super().__init__()
self.labels = labels
@@ -25,6 +26,8 @@ class GreedyCTCDecoder(torch.nn.Module):
return "".join([self.labels[i] for i in indices])
+'''
+Sorry marvin, Please fix this to use the new dataset
def main() -> None:
"""Main function."""
# choose between cuda, cpu and mps devices
@@ -44,9 +47,7 @@ def main() -> None:
print(model.__class__)
# only do all things for one single sample
- dataset = MultilingualLibriSpeech(
- "data", "mls_german_opus", split="train", download=True
- )
+ dataset = MultilingualLibriSpeech("data", "mls_german_opus", split="train", download=True)
print(dataset[0])
@@ -68,7 +69,7 @@ def main() -> None:
transcript = decoder(emission[0])
print(transcript)
-
+'''
if __name__ == "__main__":
main()