aboutsummaryrefslogtreecommitdiff
path: root/model/nn/predictor.py
diff options
context:
space:
mode:
Diffstat (limited to 'model/nn/predictor.py')
-rw-r--r--model/nn/predictor.py22
1 files changed, 21 insertions, 1 deletions
diff --git a/model/nn/predictor.py b/model/nn/predictor.py
index 94f13b8..5f08de9 100644
--- a/model/nn/predictor.py
+++ b/model/nn/predictor.py
@@ -61,7 +61,9 @@ class LociPredictor(nn.Module):
self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o=num_objects))
def get_openings(self):
- return self.predictor.get_openings()
+ openings = self.predictor.get_openings().detach()
+ openings = self.to_shared(openings[:, None])
+ return openings
def get_hidden(self):
return self.predictor.get_hidden()
@@ -69,6 +71,24 @@ class LociPredictor(nn.Module):
def set_hidden(self, hidden):
self.predictor.set_hidden(hidden)
+ def get_att_weights(self):
+ att_weights = []
+ for layer in self.predictor.attention:
+ if layer.att_weights is None:
+ return []
+ else:
+ att_weights.append(layer.att_weights)
+ att_weights = th.stack(att_weights)
+ return reduce(att_weights, 'l b o1 o2-> b o1 o2', 'mean')
+
+ def enable_att_weights(self):
+ for layer in self.predictor.attention:
+ layer.need_weights = True
+
+ def disable_att_weights(self):
+ for layer in self.predictor.attention:
+ layer.need_weights = False
+
def forward(
self,
gestalt: th.Tensor,