diff options
author | Marvin Borner | 2023-09-06 17:00:58 +0200 |
---|---|---|
committer | Marvin Borner | 2023-09-06 17:01:14 +0200 |
commit | a451ada2efba6113a6ba562c399b423804875897 (patch) | |
tree | 25c5641ff8b12fa741e01a16b70991b81f0d271f /swr2_asr | |
parent | f0910dfad95ff29e1bdf0a657558e17ff2230a14 (diff) |
Added inference
Diffstat (limited to 'swr2_asr')
-rw-r--r-- | swr2_asr/inference.py | 114 | ||||
-rw-r--r-- | swr2_asr/inference_test.py | 75 |
2 files changed, 114 insertions, 75 deletions
diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py new file mode 100644 index 0000000..c3eec42 --- /dev/null +++ b/swr2_asr/inference.py @@ -0,0 +1,114 @@ +"""Training script for the ASR model.""" +import torch +import torchaudio +import torch.nn.functional as F +from typing import TypedDict + +from swr2_asr.tokenizer import CharTokenizer +from swr2_asr.model_deep_speech import SpeechRecognitionModel + + +class HParams(TypedDict): + """Type for the hyperparameters of the model.""" + + n_cnn_layers: int + n_rnn_layers: int + rnn_dim: int + n_class: int + n_feats: int + stride: int + dropout: float + learning_rate: float + batch_size: int + epochs: int + + +def greedy_decoder(output, tokenizer, collapse_repeated=True): + """Greedily decode a sequence.""" + arg_maxes = torch.argmax(output, dim=2) # pylint: disable=no-member + blank_label = tokenizer.encode(" ").ids[0] + decodes = [] + targets = [] + for i, args in enumerate(arg_maxes): + decode = [] + for j, index in enumerate(args): + if index != blank_label: + if collapse_repeated and j != 0 and index == args[j - 1]: + continue + decode.append(index.item()) + decodes.append(tokenizer.decode(decode)) + return decodes + + +def main() -> None: + """inference function.""" + + device = "cuda" if torch.cuda.is_available() else "cpu" + device = torch.device(device) + + tokenizer = CharTokenizer.from_file("char_tokenizer_german.json") + + spectrogram_hparams = { + "sample_rate": 16000, + "n_fft": 400, + "win_length": 400, + "hop_length": 160, + "n_mels": 128, + "f_min": 0, + "f_max": 8000, + "power": 2.0, + } + + hparams = HParams( + n_cnn_layers=3, + n_rnn_layers=5, + rnn_dim=512, + n_class=tokenizer.get_vocab_size(), + n_feats=128, + stride=2, + dropout=0.1, + learning_rate=0.1, + batch_size=30, + epochs=100, + ) + + model = SpeechRecognitionModel( + hparams["n_cnn_layers"], + hparams["n_rnn_layers"], + hparams["rnn_dim"], + hparams["n_class"], + hparams["n_feats"], + hparams["stride"], + hparams["dropout"], + ).to(device) + + checkpoint = torch.load("model8", map_location=device) + state_dict = { + k[len("module.") :] if k.startswith("module.") else k: v + for k, v in checkpoint["model_state_dict"].items() + } + model.load_state_dict(state_dict) + + # waveform, sample_rate = torchaudio.load("test.opus") + waveform, sample_rate = torchaudio.load("marvin_rede.flac") + if sample_rate != spectrogram_hparams["sample_rate"]: + resampler = torchaudio.transforms.Resample(sample_rate, spectrogram_hparams["sample_rate"]) + waveform = resampler(waveform) + + spec = ( + torchaudio.transforms.MelSpectrogram(**spectrogram_hparams)(waveform) + .squeeze(0) + .transpose(0, 1) + ) + specs = [spec] + specs = torch.nn.utils.rnn.pad_sequence(specs, batch_first=True).unsqueeze(1).transpose(2, 3) + + output = model(specs) + output = F.log_softmax(output, dim=2) + output = output.transpose(0, 1) # (time, batch, n_class) + decodes = greedy_decoder(output, tokenizer) + print(decodes) + + +if __name__ == "__main__": + main() diff --git a/swr2_asr/inference_test.py b/swr2_asr/inference_test.py deleted file mode 100644 index 96277fd..0000000 --- a/swr2_asr/inference_test.py +++ /dev/null @@ -1,75 +0,0 @@ -"""Training script for the ASR model.""" -import torch -import torchaudio -import torchaudio.functional as F - - -class GreedyCTCDecoder(torch.nn.Module): - """Greedy CTC decoder for the wav2vec2 model.""" - - def __init__(self, labels, blank=0) -> None: - super().__init__() - self.labels = labels - self.blank = blank - - def forward(self, emission: torch.Tensor) -> str: - """Given a sequence emission over labels, get the best path string - Args: - emission (Tensor): Logit tensors. Shape `[num_seq, num_label]`. - - Returns: - str: The resulting transcript - """ - indices = torch.argmax(emission, dim=-1) # [num_seq,] - indices = torch.unique_consecutive(indices, dim=-1) - indices = [i for i in indices if i != self.blank] - return "".join([self.labels[i] for i in indices]) - - -''' -Sorry marvin, Please fix this to use the new dataset -def main() -> None: - """Main function.""" - # choose between cuda, cpu and mps devices - device = "cuda" if torch.cuda.is_available() else "cpu" - # device = "mps" - device = torch.device(device) - - torch.random.manual_seed(42) - - bundle = torchaudio.pipelines.WAV2VEC2_ASR_BASE_960H - - print(f"Sample rate (model): {bundle.sample_rate}") - print(f"Labels (model): {bundle.get_labels()}") - - model = bundle.get_model().to(device) - - print(model.__class__) - - # only do all things for one single sample - dataset = MultilingualLibriSpeech("data", "mls_german_opus", split="train", download=True) - - print(dataset[0]) - - # load waveforms and sample rate from dataset - waveform, sample_rate = dataset[0]["waveform"], dataset[0]["sample_rate"] - - if sample_rate != bundle.sample_rate: - waveform = F.resample(waveform, sample_rate, int(bundle.sample_rate)) - - waveform.to(device) - - with torch.inference_mode(): - features, _ = model.extract_features(waveform) - - with torch.inference_mode(): - emission, _ = model(waveform) - - decoder = GreedyCTCDecoder(labels=bundle.get_labels()) - transcript = decoder(emission[0]) - - print(transcript) -''' - -if __name__ == "__main__": - main() |