1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
|
import os
import torch as th
from model.loci import Loci
def setup_result_folders(file, name, set_test, evaluation_mode, object_view, individual_views):
net_name = file.split('/')[-1].split('.')[0]
#root_path = file.split('nets')[0]
root_path = os.path.join(*file.split('/')[0:-1])
root_path = os.path.join(root_path, f'results{name}', net_name, set_test['type'])
plot_path = os.path.join(root_path, evaluation_mode)
# create directories
#if os.path.exists(plot_path):
# shutil.rmtree(plot_path)
os.makedirs(plot_path, exist_ok = True)
if object_view:
os.makedirs(os.path.join(plot_path, 'object'), exist_ok = True)
if individual_views:
os.makedirs(os.path.join(plot_path, 'individual'), exist_ok = True)
for group in ['error', 'input', 'background', 'prediction', 'position', 'rawmask', 'mask', 'othermask', 'imagination']:
os.makedirs(os.path.join(plot_path, 'individual', group), exist_ok = True)
os.makedirs(os.path.join(root_path, 'statistics'), exist_ok = True)
# final directory
plot_path = plot_path + '/'
print(f"save plots to {plot_path}")
return root_path, plot_path
def store_statistics(memory, *args, extend=False):
for i,key in enumerate(memory.keys()):
if i >= len(args):
break
if extend:
memory[key].extend(args[i])
else:
memory[key].append(args[i])
return memory
def append_statistics(memory1, memory2, ignore=[], extend=False):
for key in memory1:
if key not in ignore:
if extend:
memory2[key] = memory2[key] + memory1[key]
else:
memory2[key].append(memory1[key])
return memory2
def load_model(cfg, cfg_net, file, device):
net = Loci(
cfg_net,
teacher_forcing = cfg.defaults.teacher_forcing
)
# load model
if file != '':
print(f"load {file} to device {device}")
state = th.load(file, map_location=device)
# backward compatibility
model = {}
for key, value in state["model"].items():
# replace update_module with percept_gate_controller in key string:
key = key.replace("update_module", "percept_gate_controller")
model[key.replace(".module.", ".")] = value
net.load_state_dict(model)
# ???
if net.get_init_status() < 1:
net.inc_init_level()
# set network to evaluation mode
net = net.to(device=device)
return net
|