aboutsummaryrefslogtreecommitdiff
path: root/model/loci.py
diff options
context:
space:
mode:
authorfredeee2024-03-23 13:27:00 +0100
committerfredeee2024-03-23 13:27:00 +0100
commit6bcf6b8306ce4903734fb31824799a50281cea69 (patch)
tree0545ff1b8beb051993c2d75fd81306db1a22274d /model/loci.py
parentad0b64a7f0140406151d18b19ab2ed5d19b6c511 (diff)
add bouncingball experiment and ablation studies
Diffstat (limited to 'model/loci.py')
-rw-r--r--model/loci.py31
1 files changed, 18 insertions, 13 deletions
diff --git a/model/loci.py b/model/loci.py
index 96088bd..46f89f7 100644
--- a/model/loci.py
+++ b/model/loci.py
@@ -66,11 +66,6 @@ class Loci(nn.Module):
self.background = BackgroundEnhancer(
input_size = cfg.input_size,
- gestalt_size = cfg.background.gestalt_size,
- img_channels = cfg.img_channels,
- depth = cfg.background.num_layers,
- latent_channels = cfg.background.latent_channels,
- level1_channels = cfg.background.level1_channels,
batch_size = cfg.batch_size,
)
@@ -93,11 +88,16 @@ class Loci(nn.Module):
self.modulator = ObjectModulator(cfg.num_objects)
self.linear_gate = LinearInterpolation(cfg.num_objects)
- self.background.set_level(cfg.level)
self.encoder.set_level(cfg.level)
self.decoder.set_level(cfg.level)
self.initial_states.set_level(cfg.level)
+ # add flag option to enable/disable latent loss
+ if 'latent_loss_enabled' in cfg:
+ self.latent_loss_enabled = cfg.latent_loss_enabled
+ else:
+ self.latent_loss_enabled = False
+
def get_init_status(self):
init = []
for module in self.modules():
@@ -133,9 +133,6 @@ class Loci(nn.Module):
if reset:
self.reset_state()
- if train_background or self.get_init_status() < 1:
- return self.background(*input)
-
return self.run_end2end(*input, evaluate=evaluate, warmup=warmup, shuffleslots = shuffleslots, reset_mask = reset_mask, allow_spawn = allow_spawn, show_hidden = show_hidden, clean_slots = clean_slots)
def run_decoder(
@@ -196,11 +193,12 @@ class Loci(nn.Module):
):
position_loss = th.tensor(0, device=input.device)
time_loss = th.tensor(0, device=input.device)
+ latent_loss = th.tensor(0, device=input.device)
bg_mask = None
position_encoder = None
if error_last is None or mask_last is None:
- bg_mask = self.background(input, only_mask=True)
+ bg_mask = self.background(input)
error_last = th.sqrt(reduce((input - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
position_last, gestalt_last, priority_last, error_cur = self.initial_states(
@@ -214,7 +212,7 @@ class Loci(nn.Module):
object_last_unprioritized = self.decoder(position_last, gestalt_last)[-1]
# background and bg_mask for the next time point
- bg_mask = self.background(input, error_last, mask_last[:,-1:], only_mask=True)
+ bg_mask = self.background(input)
# position and gestalt for the current time point
position_cur, gestalt_cur, priority_cur = self.encoder(input, error_last, mask_last, object_last_unprioritized, position_last, rawmask_last)
@@ -228,7 +226,7 @@ class Loci(nn.Module):
slots_bounded, slots_bounded_smooth, slots_occluded_cur, slots_partially_occluded_cur, slots_fully_visible_cur, slots_occlusionfactor_cur = self.occlusion_tracker(mask_cur, rawmask_cur, reset_mask)
# do not project into the future in the warmup phase
- slots_closed = th.ones_like(repeat(slots_bounded, 'b o -> b o c', c=2))
+ slots_closed = th.zeros_like(repeat(slots_bounded, 'b o -> b o c', c=2))
if warmup:
position_next = position_cur
gestalt_next = gestalt_cur
@@ -239,6 +237,7 @@ class Loci(nn.Module):
# update module
slots_closed = (1-self.percept_gate_controller(position_cur, gestalt_cur, priority_cur, slots_occlusionfactor_cur, position_last, gestalt_last, priority_last, slots_occlusionfactor_last, self.position_last2, evaluate=evaluate))
+ #slots_closed = repeat(rearrange(slots_occlusionfactor_cur, 'b (o c) -> b o c', c=1), 'b o 1 -> b o 2')
position_cur = self.linear_gate(position_cur, position_last, slots_closed[:, :, 1])
priority_cur = self.linear_gate(priority_cur, priority_last, slots_closed[:, :, 1])
@@ -313,6 +312,10 @@ class Loci(nn.Module):
position_next.detach(),
)
+ # add latent loss: current perception as target for last prediction
+ if self.cfg.inner_loop_enabled and self.latent_loss_enabled:
+ latent_loss = self.position_loss(position_cur.detach(), position_last, slots_bounded_smooth)
+
return (
output_next,
output_cur,
@@ -326,6 +329,8 @@ class Loci(nn.Module):
slots_occlusionfactor_next,
position_loss,
time_loss,
- slots_closed
+ latent_loss,
+ slots_closed,
+ slots_bounded
)