1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
|
import os
import shutil
from einops import rearrange
import torch as th
from model.loci import Loci
def masks_to_boxes(masks: th.Tensor) -> th.Tensor:
"""
Compute the bounding boxes around the provided masks.
Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with
``0 <= x1 < x2`` and ``0 <= y1 < y2``.
Args:
masks (Tensor[N, H, W]): masks to transform where N is the number of masks
and (H, W) are the spatial dimensions.
Returns:
Tensor[N, 4]: bounding boxes
"""
if masks.numel() == 0:
return th.zeros((0, 4), device=masks.device, dtype=th.float)
n = masks.shape[0]
bounding_boxes = th.zeros((n, 4), device=masks.device, dtype=th.float)
for index, mask in enumerate(masks):
if mask.sum() > 0:
y, x = th.where(mask != 0)
bounding_boxes[index, 0] = th.min(x)
bounding_boxes[index, 1] = th.min(y)
bounding_boxes[index, 2] = th.max(x)
bounding_boxes[index, 3] = th.max(y)
return bounding_boxes
def boxes_to_centroids(boxes):
"""Post-process masks instead of directly taking argmax.
Args:
bboxes: [B, T, N, 4], 4: [x1, y1, x2, y2]
Returns:
centroids: [B, T, N, 2], 2: [x, y]
"""
centroids = (boxes[:, :, :, :2] + boxes[:, :, :, 2:]) / 2
centroids = centroids.squeeze(0)
# scale to [-1, 1]
centroids[:, :, 0] = centroids[:, :, 0] / 64 * 2 - 1
centroids[:, :, 1] = centroids[:, :, 1] / 64 * 2 - 1
return centroids
def compute_position_from_mask(mask):
"""
Compute the position of the object from the mask.
Args:
mask (Tensor[B, N, H, W]): masks to transform where N is the number of masks
and (H, W) are the spatial dimensions.
Returns:
Tensor[B, N, 2]: position of the object
"""
masks_binary = (mask > 0.8).float()[:, :-1]
b, o, h, w = masks_binary.shape
masks2 = rearrange(masks_binary, 'b o h w -> (b o) h w')
boxes = masks_to_boxes(masks2.long())
boxes = rearrange(boxes, '(b o) c -> b 1 o c', b=b, o=o)
centroids = boxes_to_centroids(boxes)
centroids = centroids[:, :, :, [1, 0]].squeeze(1)
return centroids
def setup_result_folders(file, name, set_test, evaluation_mode, object_view, individual_views):
net_name = file.split('/')[-1].split('.')[0]
#root_path = file.split('nets')[0]
root_path = os.path.join(*file.split('/')[0:-2])
root_path = os.path.join(root_path, f'results{name}', net_name, set_test['type'])
plot_path = os.path.join(root_path, evaluation_mode)
# create directories
if os.path.exists(plot_path):
shutil.rmtree(plot_path)
os.makedirs(plot_path, exist_ok = True)
if object_view:
os.makedirs(os.path.join(plot_path, 'object'), exist_ok = True)
if individual_views:
os.makedirs(os.path.join(plot_path, 'individual'), exist_ok = True)
for group in ['error', 'input', 'background', 'prediction', 'position', 'rawmask', 'mask', 'othermask', 'imagination']:
os.makedirs(os.path.join(plot_path, 'individual', group), exist_ok = True)
os.makedirs(os.path.join(root_path, 'statistics'), exist_ok = True)
# final directory
plot_path = plot_path + '/'
print(f"save plots to {plot_path}")
return root_path, plot_path
def store_statistics(memory, *args, extend=False):
for i,key in enumerate(memory.keys()):
if i >= len(args):
break
if extend:
memory[key].extend(args[i])
else:
memory[key].append(args[i])
return memory
def append_statistics(memory1, memory2, ignore=[], extend=False):
for key in memory1:
if key not in ignore:
if extend:
memory2[key] = memory2[key] + memory1[key]
else:
memory2[key].append(memory1[key])
return memory2
def load_model(cfg, cfg_net, file, device):
net = Loci(
cfg_net,
teacher_forcing = cfg.defaults.teacher_forcing
)
# load model
if file != '':
print(f"load {file} to device {device}")
state = th.load(file, map_location=device)
# 1. Get keys of current model while ensuring backward compatibility
model = {}
allowed_keys = []
rand_state = net.state_dict()
for key, value in rand_state.items():
allowed_keys.append(key)
# 2. Overwrite with values from file
for key, value in state["model"].items():
# replace update_module with percept_gate_controller in key string:
key = key.replace("update_module", "percept_gate_controller")
if key in allowed_keys:
model[key.replace(".module.", ".")] = value
net.load_state_dict(model)
# ???
if net.get_init_status() < 1:
net.inc_init_level()
# set network to evaluation mode
net = net.to(device=device)
return net
|