From 6f5513140f153206cfa91df3077e67ce58043d35 Mon Sep 17 00:00:00 2001 From: Pherkel Date: Mon, 11 Sep 2023 22:58:19 +0200 Subject: model loading is broken :( --- swr2_asr/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'swr2_asr/train.py') diff --git a/swr2_asr/train.py b/swr2_asr/train.py index ca70d21..ec25918 100644 --- a/swr2_asr/train.py +++ b/swr2_asr/train.py @@ -263,7 +263,7 @@ def main(config_path: str): prev_epoch = 0 if checkpoints_config["model_load_path"] is not None: - checkpoint = torch.load(checkpoints_config["model_load_path"]) + checkpoint = torch.load(checkpoints_config["model_load_path"], map_location=device) model.load_state_dict(checkpoint["model_state_dict"]) optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) prev_epoch = checkpoint["epoch"] -- cgit v1.2.3