aboutsummaryrefslogtreecommitdiff
path: root/model/nn/predictor.py
blob: 5f08de9defb30122fad5329af6fb3b095a5a3aa8 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import torch.nn as nn
import torch as th
from model.nn.eprop_transformer import EpropGateL0rdTransformer
from model.nn.eprop_transformer_shared import EpropGateL0rdTransformerShared
from model.utils.nn_utils import  LambdaModule, Binarize
from model.nn.residual import ResidualBlock
from einops import rearrange, repeat, reduce

__author__ = "Manuel Traub"
    
class LociPredictor(nn.Module): 
    def __init__(
        self, 
        heads: int, 
        layers: int,
        channels_multiplier: int,
        reg_lambda: float,
        num_objects: int, 
        gestalt_size: int, 
        batch_size: int,
        bottleneck: str,
        transformer_type = 'standard',
    ):
        super(LociPredictor, self).__init__()
        self.num_objects = num_objects
        self.std_alpha   = nn.Parameter(th.zeros(1)+1e-16)
        self.bottleneck_type = bottleneck
        self.gestalt_size = gestalt_size

        self.reg_lambda = reg_lambda
        Transformer = EpropGateL0rdTransformerShared if transformer_type == 'shared' else EpropGateL0rdTransformer
        self.predictor  = Transformer(
            channels    = gestalt_size + 3 + 1 + 2, 
            multiplier  = channels_multiplier,
            heads       = heads, 
            depth       = layers,
            num_objects = num_objects,
            reg_lambda  = reg_lambda, 
            batch_size  = batch_size,
        )

        if bottleneck == 'binar':
            print("Binary bottleneck")
            self.bottleneck = nn.Sequential(
                LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')),
                ResidualBlock(gestalt_size, gestalt_size, kernel_size=1),
                Binarize(),
                LambdaModule(lambda x: rearrange(x, '(b o) c 1 1 -> b (o c)', o=num_objects))
            )                 
            
        else:
            print("unrestricted bottleneck")
            self.bottleneck = nn.Sequential(
                LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')),
                ResidualBlock(gestalt_size, gestalt_size, kernel_size=1),
                nn.Sigmoid(),
                LambdaModule(lambda x: rearrange(x, '(b o) c 1 1 -> b (o c)', o=num_objects))
            )
                
        self.to_batch  = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o=num_objects))
        self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o=num_objects))

    def get_openings(self):
        openings = self.predictor.get_openings().detach()
        openings = self.to_shared(openings[:, None])
        return openings

    def get_hidden(self):
        return self.predictor.get_hidden()

    def set_hidden(self, hidden):
        self.predictor.set_hidden(hidden)

    def get_att_weights(self):
        att_weights = []
        for layer in self.predictor.attention:
            if layer.att_weights is None:
                return []
            else:
                att_weights.append(layer.att_weights)
        att_weights = th.stack(att_weights)
        return reduce(att_weights, 'l b o1 o2-> b o1 o2', 'mean')
    
    def enable_att_weights(self):
        for layer in self.predictor.attention:
            layer.need_weights = True

    def disable_att_weights(self):
        for layer in self.predictor.attention:
            layer.need_weights = False

    def forward(
        self, 
        gestalt: th.Tensor, 
        priority: th.Tensor,
        position: th.Tensor,
        slots_closed: th.Tensor,
    ):

        position        = self.to_batch(position)
        gestalt_cur     = self.to_batch(gestalt)
        priority        = self.to_batch(priority)
        slots_closed    = rearrange(slots_closed, 'b o c -> (b o) c').detach()

        input  = th.cat((gestalt_cur, priority, position, slots_closed), dim=1)
        output = self.predictor(input)

        gestalt  = output[:, :self.gestalt_size]
        priority = output[:,self.gestalt_size:(self.gestalt_size+1)]
        xy       = output[:,(self.gestalt_size+1):(self.gestalt_size+3)]
        std      = output[:,(self.gestalt_size+3):(self.gestalt_size+4)]

        position = th.cat((xy, std * self.std_alpha), dim=1)
        
        position = self.to_shared(position)
        gestalt  = self.bottleneck(gestalt)
        priority = self.to_shared(priority)

        return position, gestalt, priority