diff options
author | fredeee | 2023-11-02 10:47:21 +0100 |
---|---|---|
committer | fredeee | 2023-11-02 10:47:21 +0100 |
commit | f8302ee886ef9b631f11a52900dac964a61350e1 (patch) | |
tree | 87288be6f851ab69405e524b81940c501c52789a /scripts/utils/eval_utils.py | |
parent | f16fef1ab9371e1c81a2e0b2fbea59dee285a9f8 (diff) |
initiaƶ commit
Diffstat (limited to 'scripts/utils/eval_utils.py')
-rw-r--r-- | scripts/utils/eval_utils.py | 78 |
1 files changed, 78 insertions, 0 deletions
diff --git a/scripts/utils/eval_utils.py b/scripts/utils/eval_utils.py new file mode 100644 index 0000000..faab7ec --- /dev/null +++ b/scripts/utils/eval_utils.py @@ -0,0 +1,78 @@ +import os +import torch as th +from model.loci import Loci + +def setup_result_folders(file, name, set_test, evaluation_mode, object_view, individual_views): + + net_name = file.split('/')[-1].split('.')[0] + #root_path = file.split('nets')[0] + root_path = os.path.join(*file.split('/')[0:-1]) + root_path = os.path.join(root_path, f'results{name}', net_name, set_test['type']) + plot_path = os.path.join(root_path, evaluation_mode) + + # create directories + #if os.path.exists(plot_path): + # shutil.rmtree(plot_path) + os.makedirs(plot_path, exist_ok = True) + if object_view: + os.makedirs(os.path.join(plot_path, 'object'), exist_ok = True) + if individual_views: + os.makedirs(os.path.join(plot_path, 'individual'), exist_ok = True) + for group in ['error', 'input', 'background', 'prediction', 'position', 'rawmask', 'mask', 'othermask', 'imagination']: + os.makedirs(os.path.join(plot_path, 'individual', group), exist_ok = True) + os.makedirs(os.path.join(root_path, 'statistics'), exist_ok = True) + + # final directory + plot_path = plot_path + '/' + print(f"save plots to {plot_path}") + + return root_path, plot_path + +def store_statistics(memory, *args, extend=False): + for i,key in enumerate(memory.keys()): + if i >= len(args): + break + if extend: + memory[key].extend(args[i]) + else: + memory[key].append(args[i]) + return memory + +def append_statistics(memory1, memory2, ignore=[], extend=False): + for key in memory1: + if key not in ignore: + if extend: + memory2[key] = memory2[key] + memory1[key] + else: + memory2[key].append(memory1[key]) + return memory2 + +def load_model(cfg, cfg_net, file, device): + + net = Loci( + cfg_net, + teacher_forcing = cfg.defaults.teacher_forcing + ) + + # load model + if file != '': + print(f"load {file} to device {device}") + state = th.load(file, map_location=device) + + # backward compatibility + model = {} + for key, value in state["model"].items(): + # replace update_module with percept_gate_controller in key string: + key = key.replace("update_module", "percept_gate_controller") + model[key.replace(".module.", ".")] = value + + net.load_state_dict(model) + + # ??? + if net.get_init_status() < 1: + net.inc_init_level() + + # set network to evaluation mode + net = net.to(device=device) + + return net
\ No newline at end of file |