aboutsummaryrefslogtreecommitdiff
path: root/model
diff options
context:
space:
mode:
authorfredeee2024-03-23 13:27:00 +0100
committerfredeee2024-03-23 13:27:00 +0100
commit6bcf6b8306ce4903734fb31824799a50281cea69 (patch)
tree0545ff1b8beb051993c2d75fd81306db1a22274d /model
parentad0b64a7f0140406151d18b19ab2ed5d19b6c511 (diff)
add bouncingball experiment and ablation studies
Diffstat (limited to 'model')
-rw-r--r--model/loci.py31
-rw-r--r--model/nn/background.py149
-rw-r--r--model/nn/eprop_gate_l0rd.py8
-rw-r--r--model/nn/eprop_transformer.py14
-rw-r--r--model/nn/eprop_transformer_shared.py7
-rw-r--r--model/nn/eprop_transformer_utils.py8
-rw-r--r--model/nn/predictor.py22
7 files changed, 73 insertions, 166 deletions
diff --git a/model/loci.py b/model/loci.py
index 96088bd..46f89f7 100644
--- a/model/loci.py
+++ b/model/loci.py
@@ -66,11 +66,6 @@ class Loci(nn.Module):
self.background = BackgroundEnhancer(
input_size = cfg.input_size,
- gestalt_size = cfg.background.gestalt_size,
- img_channels = cfg.img_channels,
- depth = cfg.background.num_layers,
- latent_channels = cfg.background.latent_channels,
- level1_channels = cfg.background.level1_channels,
batch_size = cfg.batch_size,
)
@@ -93,11 +88,16 @@ class Loci(nn.Module):
self.modulator = ObjectModulator(cfg.num_objects)
self.linear_gate = LinearInterpolation(cfg.num_objects)
- self.background.set_level(cfg.level)
self.encoder.set_level(cfg.level)
self.decoder.set_level(cfg.level)
self.initial_states.set_level(cfg.level)
+ # add flag option to enable/disable latent loss
+ if 'latent_loss_enabled' in cfg:
+ self.latent_loss_enabled = cfg.latent_loss_enabled
+ else:
+ self.latent_loss_enabled = False
+
def get_init_status(self):
init = []
for module in self.modules():
@@ -133,9 +133,6 @@ class Loci(nn.Module):
if reset:
self.reset_state()
- if train_background or self.get_init_status() < 1:
- return self.background(*input)
-
return self.run_end2end(*input, evaluate=evaluate, warmup=warmup, shuffleslots = shuffleslots, reset_mask = reset_mask, allow_spawn = allow_spawn, show_hidden = show_hidden, clean_slots = clean_slots)
def run_decoder(
@@ -196,11 +193,12 @@ class Loci(nn.Module):
):
position_loss = th.tensor(0, device=input.device)
time_loss = th.tensor(0, device=input.device)
+ latent_loss = th.tensor(0, device=input.device)
bg_mask = None
position_encoder = None
if error_last is None or mask_last is None:
- bg_mask = self.background(input, only_mask=True)
+ bg_mask = self.background(input)
error_last = th.sqrt(reduce((input - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
position_last, gestalt_last, priority_last, error_cur = self.initial_states(
@@ -214,7 +212,7 @@ class Loci(nn.Module):
object_last_unprioritized = self.decoder(position_last, gestalt_last)[-1]
# background and bg_mask for the next time point
- bg_mask = self.background(input, error_last, mask_last[:,-1:], only_mask=True)
+ bg_mask = self.background(input)
# position and gestalt for the current time point
position_cur, gestalt_cur, priority_cur = self.encoder(input, error_last, mask_last, object_last_unprioritized, position_last, rawmask_last)
@@ -228,7 +226,7 @@ class Loci(nn.Module):
slots_bounded, slots_bounded_smooth, slots_occluded_cur, slots_partially_occluded_cur, slots_fully_visible_cur, slots_occlusionfactor_cur = self.occlusion_tracker(mask_cur, rawmask_cur, reset_mask)
# do not project into the future in the warmup phase
- slots_closed = th.ones_like(repeat(slots_bounded, 'b o -> b o c', c=2))
+ slots_closed = th.zeros_like(repeat(slots_bounded, 'b o -> b o c', c=2))
if warmup:
position_next = position_cur
gestalt_next = gestalt_cur
@@ -239,6 +237,7 @@ class Loci(nn.Module):
# update module
slots_closed = (1-self.percept_gate_controller(position_cur, gestalt_cur, priority_cur, slots_occlusionfactor_cur, position_last, gestalt_last, priority_last, slots_occlusionfactor_last, self.position_last2, evaluate=evaluate))
+ #slots_closed = repeat(rearrange(slots_occlusionfactor_cur, 'b (o c) -> b o c', c=1), 'b o 1 -> b o 2')
position_cur = self.linear_gate(position_cur, position_last, slots_closed[:, :, 1])
priority_cur = self.linear_gate(priority_cur, priority_last, slots_closed[:, :, 1])
@@ -313,6 +312,10 @@ class Loci(nn.Module):
position_next.detach(),
)
+ # add latent loss: current perception as target for last prediction
+ if self.cfg.inner_loop_enabled and self.latent_loss_enabled:
+ latent_loss = self.position_loss(position_cur.detach(), position_last, slots_bounded_smooth)
+
return (
output_next,
output_cur,
@@ -326,6 +329,8 @@ class Loci(nn.Module):
slots_occlusionfactor_next,
position_loss,
time_loss,
- slots_closed
+ latent_loss,
+ slots_closed,
+ slots_bounded
)
diff --git a/model/nn/background.py b/model/nn/background.py
index 38ec325..ff14e02 100644
--- a/model/nn/background.py
+++ b/model/nn/background.py
@@ -14,154 +14,21 @@ class BackgroundEnhancer(nn.Module):
def __init__(
self,
input_size: Tuple[int, int],
- img_channels: int,
- level1_channels,
- latent_channels,
- gestalt_size,
batch_size,
- depth
):
super(BackgroundEnhancer, self).__init__()
- latent_size = [input_size[0] // 16, input_size[1] // 16]
- self.input_size = input_size
-
- self.register_buffer('init', th.zeros(1).long())
- self.alpha = nn.Parameter(th.zeros(1)+1e-16)
-
- self.level = 1
- self.down_level2 = nn.Sequential(
- PatchDownConv(img_channels*2+2, level1_channels, alpha = 1e-16),
- *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for i in range(depth)]
- )
-
- self.down_level1 = nn.Sequential(
- PatchDownConv(level1_channels, latent_channels, alpha = 1),
- *[ResidualBlock(latent_channels, latent_channels, alpha_residual = False) for i in range(depth)]
- )
-
- self.down_level0 = nn.Sequential(
- *[ResidualBlock(latent_channels, latent_channels) for i in range(depth)],
- AggressiveConvToGestalt(latent_channels, gestalt_size, latent_size),
- LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')),
- Binarize(),
- )
-
- self.bias = nn.Parameter(th.zeros((1, gestalt_size, *latent_size)))
-
- self.to_grid = nn.Sequential(
- LinearResidual(gestalt_size, gestalt_size, input_relu = False),
- LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')),
- LambdaModule(lambda x: x + self.bias),
- *[ResidualBlock(gestalt_size, gestalt_size) for i in range(depth)],
- )
-
-
- self.up_level0 = nn.Sequential(
- ResidualBlock(gestalt_size, latent_channels),
- *[ResidualBlock(latent_channels, latent_channels) for i in range(depth)],
- )
-
- self.up_level1 = nn.Sequential(
- *[ResidualBlock(latent_channels, latent_channels, alpha_residual = False) for i in range(depth)],
- PatchUpscale(latent_channels, level1_channels, alpha = 1),
- )
-
- self.up_level2 = nn.Sequential(
- *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for i in range(depth)],
- PatchUpscale(level1_channels, img_channels, alpha = 1e-16),
- )
-
- self.to_channels = nn.ModuleList([
- SkipConnection(img_channels*2+2, latent_channels),
- SkipConnection(img_channels*2+2, level1_channels),
- SkipConnection(img_channels*2+2, img_channels*2+2),
- ])
-
- self.to_img = nn.ModuleList([
- SkipConnection(latent_channels, img_channels),
- SkipConnection(level1_channels, img_channels),
- SkipConnection(img_channels, img_channels),
- ])
-
+ self.batch_size = batch_size
+ self.height = input_size[0]
+ self.width = input_size[1]
self.mask = nn.Parameter(th.ones(1, 1, *input_size) * 10)
- self.object = nn.Parameter(th.ones(1, img_channels, *input_size))
-
- self.register_buffer('latent', th.zeros((batch_size, gestalt_size)), persistent=False)
+ self.register_buffer('init', th.zeros(1).long())
def get_init(self):
return self.init.item()
-
- def step_init(self):
- self.init = self.init + 1
-
- def detach(self):
- self.latent = self.latent.detach()
-
- def reset_state(self):
- self.latent = th.zeros_like(self.latent)
-
- def set_level(self, level):
- self.level = level
-
- def encoder(self, input):
- latent = self.to_channels[self.level](input)
-
- if self.level >= 2:
- latent = self.down_level2(latent)
-
- if self.level >= 1:
- latent = self.down_level1(latent)
-
- return self.down_level0(latent)
-
- def get_last_latent_gird(self):
- return self.to_grid(self.latent) * self.alpha
-
- def decoder(self, latent, input):
- grid = self.to_grid(latent)
- latent = self.up_level0(grid)
-
- if self.level >= 1:
- latent = self.up_level1(latent)
-
- if self.level >= 2:
- latent = self.up_level2(latent)
-
- object = reduce(self.object, '1 c (h h2) (w w2) -> 1 c h w', 'mean', h = input.shape[2], w = input.shape[3])
- object = repeat(object, '1 c h w -> b c h w', b = input.shape[0])
-
- return th.sigmoid(object + self.to_img[self.level](latent)), grid
-
- def forward(self, input: th.Tensor, error: th.Tensor = None, mask: th.Tensor = None, only_mask: bool = False):
+
+ def forward(self, input: th.Tensor):
- if only_mask:
- mask = reduce(self.mask, '1 1 (h h2) (w w2) -> 1 1 h w', 'mean', h = input.shape[2], w = input.shape[3])
- mask = repeat(mask, '1 1 h w -> b 1 h w', b = input.shape[0]) * 0.1
- return mask
-
- last_bg = self.decoder(self.latent, input)[0]
-
- bg_error = th.sqrt(reduce((input - last_bg)**2, 'b c h w -> b 1 h w', 'mean')).detach()
- bg_mask = (bg_error < th.mean(bg_error) + th.std(bg_error)).float().detach()
-
- if error is None or self.get_init() < 2:
- error = bg_error
-
- if mask is None or self.get_init() < 2:
- mask = bg_mask
-
- self.latent = self.encoder(th.cat((input, last_bg, error, mask), dim=1))
-
mask = reduce(self.mask, '1 1 (h h2) (w w2) -> 1 1 h w', 'mean', h = input.shape[2], w = input.shape[3])
- mask = repeat(mask, '1 1 h w -> b 1 h w', b = input.shape[0]) * 0.1
-
- background, grid = self.decoder(self.latent, input)
-
- if self.get_init() < 1:
- return mask, background
-
- if self.get_init() < 2:
- return mask, th.zeros_like(background), th.zeros_like(grid), background
-
- return mask, background, grid * self.alpha, background
+ mask = repeat(mask, '1 1 h w -> b 1 h w', b = self.batch_size) * 0.1
+ return mask
diff --git a/model/nn/eprop_gate_l0rd.py b/model/nn/eprop_gate_l0rd.py
index 82a9895..c2473d1 100644
--- a/model/nn/eprop_gate_l0rd.py
+++ b/model/nn/eprop_gate_l0rd.py
@@ -49,7 +49,7 @@ class EpropGateL0rdFunction(Function):
e_w_rx.clone(), e_w_rh.clone(), e_b_r.clone(),
)
- return h, th.mean(H_g)
+ return h, H_g
@staticmethod
def backward(ctx, dh, _):
@@ -169,6 +169,7 @@ class EpropGateL0rd(nn.Module):
self.register_buffer("h_last", th.zeros(batch_size, num_hidden), persistent=False)
self.register_buffer("openings", th.zeros(1), persistent=False)
+ self.register_buffer("openings_perslot", th.zeros(batch_size), persistent=False)
# initialize weights
stdv_ih = np.sqrt(6/(self.num_inputs + self.num_hidden))
@@ -294,7 +295,7 @@ class EpropGateL0rdShared(EpropGateL0rd):
return o * p, h_last
def eprop_forward(self, x: th.Tensor, h_last: th.Tensor):
- h, openings = self.fcn(
+ h, H_g = self.fcn(
x, h_last,
self.w_gx, self.w_gh, self.b_g,
self.w_rx, self.w_rh, self.b_r,
@@ -305,7 +306,8 @@ class EpropGateL0rdShared(EpropGateL0rd):
)
)
- self.openings = openings
+ self.openings = th.mean(H_g)
+ self.openings_perslot = th.mean(H_g, dim=1)
p = th.tanh(x.mm(self.w_px.t()) + h.mm(self.w_ph.t()) + self.b_p)
o = th.sigmoid(x.mm(self.w_ox.t()) + h.mm(self.w_oh.t()) + self.b_o)
diff --git a/model/nn/eprop_transformer.py b/model/nn/eprop_transformer.py
index 4d89ec4..d2341dd 100644
--- a/model/nn/eprop_transformer.py
+++ b/model/nn/eprop_transformer.py
@@ -35,13 +35,21 @@ class EpropGateL0rdTransformer(nn.Module):
_layers.append(OutputEmbeding(num_hidden, num_outputs))
self.layers = nn.Sequential(*_layers)
+ self.attention = []
+ self.l0rds = []
+ for l in self.layers:
+ if 'AlphaAttention' in type(l).__name__:
+ self.attention.append(l)
+ elif 'EpropAlphaGateL0rd' in type(l).__name__:
+ self.l0rds.append(l)
def get_openings(self):
- openings = 0
+ openings = []
for i in range(self.depth):
- openings += self.layers[2 * (i + 1)].l0rd.openings.item()
+ openings.append(self.l0rds[i].l0rd.openings_perslot)
- return openings / self.depth
+ openings = th.mean(th.stack(openings, dim=0), dim=0)
+ return openings
def get_hidden(self):
states = []
diff --git a/model/nn/eprop_transformer_shared.py b/model/nn/eprop_transformer_shared.py
index 23a1b0f..79223c1 100644
--- a/model/nn/eprop_transformer_shared.py
+++ b/model/nn/eprop_transformer_shared.py
@@ -39,11 +39,12 @@ class EpropGateL0rdTransformerShared(nn.Module):
self.output_embeding = OutputEmbeding(num_hidden, num_outputs)
def get_openings(self):
- openings = 0
+ openings = []
for i in range(self.depth):
- openings += self.l0rds[i].l0rd.openings.item()
+ openings.append(self.l0rds[i].l0rd.openings_perslot)
- return openings / self.depth
+ openings = th.mean(th.stack(openings, dim=0), dim=0)
+ return openings
def get_hidden(self):
return self.hidden
diff --git a/model/nn/eprop_transformer_utils.py b/model/nn/eprop_transformer_utils.py
index 9e5e874..3219cd0 100644
--- a/model/nn/eprop_transformer_utils.py
+++ b/model/nn/eprop_transformer_utils.py
@@ -9,7 +9,8 @@ class AlphaAttention(nn.Module):
num_hidden,
num_objects,
heads,
- dropout = 0.0
+ dropout = 0.0,
+ need_weights = False
):
super(AlphaAttention, self).__init__()
@@ -23,10 +24,13 @@ class AlphaAttention(nn.Module):
dropout = dropout,
batch_first = True
)
+ self.need_weights = need_weights
+ self.att_weights = None
def forward(self, x: th.Tensor):
x = self.to_sequence(x)
- x = x + self.alpha * self.attention(x, x, x, need_weights=False)[0]
+ att, self.att_weights = self.attention(x, x, x, need_weights=self.need_weights)
+ x = x + self.alpha * att
return self.to_batch(x)
class InputEmbeding(nn.Module):
diff --git a/model/nn/predictor.py b/model/nn/predictor.py
index 94f13b8..5f08de9 100644
--- a/model/nn/predictor.py
+++ b/model/nn/predictor.py
@@ -61,7 +61,9 @@ class LociPredictor(nn.Module):
self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o=num_objects))
def get_openings(self):
- return self.predictor.get_openings()
+ openings = self.predictor.get_openings().detach()
+ openings = self.to_shared(openings[:, None])
+ return openings
def get_hidden(self):
return self.predictor.get_hidden()
@@ -69,6 +71,24 @@ class LociPredictor(nn.Module):
def set_hidden(self, hidden):
self.predictor.set_hidden(hidden)
+ def get_att_weights(self):
+ att_weights = []
+ for layer in self.predictor.attention:
+ if layer.att_weights is None:
+ return []
+ else:
+ att_weights.append(layer.att_weights)
+ att_weights = th.stack(att_weights)
+ return reduce(att_weights, 'l b o1 o2-> b o1 o2', 'mean')
+
+ def enable_att_weights(self):
+ for layer in self.predictor.attention:
+ layer.need_weights = True
+
+ def disable_att_weights(self):
+ for layer in self.predictor.attention:
+ layer.need_weights = False
+
def forward(
self,
gestalt: th.Tensor,