From 4d42fdd4f9f96018f1fe9024e59278ae78c4f605 Mon Sep 17 00:00:00 2001 From: Marvin Borner Date: Sun, 17 Sep 2023 18:15:45 +0200 Subject: Model loading fix for DataParallel models --- swr2_asr/inference.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/swr2_asr/inference.py b/swr2_asr/inference.py index 3f6a44e..3c58af0 100644 --- a/swr2_asr/inference.py +++ b/swr2_asr/inference.py @@ -66,7 +66,12 @@ def main(config_path: str, file_path: str) -> None: ).to(device) checkpoint = torch.load(inference_config["model_load_path"], map_location=device) - model.load_state_dict(checkpoint["model_state_dict"], strict=True) + + state_dict = { + k[len("module.") :] if k.startswith("module.") else k: v + for k, v in checkpoint["model_state_dict"].items() + } + model.load_state_dict(state_dict, strict=True) model.eval() waveform, sample_rate = torchaudio.load(file_path) # pylint: disable=no-member -- cgit v1.2.3