diff options
author | fredeee | 2024-03-23 13:27:00 +0100 |
---|---|---|
committer | fredeee | 2024-03-23 13:27:00 +0100 |
commit | 6bcf6b8306ce4903734fb31824799a50281cea69 (patch) | |
tree | 0545ff1b8beb051993c2d75fd81306db1a22274d /model/nn | |
parent | ad0b64a7f0140406151d18b19ab2ed5d19b6c511 (diff) |
add bouncingball experiment and ablation studies
Diffstat (limited to 'model/nn')
-rw-r--r-- | model/nn/background.py | 149 | ||||
-rw-r--r-- | model/nn/eprop_gate_l0rd.py | 8 | ||||
-rw-r--r-- | model/nn/eprop_transformer.py | 14 | ||||
-rw-r--r-- | model/nn/eprop_transformer_shared.py | 7 | ||||
-rw-r--r-- | model/nn/eprop_transformer_utils.py | 8 | ||||
-rw-r--r-- | model/nn/predictor.py | 22 |
6 files changed, 55 insertions, 153 deletions
diff --git a/model/nn/background.py b/model/nn/background.py index 38ec325..ff14e02 100644 --- a/model/nn/background.py +++ b/model/nn/background.py @@ -14,154 +14,21 @@ class BackgroundEnhancer(nn.Module): def __init__( self, input_size: Tuple[int, int], - img_channels: int, - level1_channels, - latent_channels, - gestalt_size, batch_size, - depth ): super(BackgroundEnhancer, self).__init__() - latent_size = [input_size[0] // 16, input_size[1] // 16] - self.input_size = input_size - - self.register_buffer('init', th.zeros(1).long()) - self.alpha = nn.Parameter(th.zeros(1)+1e-16) - - self.level = 1 - self.down_level2 = nn.Sequential( - PatchDownConv(img_channels*2+2, level1_channels, alpha = 1e-16), - *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for i in range(depth)] - ) - - self.down_level1 = nn.Sequential( - PatchDownConv(level1_channels, latent_channels, alpha = 1), - *[ResidualBlock(latent_channels, latent_channels, alpha_residual = False) for i in range(depth)] - ) - - self.down_level0 = nn.Sequential( - *[ResidualBlock(latent_channels, latent_channels) for i in range(depth)], - AggressiveConvToGestalt(latent_channels, gestalt_size, latent_size), - LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')), - Binarize(), - ) - - self.bias = nn.Parameter(th.zeros((1, gestalt_size, *latent_size))) - - self.to_grid = nn.Sequential( - LinearResidual(gestalt_size, gestalt_size, input_relu = False), - LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')), - LambdaModule(lambda x: x + self.bias), - *[ResidualBlock(gestalt_size, gestalt_size) for i in range(depth)], - ) - - - self.up_level0 = nn.Sequential( - ResidualBlock(gestalt_size, latent_channels), - *[ResidualBlock(latent_channels, latent_channels) for i in range(depth)], - ) - - self.up_level1 = nn.Sequential( - *[ResidualBlock(latent_channels, latent_channels, alpha_residual = False) for i in range(depth)], - PatchUpscale(latent_channels, level1_channels, alpha = 1), - ) - - self.up_level2 = nn.Sequential( - *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for i in range(depth)], - PatchUpscale(level1_channels, img_channels, alpha = 1e-16), - ) - - self.to_channels = nn.ModuleList([ - SkipConnection(img_channels*2+2, latent_channels), - SkipConnection(img_channels*2+2, level1_channels), - SkipConnection(img_channels*2+2, img_channels*2+2), - ]) - - self.to_img = nn.ModuleList([ - SkipConnection(latent_channels, img_channels), - SkipConnection(level1_channels, img_channels), - SkipConnection(img_channels, img_channels), - ]) - + self.batch_size = batch_size + self.height = input_size[0] + self.width = input_size[1] self.mask = nn.Parameter(th.ones(1, 1, *input_size) * 10) - self.object = nn.Parameter(th.ones(1, img_channels, *input_size)) - - self.register_buffer('latent', th.zeros((batch_size, gestalt_size)), persistent=False) + self.register_buffer('init', th.zeros(1).long()) def get_init(self): return self.init.item() - - def step_init(self): - self.init = self.init + 1 - - def detach(self): - self.latent = self.latent.detach() - - def reset_state(self): - self.latent = th.zeros_like(self.latent) - - def set_level(self, level): - self.level = level - - def encoder(self, input): - latent = self.to_channels[self.level](input) - - if self.level >= 2: - latent = self.down_level2(latent) - - if self.level >= 1: - latent = self.down_level1(latent) - - return self.down_level0(latent) - - def get_last_latent_gird(self): - return self.to_grid(self.latent) * self.alpha - - def decoder(self, latent, input): - grid = self.to_grid(latent) - latent = self.up_level0(grid) - - if self.level >= 1: - latent = self.up_level1(latent) - - if self.level >= 2: - latent = self.up_level2(latent) - - object = reduce(self.object, '1 c (h h2) (w w2) -> 1 c h w', 'mean', h = input.shape[2], w = input.shape[3]) - object = repeat(object, '1 c h w -> b c h w', b = input.shape[0]) - - return th.sigmoid(object + self.to_img[self.level](latent)), grid - - def forward(self, input: th.Tensor, error: th.Tensor = None, mask: th.Tensor = None, only_mask: bool = False): + + def forward(self, input: th.Tensor): - if only_mask: - mask = reduce(self.mask, '1 1 (h h2) (w w2) -> 1 1 h w', 'mean', h = input.shape[2], w = input.shape[3]) - mask = repeat(mask, '1 1 h w -> b 1 h w', b = input.shape[0]) * 0.1 - return mask - - last_bg = self.decoder(self.latent, input)[0] - - bg_error = th.sqrt(reduce((input - last_bg)**2, 'b c h w -> b 1 h w', 'mean')).detach() - bg_mask = (bg_error < th.mean(bg_error) + th.std(bg_error)).float().detach() - - if error is None or self.get_init() < 2: - error = bg_error - - if mask is None or self.get_init() < 2: - mask = bg_mask - - self.latent = self.encoder(th.cat((input, last_bg, error, mask), dim=1)) - mask = reduce(self.mask, '1 1 (h h2) (w w2) -> 1 1 h w', 'mean', h = input.shape[2], w = input.shape[3]) - mask = repeat(mask, '1 1 h w -> b 1 h w', b = input.shape[0]) * 0.1 - - background, grid = self.decoder(self.latent, input) - - if self.get_init() < 1: - return mask, background - - if self.get_init() < 2: - return mask, th.zeros_like(background), th.zeros_like(grid), background - - return mask, background, grid * self.alpha, background + mask = repeat(mask, '1 1 h w -> b 1 h w', b = self.batch_size) * 0.1 + return mask diff --git a/model/nn/eprop_gate_l0rd.py b/model/nn/eprop_gate_l0rd.py index 82a9895..c2473d1 100644 --- a/model/nn/eprop_gate_l0rd.py +++ b/model/nn/eprop_gate_l0rd.py @@ -49,7 +49,7 @@ class EpropGateL0rdFunction(Function): e_w_rx.clone(), e_w_rh.clone(), e_b_r.clone(), ) - return h, th.mean(H_g) + return h, H_g @staticmethod def backward(ctx, dh, _): @@ -169,6 +169,7 @@ class EpropGateL0rd(nn.Module): self.register_buffer("h_last", th.zeros(batch_size, num_hidden), persistent=False) self.register_buffer("openings", th.zeros(1), persistent=False) + self.register_buffer("openings_perslot", th.zeros(batch_size), persistent=False) # initialize weights stdv_ih = np.sqrt(6/(self.num_inputs + self.num_hidden)) @@ -294,7 +295,7 @@ class EpropGateL0rdShared(EpropGateL0rd): return o * p, h_last def eprop_forward(self, x: th.Tensor, h_last: th.Tensor): - h, openings = self.fcn( + h, H_g = self.fcn( x, h_last, self.w_gx, self.w_gh, self.b_g, self.w_rx, self.w_rh, self.b_r, @@ -305,7 +306,8 @@ class EpropGateL0rdShared(EpropGateL0rd): ) ) - self.openings = openings + self.openings = th.mean(H_g) + self.openings_perslot = th.mean(H_g, dim=1) p = th.tanh(x.mm(self.w_px.t()) + h.mm(self.w_ph.t()) + self.b_p) o = th.sigmoid(x.mm(self.w_ox.t()) + h.mm(self.w_oh.t()) + self.b_o) diff --git a/model/nn/eprop_transformer.py b/model/nn/eprop_transformer.py index 4d89ec4..d2341dd 100644 --- a/model/nn/eprop_transformer.py +++ b/model/nn/eprop_transformer.py @@ -35,13 +35,21 @@ class EpropGateL0rdTransformer(nn.Module): _layers.append(OutputEmbeding(num_hidden, num_outputs)) self.layers = nn.Sequential(*_layers) + self.attention = [] + self.l0rds = [] + for l in self.layers: + if 'AlphaAttention' in type(l).__name__: + self.attention.append(l) + elif 'EpropAlphaGateL0rd' in type(l).__name__: + self.l0rds.append(l) def get_openings(self): - openings = 0 + openings = [] for i in range(self.depth): - openings += self.layers[2 * (i + 1)].l0rd.openings.item() + openings.append(self.l0rds[i].l0rd.openings_perslot) - return openings / self.depth + openings = th.mean(th.stack(openings, dim=0), dim=0) + return openings def get_hidden(self): states = [] diff --git a/model/nn/eprop_transformer_shared.py b/model/nn/eprop_transformer_shared.py index 23a1b0f..79223c1 100644 --- a/model/nn/eprop_transformer_shared.py +++ b/model/nn/eprop_transformer_shared.py @@ -39,11 +39,12 @@ class EpropGateL0rdTransformerShared(nn.Module): self.output_embeding = OutputEmbeding(num_hidden, num_outputs) def get_openings(self): - openings = 0 + openings = [] for i in range(self.depth): - openings += self.l0rds[i].l0rd.openings.item() + openings.append(self.l0rds[i].l0rd.openings_perslot) - return openings / self.depth + openings = th.mean(th.stack(openings, dim=0), dim=0) + return openings def get_hidden(self): return self.hidden diff --git a/model/nn/eprop_transformer_utils.py b/model/nn/eprop_transformer_utils.py index 9e5e874..3219cd0 100644 --- a/model/nn/eprop_transformer_utils.py +++ b/model/nn/eprop_transformer_utils.py @@ -9,7 +9,8 @@ class AlphaAttention(nn.Module): num_hidden, num_objects, heads, - dropout = 0.0 + dropout = 0.0, + need_weights = False ): super(AlphaAttention, self).__init__() @@ -23,10 +24,13 @@ class AlphaAttention(nn.Module): dropout = dropout, batch_first = True ) + self.need_weights = need_weights + self.att_weights = None def forward(self, x: th.Tensor): x = self.to_sequence(x) - x = x + self.alpha * self.attention(x, x, x, need_weights=False)[0] + att, self.att_weights = self.attention(x, x, x, need_weights=self.need_weights) + x = x + self.alpha * att return self.to_batch(x) class InputEmbeding(nn.Module): diff --git a/model/nn/predictor.py b/model/nn/predictor.py index 94f13b8..5f08de9 100644 --- a/model/nn/predictor.py +++ b/model/nn/predictor.py @@ -61,7 +61,9 @@ class LociPredictor(nn.Module): self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o=num_objects)) def get_openings(self): - return self.predictor.get_openings() + openings = self.predictor.get_openings().detach() + openings = self.to_shared(openings[:, None]) + return openings def get_hidden(self): return self.predictor.get_hidden() @@ -69,6 +71,24 @@ class LociPredictor(nn.Module): def set_hidden(self, hidden): self.predictor.set_hidden(hidden) + def get_att_weights(self): + att_weights = [] + for layer in self.predictor.attention: + if layer.att_weights is None: + return [] + else: + att_weights.append(layer.att_weights) + att_weights = th.stack(att_weights) + return reduce(att_weights, 'l b o1 o2-> b o1 o2', 'mean') + + def enable_att_weights(self): + for layer in self.predictor.attention: + layer.need_weights = True + + def disable_att_weights(self): + for layer in self.predictor.attention: + layer.need_weights = False + def forward( self, gestalt: th.Tensor, |