diff options
author | fredeee | 2024-03-23 13:27:00 +0100 |
---|---|---|
committer | fredeee | 2024-03-23 13:27:00 +0100 |
commit | 6bcf6b8306ce4903734fb31824799a50281cea69 (patch) | |
tree | 0545ff1b8beb051993c2d75fd81306db1a22274d /model | |
parent | ad0b64a7f0140406151d18b19ab2ed5d19b6c511 (diff) |
add bouncingball experiment and ablation studies
Diffstat (limited to 'model')
-rw-r--r-- | model/loci.py | 31 | ||||
-rw-r--r-- | model/nn/background.py | 149 | ||||
-rw-r--r-- | model/nn/eprop_gate_l0rd.py | 8 | ||||
-rw-r--r-- | model/nn/eprop_transformer.py | 14 | ||||
-rw-r--r-- | model/nn/eprop_transformer_shared.py | 7 | ||||
-rw-r--r-- | model/nn/eprop_transformer_utils.py | 8 | ||||
-rw-r--r-- | model/nn/predictor.py | 22 |
7 files changed, 73 insertions, 166 deletions
diff --git a/model/loci.py b/model/loci.py index 96088bd..46f89f7 100644 --- a/model/loci.py +++ b/model/loci.py @@ -66,11 +66,6 @@ class Loci(nn.Module): self.background = BackgroundEnhancer( input_size = cfg.input_size, - gestalt_size = cfg.background.gestalt_size, - img_channels = cfg.img_channels, - depth = cfg.background.num_layers, - latent_channels = cfg.background.latent_channels, - level1_channels = cfg.background.level1_channels, batch_size = cfg.batch_size, ) @@ -93,11 +88,16 @@ class Loci(nn.Module): self.modulator = ObjectModulator(cfg.num_objects) self.linear_gate = LinearInterpolation(cfg.num_objects) - self.background.set_level(cfg.level) self.encoder.set_level(cfg.level) self.decoder.set_level(cfg.level) self.initial_states.set_level(cfg.level) + # add flag option to enable/disable latent loss + if 'latent_loss_enabled' in cfg: + self.latent_loss_enabled = cfg.latent_loss_enabled + else: + self.latent_loss_enabled = False + def get_init_status(self): init = [] for module in self.modules(): @@ -133,9 +133,6 @@ class Loci(nn.Module): if reset: self.reset_state() - if train_background or self.get_init_status() < 1: - return self.background(*input) - return self.run_end2end(*input, evaluate=evaluate, warmup=warmup, shuffleslots = shuffleslots, reset_mask = reset_mask, allow_spawn = allow_spawn, show_hidden = show_hidden, clean_slots = clean_slots) def run_decoder( @@ -196,11 +193,12 @@ class Loci(nn.Module): ): position_loss = th.tensor(0, device=input.device) time_loss = th.tensor(0, device=input.device) + latent_loss = th.tensor(0, device=input.device) bg_mask = None position_encoder = None if error_last is None or mask_last is None: - bg_mask = self.background(input, only_mask=True) + bg_mask = self.background(input) error_last = th.sqrt(reduce((input - background)**2, 'b c h w -> b 1 h w', 'mean')).detach() position_last, gestalt_last, priority_last, error_cur = self.initial_states( @@ -214,7 +212,7 @@ class Loci(nn.Module): object_last_unprioritized = self.decoder(position_last, gestalt_last)[-1] # background and bg_mask for the next time point - bg_mask = self.background(input, error_last, mask_last[:,-1:], only_mask=True) + bg_mask = self.background(input) # position and gestalt for the current time point position_cur, gestalt_cur, priority_cur = self.encoder(input, error_last, mask_last, object_last_unprioritized, position_last, rawmask_last) @@ -228,7 +226,7 @@ class Loci(nn.Module): slots_bounded, slots_bounded_smooth, slots_occluded_cur, slots_partially_occluded_cur, slots_fully_visible_cur, slots_occlusionfactor_cur = self.occlusion_tracker(mask_cur, rawmask_cur, reset_mask) # do not project into the future in the warmup phase - slots_closed = th.ones_like(repeat(slots_bounded, 'b o -> b o c', c=2)) + slots_closed = th.zeros_like(repeat(slots_bounded, 'b o -> b o c', c=2)) if warmup: position_next = position_cur gestalt_next = gestalt_cur @@ -239,6 +237,7 @@ class Loci(nn.Module): # update module slots_closed = (1-self.percept_gate_controller(position_cur, gestalt_cur, priority_cur, slots_occlusionfactor_cur, position_last, gestalt_last, priority_last, slots_occlusionfactor_last, self.position_last2, evaluate=evaluate)) + #slots_closed = repeat(rearrange(slots_occlusionfactor_cur, 'b (o c) -> b o c', c=1), 'b o 1 -> b o 2') position_cur = self.linear_gate(position_cur, position_last, slots_closed[:, :, 1]) priority_cur = self.linear_gate(priority_cur, priority_last, slots_closed[:, :, 1]) @@ -313,6 +312,10 @@ class Loci(nn.Module): position_next.detach(), ) + # add latent loss: current perception as target for last prediction + if self.cfg.inner_loop_enabled and self.latent_loss_enabled: + latent_loss = self.position_loss(position_cur.detach(), position_last, slots_bounded_smooth) + return ( output_next, output_cur, @@ -326,6 +329,8 @@ class Loci(nn.Module): slots_occlusionfactor_next, position_loss, time_loss, - slots_closed + latent_loss, + slots_closed, + slots_bounded ) diff --git a/model/nn/background.py b/model/nn/background.py index 38ec325..ff14e02 100644 --- a/model/nn/background.py +++ b/model/nn/background.py @@ -14,154 +14,21 @@ class BackgroundEnhancer(nn.Module): def __init__( self, input_size: Tuple[int, int], - img_channels: int, - level1_channels, - latent_channels, - gestalt_size, batch_size, - depth ): super(BackgroundEnhancer, self).__init__() - latent_size = [input_size[0] // 16, input_size[1] // 16] - self.input_size = input_size - - self.register_buffer('init', th.zeros(1).long()) - self.alpha = nn.Parameter(th.zeros(1)+1e-16) - - self.level = 1 - self.down_level2 = nn.Sequential( - PatchDownConv(img_channels*2+2, level1_channels, alpha = 1e-16), - *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for i in range(depth)] - ) - - self.down_level1 = nn.Sequential( - PatchDownConv(level1_channels, latent_channels, alpha = 1), - *[ResidualBlock(latent_channels, latent_channels, alpha_residual = False) for i in range(depth)] - ) - - self.down_level0 = nn.Sequential( - *[ResidualBlock(latent_channels, latent_channels) for i in range(depth)], - AggressiveConvToGestalt(latent_channels, gestalt_size, latent_size), - LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')), - Binarize(), - ) - - self.bias = nn.Parameter(th.zeros((1, gestalt_size, *latent_size))) - - self.to_grid = nn.Sequential( - LinearResidual(gestalt_size, gestalt_size, input_relu = False), - LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')), - LambdaModule(lambda x: x + self.bias), - *[ResidualBlock(gestalt_size, gestalt_size) for i in range(depth)], - ) - - - self.up_level0 = nn.Sequential( - ResidualBlock(gestalt_size, latent_channels), - *[ResidualBlock(latent_channels, latent_channels) for i in range(depth)], - ) - - self.up_level1 = nn.Sequential( - *[ResidualBlock(latent_channels, latent_channels, alpha_residual = False) for i in range(depth)], - PatchUpscale(latent_channels, level1_channels, alpha = 1), - ) - - self.up_level2 = nn.Sequential( - *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for i in range(depth)], - PatchUpscale(level1_channels, img_channels, alpha = 1e-16), - ) - - self.to_channels = nn.ModuleList([ - SkipConnection(img_channels*2+2, latent_channels), - SkipConnection(img_channels*2+2, level1_channels), - SkipConnection(img_channels*2+2, img_channels*2+2), - ]) - - self.to_img = nn.ModuleList([ - SkipConnection(latent_channels, img_channels), - SkipConnection(level1_channels, img_channels), - SkipConnection(img_channels, img_channels), - ]) - + self.batch_size = batch_size + self.height = input_size[0] + self.width = input_size[1] self.mask = nn.Parameter(th.ones(1, 1, *input_size) * 10) - self.object = nn.Parameter(th.ones(1, img_channels, *input_size)) - - self.register_buffer('latent', th.zeros((batch_size, gestalt_size)), persistent=False) + self.register_buffer('init', th.zeros(1).long()) def get_init(self): return self.init.item() - - def step_init(self): - self.init = self.init + 1 - - def detach(self): - self.latent = self.latent.detach() - - def reset_state(self): - self.latent = th.zeros_like(self.latent) - - def set_level(self, level): - self.level = level - - def encoder(self, input): - latent = self.to_channels[self.level](input) - - if self.level >= 2: - latent = self.down_level2(latent) - - if self.level >= 1: - latent = self.down_level1(latent) - - return self.down_level0(latent) - - def get_last_latent_gird(self): - return self.to_grid(self.latent) * self.alpha - - def decoder(self, latent, input): - grid = self.to_grid(latent) - latent = self.up_level0(grid) - - if self.level >= 1: - latent = self.up_level1(latent) - - if self.level >= 2: - latent = self.up_level2(latent) - - object = reduce(self.object, '1 c (h h2) (w w2) -> 1 c h w', 'mean', h = input.shape[2], w = input.shape[3]) - object = repeat(object, '1 c h w -> b c h w', b = input.shape[0]) - - return th.sigmoid(object + self.to_img[self.level](latent)), grid - - def forward(self, input: th.Tensor, error: th.Tensor = None, mask: th.Tensor = None, only_mask: bool = False): + + def forward(self, input: th.Tensor): - if only_mask: - mask = reduce(self.mask, '1 1 (h h2) (w w2) -> 1 1 h w', 'mean', h = input.shape[2], w = input.shape[3]) - mask = repeat(mask, '1 1 h w -> b 1 h w', b = input.shape[0]) * 0.1 - return mask - - last_bg = self.decoder(self.latent, input)[0] - - bg_error = th.sqrt(reduce((input - last_bg)**2, 'b c h w -> b 1 h w', 'mean')).detach() - bg_mask = (bg_error < th.mean(bg_error) + th.std(bg_error)).float().detach() - - if error is None or self.get_init() < 2: - error = bg_error - - if mask is None or self.get_init() < 2: - mask = bg_mask - - self.latent = self.encoder(th.cat((input, last_bg, error, mask), dim=1)) - mask = reduce(self.mask, '1 1 (h h2) (w w2) -> 1 1 h w', 'mean', h = input.shape[2], w = input.shape[3]) - mask = repeat(mask, '1 1 h w -> b 1 h w', b = input.shape[0]) * 0.1 - - background, grid = self.decoder(self.latent, input) - - if self.get_init() < 1: - return mask, background - - if self.get_init() < 2: - return mask, th.zeros_like(background), th.zeros_like(grid), background - - return mask, background, grid * self.alpha, background + mask = repeat(mask, '1 1 h w -> b 1 h w', b = self.batch_size) * 0.1 + return mask diff --git a/model/nn/eprop_gate_l0rd.py b/model/nn/eprop_gate_l0rd.py index 82a9895..c2473d1 100644 --- a/model/nn/eprop_gate_l0rd.py +++ b/model/nn/eprop_gate_l0rd.py @@ -49,7 +49,7 @@ class EpropGateL0rdFunction(Function): e_w_rx.clone(), e_w_rh.clone(), e_b_r.clone(), ) - return h, th.mean(H_g) + return h, H_g @staticmethod def backward(ctx, dh, _): @@ -169,6 +169,7 @@ class EpropGateL0rd(nn.Module): self.register_buffer("h_last", th.zeros(batch_size, num_hidden), persistent=False) self.register_buffer("openings", th.zeros(1), persistent=False) + self.register_buffer("openings_perslot", th.zeros(batch_size), persistent=False) # initialize weights stdv_ih = np.sqrt(6/(self.num_inputs + self.num_hidden)) @@ -294,7 +295,7 @@ class EpropGateL0rdShared(EpropGateL0rd): return o * p, h_last def eprop_forward(self, x: th.Tensor, h_last: th.Tensor): - h, openings = self.fcn( + h, H_g = self.fcn( x, h_last, self.w_gx, self.w_gh, self.b_g, self.w_rx, self.w_rh, self.b_r, @@ -305,7 +306,8 @@ class EpropGateL0rdShared(EpropGateL0rd): ) ) - self.openings = openings + self.openings = th.mean(H_g) + self.openings_perslot = th.mean(H_g, dim=1) p = th.tanh(x.mm(self.w_px.t()) + h.mm(self.w_ph.t()) + self.b_p) o = th.sigmoid(x.mm(self.w_ox.t()) + h.mm(self.w_oh.t()) + self.b_o) diff --git a/model/nn/eprop_transformer.py b/model/nn/eprop_transformer.py index 4d89ec4..d2341dd 100644 --- a/model/nn/eprop_transformer.py +++ b/model/nn/eprop_transformer.py @@ -35,13 +35,21 @@ class EpropGateL0rdTransformer(nn.Module): _layers.append(OutputEmbeding(num_hidden, num_outputs)) self.layers = nn.Sequential(*_layers) + self.attention = [] + self.l0rds = [] + for l in self.layers: + if 'AlphaAttention' in type(l).__name__: + self.attention.append(l) + elif 'EpropAlphaGateL0rd' in type(l).__name__: + self.l0rds.append(l) def get_openings(self): - openings = 0 + openings = [] for i in range(self.depth): - openings += self.layers[2 * (i + 1)].l0rd.openings.item() + openings.append(self.l0rds[i].l0rd.openings_perslot) - return openings / self.depth + openings = th.mean(th.stack(openings, dim=0), dim=0) + return openings def get_hidden(self): states = [] diff --git a/model/nn/eprop_transformer_shared.py b/model/nn/eprop_transformer_shared.py index 23a1b0f..79223c1 100644 --- a/model/nn/eprop_transformer_shared.py +++ b/model/nn/eprop_transformer_shared.py @@ -39,11 +39,12 @@ class EpropGateL0rdTransformerShared(nn.Module): self.output_embeding = OutputEmbeding(num_hidden, num_outputs) def get_openings(self): - openings = 0 + openings = [] for i in range(self.depth): - openings += self.l0rds[i].l0rd.openings.item() + openings.append(self.l0rds[i].l0rd.openings_perslot) - return openings / self.depth + openings = th.mean(th.stack(openings, dim=0), dim=0) + return openings def get_hidden(self): return self.hidden diff --git a/model/nn/eprop_transformer_utils.py b/model/nn/eprop_transformer_utils.py index 9e5e874..3219cd0 100644 --- a/model/nn/eprop_transformer_utils.py +++ b/model/nn/eprop_transformer_utils.py @@ -9,7 +9,8 @@ class AlphaAttention(nn.Module): num_hidden, num_objects, heads, - dropout = 0.0 + dropout = 0.0, + need_weights = False ): super(AlphaAttention, self).__init__() @@ -23,10 +24,13 @@ class AlphaAttention(nn.Module): dropout = dropout, batch_first = True ) + self.need_weights = need_weights + self.att_weights = None def forward(self, x: th.Tensor): x = self.to_sequence(x) - x = x + self.alpha * self.attention(x, x, x, need_weights=False)[0] + att, self.att_weights = self.attention(x, x, x, need_weights=self.need_weights) + x = x + self.alpha * att return self.to_batch(x) class InputEmbeding(nn.Module): diff --git a/model/nn/predictor.py b/model/nn/predictor.py index 94f13b8..5f08de9 100644 --- a/model/nn/predictor.py +++ b/model/nn/predictor.py @@ -61,7 +61,9 @@ class LociPredictor(nn.Module): self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o=num_objects)) def get_openings(self): - return self.predictor.get_openings() + openings = self.predictor.get_openings().detach() + openings = self.to_shared(openings[:, None]) + return openings def get_hidden(self): return self.predictor.get_hidden() @@ -69,6 +71,24 @@ class LociPredictor(nn.Module): def set_hidden(self, hidden): self.predictor.set_hidden(hidden) + def get_att_weights(self): + att_weights = [] + for layer in self.predictor.attention: + if layer.att_weights is None: + return [] + else: + att_weights.append(layer.att_weights) + att_weights = th.stack(att_weights) + return reduce(att_weights, 'l b o1 o2-> b o1 o2', 'mean') + + def enable_att_weights(self): + for layer in self.predictor.attention: + layer.need_weights = True + + def disable_att_weights(self): + for layer in self.predictor.attention: + layer.need_weights = False + def forward( self, gestalt: th.Tensor, |