aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/inference_test.py
diff options
context:
space:
mode:
authorMarvin Borner2023-09-06 17:00:58 +0200
committerMarvin Borner2023-09-06 17:01:14 +0200
commita451ada2efba6113a6ba562c399b423804875897 (patch)
tree25c5641ff8b12fa741e01a16b70991b81f0d271f /swr2_asr/inference_test.py
parentf0910dfad95ff29e1bdf0a657558e17ff2230a14 (diff)
Added inference
Diffstat (limited to 'swr2_asr/inference_test.py')
-rw-r--r--swr2_asr/inference_test.py75
1 files changed, 0 insertions, 75 deletions
diff --git a/swr2_asr/inference_test.py b/swr2_asr/inference_test.py
deleted file mode 100644
index 96277fd..0000000
--- a/swr2_asr/inference_test.py
+++ /dev/null
@@ -1,75 +0,0 @@
-"""Training script for the ASR model."""
-import torch
-import torchaudio
-import torchaudio.functional as F
-
-
-class GreedyCTCDecoder(torch.nn.Module):
- """Greedy CTC decoder for the wav2vec2 model."""
-
- def __init__(self, labels, blank=0) -> None:
- super().__init__()
- self.labels = labels
- self.blank = blank
-
- def forward(self, emission: torch.Tensor) -> str:
- """Given a sequence emission over labels, get the best path string
- Args:
- emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
-
- Returns:
- str: The resulting transcript
- """
- indices = torch.argmax(emission, dim=-1) # [num_seq,]
- indices = torch.unique_consecutive(indices, dim=-1)
- indices = [i for i in indices if i != self.blank]
- return "".join([self.labels[i] for i in indices])
-
-
-'''
-Sorry marvin, Please fix this to use the new dataset
-def main() -> None:
- """Main function."""
- # choose between cuda, cpu and mps devices
- device = "cuda" if torch.cuda.is_available() else "cpu"
- # device = "mps"
- device = torch.device(device)
-
- torch.random.manual_seed(42)
-
- bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
-
- print(f"Sample rate (model): {bundle.sample_rate}")
- print(f"Labels (model): {bundle.get_labels()}")
-
- model = bundle.get_model().to(device)
-
- print(model.__class__)
-
- # only do all things for one single sample
- dataset = MultilingualLibriSpeech("data", "mls_german_opus", split="train", download=True)
-
- print(dataset[0])
-
- # load waveforms and sample rate from dataset
- waveform, sample_rate = dataset[0]["waveform"], dataset[0]["sample_rate"]
-
- if sample_rate != bundle.sample_rate:
- waveform = F.resample(waveform, sample_rate, int(bundle.sample_rate))
-
- waveform.to(device)
-
- with torch.inference_mode():
- features, _ = model.extract_features(waveform)
-
- with torch.inference_mode():
- emission, _ = model(waveform)
-
- decoder = GreedyCTCDecoder(labels=bundle.get_labels())
- transcript = decoder(emission[0])
-
- print(transcript)
-'''
-
-if __name__ == "__main__":
- main()