import torch.nn as nn import torch as th from model.nn.eprop_transformer import EpropGateL0rdTransformer from model.nn.eprop_transformer_shared import EpropGateL0rdTransformerShared from model.utils.nn_utils import LambdaModule, Binarize from model.nn.residual import ResidualBlock from einops import rearrange, repeat, reduce __author__ = "Manuel Traub" class LociPredictor(nn.Module): def __init__( self, heads: int, layers: int, channels_multiplier: int, reg_lambda: float, num_objects: int, gestalt_size: int, batch_size: int, bottleneck: str, transformer_type = 'standard', ): super(LociPredictor, self).__init__() self.num_objects = num_objects self.std_alpha = nn.Parameter(th.zeros(1)+1e-16) self.bottleneck_type = bottleneck self.gestalt_size = gestalt_size self.reg_lambda = reg_lambda Transformer = EpropGateL0rdTransformerShared if transformer_type == 'shared' else EpropGateL0rdTransformer self.predictor = Transformer( channels = gestalt_size + 3 + 1 + 2, multiplier = channels_multiplier, heads = heads, depth = layers, num_objects = num_objects, reg_lambda = reg_lambda, batch_size = batch_size, ) if bottleneck == 'binar': print("Binary bottleneck") self.bottleneck = nn.Sequential( LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')), ResidualBlock(gestalt_size, gestalt_size, kernel_size=1), Binarize(), LambdaModule(lambda x: rearrange(x, '(b o) c 1 1 -> b (o c)', o=num_objects)) ) else: print("unrestricted bottleneck") self.bottleneck = nn.Sequential( LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')), ResidualBlock(gestalt_size, gestalt_size, kernel_size=1), nn.Sigmoid(), LambdaModule(lambda x: rearrange(x, '(b o) c 1 1 -> b (o c)', o=num_objects)) ) self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o=num_objects)) self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o=num_objects)) def get_openings(self): openings = self.predictor.get_openings().detach() openings = self.to_shared(openings[:, None]) return openings def get_hidden(self): return self.predictor.get_hidden() def set_hidden(self, hidden): self.predictor.set_hidden(hidden) def get_att_weights(self): att_weights = [] for layer in self.predictor.attention: if layer.att_weights is None: return [] else: att_weights.append(layer.att_weights) att_weights = th.stack(att_weights) return reduce(att_weights, 'l b o1 o2-> b o1 o2', 'mean') def enable_att_weights(self): for layer in self.predictor.attention: layer.need_weights = True def disable_att_weights(self): for layer in self.predictor.attention: layer.need_weights = False def forward( self, gestalt: th.Tensor, priority: th.Tensor, position: th.Tensor, slots_closed: th.Tensor, ): position = self.to_batch(position) gestalt_cur = self.to_batch(gestalt) priority = self.to_batch(priority) slots_closed = rearrange(slots_closed, 'b o c -> (b o) c').detach() input = th.cat((gestalt_cur, priority, position, slots_closed), dim=1) output = self.predictor(input) gestalt = output[:, :self.gestalt_size] priority = output[:,self.gestalt_size:(self.gestalt_size+1)] xy = output[:,(self.gestalt_size+1):(self.gestalt_size+3)] std = output[:,(self.gestalt_size+3):(self.gestalt_size+4)] position = th.cat((xy, std * self.std_alpha), dim=1) position = self.to_shared(position) gestalt = self.bottleneck(gestalt) priority = self.to_shared(priority) return position, gestalt, priority