diff options
author | fredeee | 2024-03-23 13:27:00 +0100 |
---|---|---|
committer | fredeee | 2024-03-23 13:27:00 +0100 |
commit | 6bcf6b8306ce4903734fb31824799a50281cea69 (patch) | |
tree | 0545ff1b8beb051993c2d75fd81306db1a22274d /model/loci.py | |
parent | ad0b64a7f0140406151d18b19ab2ed5d19b6c511 (diff) |
add bouncingball experiment and ablation studies
Diffstat (limited to 'model/loci.py')
-rw-r--r-- | model/loci.py | 31 |
1 files changed, 18 insertions, 13 deletions
diff --git a/model/loci.py b/model/loci.py index 96088bd..46f89f7 100644 --- a/model/loci.py +++ b/model/loci.py @@ -66,11 +66,6 @@ class Loci(nn.Module): self.background = BackgroundEnhancer( input_size = cfg.input_size, - gestalt_size = cfg.background.gestalt_size, - img_channels = cfg.img_channels, - depth = cfg.background.num_layers, - latent_channels = cfg.background.latent_channels, - level1_channels = cfg.background.level1_channels, batch_size = cfg.batch_size, ) @@ -93,11 +88,16 @@ class Loci(nn.Module): self.modulator = ObjectModulator(cfg.num_objects) self.linear_gate = LinearInterpolation(cfg.num_objects) - self.background.set_level(cfg.level) self.encoder.set_level(cfg.level) self.decoder.set_level(cfg.level) self.initial_states.set_level(cfg.level) + # add flag option to enable/disable latent loss + if 'latent_loss_enabled' in cfg: + self.latent_loss_enabled = cfg.latent_loss_enabled + else: + self.latent_loss_enabled = False + def get_init_status(self): init = [] for module in self.modules(): @@ -133,9 +133,6 @@ class Loci(nn.Module): if reset: self.reset_state() - if train_background or self.get_init_status() < 1: - return self.background(*input) - return self.run_end2end(*input, evaluate=evaluate, warmup=warmup, shuffleslots = shuffleslots, reset_mask = reset_mask, allow_spawn = allow_spawn, show_hidden = show_hidden, clean_slots = clean_slots) def run_decoder( @@ -196,11 +193,12 @@ class Loci(nn.Module): ): position_loss = th.tensor(0, device=input.device) time_loss = th.tensor(0, device=input.device) + latent_loss = th.tensor(0, device=input.device) bg_mask = None position_encoder = None if error_last is None or mask_last is None: - bg_mask = self.background(input, only_mask=True) + bg_mask = self.background(input) error_last = th.sqrt(reduce((input - background)**2, 'b c h w -> b 1 h w', 'mean')).detach() position_last, gestalt_last, priority_last, error_cur = self.initial_states( @@ -214,7 +212,7 @@ class Loci(nn.Module): object_last_unprioritized = self.decoder(position_last, gestalt_last)[-1] # background and bg_mask for the next time point - bg_mask = self.background(input, error_last, mask_last[:,-1:], only_mask=True) + bg_mask = self.background(input) # position and gestalt for the current time point position_cur, gestalt_cur, priority_cur = self.encoder(input, error_last, mask_last, object_last_unprioritized, position_last, rawmask_last) @@ -228,7 +226,7 @@ class Loci(nn.Module): slots_bounded, slots_bounded_smooth, slots_occluded_cur, slots_partially_occluded_cur, slots_fully_visible_cur, slots_occlusionfactor_cur = self.occlusion_tracker(mask_cur, rawmask_cur, reset_mask) # do not project into the future in the warmup phase - slots_closed = th.ones_like(repeat(slots_bounded, 'b o -> b o c', c=2)) + slots_closed = th.zeros_like(repeat(slots_bounded, 'b o -> b o c', c=2)) if warmup: position_next = position_cur gestalt_next = gestalt_cur @@ -239,6 +237,7 @@ class Loci(nn.Module): # update module slots_closed = (1-self.percept_gate_controller(position_cur, gestalt_cur, priority_cur, slots_occlusionfactor_cur, position_last, gestalt_last, priority_last, slots_occlusionfactor_last, self.position_last2, evaluate=evaluate)) + #slots_closed = repeat(rearrange(slots_occlusionfactor_cur, 'b (o c) -> b o c', c=1), 'b o 1 -> b o 2') position_cur = self.linear_gate(position_cur, position_last, slots_closed[:, :, 1]) priority_cur = self.linear_gate(priority_cur, priority_last, slots_closed[:, :, 1]) @@ -313,6 +312,10 @@ class Loci(nn.Module): position_next.detach(), ) + # add latent loss: current perception as target for last prediction + if self.cfg.inner_loop_enabled and self.latent_loss_enabled: + latent_loss = self.position_loss(position_cur.detach(), position_last, slots_bounded_smooth) + return ( output_next, output_cur, @@ -326,6 +329,8 @@ class Loci(nn.Module): slots_occlusionfactor_next, position_loss, time_loss, - slots_closed + latent_loss, + slots_closed, + slots_bounded ) |