aboutsummaryrefslogtreecommitdiff
path: root/swr2_asr/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'swr2_asr/train.py')
-rw-r--r--swr2_asr/train.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/swr2_asr/train.py b/swr2_asr/train.py
index ca70d21..ec25918 100644
--- a/swr2_asr/train.py
+++ b/swr2_asr/train.py
@@ -263,7 +263,7 @@ def main(config_path: str):
prev_epoch = 0
if checkpoints_config["model_load_path"] is not None:
- checkpoint = torch.load(checkpoints_config["model_load_path"])
+ checkpoint = torch.load(checkpoints_config["model_load_path"], map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
prev_epoch = checkpoint["epoch"]