1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
|
import torch.nn as nn
import torch as th
from model.nn.eprop_transformer import EpropGateL0rdTransformer
from model.nn.eprop_transformer_shared import EpropGateL0rdTransformerShared
from model.utils.nn_utils import LambdaModule, Binarize
from model.nn.residual import ResidualBlock
from einops import rearrange, repeat, reduce
__author__ = "Manuel Traub"
class LociPredictor(nn.Module):
def __init__(
self,
heads: int,
layers: int,
channels_multiplier: int,
reg_lambda: float,
num_objects: int,
gestalt_size: int,
batch_size: int,
bottleneck: str,
transformer_type = 'standard',
):
super(LociPredictor, self).__init__()
self.num_objects = num_objects
self.std_alpha = nn.Parameter(th.zeros(1)+1e-16)
self.bottleneck_type = bottleneck
self.gestalt_size = gestalt_size
self.reg_lambda = reg_lambda
Transformer = EpropGateL0rdTransformerShared if transformer_type == 'shared' else EpropGateL0rdTransformer
self.predictor = Transformer(
channels = gestalt_size + 3 + 1 + 2,
multiplier = channels_multiplier,
heads = heads,
depth = layers,
num_objects = num_objects,
reg_lambda = reg_lambda,
batch_size = batch_size,
)
if bottleneck == 'binar':
print("Binary bottleneck")
self.bottleneck = nn.Sequential(
LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')),
ResidualBlock(gestalt_size, gestalt_size, kernel_size=1),
Binarize(),
LambdaModule(lambda x: rearrange(x, '(b o) c 1 1 -> b (o c)', o=num_objects))
)
else:
print("unrestricted bottleneck")
self.bottleneck = nn.Sequential(
LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')),
ResidualBlock(gestalt_size, gestalt_size, kernel_size=1),
nn.Sigmoid(),
LambdaModule(lambda x: rearrange(x, '(b o) c 1 1 -> b (o c)', o=num_objects))
)
self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o=num_objects))
self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o=num_objects))
def get_openings(self):
openings = self.predictor.get_openings().detach()
openings = self.to_shared(openings[:, None])
return openings
def get_hidden(self):
return self.predictor.get_hidden()
def set_hidden(self, hidden):
self.predictor.set_hidden(hidden)
def get_att_weights(self):
att_weights = []
for layer in self.predictor.attention:
if layer.att_weights is None:
return []
else:
att_weights.append(layer.att_weights)
att_weights = th.stack(att_weights)
return reduce(att_weights, 'l b o1 o2-> b o1 o2', 'mean')
def enable_att_weights(self):
for layer in self.predictor.attention:
layer.need_weights = True
def disable_att_weights(self):
for layer in self.predictor.attention:
layer.need_weights = False
def forward(
self,
gestalt: th.Tensor,
priority: th.Tensor,
position: th.Tensor,
slots_closed: th.Tensor,
):
position = self.to_batch(position)
gestalt_cur = self.to_batch(gestalt)
priority = self.to_batch(priority)
slots_closed = rearrange(slots_closed, 'b o c -> (b o) c').detach()
input = th.cat((gestalt_cur, priority, position, slots_closed), dim=1)
output = self.predictor(input)
gestalt = output[:, :self.gestalt_size]
priority = output[:,self.gestalt_size:(self.gestalt_size+1)]
xy = output[:,(self.gestalt_size+1):(self.gestalt_size+3)]
std = output[:,(self.gestalt_size+3):(self.gestalt_size+4)]
position = th.cat((xy, std * self.std_alpha), dim=1)
position = self.to_shared(position)
gestalt = self.bottleneck(gestalt)
priority = self.to_shared(priority)
return position, gestalt, priority
|