diff options
Diffstat (limited to 'model/nn/predictor.py')
-rw-r--r-- | model/nn/predictor.py | 22 |
1 files changed, 21 insertions, 1 deletions
diff --git a/model/nn/predictor.py b/model/nn/predictor.py index 94f13b8..5f08de9 100644 --- a/model/nn/predictor.py +++ b/model/nn/predictor.py @@ -61,7 +61,9 @@ class LociPredictor(nn.Module): self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o=num_objects)) def get_openings(self): - return self.predictor.get_openings() + openings = self.predictor.get_openings().detach() + openings = self.to_shared(openings[:, None]) + return openings def get_hidden(self): return self.predictor.get_hidden() @@ -69,6 +71,24 @@ class LociPredictor(nn.Module): def set_hidden(self, hidden): self.predictor.set_hidden(hidden) + def get_att_weights(self): + att_weights = [] + for layer in self.predictor.attention: + if layer.att_weights is None: + return [] + else: + att_weights.append(layer.att_weights) + att_weights = th.stack(att_weights) + return reduce(att_weights, 'l b o1 o2-> b o1 o2', 'mean') + + def enable_att_weights(self): + for layer in self.predictor.attention: + layer.need_weights = True + + def disable_att_weights(self): + for layer in self.predictor.attention: + layer.need_weights = False + def forward( self, gestalt: th.Tensor, |