aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr
diff options
context:
space:
mode:
authorMarvin Borner2023-09-06 17:00:58 +0200
committerMarvin Borner2023-09-06 17:01:14 +0200
commita451ada2efba6113a6ba562c399b423804875897 (patch)
tree25c5641ff8b12fa741e01a16b70991b81f0d271f /swr2_asr
parentf0910dfad95ff29e1bdf0a657558e17ff2230a14 (diff)
Added inference
Diffstat (limited to 'swr2_asr')
-rw-r--r--swr2_asr/inference.py114
-rw-r--r--swr2_asr/inference_test.py75
2 files changed, 114 insertions, 75 deletions
diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py
new file mode 100644
index 0000000..c3eec42
--- /dev/null
+++ b/swr2_asr/inference.py
@@ -0,0 +1,114 @@
+"""Training script for the ASR model."""
+import torch
+import torchaudio
+import torch.nn.functional as F
+from typing import TypedDict
+
+from swr2_asr.tokenizer import CharTokenizer
+from swr2_asr.model_deep_speech import SpeechRecognitionModel
+
+
+class HParams(TypedDict):
+ """Type for the hyperparameters of the model."""
+
+ n_cnn_layers: int
+ n_rnn_layers: int
+ rnn_dim: int
+ n_class: int
+ n_feats: int
+ stride: int
+ dropout: float
+ learning_rate: float
+ batch_size: int
+ epochs: int
+
+
+def greedy_decoder(output, tokenizer, collapse_repeated=True):
+ """Greedily decode a sequence."""
+ arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member
+ blank_label = tokenizer.encode(" ").ids[0]
+ decodes = []
+ targets = []
+ for i, args in enumerate(arg_maxes):
+ decode = []
+ for j, index in enumerate(args):
+ if index != blank_label:
+ if collapse_repeated and j != 0 and index == args[j - 1]:
+ continue
+ decode.append(index.item())
+ decodes.append(tokenizer.decode(decode))
+ return decodes
+
+
+def main() -> None:
+ """inference function."""
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ device = torch.device(device)
+
+ tokenizer = CharTokenizer.from_file("char_tokenizer_german.json")
+
+ spectrogram_hparams = {
+ "sample_rate": 16000,
+ "n_fft": 400,
+ "win_length": 400,
+ "hop_length": 160,
+ "n_mels": 128,
+ "f_min": 0,
+ "f_max": 8000,
+ "power": 2.0,
+ }
+
+ hparams = HParams(
+ n_cnn_layers=3,
+ n_rnn_layers=5,
+ rnn_dim=512,
+ n_class=tokenizer.get_vocab_size(),
+ n_feats=128,
+ stride=2,
+ dropout=0.1,
+ learning_rate=0.1,
+ batch_size=30,
+ epochs=100,
+ )
+
+ model = SpeechRecognitionModel(
+ hparams["n_cnn_layers"],
+ hparams["n_rnn_layers"],
+ hparams["rnn_dim"],
+ hparams["n_class"],
+ hparams["n_feats"],
+ hparams["stride"],
+ hparams["dropout"],
+ ).to(device)
+
+ checkpoint = torch.load("model8", map_location=device)
+ state_dict = {
+ k[len("module.") :] if k.startswith("module.") else k: v
+ for k, v in checkpoint["model_state_dict"].items()
+ }
+ model.load_state_dict(state_dict)
+
+ # waveform, sample_rate = torchaudio.load("test.opus")
+ waveform, sample_rate = torchaudio.load("marvin_rede.flac")
+ if sample_rate != spectrogram_hparams["sample_rate"]:
+ resampler = torchaudio.transforms.Resample(sample_rate, spectrogram_hparams["sample_rate"])
+ waveform = resampler(waveform)
+
+ spec = (
+ torchaudio.transforms.MelSpectrogram(**spectrogram_hparams)(waveform)
+ .squeeze(0)
+ .transpose(0, 1)
+ )
+ specs = [spec]
+ specs = torch.nn.utils.rnn.pad_sequence(specs, batch_first=True).unsqueeze(1).transpose(2, 3)
+
+ output = model(specs)
+ output = F.log_softmax(output, dim=2)
+ output = output.transpose(0, 1) # (time, batch, n_class)
+ decodes = greedy_decoder(output, tokenizer)
+ print(decodes)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/swr2_asr/inference_test.py b/swr2_asr/inference_test.py
deleted file mode 100644
index 96277fd..0000000
--- a/swr2_asr/inference_test.py
+++ /dev/null
@@ -1,75 +0,0 @@
-"""Training script for the ASR model."""
-import torch
-import torchaudio
-import torchaudio.functional as F
-
-
-class GreedyCTCDecoder(torch.nn.Module):
- """Greedy CTC decoder for the wav2vec2 model."""
-
- def __init__(self, labels, blank=0) -> None:
- super().__init__()
- self.labels = labels
- self.blank = blank
-
- def forward(self, emission: torch.Tensor) -> str:
- """Given a sequence emission over labels, get the best path string
- Args:
- emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`.
-
- Returns:
- str: The resulting transcript
- """
- indices = torch.argmax(emission, dim=-1) # [num_seq,]
- indices = torch.unique_consecutive(indices, dim=-1)
- indices = [i for i in indices if i != self.blank]
- return "".join([self.labels[i] for i in indices])
-
-
-'''
-Sorry marvin, Please fix this to use the new dataset
-def main() -> None:
- """Main function."""
- # choose between cuda, cpu and mps devices
- device = "cuda" if torch.cuda.is_available() else "cpu"
- # device = "mps"
- device = torch.device(device)
-
- torch.random.manual_seed(42)
-
- bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H
-
- print(f"Sample rate (model): {bundle.sample_rate}")
- print(f"Labels (model): {bundle.get_labels()}")
-
- model = bundle.get_model().to(device)
-
- print(model.__class__)
-
- # only do all things for one single sample
- dataset = MultilingualLibriSpeech("data", "mls_german_opus", split="train", download=True)
-
- print(dataset[0])
-
- # load waveforms and sample rate from dataset
- waveform, sample_rate = dataset[0]["waveform"], dataset[0]["sample_rate"]
-
- if sample_rate != bundle.sample_rate:
- waveform = F.resample(waveform, sample_rate, int(bundle.sample_rate))
-
- waveform.to(device)
-
- with torch.inference_mode():
- features, _ = model.extract_features(waveform)
-
- with torch.inference_mode():
- emission, _ = model(waveform)
-
- decoder = GreedyCTCDecoder(labels=bundle.get_labels())
- transcript = decoder(emission[0])
-
- print(transcript)
-'''
-
-if __name__ == "__main__":
- main()