aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/inference.py
diff options
context:
space:
mode:
authorPherkel2023-09-18 12:44:40 +0200
committerPherkel2023-09-18 12:44:40 +0200
commitd5689047fa7062b284d13271bda39013dcf6150f (patch)
treebd1a843abda1929b826d9441df3ddc3db5cbec29 /swr2_asr/inference.py
parent9475900a1085b8277808b0a0b1555c59f7eb6d36 (diff)
parente06227289ad9d2fa45c736c771d859e9911b9a11 (diff)
Merge branch 'decoder' of github.com:Algo-Boys/SWR2-cool-projekt into decoder
Diffstat (limited to 'swr2_asr/inference.py')
-rw-r--r--swr2_asr/inference.py7
1 files changed, 6 insertions, 1 deletions
diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py
index 3f6a44e..3c58af0 100644
--- a/swr2_asr/inference.py
+++ b/swr2_asr/inference.py
@@ -66,7 +66,12 @@ def main(config_path: str, file_path: str) -> None:
).to(device)
checkpoint = torch.load(inference_config["model_load_path"], map_location=device)
- model.load_state_dict(checkpoint["model_state_dict"], strict=True)
+
+ state_dict = {
+ k[len("module.") :] if k.startswith("module.") else k: v
+ for k, v in checkpoint["model_state_dict"].items()
+ }
+ model.load_state_dict(state_dict, strict=True)
model.eval()
waveform, sample_rate = torchaudio.load(file_path) # pylint: disable=no-member