aboutsummaryrefslogtreecommitdiff
path: root/model
diff options
context:
space:
mode:
authorfredeee2023-11-02 10:47:21 +0100
committerfredeee2023-11-02 10:47:21 +0100
commitf8302ee886ef9b631f11a52900dac964a61350e1 (patch)
tree87288be6f851ab69405e524b81940c501c52789a /model
parentf16fef1ab9371e1c81a2e0b2fbea59dee285a9f8 (diff)
initiaƶ commit
Diffstat (limited to 'model')
-rw-r--r--model/loci.py331
-rw-r--r--model/nn/background.py167
-rw-r--r--model/nn/decoder.py169
-rw-r--r--model/nn/encoder.py269
-rw-r--r--model/nn/eprop_gate_l0rd.py329
-rw-r--r--model/nn/eprop_transformer.py76
-rw-r--r--model/nn/eprop_transformer_shared.py92
-rw-r--r--model/nn/eprop_transformer_utils.py66
-rw-r--r--model/nn/percept_gate_controller.py59
-rw-r--r--model/nn/predictor.py99
-rw-r--r--model/nn/residual.py396
-rw-r--r--model/utils/loss.py97
-rw-r--r--model/utils/nn_utils.py298
-rw-r--r--model/utils/slot_utils.py326
14 files changed, 2774 insertions, 0 deletions
diff --git a/model/loci.py b/model/loci.py
new file mode 100644
index 0000000..96088bd
--- /dev/null
+++ b/model/loci.py
@@ -0,0 +1,331 @@
+import torch as th
+import torch.nn as nn
+from einops import rearrange, repeat, reduce
+from model.nn.percept_gate_controller import PerceptGateController
+from model.nn.decoder import LociDecoder
+from model.nn.encoder import LociEncoder
+from model.nn.predictor import LociPredictor
+from model.nn.background import BackgroundEnhancer
+from model.utils.nn_utils import LinearInterpolation
+from model.utils.loss import ObjectModulator, TranslationInvariantObjectLoss, PositionLoss
+from model.utils.slot_utils import OcclusionTracker, InitialLatentStates, compute_rawmask
+
+class Loci(nn.Module):
+ def __init__(
+ self,
+ cfg,
+ teacher_forcing=1,
+ ):
+ super(Loci, self).__init__()
+
+ self.teacher_forcing = teacher_forcing
+ self.cfg = cfg
+
+ self.encoder = LociEncoder(
+ input_size = cfg.input_size,
+ latent_size = cfg.latent_size,
+ num_objects = cfg.num_objects,
+ img_channels = cfg.img_channels * 2 + 6,
+ hidden_channels = cfg.encoder.channels,
+ level1_channels = cfg.encoder.level1_channels,
+ num_layers = cfg.encoder.num_layers,
+ gestalt_size = cfg.gestalt_size,
+ bottleneck = cfg.bottleneck,
+ )
+
+ self.percept_gate_controller = PerceptGateController(
+ num_inputs = 2*(cfg.gestalt_size + 3 + 1 + 1) + 3,
+ num_hidden = [32, 16],
+ bias = True,
+ num_objects = cfg.num_objects,
+ reg_lambda=cfg.update_module.reg_lambda,
+ )
+
+ self.predictor = LociPredictor(
+ num_objects = cfg.num_objects,
+ gestalt_size = cfg.gestalt_size,
+ bottleneck = cfg.bottleneck,
+ channels_multiplier = cfg.predictor.channels_multiplier,
+ heads = cfg.predictor.heads,
+ layers = cfg.predictor.layers,
+ reg_lambda = cfg.predictor.reg_lambda,
+ batch_size = cfg.batch_size,
+ transformer_type = cfg.predictor.transformer_type,
+ )
+
+ self.decoder = LociDecoder(
+ latent_size = cfg.latent_size,
+ num_objects = cfg.num_objects,
+ gestalt_size = cfg.gestalt_size,
+ img_channels = cfg.img_channels,
+ hidden_channels = cfg.decoder.channels,
+ level1_channels = cfg.decoder.level1_channels,
+ num_layers = cfg.decoder.num_layers,
+ batch_size = cfg.batch_size,
+ )
+
+ self.background = BackgroundEnhancer(
+ input_size = cfg.input_size,
+ gestalt_size = cfg.background.gestalt_size,
+ img_channels = cfg.img_channels,
+ depth = cfg.background.num_layers,
+ latent_channels = cfg.background.latent_channels,
+ level1_channels = cfg.background.level1_channels,
+ batch_size = cfg.batch_size,
+ )
+
+ self.initial_states = InitialLatentStates(
+ gestalt_size = cfg.gestalt_size,
+ bottleneck = cfg.bottleneck,
+ num_objects = cfg.num_objects,
+ size = cfg.input_size,
+ teacher_forcing = teacher_forcing
+ )
+
+ self.occlusion_tracker = OcclusionTracker(
+ batch_size=cfg.batch_size,
+ num_objects=cfg.num_objects,
+ device=cfg.device
+ )
+
+ self.translation_invariant_object_loss = TranslationInvariantObjectLoss(cfg.num_objects)
+ self.position_loss = PositionLoss(cfg.num_objects)
+ self.modulator = ObjectModulator(cfg.num_objects)
+ self.linear_gate = LinearInterpolation(cfg.num_objects)
+
+ self.background.set_level(cfg.level)
+ self.encoder.set_level(cfg.level)
+ self.decoder.set_level(cfg.level)
+ self.initial_states.set_level(cfg.level)
+
+ def get_init_status(self):
+ init = []
+ for module in self.modules():
+ if callable(getattr(module, "get_init", None)):
+ init.append(module.get_init())
+
+ assert len(set(init)) == 1
+ return init[0]
+
+ def inc_init_level(self):
+ for module in self.modules():
+ if callable(getattr(module, "step_init", None)):
+ module.step_init()
+
+ def get_openings(self):
+ return self.predictor.get_openings()
+
+ def detach(self):
+ for module in self.modules():
+ if module != self and callable(getattr(module, "detach", None)):
+ module.detach()
+
+ def reset_state(self):
+ for module in self.modules():
+ if module != self and callable(getattr(module, "reset_state", None)):
+ module.reset_state()
+
+ def forward(self, *input, reset=True, detach=True, mode='end2end', evaluate=False, train_background=False, warmup=False, shuffleslots = True, reset_mask = False, allow_spawn = True, show_hidden = False, clean_slots = False):
+
+ if detach:
+ self.detach()
+
+ if reset:
+ self.reset_state()
+
+ if train_background or self.get_init_status() < 1:
+ return self.background(*input)
+
+ return self.run_end2end(*input, evaluate=evaluate, warmup=warmup, shuffleslots = shuffleslots, reset_mask = reset_mask, allow_spawn = allow_spawn, show_hidden = show_hidden, clean_slots = clean_slots)
+
+ def run_decoder(
+ self,
+ position: th.Tensor,
+ gestalt: th.Tensor,
+ priority: th.Tensor,
+ bg_mask: th.Tensor,
+ background: th.Tensor,
+ only_mask: bool = False
+ ):
+ mask, object = self.decoder(position, gestalt, priority)
+
+ rawmask = compute_rawmask(mask, bg_mask)
+ mask = th.softmax(th.cat((mask, bg_mask), dim=1), dim=1)
+
+ if only_mask:
+ return mask, rawmask
+
+ object = th.cat((th.sigmoid(object - 2.5), background), dim=1)
+ _mask = mask.unsqueeze(dim=2)
+ _object = object.view(
+ mask.shape[0],
+ self.cfg.num_objects + 1,
+ self.cfg.img_channels,
+ *mask.shape[2:]
+ )
+
+ output = th.sum(_mask * _object, dim=1)
+ return output, mask, object, rawmask
+
+ def reset_unassigned_slots(self, position, gestalt, priority):
+
+ position = self.linear_gate(position, th.zeros_like(position)-1, self.initial_states.get_slots_unassigned())
+ gestalt = self.linear_gate(gestalt, th.zeros_like(gestalt), self.initial_states.get_slots_unassigned())
+ priority = self.linear_gate(priority, th.zeros_like(priority), self.initial_states.get_slots_unassigned())
+
+ return position, gestalt, priority
+
+ def run_end2end(
+ self,
+ input: th.Tensor,
+ error_last: th.Tensor = None,
+ mask_last: th.Tensor = None,
+ rawmask_last: th.Tensor = None,
+ position_last: th.Tensor = None,
+ gestalt_last: th.Tensor = None,
+ priority_last: th.Tensor = None,
+ background = None,
+ slots_occlusionfactor_last: th.Tensor = None,
+ evaluate = False,
+ warmup = False,
+ shuffleslots = True,
+ reset_mask = False,
+ allow_spawn = True,
+ show_hidden = False,
+ clean_slots = False
+ ):
+ position_loss = th.tensor(0, device=input.device)
+ time_loss = th.tensor(0, device=input.device)
+ bg_mask = None
+ position_encoder = None
+
+ if error_last is None or mask_last is None:
+ bg_mask = self.background(input, only_mask=True)
+ error_last = th.sqrt(reduce((input - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+
+ position_last, gestalt_last, priority_last, error_cur = self.initial_states(
+ error_last, mask_last, position_last, gestalt_last, priority_last, shuffleslots, self.occlusion_tracker.slots_bounded_all, slots_occlusionfactor_last, allow_spawn=allow_spawn, clean_slots=clean_slots
+ )
+ # only use assigned slots
+ position_last, gestalt_last, priority_last = self.reset_unassigned_slots(position_last, gestalt_last, priority_last)
+
+ if mask_last is None:
+ mask_last, rawmask_last = self.run_decoder(position_last, gestalt_last, priority_last, bg_mask, background, only_mask=True)
+ object_last_unprioritized = self.decoder(position_last, gestalt_last)[-1]
+
+ # background and bg_mask for the next time point
+ bg_mask = self.background(input, error_last, mask_last[:,-1:], only_mask=True)
+
+ # position and gestalt for the current time point
+ position_cur, gestalt_cur, priority_cur = self.encoder(input, error_last, mask_last, object_last_unprioritized, position_last, rawmask_last)
+ if evaluate:
+ position_encoder = position_cur.clone().detach()
+
+ # only use assigned slots
+ position_cur, gestalt_cur, priority_cur = self.reset_unassigned_slots(position_cur, gestalt_cur, priority_cur)
+
+ output_cur, mask_cur, object_cur, rawmask_cur = self.run_decoder(position_cur, gestalt_cur, priority_cur, bg_mask, background)
+ slots_bounded, slots_bounded_smooth, slots_occluded_cur, slots_partially_occluded_cur, slots_fully_visible_cur, slots_occlusionfactor_cur = self.occlusion_tracker(mask_cur, rawmask_cur, reset_mask)
+
+ # do not project into the future in the warmup phase
+ slots_closed = th.ones_like(repeat(slots_bounded, 'b o -> b o c', c=2))
+ if warmup:
+ position_next = position_cur
+ gestalt_next = gestalt_cur
+ priority_next = priority_cur
+ else:
+
+ if self.cfg.inner_loop_enabled:
+
+ # update module
+ slots_closed = (1-self.percept_gate_controller(position_cur, gestalt_cur, priority_cur, slots_occlusionfactor_cur, position_last, gestalt_last, priority_last, slots_occlusionfactor_last, self.position_last2, evaluate=evaluate))
+
+ position_cur = self.linear_gate(position_cur, position_last, slots_closed[:, :, 1])
+ priority_cur = self.linear_gate(priority_cur, priority_last, slots_closed[:, :, 1])
+ gestalt_cur = self.linear_gate(gestalt_cur, gestalt_last, slots_closed[:, :, 0])
+
+ # position and gestalt for the next time point
+ position_next, gestalt_next, priority_next = self.predictor(gestalt_cur, priority_cur, position_cur, slots_closed)
+
+ # combinded background and objects (masks) for next timepoint
+ self.position_last2 = position_last.clone().detach()
+ output_next, mask_next, object_next, rawmask_next = self.run_decoder(position_next, gestalt_next, priority_next, bg_mask, background)
+ slots_bounded_next, slots_bounded_smooth_next, slots_occluded_next, slots_partially_occluded_next, slots_fully_visible_next, slots_occlusionfactor_next = self.occlusion_tracker(mask_next, rawmask_next, update=False)
+
+ if evaluate:
+
+ if show_hidden:
+ pos_next = rearrange(position_next.clone(), '1 (o c) -> o c', c=3)
+ largest_object = th.argmax(pos_next[:, 2], dim=0)
+ pos_next[largest_object] = th.tensor([2, 2, 0.001])
+ pos_next = rearrange(pos_next, 'o c -> 1 (o c)')
+ output_hidden, _, object_hidden, rawmask_hidden = self.run_decoder(pos_next, gestalt_next, priority_next, bg_mask, background)
+ else:
+ output_hidden = None
+ largest_object = None
+ rawmask_hidden = None
+ object_hidden = None
+
+ return (
+ output_next,
+ position_next,
+ gestalt_next,
+ priority_next,
+ mask_next,
+ rawmask_next,
+ object_next,
+ background,
+ slots_occlusionfactor_next,
+ output_cur,
+ position_cur,
+ gestalt_cur,
+ priority_cur,
+ mask_cur,
+ rawmask_cur,
+ object_cur,
+ position_encoder,
+ slots_bounded,
+ slots_partially_occluded_cur,
+ slots_occluded_cur,
+ slots_partially_occluded_next,
+ slots_occluded_next,
+ slots_closed,
+ output_hidden,
+ largest_object,
+ rawmask_hidden,
+ object_hidden
+ )
+
+ else:
+
+ if not warmup:
+
+ #regularize to small possition chananges over time
+ position_loss = position_loss + self.position_loss(position_next, position_last.detach(), slots_bounded_smooth)
+
+ # regularize to produce consistent object codes over time
+ object_next_unprioritized = self.decoder(position_next, gestalt_next)[-1]
+ time_loss = time_loss + self.translation_invariant_object_loss(
+ slots_bounded_smooth,
+ object_last_unprioritized.detach(),
+ position_last.detach(),
+ object_next_unprioritized,
+ position_next.detach(),
+ )
+
+ return (
+ output_next,
+ output_cur,
+ position_next,
+ gestalt_next,
+ priority_next,
+ mask_next,
+ rawmask_next,
+ object_next,
+ background,
+ slots_occlusionfactor_next,
+ position_loss,
+ time_loss,
+ slots_closed
+ )
+
diff --git a/model/nn/background.py b/model/nn/background.py
new file mode 100644
index 0000000..38ec325
--- /dev/null
+++ b/model/nn/background.py
@@ -0,0 +1,167 @@
+import torch.nn as nn
+import torch as th
+from model.nn.residual import ResidualBlock, SkipConnection, LinearResidual
+from model.nn.encoder import PatchDownConv
+from model.nn.encoder import AggressiveConvToGestalt
+from model.nn.decoder import PatchUpscale
+from model.utils.nn_utils import LambdaModule, Binarize
+from einops import rearrange, repeat, reduce
+from typing import Tuple
+
+__author__ = "Manuel Traub"
+
+class BackgroundEnhancer(nn.Module):
+ def __init__(
+ self,
+ input_size: Tuple[int, int],
+ img_channels: int,
+ level1_channels,
+ latent_channels,
+ gestalt_size,
+ batch_size,
+ depth
+ ):
+ super(BackgroundEnhancer, self).__init__()
+
+ latent_size = [input_size[0] // 16, input_size[1] // 16]
+ self.input_size = input_size
+
+ self.register_buffer('init', th.zeros(1).long())
+ self.alpha = nn.Parameter(th.zeros(1)+1e-16)
+
+ self.level = 1
+ self.down_level2 = nn.Sequential(
+ PatchDownConv(img_channels*2+2, level1_channels, alpha = 1e-16),
+ *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for i in range(depth)]
+ )
+
+ self.down_level1 = nn.Sequential(
+ PatchDownConv(level1_channels, latent_channels, alpha = 1),
+ *[ResidualBlock(latent_channels, latent_channels, alpha_residual = False) for i in range(depth)]
+ )
+
+ self.down_level0 = nn.Sequential(
+ *[ResidualBlock(latent_channels, latent_channels) for i in range(depth)],
+ AggressiveConvToGestalt(latent_channels, gestalt_size, latent_size),
+ LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')),
+ Binarize(),
+ )
+
+ self.bias = nn.Parameter(th.zeros((1, gestalt_size, *latent_size)))
+
+ self.to_grid = nn.Sequential(
+ LinearResidual(gestalt_size, gestalt_size, input_relu = False),
+ LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')),
+ LambdaModule(lambda x: x + self.bias),
+ *[ResidualBlock(gestalt_size, gestalt_size) for i in range(depth)],
+ )
+
+
+ self.up_level0 = nn.Sequential(
+ ResidualBlock(gestalt_size, latent_channels),
+ *[ResidualBlock(latent_channels, latent_channels) for i in range(depth)],
+ )
+
+ self.up_level1 = nn.Sequential(
+ *[ResidualBlock(latent_channels, latent_channels, alpha_residual = False) for i in range(depth)],
+ PatchUpscale(latent_channels, level1_channels, alpha = 1),
+ )
+
+ self.up_level2 = nn.Sequential(
+ *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for i in range(depth)],
+ PatchUpscale(level1_channels, img_channels, alpha = 1e-16),
+ )
+
+ self.to_channels = nn.ModuleList([
+ SkipConnection(img_channels*2+2, latent_channels),
+ SkipConnection(img_channels*2+2, level1_channels),
+ SkipConnection(img_channels*2+2, img_channels*2+2),
+ ])
+
+ self.to_img = nn.ModuleList([
+ SkipConnection(latent_channels, img_channels),
+ SkipConnection(level1_channels, img_channels),
+ SkipConnection(img_channels, img_channels),
+ ])
+
+ self.mask = nn.Parameter(th.ones(1, 1, *input_size) * 10)
+ self.object = nn.Parameter(th.ones(1, img_channels, *input_size))
+
+ self.register_buffer('latent', th.zeros((batch_size, gestalt_size)), persistent=False)
+
+ def get_init(self):
+ return self.init.item()
+
+ def step_init(self):
+ self.init = self.init + 1
+
+ def detach(self):
+ self.latent = self.latent.detach()
+
+ def reset_state(self):
+ self.latent = th.zeros_like(self.latent)
+
+ def set_level(self, level):
+ self.level = level
+
+ def encoder(self, input):
+ latent = self.to_channels[self.level](input)
+
+ if self.level >= 2:
+ latent = self.down_level2(latent)
+
+ if self.level >= 1:
+ latent = self.down_level1(latent)
+
+ return self.down_level0(latent)
+
+ def get_last_latent_gird(self):
+ return self.to_grid(self.latent) * self.alpha
+
+ def decoder(self, latent, input):
+ grid = self.to_grid(latent)
+ latent = self.up_level0(grid)
+
+ if self.level >= 1:
+ latent = self.up_level1(latent)
+
+ if self.level >= 2:
+ latent = self.up_level2(latent)
+
+ object = reduce(self.object, '1 c (h h2) (w w2) -> 1 c h w', 'mean', h = input.shape[2], w = input.shape[3])
+ object = repeat(object, '1 c h w -> b c h w', b = input.shape[0])
+
+ return th.sigmoid(object + self.to_img[self.level](latent)), grid
+
+ def forward(self, input: th.Tensor, error: th.Tensor = None, mask: th.Tensor = None, only_mask: bool = False):
+
+ if only_mask:
+ mask = reduce(self.mask, '1 1 (h h2) (w w2) -> 1 1 h w', 'mean', h = input.shape[2], w = input.shape[3])
+ mask = repeat(mask, '1 1 h w -> b 1 h w', b = input.shape[0]) * 0.1
+ return mask
+
+ last_bg = self.decoder(self.latent, input)[0]
+
+ bg_error = th.sqrt(reduce((input - last_bg)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+ bg_mask = (bg_error < th.mean(bg_error) + th.std(bg_error)).float().detach()
+
+ if error is None or self.get_init() < 2:
+ error = bg_error
+
+ if mask is None or self.get_init() < 2:
+ mask = bg_mask
+
+ self.latent = self.encoder(th.cat((input, last_bg, error, mask), dim=1))
+
+ mask = reduce(self.mask, '1 1 (h h2) (w w2) -> 1 1 h w', 'mean', h = input.shape[2], w = input.shape[3])
+ mask = repeat(mask, '1 1 h w -> b 1 h w', b = input.shape[0]) * 0.1
+
+ background, grid = self.decoder(self.latent, input)
+
+ if self.get_init() < 1:
+ return mask, background
+
+ if self.get_init() < 2:
+ return mask, th.zeros_like(background), th.zeros_like(grid), background
+
+ return mask, background, grid * self.alpha, background
diff --git a/model/nn/decoder.py b/model/nn/decoder.py
new file mode 100644
index 0000000..3e56640
--- /dev/null
+++ b/model/nn/decoder.py
@@ -0,0 +1,169 @@
+import torch.nn as nn
+import torch as th
+from model.nn.residual import SkipConnection, ResidualBlock
+from model.utils.nn_utils import Gaus2D, SharedObjectsToBatch, BatchToSharedObjects, Prioritize
+from einops import rearrange, repeat, reduce
+from typing import Tuple, Union, List
+
+__author__ = "Manuel Traub"
+
+class PriorityEncoder(nn.Module):
+ def __init__(self, num_objects, batch_size):
+ super(PriorityEncoder, self).__init__()
+
+ self.num_objects = num_objects
+ self.register_buffer("indices", repeat(th.arange(num_objects), 'a -> b a', b=batch_size), persistent=False)
+
+ self.index_factor = nn.Parameter(th.ones(1))
+ self.priority_factor = nn.Parameter(th.ones(1))
+
+ def forward(self, priority: th.Tensor) -> th.Tensor:
+
+ if priority is None:
+ return None
+
+ priority = priority * self.num_objects + th.randn_like(priority) * 0.1
+ priority = priority * self.priority_factor
+ priority = priority + self.indices * self.index_factor
+
+ return priority * 25
+
+
+class GestaltPositionMerge(nn.Module):
+ def __init__(
+ self,
+ latent_size: Union[int, Tuple[int, int]],
+ num_objects: int,
+ batch_size: int
+ ):
+
+ super(GestaltPositionMerge, self).__init__()
+ self.num_objects = num_objects
+
+ self.gaus2d = Gaus2D(size=latent_size)
+
+ self.to_batch = SharedObjectsToBatch(num_objects)
+ self.to_shared = BatchToSharedObjects(num_objects)
+
+ self.prioritize = Prioritize(num_objects)
+
+ self.priority_encoder = PriorityEncoder(num_objects, batch_size)
+
+ def forward(self, position, gestalt, priority):
+
+ position = rearrange(position, 'b (o c) -> (b o) c', o = self.num_objects)
+ gestalt = rearrange(gestalt, 'b (o c) -> (b o) c 1 1', o = self.num_objects)
+ priority = self.priority_encoder(priority)
+
+ position = self.gaus2d(position)
+ position = self.to_batch(self.prioritize(self.to_shared(position), priority))
+
+ return position * gestalt
+
+class PatchUpscale(nn.Module):
+ def __init__(self, in_channels, out_channels, scale_factor = 4, alpha = 1):
+ super(PatchUpscale, self).__init__()
+ assert in_channels % out_channels == 0
+
+ self.skip = SkipConnection(in_channels, out_channels, scale_factor=scale_factor)
+
+ self.residual = nn.Sequential(
+ nn.ReLU(),
+ nn.Conv2d(
+ in_channels = in_channels,
+ out_channels = in_channels,
+ kernel_size = 3,
+ padding = 1
+ ),
+ nn.ReLU(),
+ nn.ConvTranspose2d(
+ in_channels = in_channels,
+ out_channels = out_channels,
+ kernel_size = scale_factor,
+ stride = scale_factor,
+ ),
+ )
+
+ self.alpha = nn.Parameter(th.zeros(1) + alpha)
+
+ def forward(self, input):
+ return self.skip(input) + self.alpha * self.residual(input)
+
+
+class LociDecoder(nn.Module):
+ def __init__(
+ self,
+ latent_size: Union[int, Tuple[int, int]],
+ gestalt_size: int,
+ num_objects: int,
+ img_channels: int,
+ hidden_channels: int,
+ level1_channels: int,
+ num_layers: int,
+ batch_size: int
+ ):
+
+ super(LociDecoder, self).__init__()
+ self.to_batch = SharedObjectsToBatch(num_objects)
+ self.to_shared = BatchToSharedObjects(num_objects)
+ self.level = 1
+
+ assert(level1_channels % img_channels == 0)
+ level1_factor = level1_channels // img_channels
+ print(f"Level1 channels: {level1_channels}")
+
+ self.merge = GestaltPositionMerge(
+ latent_size = latent_size,
+ num_objects = num_objects,
+ batch_size = batch_size
+ )
+
+ self.layer0 = nn.Sequential(
+ ResidualBlock(gestalt_size, hidden_channels, input_nonlinearity = False),
+ *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers-1)],
+ )
+
+ self.to_mask_level0 = ResidualBlock(hidden_channels, hidden_channels)
+ self.to_mask_level1 = PatchUpscale(hidden_channels, 1)
+
+ self.to_mask_level2 = nn.Sequential(
+ ResidualBlock(hidden_channels, hidden_channels),
+ ResidualBlock(hidden_channels, hidden_channels),
+ PatchUpscale(hidden_channels, level1_factor, alpha = 1),
+ PatchUpscale(level1_factor, 1, alpha = 1)
+ )
+
+ self.to_object_level0 = ResidualBlock(hidden_channels, hidden_channels)
+ self.to_object_level1 = PatchUpscale(hidden_channels, img_channels)
+
+ self.to_object_level2 = nn.Sequential(
+ ResidualBlock(hidden_channels, hidden_channels),
+ ResidualBlock(hidden_channels, hidden_channels),
+ PatchUpscale(hidden_channels, level1_channels, alpha = 1),
+ PatchUpscale(level1_channels, img_channels, alpha = 1)
+ )
+
+ self.mask_alpha = nn.Parameter(th.zeros(1)+1e-16)
+ self.object_alpha = nn.Parameter(th.zeros(1)+1e-16)
+
+
+ def set_level(self, level):
+ self.level = level
+
+ def forward(self, position, gestalt, priority = None):
+
+ maps = self.layer0(self.merge(position, gestalt, priority))
+ mask0 = self.to_mask_level0(maps)
+ object0 = self.to_object_level0(maps)
+
+ mask = self.to_mask_level1(mask0)
+ object = self.to_object_level1(object0)
+
+ if self.level > 1:
+ mask = repeat(mask, 'b c h w -> b c (h h2) (w w2)', h2 = 4, w2 = 4)
+ object = repeat(object, 'b c h w -> b c (h h2) (w w2)', h2 = 4, w2 = 4)
+
+ mask = mask + self.to_mask_level2(mask0) * self.mask_alpha
+ object = object + self.to_object_level2(object0) * self.object_alpha
+
+ return self.to_shared(mask), self.to_shared(object)
diff --git a/model/nn/encoder.py b/model/nn/encoder.py
new file mode 100644
index 0000000..8eb2e7f
--- /dev/null
+++ b/model/nn/encoder.py
@@ -0,0 +1,269 @@
+import torch.nn as nn
+import torch as th
+from model.utils.nn_utils import Gaus2D, BatchToSharedObjects, LambdaModule, ForcedAlpha, Binarize
+from model.nn.residual import ResidualBlock, SkipConnection
+from einops import rearrange, repeat, reduce
+from typing import Tuple, Union, List
+
+__author__ = "Manuel Traub"
+
+class NeighbourChannels(nn.Module):
+ def __init__(self, channels):
+ super(NeighbourChannels, self).__init__()
+
+ self.register_buffer("weights", th.ones(channels, channels, 1, 1), persistent=False)
+
+ for i in range(channels):
+ self.weights[i,i,0,0] = 0
+
+ def forward(self, input: th.Tensor):
+ return nn.functional.conv2d(input, self.weights)
+
+class InputPreprocessing(nn.Module):
+ def __init__(self, num_objects: int, size: Union[int, Tuple[int, int]]):
+ super(InputPreprocessing, self).__init__()
+ self.num_objects = num_objects
+ self.neighbours = NeighbourChannels(num_objects)
+ self.gaus2d = Gaus2D(size)
+ self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects))
+ self.to_shared = BatchToSharedObjects(num_objects)
+
+ def forward(
+ self,
+ input: th.Tensor,
+ error: th.Tensor,
+ mask: th.Tensor,
+ object: th.Tensor,
+ position: th.Tensor,
+ rawmask: th.Tensor
+ ):
+ bg_mask = repeat(mask[:,-1:], 'b 1 h w -> b c h w', c = self.num_objects)
+ mask = mask[:,:-1]
+ mask_others = self.neighbours(mask)
+ rawmask = rawmask[:,:-1]
+
+ own_gaus2d = self.to_shared(self.gaus2d(self.to_batch(position)))
+
+ input = repeat(input, 'b c h w -> b o c h w', o = self.num_objects)
+ error = repeat(error, 'b 1 h w -> b o 1 h w', o = self.num_objects)
+ bg_mask = rearrange(bg_mask, 'b o h w -> b o 1 h w')
+ mask_others = rearrange(mask_others, 'b o h w -> b o 1 h w')
+ mask = rearrange(mask, 'b o h w -> b o 1 h w')
+ object = rearrange(object, 'b (o c) h w -> b o c h w', o = self.num_objects)
+ own_gaus2d = rearrange(own_gaus2d, 'b o h w -> b o 1 h w')
+ rawmask = rearrange(rawmask, 'b o h w -> b o 1 h w')
+
+ output = th.cat((input, error, mask, mask_others, bg_mask, object, own_gaus2d, rawmask), dim=2)
+ output = rearrange(output, 'b o c h w -> (b o) c h w')
+
+ return output
+
+class PatchDownConv(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size = 4, alpha = 1):
+ super(PatchDownConv, self).__init__()
+ assert out_channels % in_channels == 0
+
+ self.layers = nn.Conv2d(
+ in_channels = in_channels,
+ out_channels = out_channels,
+ kernel_size = kernel_size,
+ stride = kernel_size,
+ )
+
+ self.alpha = nn.Parameter(th.zeros(1) + alpha)
+ self.kernel_size = 4
+ self.channels_factor = out_channels // in_channels
+
+ def forward(self, input: th.Tensor):
+ k = self.kernel_size
+ c = self.channels_factor
+ skip = reduce(input, 'b c (h h2) (w w2) -> b c h w', 'mean', h2=k, w2=k)
+ skip = repeat(skip, 'b c h w -> b (c n) h w', n=c)
+ return skip + self.alpha * self.layers(input)
+
+class AggressiveConvToGestalt(nn.Module):
+ def __init__(self, channels, gestalt_size, size: Union[int, Tuple[int, int]]):
+ super(AggressiveConvToGestalt, self).__init__()
+
+ assert gestalt_size % channels == 0 or channels % gestalt_size == 0
+
+ self.layers = nn.Sequential(
+ nn.Conv2d(
+ in_channels = channels,
+ out_channels = gestalt_size,
+ kernel_size = 5,
+ stride = 3,
+ padding = 3
+ ),
+ nn.ReLU(),
+ nn.Conv2d(
+ in_channels = gestalt_size,
+ out_channels = gestalt_size,
+ kernel_size = ((size[0] + 1)//3 + 1, (size[1] + 1)//3 + 1)
+ )
+ )
+ if gestalt_size > channels:
+ self.skip = nn.Sequential(
+ LambdaModule(lambda x: reduce(x, 'b c h w -> b c 1 1', 'mean')),
+ LambdaModule(lambda x: repeat(x, 'b c 1 1 -> b (c n) 1 1', n = gestalt_size // channels))
+ )
+ else:
+ self.skip = LambdaModule(lambda x: reduce(x, 'b (c n) h w -> b c 1 1', 'mean', n = channels // gestalt_size))
+
+
+ def forward(self, input: th.Tensor):
+ return self.skip(input) + self.layers(input)
+
+class PixelToPosition(nn.Module):
+ def __init__(self, size: Union[int, Tuple[int, int]]):
+ super(PixelToPosition, self).__init__()
+
+ self.register_buffer("grid_x", th.arange(size[0]), persistent=False)
+ self.register_buffer("grid_y", th.arange(size[1]), persistent=False)
+
+ self.grid_x = (self.grid_x / (size[0]-1)) * 2 - 1
+ self.grid_y = (self.grid_y / (size[1]-1)) * 2 - 1
+
+ self.grid_x = self.grid_x.view(1, 1, -1, 1).expand(1, 1, *size).clone()
+ self.grid_y = self.grid_y.view(1, 1, 1, -1).expand(1, 1, *size).clone()
+
+ self.size = size
+
+ def forward(self, input: th.Tensor):
+ assert input.shape[1] == 1
+
+ input = rearrange(input, 'b c h w -> b c (h w)')
+ input = th.softmax(input, dim=2)
+ input = rearrange(input, 'b c (h w) -> b c h w', h = self.size[0], w = self.size[1])
+
+ x = th.sum(input * self.grid_x, dim=(2,3))
+ y = th.sum(input * self.grid_y, dim=(2,3))
+
+ return th.cat((x,y),dim=1)
+
+class PixelToSTD(nn.Module):
+ def __init__(self):
+ super(PixelToSTD, self).__init__()
+ self.alpha = ForcedAlpha()
+
+ def forward(self, input: th.Tensor):
+ assert input.shape[1] == 1
+ return self.alpha(reduce(th.sigmoid(input - 10), 'b c h w -> b c', 'mean'))
+
+class PixelToPriority(nn.Module):
+ def __init__(self):
+ super(PixelToPriority, self).__init__()
+
+ def forward(self, input: th.Tensor):
+ assert input.shape[1] == 1
+ return reduce(th.tanh(input), 'b c h w -> b c', 'mean')
+
+class LociEncoder(nn.Module):
+ def __init__(
+ self,
+ input_size: Union[int, Tuple[int, int]],
+ latent_size: Union[int, Tuple[int, int]],
+ num_objects: int,
+ img_channels: int,
+ hidden_channels: int,
+ level1_channels: int,
+ num_layers: int,
+ gestalt_size: int,
+ bottleneck: str
+ ):
+ super(LociEncoder, self).__init__()
+
+ self.num_objects = num_objects
+ self.latent_size = latent_size
+ self.level = 1
+
+ self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o = self.num_objects))
+
+ print(f"Level1 channels: {level1_channels}")
+
+ self.preprocess = nn.ModuleList([
+ InputPreprocessing(num_objects, (input_size[0] // 16, input_size[1] // 16)),
+ InputPreprocessing(num_objects, (input_size[0] // 4, input_size[1] // 4)),
+ InputPreprocessing(num_objects, (input_size[0], input_size[1]))
+ ])
+
+ self.to_channels = nn.ModuleList([
+ SkipConnection(img_channels, hidden_channels),
+ SkipConnection(img_channels, level1_channels),
+ SkipConnection(img_channels, img_channels)
+ ])
+
+ self.layers2 = nn.Sequential(
+ PatchDownConv(img_channels, level1_channels, alpha = 1e-16),
+ *[ResidualBlock(level1_channels, level1_channels, alpha_residual=True) for _ in range(num_layers)]
+ )
+
+ self.layers1 = PatchDownConv(level1_channels, hidden_channels)
+
+ self.layers0 = nn.Sequential(
+ *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)]
+ )
+
+ self.position_encoder = nn.Sequential(
+ *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)],
+ ResidualBlock(hidden_channels, 3),
+ )
+
+ self.xy_encoder = PixelToPosition(latent_size)
+ self.std_encoder = PixelToSTD()
+ self.priority_encoder = PixelToPriority()
+
+ if bottleneck == "binar":
+ print("Binary bottleneck")
+ self.gestalt_encoder = nn.Sequential(
+ *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)],
+ AggressiveConvToGestalt(hidden_channels, gestalt_size, latent_size),
+ LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')),
+ Binarize(),
+ )
+
+ else:
+ print("unrestricted bottleneck")
+ self.gestalt_encoder = nn.Sequential(
+ *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)],
+ AggressiveConvToGestalt(hidden_channels, gestalt_size, latent_size),
+ LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')),
+ nn.Sigmoid(),
+ )
+
+ def set_level(self, level):
+ self.level = level
+
+ def forward(
+ self,
+ input: th.Tensor,
+ error: th.Tensor,
+ mask: th.Tensor,
+ object: th.Tensor,
+ position: th.Tensor,
+ rawmask: th.Tensor
+ ):
+
+ latent = self.preprocess[self.level](input, error, mask, object, position, rawmask)
+ latent = self.to_channels[self.level](latent)
+
+ if self.level >= 2:
+ latent = self.layers2(latent)
+
+ if self.level >= 1:
+ latent = self.layers1(latent)
+
+ latent = self.layers0(latent)
+ gestalt = self.gestalt_encoder(latent)
+
+ latent = self.position_encoder(latent)
+ std = self.std_encoder(latent[:,0:1])
+ xy = self.xy_encoder(latent[:,1:2])
+ priority = self.priority_encoder(latent[:,2:3])
+
+ position = self.to_shared(th.cat((xy, std), dim=1))
+ gestalt = self.to_shared(gestalt)
+ priority = self.to_shared(priority)
+
+ return position, gestalt, priority
+
diff --git a/model/nn/eprop_gate_l0rd.py b/model/nn/eprop_gate_l0rd.py
new file mode 100644
index 0000000..82a9895
--- /dev/null
+++ b/model/nn/eprop_gate_l0rd.py
@@ -0,0 +1,329 @@
+import torch.nn as nn
+import torch as th
+import numpy as np
+from torch.autograd import Function
+from einops import rearrange, repeat, reduce
+
+__author__ = "Manuel Traub"
+
+class EpropGateL0rdFunction(Function):
+ @staticmethod
+ def forward(ctx, x, h_last, w_gx, w_gh, b_g, w_rx, w_rh, b_r, args):
+
+ e_w_gx, e_w_gh, e_b_g, e_w_rx, e_w_rh, e_b_r, reg, noise_level = args
+
+ noise = th.normal(mean=0, std=noise_level, size=b_g.shape, device=b_g.device)
+ g = th.relu(th.tanh(x.mm(w_gx.t()) + h_last.mm(w_gh.t()) + b_g + noise))
+ r = th.tanh(x.mm(w_rx.t()) + h_last.mm(w_rh.t()) + b_r)
+
+ h = g * r + (1 - g) * h_last
+
+ # Haevisite step function
+ H_g = th.ceil(g).clamp(0, 1)
+
+ dg = (1 - g**2) * H_g
+ dr = (1 - r**2)
+
+ delta_h = r - h_last
+
+ g_j = g.unsqueeze(dim=2)
+ dg_j = dg.unsqueeze(dim=2)
+ dr_j = dr.unsqueeze(dim=2)
+
+ x_i = x.unsqueeze(dim=1)
+ h_last_i = h_last.unsqueeze(dim=1)
+ delta_h_j = delta_h.unsqueeze(dim=2)
+
+ e_w_gh.copy_(e_w_gh * (1 - g_j) + dg_j * h_last_i * delta_h_j)
+ e_w_gx.copy_(e_w_gx * (1 - g_j) + dg_j * x_i * delta_h_j)
+ e_b_g.copy_( e_b_g * (1 - g) + dg * delta_h )
+
+ e_w_rh.copy_(e_w_rh * (1 - g_j) + dr_j * h_last_i * g_j)
+ e_w_rx.copy_(e_w_rx * (1 - g_j) + dr_j * x_i * g_j)
+ e_b_r.copy_( e_b_r * (1 - g) + dr * g )
+
+ ctx.save_for_backward(
+ g.clone(), dg.clone(), dg_j.clone(), dr.clone(), x_i.clone(), h_last_i.clone(),
+ reg.clone(), H_g.clone(), delta_h.clone(), w_gx.clone(), w_gh.clone(), w_rx.clone(), w_rh.clone(),
+ e_w_gx.clone(), e_w_gh.clone(), e_b_g.clone(),
+ e_w_rx.clone(), e_w_rh.clone(), e_b_r.clone(),
+ )
+
+ return h, th.mean(H_g)
+
+ @staticmethod
+ def backward(ctx, dh, _):
+
+ g, dg, dg_j, dr, x_i, h_last_i, reg, H_g, delta_h, w_gx, w_gh, w_rx, w_rh, \
+ e_w_gx, e_w_gh, e_b_g, e_w_rx, e_w_rh, e_b_r = ctx.saved_tensors
+
+ dh_j = dh.unsqueeze(dim=2)
+ H_g_reg = reg * H_g
+ H_g_reg_j = H_g_reg.unsqueeze(dim=2)
+
+ dw_gx = th.sum(dh_j * e_w_gx + H_g_reg_j * dg_j * x_i, dim=0)
+ dw_gh = th.sum(dh_j * e_w_gh + H_g_reg_j * dg_j * h_last_i, dim=0)
+ db_g = th.sum(dh * e_b_g + H_g_reg * dg, dim=0)
+
+ dw_rx = th.sum(dh_j * e_w_rx, dim=0)
+ dw_rh = th.sum(dh_j * e_w_rh, dim=0)
+ db_r = th.sum(dh * e_b_r , dim=0)
+
+ dh_dg = (dh * delta_h + H_g_reg) * dg
+ dh_dr = dh * g * dr
+
+ dx = dh_dg.mm(w_gx) + dh_dr.mm(w_rx)
+ dh = dh * (1 - g) + dh_dg.mm(w_gh) + dh_dr.mm(w_rh)
+
+ return dx, dh, dw_gx, dw_gh, db_g, dw_rx, dw_rh, db_r, None
+
+class ReTanhFunction(Function):
+ @staticmethod
+ def forward(ctx, x, reg):
+
+ g = th.relu(th.tanh(x))
+
+ # Haevisite step function
+ H_g = th.ceil(g).clamp(0, 1)
+
+ dg = (1 - g**2) * H_g
+
+ ctx.save_for_backward(g, dg, H_g, reg)
+ return g, th.mean(H_g)
+
+ @staticmethod
+ def backward(ctx, dh, _):
+
+ g, dg, H_g, reg = ctx.saved_tensors
+
+ dx = (dh + reg * H_g) * dg
+
+ return dx, None
+
+class ReTanh(nn.Module):
+ def __init__(self, reg_lambda):
+ super(ReTanh, self).__init__()
+
+ self.re_tanh = ReTanhFunction().apply
+ self.register_buffer("reg_lambda", th.tensor(reg_lambda), persistent=False)
+
+ def forward(self, input):
+ h, openings = self.re_tanh(input, self.reg_lambda)
+ self.openings = openings.item()
+
+ return h
+
+
+class EpropGateL0rd(nn.Module):
+ def __init__(
+ self,
+ num_inputs,
+ num_hidden,
+ num_outputs,
+ batch_size,
+ reg_lambda = 0,
+ gate_noise_level = 0,
+ ):
+ super(EpropGateL0rd, self).__init__()
+
+ self.register_buffer("reg", th.tensor(reg_lambda).view(1,1), persistent=False)
+ self.register_buffer("noise", th.tensor(gate_noise_level), persistent=False)
+ self.num_inputs = num_inputs
+ self.num_hidden = num_hidden
+ self.num_outputs = num_outputs
+
+ self.fcn = EpropGateL0rdFunction().apply
+ self.retanh = ReTanh(reg_lambda)
+
+ # gate weights and biases
+ self.w_gx = nn.Parameter(th.empty(num_hidden, num_inputs))
+ self.w_gh = nn.Parameter(th.empty(num_hidden, num_hidden))
+ self.b_g = nn.Parameter(th.zeros(num_hidden))
+
+ # candidate weights and biases
+ self.w_rx = nn.Parameter(th.empty(num_hidden, num_inputs))
+ self.w_rh = nn.Parameter(th.empty(num_hidden, num_hidden))
+ self.b_r = nn.Parameter(th.zeros(num_hidden))
+
+ # output projection weights and bias
+ self.w_px = nn.Parameter(th.empty(num_outputs, num_inputs))
+ self.w_ph = nn.Parameter(th.empty(num_outputs, num_hidden))
+ self.b_p = nn.Parameter(th.zeros(num_outputs))
+
+ # output gate weights and bias
+ self.w_ox = nn.Parameter(th.empty(num_outputs, num_inputs))
+ self.w_oh = nn.Parameter(th.empty(num_outputs, num_hidden))
+ self.b_o = nn.Parameter(th.zeros(num_outputs))
+
+ # input gate eligibilitiy traces
+ self.register_buffer("e_w_gx", th.zeros(batch_size, num_hidden, num_inputs), persistent=False)
+ self.register_buffer("e_w_gh", th.zeros(batch_size, num_hidden, num_hidden), persistent=False)
+ self.register_buffer("e_b_g", th.zeros(batch_size, num_hidden), persistent=False)
+
+ # forget gate eligibilitiy traces
+ self.register_buffer("e_w_rx", th.zeros(batch_size, num_hidden, num_inputs), persistent=False)
+ self.register_buffer("e_w_rh", th.zeros(batch_size, num_hidden, num_hidden), persistent=False)
+ self.register_buffer("e_b_r", th.zeros(batch_size, num_hidden), persistent=False)
+
+ # hidden state
+ self.register_buffer("h_last", th.zeros(batch_size, num_hidden), persistent=False)
+
+ self.register_buffer("openings", th.zeros(1), persistent=False)
+
+ # initialize weights
+ stdv_ih = np.sqrt(6/(self.num_inputs + self.num_hidden))
+ stdv_hh = np.sqrt(3/self.num_hidden)
+ stdv_io = np.sqrt(6/(self.num_inputs + self.num_outputs))
+ stdv_ho = np.sqrt(6/(self.num_hidden + self.num_outputs))
+
+ nn.init.uniform_(self.w_gx, -stdv_ih, stdv_ih)
+ nn.init.uniform_(self.w_gh, -stdv_hh, stdv_hh)
+
+ nn.init.uniform_(self.w_rx, -stdv_ih, stdv_ih)
+ nn.init.uniform_(self.w_rh, -stdv_hh, stdv_hh)
+
+ nn.init.uniform_(self.w_px, -stdv_io, stdv_io)
+ nn.init.uniform_(self.w_ph, -stdv_ho, stdv_ho)
+
+ nn.init.uniform_(self.w_ox, -stdv_io, stdv_io)
+ nn.init.uniform_(self.w_oh, -stdv_ho, stdv_ho)
+
+ self.backprop = False
+
+ def reset_state(self):
+ self.h_last.zero_()
+ self.e_w_gx.zero_()
+ self.e_w_gh.zero_()
+ self.e_b_g.zero_()
+ self.e_w_rx.zero_()
+ self.e_w_rh.zero_()
+ self.e_b_r.zero_()
+ self.openings.zero_()
+
+ def backprop_forward(self, x: th.Tensor):
+
+ noise = th.normal(mean=0, std=self.noise, size=self.b_g.shape, device=self.b_g.device)
+ g = self.retanh(x.mm(self.w_gx.t()) + self.h_last.mm(self.w_gh.t()) + self.b_g + noise)
+ r = th.tanh(x.mm(self.w_rx.t()) + self.h_last.mm(self.w_rh.t()) + self.b_r)
+
+ self.h_last = g * r + (1 - g) * self.h_last
+
+ # Haevisite step function
+ H_g = th.ceil(g).clamp(0, 1)
+
+ self.openings = th.mean(H_g)
+
+ p = th.tanh(x.mm(self.w_px.t()) + self.h_last.mm(self.w_ph.t()) + self.b_p)
+ o = th.sigmoid(x.mm(self.w_ox.t()) + self.h_last.mm(self.w_oh.t()) + self.b_o)
+ return o * p
+
+ def activate_backprop(self):
+ self.backprop = True
+
+ def deactivate_backprop(self):
+ self.backprop = False
+
+ def detach(self):
+ self.h_last.detach_()
+
+ def eprop_forward(self, x: th.Tensor):
+ h, openings = self.fcn(
+ x, self.h_last,
+ self.w_gx, self.w_gh, self.b_g,
+ self.w_rx, self.w_rh, self.b_r,
+ (
+ self.e_w_gx, self.e_w_gh, self.e_b_g,
+ self.e_w_rx, self.e_w_rh, self.e_b_r,
+ self.reg, self.noise
+ )
+ )
+
+ self.openings = openings
+ self.h_last = h
+
+ p = th.tanh(x.mm(self.w_px.t()) + h.mm(self.w_ph.t()) + self.b_p)
+ o = th.sigmoid(x.mm(self.w_ox.t()) + h.mm(self.w_oh.t()) + self.b_o)
+ return o * p
+
+ def save_hidden(self):
+ self.h_last_saved = self.h_last.detach()
+
+ def restore_hidden(self):
+ self.h_last = self.h_last_saved
+
+ def get_hidden(self):
+ return self.h_last
+
+ def set_hidden(self, h_last):
+ self.h_last = h_last
+
+ def forward(self, x: th.Tensor):
+ if self.backprop:
+ return self.backprop_forward(x)
+
+ return self.eprop_forward(x)
+
+
+class EpropGateL0rdShared(EpropGateL0rd):
+ def __init__(
+ self,
+ num_inputs,
+ num_hidden,
+ num_outputs,
+ batch_size,
+ reg_lambda = 0,
+ gate_noise_level = 0,
+ ):
+ super().__init__(num_inputs, num_hidden, num_outputs, batch_size, reg_lambda, gate_noise_level)
+
+ def backprop_forward(self, x: th.Tensor, h_last: th.Tensor):
+
+ noise = th.normal(mean=0, std=self.noise, size=self.b_g.shape, device=self.b_g.device)
+ g = self.retanh(x.mm(self.w_gx.t()) + h_last.mm(self.w_gh.t()) + self.b_g + noise)
+ r = th.tanh(x.mm(self.w_rx.t()) + h_last.mm(self.w_rh.t()) + self.b_r)
+
+ h_last = g * r + (1 - g) * h_last
+
+ # Haevisite step function
+ H_g = th.ceil(g).clamp(0, 1)
+
+ self.openings = th.mean(H_g)
+
+ p = th.tanh(x.mm(self.w_px.t()) + h_last.mm(self.w_ph.t()) + self.b_p)
+ o = th.sigmoid(x.mm(self.w_ox.t()) + h_last.mm(self.w_oh.t()) + self.b_o)
+ return o * p, h_last
+
+ def eprop_forward(self, x: th.Tensor, h_last: th.Tensor):
+ h, openings = self.fcn(
+ x, h_last,
+ self.w_gx, self.w_gh, self.b_g,
+ self.w_rx, self.w_rh, self.b_r,
+ (
+ self.e_w_gx, self.e_w_gh, self.e_b_g,
+ self.e_w_rx, self.e_w_rh, self.e_b_r,
+ self.reg, self.noise
+ )
+ )
+
+ self.openings = openings
+
+ p = th.tanh(x.mm(self.w_px.t()) + h.mm(self.w_ph.t()) + self.b_p)
+ o = th.sigmoid(x.mm(self.w_ox.t()) + h.mm(self.w_oh.t()) + self.b_o)
+ return o * p, h
+
+ def forward(self, x: th.Tensor, h_last: th.Tensor = None):
+
+ if h_last is not None:
+ if self.backprop:
+ return self.backprop_forward(x, h_last)
+
+ return self.eprop_forward(x, h_last)
+
+ # backward compatibility
+ if self.backprop:
+ x, h = self.backprop_forward(x, self.h_last)
+
+ x, h = self.eprop_forward(x, self.h_last)
+
+ self.h_last = h
+ return x \ No newline at end of file
diff --git a/model/nn/eprop_transformer.py b/model/nn/eprop_transformer.py
new file mode 100644
index 0000000..4d89ec4
--- /dev/null
+++ b/model/nn/eprop_transformer.py
@@ -0,0 +1,76 @@
+import torch.nn as nn
+import torch as th
+from model.nn.eprop_gate_l0rd import EpropGateL0rd
+from model.nn.eprop_transformer_utils import AlphaAttention, InputEmbeding, OutputEmbeding
+
+class EpropGateL0rdTransformer(nn.Module):
+ def __init__(
+ self,
+ channels,
+ multiplier,
+ num_objects,
+ batch_size,
+ heads,
+ depth,
+ reg_lambda,
+ dropout=0.0
+ ):
+ super(EpropGateL0rdTransformer, self).__init__()
+
+ num_inputs = channels
+ num_outputs = channels
+ num_hidden = channels
+ num_hidden = channels * multiplier
+
+ print(f"Predictor channels: {num_hidden}@({num_hidden // heads}x{heads})")
+
+
+ self.depth = depth
+ _layers = []
+ _layers.append(InputEmbeding(num_inputs, num_hidden))
+
+ for i in range(depth):
+ _layers.append(AlphaAttention(num_hidden, num_objects, heads, dropout))
+ _layers.append(EpropAlphaGateL0rd(num_hidden, batch_size * num_objects, reg_lambda))
+
+ _layers.append(OutputEmbeding(num_hidden, num_outputs))
+ self.layers = nn.Sequential(*_layers)
+
+ def get_openings(self):
+ openings = 0
+ for i in range(self.depth):
+ openings += self.layers[2 * (i + 1)].l0rd.openings.item()
+
+ return openings / self.depth
+
+ def get_hidden(self):
+ states = []
+ for i in range(self.depth):
+ states.append(self.layers[2 * (i + 1)].l0rd.get_hidden())
+
+ return th.cat(states, dim=1)
+
+ def set_hidden(self, hidden):
+ states = th.chunk(hidden, self.depth, dim=1)
+ for i in range(self.depth):
+ self.layers[2 * (i + 1)].l0rd.set_hidden(states[i])
+
+ def forward(self, input: th.Tensor) -> th.Tensor:
+ return self.layers(input)
+
+
+class EpropAlphaGateL0rd(nn.Module):
+ def __init__(self, num_hidden, batch_size, reg_lambda):
+ super(EpropAlphaGateL0rd, self).__init__()
+
+ self.alpha = nn.Parameter(th.zeros(1)+1e-12)
+ self.l0rd = EpropGateL0rd(
+ num_inputs = num_hidden,
+ num_hidden = num_hidden,
+ num_outputs = num_hidden,
+ reg_lambda = reg_lambda,
+ batch_size = batch_size
+ )
+
+ def forward(self, input):
+ return input + self.alpha * self.l0rd(input) \ No newline at end of file
diff --git a/model/nn/eprop_transformer_shared.py b/model/nn/eprop_transformer_shared.py
new file mode 100644
index 0000000..23a1b0f
--- /dev/null
+++ b/model/nn/eprop_transformer_shared.py
@@ -0,0 +1,92 @@
+import torch.nn as nn
+import torch as th
+from model.nn.eprop_gate_l0rd import EpropGateL0rdShared
+from model.nn.eprop_transformer_utils import AlphaAttention, InputEmbeding, OutputEmbeding
+
+class EpropGateL0rdTransformerShared(nn.Module):
+ def __init__(
+ self,
+ channels,
+ multiplier,
+ num_objects,
+ batch_size,
+ heads,
+ depth,
+ reg_lambda,
+ dropout=0.0,
+ exchange_length = 48,
+ ):
+ super(EpropGateL0rdTransformerShared, self).__init__()
+
+ num_inputs = channels
+ num_outputs = channels
+ num_hidden = channels * multiplier
+ num_hidden_gatelord = num_hidden + exchange_length
+ num_hidden_attention = num_hidden + exchange_length + num_hidden_gatelord
+
+ self.num_hidden = num_hidden
+ self.num_hidden_gatelord = num_hidden_gatelord
+
+ #print(f"Predictor channels: {num_hidden}@({num_hidden // heads}x{heads})")
+
+ self.register_buffer('hidden', th.zeros(batch_size * num_objects, num_hidden_gatelord), persistent=False)
+ self.register_buffer('exchange_code', th.zeros(batch_size * num_objects, exchange_length), persistent=False)
+
+ self.depth = depth
+ self.input_embeding = InputEmbeding(num_inputs, num_hidden)
+ self.attention = nn.Sequential(*[AlphaAttention(num_hidden_attention, num_objects, heads, dropout) for _ in range(depth)])
+ self.l0rds = nn.Sequential(*[EpropAlphaGateL0rdShared(num_hidden_gatelord, batch_size * num_objects, reg_lambda) for _ in range(depth)])
+ self.output_embeding = OutputEmbeding(num_hidden, num_outputs)
+
+ def get_openings(self):
+ openings = 0
+ for i in range(self.depth):
+ openings += self.l0rds[i].l0rd.openings.item()
+
+ return openings / self.depth
+
+ def get_hidden(self):
+ return self.hidden
+
+ def set_hidden(self, hidden):
+ self.hidden = hidden
+
+ def detach(self):
+ self.hidden = self.hidden.detach()
+
+ def reset_state(self):
+ self.hidden = th.zeros_like(self.hidden)
+
+ def forward(self, x: th.Tensor) -> th.Tensor:
+ x = self.input_embeding(x)
+ exchange_code = self.exchange_code.clone() * 0.0
+ x_ex = th.concat((x, exchange_code), dim=1)
+
+ for i in range(self.depth):
+ # attention layer
+ att = self.attention(th.concat((x_ex, self.hidden), dim=1))
+ x_ex = att[:, :self.num_hidden_gatelord]
+
+ # gatelord layer
+ x_ex, self.hidden = self.l0rds[i](x_ex, self.hidden)
+
+ # only yield x
+ x = x_ex[:, :self.num_hidden]
+ return self.output_embeding(x)
+
+class EpropAlphaGateL0rdShared(nn.Module):
+ def __init__(self, num_hidden, batch_size, reg_lambda):
+ super(EpropAlphaGateL0rdShared, self).__init__()
+
+ self.alpha = nn.Parameter(th.zeros(1)+1e-12)
+ self.l0rd = EpropGateL0rdShared(
+ num_inputs = num_hidden,
+ num_hidden = num_hidden,
+ num_outputs = num_hidden,
+ reg_lambda = reg_lambda,
+ batch_size = batch_size
+ )
+
+ def forward(self, input, hidden):
+ output, hidden = self.l0rd(input, hidden)
+ return input + self.alpha * output, hidden \ No newline at end of file
diff --git a/model/nn/eprop_transformer_utils.py b/model/nn/eprop_transformer_utils.py
new file mode 100644
index 0000000..9e5e874
--- /dev/null
+++ b/model/nn/eprop_transformer_utils.py
@@ -0,0 +1,66 @@
+import torch.nn as nn
+import torch as th
+from model.utils.nn_utils import LambdaModule
+from einops import rearrange, repeat, reduce
+
+class AlphaAttention(nn.Module):
+ def __init__(
+ self,
+ num_hidden,
+ num_objects,
+ heads,
+ dropout = 0.0
+ ):
+ super(AlphaAttention, self).__init__()
+
+ self.to_sequence = LambdaModule(lambda x: rearrange(x, '(b o) c -> b o c', o = num_objects))
+ self.to_batch = LambdaModule(lambda x: rearrange(x, 'b o c -> (b o) c', o = num_objects))
+
+ self.alpha = nn.Parameter(th.zeros(1)+1e-12)
+ self.attention = nn.MultiheadAttention(
+ num_hidden,
+ heads,
+ dropout = dropout,
+ batch_first = True
+ )
+
+ def forward(self, x: th.Tensor):
+ x = self.to_sequence(x)
+ x = x + self.alpha * self.attention(x, x, x, need_weights=False)[0]
+ return self.to_batch(x)
+
+class InputEmbeding(nn.Module):
+ def __init__(self, num_inputs, num_hidden):
+ super(InputEmbeding, self).__init__()
+
+ self.embeding = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(num_inputs, num_hidden),
+ nn.ReLU(),
+ nn.Linear(num_hidden, num_hidden),
+ )
+ self.skip = LambdaModule(
+ lambda x: repeat(x, 'b c -> b (n c)', n = num_hidden // num_inputs)
+ )
+ self.alpha = nn.Parameter(th.zeros(1)+1e-12)
+
+ def forward(self, input: th.Tensor):
+ return self.skip(input) + self.alpha * self.embeding(input)
+
+class OutputEmbeding(nn.Module):
+ def __init__(self, num_hidden, num_outputs):
+ super(OutputEmbeding, self).__init__()
+
+ self.embeding = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(num_hidden, num_outputs),
+ nn.ReLU(),
+ nn.Linear(num_outputs, num_outputs),
+ )
+ self.skip = LambdaModule(
+ lambda x: reduce(x, 'b (n c) -> b c', 'mean', n = num_hidden // num_outputs)
+ )
+ self.alpha = nn.Parameter(th.zeros(1)+1e-12)
+
+ def forward(self, input: th.Tensor):
+ return self.skip(input) + self.alpha * self.embeding(input) \ No newline at end of file
diff --git a/model/nn/percept_gate_controller.py b/model/nn/percept_gate_controller.py
new file mode 100644
index 0000000..a548c20
--- /dev/null
+++ b/model/nn/percept_gate_controller.py
@@ -0,0 +1,59 @@
+import torch.nn as nn
+import torch as th
+from model.utils.nn_utils import LambdaModule
+from einops import rearrange, repeat, reduce
+from model.nn.eprop_gate_l0rd import ReTanh
+
+class PerceptGateController(nn.Module):
+ def __init__(
+ self,
+ num_inputs: int,
+ num_hidden: list,
+ bias: bool,
+ num_objects: int,
+ gate_noise_level: float = 0.1,
+ reg_lambda: float = 0.000005
+ ):
+ super(PerceptGateController, self).__init__()
+
+ self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o=num_objects))
+ self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b o c', o=num_objects))
+
+ self.layers = nn.Sequential(
+ nn.Linear(num_inputs, num_hidden[0], bias = bias),
+ nn.Tanh(),
+ nn.Linear(num_hidden[0], num_hidden[1], bias = bias),
+ nn.Tanh(),
+ nn.Linear(num_hidden[1], 2, bias = bias)
+ )
+ self.output_function = ReTanh(reg_lambda)
+ self.register_buffer("noise", th.tensor(gate_noise_level), persistent=False)
+ self.init_weights()
+
+ def init_weights(self):
+ for layer in self.layers:
+ if isinstance(layer, nn.Linear):
+ nn.init.xavier_uniform(layer.weight)
+ layer.bias.data.fill_(3.00)
+
+ def forward(self, position_cur, gestalt_cur, priority_cur, slots_occlusionfactor_cur, position_last, gestalt_last, priority_last, slots_occlusionfactor_last, position_last2, evaluate=False):
+
+ position_cur = self.to_batch(position_cur)
+ gestalt_cur = self.to_batch(gestalt_cur)
+ priority_cur = self.to_batch(priority_cur)
+ position_last = self.to_batch(position_last)
+ gestalt_last = self.to_batch(gestalt_last)
+ priority_last = self.to_batch(priority_last)
+ slots_occlusionfactor_cur = self.to_batch(slots_occlusionfactor_cur).detach()
+ slots_occlusionfactor_last = self.to_batch(slots_occlusionfactor_last).detach()
+ position_last2 = self.to_batch(position_last2).detach()
+
+ input = th.cat((position_cur, gestalt_cur, priority_cur, slots_occlusionfactor_cur, position_last, gestalt_last, priority_last, slots_occlusionfactor_last, position_last2), dim=1)
+ output = self.layers(input)
+ if evaluate:
+ output = self.output_function(output)
+ else:
+ noise = th.normal(mean=0, std=self.noise, size=output.shape, device=output.device)
+ output = self.output_function(output + noise)
+
+ return self.to_shared(output)
diff --git a/model/nn/predictor.py b/model/nn/predictor.py
new file mode 100644
index 0000000..94f13b8
--- /dev/null
+++ b/model/nn/predictor.py
@@ -0,0 +1,99 @@
+import torch.nn as nn
+import torch as th
+from model.nn.eprop_transformer import EpropGateL0rdTransformer
+from model.nn.eprop_transformer_shared import EpropGateL0rdTransformerShared
+from model.utils.nn_utils import LambdaModule, Binarize
+from model.nn.residual import ResidualBlock
+from einops import rearrange, repeat, reduce
+
+__author__ = "Manuel Traub"
+
+class LociPredictor(nn.Module):
+ def __init__(
+ self,
+ heads: int,
+ layers: int,
+ channels_multiplier: int,
+ reg_lambda: float,
+ num_objects: int,
+ gestalt_size: int,
+ batch_size: int,
+ bottleneck: str,
+ transformer_type = 'standard',
+ ):
+ super(LociPredictor, self).__init__()
+ self.num_objects = num_objects
+ self.std_alpha = nn.Parameter(th.zeros(1)+1e-16)
+ self.bottleneck_type = bottleneck
+ self.gestalt_size = gestalt_size
+
+ self.reg_lambda = reg_lambda
+ Transformer = EpropGateL0rdTransformerShared if transformer_type == 'shared' else EpropGateL0rdTransformer
+ self.predictor = Transformer(
+ channels = gestalt_size + 3 + 1 + 2,
+ multiplier = channels_multiplier,
+ heads = heads,
+ depth = layers,
+ num_objects = num_objects,
+ reg_lambda = reg_lambda,
+ batch_size = batch_size,
+ )
+
+ if bottleneck == 'binar':
+ print("Binary bottleneck")
+ self.bottleneck = nn.Sequential(
+ LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')),
+ ResidualBlock(gestalt_size, gestalt_size, kernel_size=1),
+ Binarize(),
+ LambdaModule(lambda x: rearrange(x, '(b o) c 1 1 -> b (o c)', o=num_objects))
+ )
+
+ else:
+ print("unrestricted bottleneck")
+ self.bottleneck = nn.Sequential(
+ LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')),
+ ResidualBlock(gestalt_size, gestalt_size, kernel_size=1),
+ nn.Sigmoid(),
+ LambdaModule(lambda x: rearrange(x, '(b o) c 1 1 -> b (o c)', o=num_objects))
+ )
+
+ self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o=num_objects))
+ self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o=num_objects))
+
+ def get_openings(self):
+ return self.predictor.get_openings()
+
+ def get_hidden(self):
+ return self.predictor.get_hidden()
+
+ def set_hidden(self, hidden):
+ self.predictor.set_hidden(hidden)
+
+ def forward(
+ self,
+ gestalt: th.Tensor,
+ priority: th.Tensor,
+ position: th.Tensor,
+ slots_closed: th.Tensor,
+ ):
+
+ position = self.to_batch(position)
+ gestalt_cur = self.to_batch(gestalt)
+ priority = self.to_batch(priority)
+ slots_closed = rearrange(slots_closed, 'b o c -> (b o) c').detach()
+
+ input = th.cat((gestalt_cur, priority, position, slots_closed), dim=1)
+ output = self.predictor(input)
+
+ gestalt = output[:, :self.gestalt_size]
+ priority = output[:,self.gestalt_size:(self.gestalt_size+1)]
+ xy = output[:,(self.gestalt_size+1):(self.gestalt_size+3)]
+ std = output[:,(self.gestalt_size+3):(self.gestalt_size+4)]
+
+ position = th.cat((xy, std * self.std_alpha), dim=1)
+
+ position = self.to_shared(position)
+ gestalt = self.bottleneck(gestalt)
+ priority = self.to_shared(priority)
+
+ return position, gestalt, priority
diff --git a/model/nn/residual.py b/model/nn/residual.py
new file mode 100644
index 0000000..1602e16
--- /dev/null
+++ b/model/nn/residual.py
@@ -0,0 +1,396 @@
+import torch.nn as nn
+import torch as th
+import numpy as np
+from einops import rearrange, repeat, reduce
+from model.utils.nn_utils import LambdaModule
+
+from typing import Union, Tuple
+
+__author__ = "Manuel Traub"
+
+class DynamicLayerNorm(nn.Module):
+
+ def __init__(self, eps: float = 1e-5):
+ super(DynamicLayerNorm, self).__init__()
+ self.eps = eps
+
+ def forward(self, input: th.Tensor) -> th.Tensor:
+ return nn.functional.layer_norm(input, input.shape[2:], None, None, self.eps)
+
+
+class SkipConnection(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ scale_factor: float = 1.0
+ ):
+ super(SkipConnection, self).__init__()
+ assert scale_factor == 1 or int(scale_factor) > 1 or int(1 / scale_factor) > 1, f'invalid scale factor in SpikeFunction: {scale_factor}'
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.scale_factor = scale_factor
+
+ def channel_skip(self, input: th.Tensor):
+ in_channels = self.in_channels
+ out_channels = self.out_channels
+
+ if in_channels == out_channels:
+ return input
+
+ if in_channels % out_channels == 0 or out_channels % in_channels == 0:
+
+ if in_channels > out_channels:
+ return reduce(input, 'b (c n) h w -> b c h w', 'mean', n = in_channels // out_channels)
+
+ if out_channels > in_channels:
+ return repeat(input, 'b c h w -> b (c n) h w', n = out_channels // in_channels)
+
+ mean_channels = np.gcd(in_channels, out_channels)
+ input = reduce(input, 'b (c n) h w -> b c h w', 'mean', n = in_channels // mean_channels)
+ return repeat(input, 'b c h w -> b (c n) h w', n = out_channels // mean_channels)
+
+ def scale_skip(self, input: th.Tensor):
+ scale_factor = self.scale_factor
+
+ if scale_factor == 1:
+ return input
+
+ if scale_factor > 1:
+ return repeat(
+ input,
+ 'b c h w -> b c (h h2) (w w2)',
+ h2 = int(scale_factor),
+ w2 = int(scale_factor)
+ )
+
+ height = input.shape[2]
+ width = input.shape[3]
+
+ # scale factor < 1
+ scale_factor = int(1 / scale_factor)
+
+ if width % scale_factor == 0 and height % scale_factor == 0:
+ return reduce(
+ input,
+ 'b c (h h2) (w w2) -> b c h w',
+ 'mean',
+ h2 = scale_factor,
+ w2 = scale_factor
+ )
+
+ if width >= scale_factor and height >= scale_factor:
+ return nn.functional.avg_pool2d(
+ input,
+ kernel_size = scale_factor,
+ stride = scale_factor
+ )
+
+ assert width > 1 or height > 1
+ return reduce(input, 'b c h w -> b c 1 1', 'mean')
+
+
+ def forward(self, input: th.Tensor):
+
+ if self.scale_factor > 1:
+ return self.scale_skip(self.channel_skip(input))
+
+ return self.channel_skip(self.scale_skip(input))
+
+class DownScale(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ scale_factor: int,
+ groups: int = 1,
+ bias: bool = True
+ ):
+
+ super(DownScale, self).__init__()
+
+ assert(in_channels % groups == 0)
+ assert(out_channels % groups == 0)
+
+ self.groups = groups
+ self.scale_factor = scale_factor
+ self.weight = nn.Parameter(th.empty((out_channels, in_channels // groups, scale_factor, scale_factor)))
+ self.bias = nn.Parameter(th.empty((out_channels,))) if bias else None
+
+ nn.init.kaiming_uniform_(self.weight, a=np.sqrt(5))
+
+ if self.bias is not None:
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
+ bound = 1 / np.sqrt(fan_in)
+ nn.init.uniform_(self.bias, -bound, bound)
+
+ def forward(self, input: th.Tensor):
+ height = input.shape[2]
+ width = input.shape[3]
+ assert height > 1 or width > 1, "trying to dowscale 1x1"
+
+ scale_factor = self.scale_factor
+ padding = [0, 0]
+
+ if height < scale_factor:
+ padding[0] = scale_factor - height
+
+ if width < scale_factor:
+ padding[1] = scale_factor - width
+
+ return nn.functional.conv2d(
+ input,
+ self.weight,
+ bias=self.bias,
+ stride=scale_factor,
+ padding=padding,
+ groups=self.groups
+ )
+
+
+class ResidualBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]] = (3, 3),
+ scale_factor: int = 1,
+ groups: Union[int, Tuple[int, int]] = (1, 1),
+ bias: bool = True,
+ layer_norm: bool = False,
+ leaky_relu: bool = False,
+ residual: bool = True,
+ alpha_residual: bool = False,
+ input_nonlinearity = True
+ ):
+
+ super(ResidualBlock, self).__init__()
+ self.residual = residual
+ self.alpha_residual = alpha_residual
+ self.skip = False
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ if isinstance(kernel_size, int):
+ kernel_size = [kernel_size, kernel_size]
+
+ if isinstance(groups, int):
+ groups = [groups, groups]
+
+ padding = (kernel_size[0] // 2, kernel_size[1] // 2)
+
+ _layers = list()
+ if layer_norm:
+ _layers.append(DynamicLayerNorm())
+
+ if input_nonlinearity:
+ if leaky_relu:
+ _layers.append(nn.LeakyReLU())
+ else:
+ _layers.append(nn.ReLU())
+
+ if scale_factor > 1:
+ _layers.append(
+ nn.ConvTranspose2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=scale_factor,
+ stride=scale_factor,
+ groups=groups[0],
+ bias=bias
+ )
+ )
+ elif scale_factor < 1:
+ _layers.append(
+ DownScale(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ scale_factor=int(1.0/scale_factor),
+ groups=groups[0],
+ bias=bias
+ )
+ )
+ else:
+ _layers.append(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ padding=padding,
+ groups=groups[0],
+ bias=bias
+ )
+ )
+
+ if layer_norm:
+ _layers.append(DynamicLayerNorm())
+ if leaky_relu:
+ _layers.append(nn.LeakyReLU())
+ else:
+ _layers.append(nn.ReLU())
+ _layers.append(
+ nn.Conv2d(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ padding=padding,
+ groups=groups[1],
+ bias=bias
+ )
+ )
+ self.layers = nn.Sequential(*_layers)
+
+ if self.residual:
+ self.skip_connection = SkipConnection(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ scale_factor=scale_factor
+ )
+
+ if self.alpha_residual:
+ self.alpha = nn.Parameter(th.zeros(1) + 1e-12)
+
+ def set_mode(self, **kwargs):
+ if 'skip' in kwargs:
+ self.skip = kwargs['skip']
+
+ if 'residual' in kwargs:
+ self.residual = kwargs['residual']
+
+ def forward(self, input: th.Tensor) -> th.Tensor:
+ if self.skip:
+ return self.skip_connection(input)
+
+ if not self.residual:
+ return self.layers(input)
+
+ if self.alpha_residual:
+ return self.alpha * self.layers(input) + self.skip_connection(input)
+
+ return self.layers(input) + self.skip_connection(input)
+
+class LinearSkip(nn.Module):
+ def __init__(self, num_inputs: int, num_outputs: int):
+ super(LinearSkip, self).__init__()
+
+ self.num_inputs = num_inputs
+ self.num_outputs = num_outputs
+
+ if num_inputs % num_outputs != 0 and num_outputs % num_inputs != 0:
+ mean_channels = np.gcd(num_inputs, num_outputs)
+ print(f"[WW] gcd skip: {num_inputs} -> {mean_channels} -> {num_outputs}")
+ assert(False)
+
+ def forward(self, input: th.Tensor):
+ num_inputs = self.num_inputs
+ num_outputs = self.num_outputs
+
+ if num_inputs == num_outputs:
+ return input
+
+ if num_inputs % num_outputs == 0 or num_outputs % num_inputs == 0:
+
+ if num_inputs > num_outputs:
+ return reduce(input, 'b (c n) -> b c', 'mean', n = num_inputs // num_outputs)
+
+ if num_outputs > num_inputs:
+ return repeat(input, 'b c -> b (c n)', n = num_outputs // num_inputs)
+
+ mean_channels = np.gcd(num_inputs, num_outputs)
+ input = reduce(input, 'b (c n) -> b c', 'mean', n = num_inputs // mean_channels)
+ return repeat(input, 'b c -> b (c n)', n = num_outputs // mean_channels)
+
+class LinearResidual(nn.Module):
+ def __init__(
+ self,
+ num_inputs: int,
+ num_outputs: int,
+ num_hidden: int = None,
+ residual: bool = True,
+ alpha_residual: bool = False,
+ input_relu: bool = True
+ ):
+ super(LinearResidual, self).__init__()
+
+ self.residual = residual
+ self.alpha_residual = alpha_residual
+
+ if num_hidden is None:
+ num_hidden = num_outputs
+
+ _layers = []
+ if input_relu:
+ _layers.append(nn.ReLU())
+ _layers.append(nn.Linear(num_inputs, num_hidden))
+ _layers.append(nn.ReLU())
+ _layers.append(nn.Linear(num_hidden, num_outputs))
+
+ self.layers = nn.Sequential(*_layers)
+
+ if residual:
+ self.skip = LinearSkip(num_inputs, num_outputs)
+
+ if alpha_residual:
+ self.alpha = nn.Parameter(th.zeros(1)+1e-16)
+
+ def forward(self, input: th.Tensor):
+ if not self.residual:
+ return self.layers(input)
+
+ if not self.alpha_residual:
+ return self.skip(input) + self.layers(input)
+
+ return self.skip(input) + self.alpha * self.layers(input)
+
+
+class EntityAttention(nn.Module):
+ def __init__(self, channels, num_objects, size, channels_per_head = 12, dropout = 0.0):
+ super(EntityAttention, self).__init__()
+
+ assert channels % channels_per_head == 0
+ heads = channels // channels_per_head
+
+ self.alpha = nn.Parameter(th.zeros(1)+1e-16)
+ self.attention = nn.MultiheadAttention(
+ channels,
+ heads,
+ dropout = dropout,
+ batch_first = True
+ )
+
+ self.channel_attention = nn.Sequential(
+ LambdaModule(lambda x: rearrange(x, '(b o) c h w -> (b h w) o c', o = num_objects)),
+ LambdaModule(lambda x: self.attention(x, x, x, need_weights=False)[0]),
+ LambdaModule(lambda x: rearrange(x, '(b h w) o c-> (b o) c h w', h = size[0], w = size[1])),
+ )
+
+
+ def forward(self, input: th.Tensor) -> th.Tensor:
+ return input + self.channel_attention(input) * self.alpha
+
+class ImageAttention(nn.Module):
+ def __init__(self, channels, gestalt_size, num_objects, size, channels_per_head = 12, dropout = 0.0):
+ super(ImageAttention, self).__init__()
+
+ assert gestalt_size % channels_per_head == 0
+ heads = gestalt_size // channels_per_head
+
+ self.alpha = nn.Parameter(th.zeros(1)+1e-16)
+ self.attention = nn.MultiheadAttention(
+ gestalt_size,
+ heads,
+ dropout = dropout,
+ batch_first = True
+ )
+
+ self.image_attention = nn.Sequential(
+ nn.Conv2d(channels, gestalt_size, kernel_size=1),
+ LambdaModule(lambda x: rearrange(x, 'b c h w -> b (h w) c')),
+ LambdaModule(lambda x: self.attention(x, x, x, need_weights=False)[0]),
+ LambdaModule(lambda x: rearrange(x, 'b (h w) c -> b c h w', h = size[0], w = size[1])),
+ nn.Conv2d(gestalt_size, channels, kernel_size=1),
+ )
+
+ def forward(self, input: th.Tensor) -> th.Tensor:
+ return input + self.image_attention(input) * self.alpha \ No newline at end of file
diff --git a/model/utils/loss.py b/model/utils/loss.py
new file mode 100644
index 0000000..829bfd3
--- /dev/null
+++ b/model/utils/loss.py
@@ -0,0 +1,97 @@
+import torch as th
+from torch import nn
+from model.utils.nn_utils import SharedObjectsToBatch, LambdaModule
+from einops import rearrange, repeat, reduce
+
+__author__ = "Manuel Traub"
+class PositionLoss(nn.Module):
+ def __init__(self, num_objects: int):
+ super(PositionLoss, self).__init__()
+
+ self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects))
+
+ def forward(self, position, position_last, slot_mask):
+
+ slot_mask = rearrange(slot_mask, 'b o -> (b o) 1 1 1')
+ position = self.to_batch(position)
+ position_last = self.to_batch(position_last).detach()
+
+ return th.mean(slot_mask * (position - position_last)**2)
+
+class ObjectModulator(nn.Module):
+ def __init__(self, num_objects: int):
+ super(ObjectModulator, self).__init__()
+ self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects))
+ self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o = num_objects))
+ self.position = None
+ self.gestalt = None
+
+ def reset_state(self):
+ self.position = None
+ self.gestalt = None
+
+ def forward(self, position: th.Tensor, gestalt: th.Tensor, slot_mask: th.Tensor):
+
+ position = self.to_batch(position)
+ gestalt = self.to_batch(gestalt)
+ slot_mask = self.to_batch(slot_mask)
+
+ if self.position is None or self.gestalt is None:
+ self.position = position.detach()
+ self.gestalt = gestalt.detach()
+ return self.to_shared(position), self.to_shared(gestalt)
+
+ _position = slot_mask * position + (1 - slot_mask) * self.position
+ position = th.cat((position[:,:-1], _position[:,-1:]), dim=1) # keep the position of the objects fixed
+ gestalt = slot_mask * gestalt + (1 - slot_mask) * self.gestalt
+
+ self.gestalt = gestalt.detach()
+ self.position = position.detach()
+
+ return self.to_shared(position), self.to_shared(gestalt)
+
+class MoveToCenter(nn.Module):
+ def __init__(self, num_objects: int):
+ super(MoveToCenter, self).__init__()
+
+ self.to_batch2d = SharedObjectsToBatch(num_objects)
+ self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects))
+
+ def forward(self, input: th.Tensor, position: th.Tensor):
+
+ input = self.to_batch2d(input) # b (o c) h w -> (b o) c h w
+ position = self.to_batch(position).detach()
+ position = th.stack((position[:,1], position[:,0]), dim=1)
+
+ theta = th.tensor([1, 0, 0, 1], dtype=th.float, device=input.device).view(1,2,2)
+ theta = repeat(theta, '1 a b -> n a b', n=input.shape[0])
+
+ position = rearrange(position, 'b c -> b c 1')
+ theta = th.cat((theta, position), dim=2)
+
+ grid = nn.functional.affine_grid(theta, input.shape, align_corners=False)
+ output = nn.functional.grid_sample(input, grid, align_corners=False)
+
+ return output
+
+class TranslationInvariantObjectLoss(nn.Module):
+ def __init__(self, num_objects: int):
+ super(TranslationInvariantObjectLoss, self).__init__()
+
+ self.move_to_center = MoveToCenter(num_objects)
+ self.to_batch = SharedObjectsToBatch(num_objects)
+
+ def forward(
+ self,
+ slot_mask: th.Tensor,
+ object1: th.Tensor,
+ position1: th.Tensor,
+ object2: th.Tensor,
+ position2: th.Tensor,
+ ):
+ slot_mask = rearrange(slot_mask, 'b o -> (b o) 1 1 1')
+ object1 = self.move_to_center(th.sigmoid(object1 - 2.5), position1)
+ object2 = self.move_to_center(th.sigmoid(object2 - 2.5), position2)
+
+ return th.mean(slot_mask * (object1 - object2)**2)
+
diff --git a/model/utils/nn_utils.py b/model/utils/nn_utils.py
new file mode 100644
index 0000000..5116e14
--- /dev/null
+++ b/model/utils/nn_utils.py
@@ -0,0 +1,298 @@
+from typing import Tuple
+import torch.nn as nn
+import torch as th
+import numpy as np
+from torch.autograd import Function
+from einops import rearrange, repeat, reduce
+
+class PushToInfFunction(Function):
+ @staticmethod
+ def forward(ctx, tensor):
+ ctx.save_for_backward(tensor)
+ return tensor.clone()
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ tensor = ctx.saved_tensors[0]
+ grad_input = -th.ones_like(grad_output)
+ return grad_input
+
+class PushToInf(nn.Module):
+ def __init__(self):
+ super(PushToInf, self).__init__()
+
+ self.fcn = PushToInfFunction.apply
+
+ def forward(self, input: th.Tensor):
+ return self.fcn(input)
+
+class ForcedAlpha(nn.Module):
+ def __init__(self, speed = 1):
+ super(ForcedAlpha, self).__init__()
+
+ self.init = nn.Parameter(th.zeros(1))
+ self.speed = speed
+ self.to_inf = PushToInf()
+
+ def item(self):
+ return th.tanh(self.to_inf(self.init * self.speed)).item()
+
+ def forward(self, input: th.Tensor):
+ return input * th.tanh(self.to_inf(self.init * self.speed))
+
+class LinearInterpolation(nn.Module):
+ def __init__(self, num_objects):
+ super(LinearInterpolation, self).__init__()
+ self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects))
+ self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o = num_objects))
+
+ def forward(
+ self,
+ tensor_cur: th.Tensor = None,
+ tensor_last: th.Tensor = None,
+ slot_interpolation_value: th.Tensor = None
+ ):
+
+ slot_interpolation_value = rearrange(slot_interpolation_value, 'b o -> (b o) 1')
+ tensor_cur = slot_interpolation_value * self.to_batch(tensor_last) + (1 - slot_interpolation_value) * self.to_batch(tensor_cur)
+
+ return self.to_shared(tensor_cur)
+
+class Gaus2D(nn.Module):
+ def __init__(self, size: Tuple[int, int]):
+ super(Gaus2D, self).__init__()
+
+ self.size = size
+
+ self.register_buffer("grid_x", th.arange(size[0]), persistent=False)
+ self.register_buffer("grid_y", th.arange(size[1]), persistent=False)
+
+ self.grid_x = (self.grid_x / (size[0]-1)) * 2 - 1
+ self.grid_y = (self.grid_y / (size[1]-1)) * 2 - 1
+
+ self.grid_x = self.grid_x.view(1, 1, -1, 1).expand(1, 1, *size).clone()
+ self.grid_y = self.grid_y.view(1, 1, 1, -1).expand(1, 1, *size).clone()
+
+ def forward(self, input: th.Tensor):
+
+ x = rearrange(input[:,0:1], 'b c -> b c 1 1')
+ y = rearrange(input[:,1:2], 'b c -> b c 1 1')
+ std = rearrange(input[:,2:3], 'b c -> b c 1 1')
+
+ x = th.clip(x, -1, 1)
+ y = th.clip(y, -1, 1)
+ std = th.clip(std, 0, 1)
+
+ max_size = max(self.size)
+ std_x = (1 + max_size * std) / self.size[0]
+ std_y = (1 + max_size * std) / self.size[1]
+
+ return th.exp(-1 * ((self.grid_x - x)**2/(2 * std_x**2) + (self.grid_y - y)**2/(2 * std_y**2)))
+
+class Vector2D(nn.Module):
+ def __init__(self, size: Tuple[int, int]):
+ super(Vector2D, self).__init__()
+
+ self.size = size
+
+ self.register_buffer("grid_x", th.arange(size[0]), persistent=False)
+ self.register_buffer("grid_y", th.arange(size[1]), persistent=False)
+
+ self.grid_x = (self.grid_x / (size[0]-1)) * 2 - 1
+ self.grid_y = (self.grid_y / (size[1]-1)) * 2 - 1
+
+ self.grid_x = self.grid_x.view(1, 1, -1, 1).expand(1, 3, *size).clone()
+ self.grid_y = self.grid_y.view(1, 1, 1, -1).expand(1, 3, *size).clone()
+
+ def forward(self, input: th.Tensor, vector: th.Tensor = None):
+
+ x = rearrange(input[:,0:1], 'b c -> b c 1 1')
+ y = rearrange(input[:,1:2], 'b c -> b c 1 1')
+ if vector is not None:
+ x_vec = rearrange(vector[:,0:1], 'b c -> b c 1 1')
+ y_vec = rearrange(vector[:,1:2], 'b c -> b c 1 1')
+
+ x = th.clip(x, -1, 1)
+ y = th.clip(y, -1, 1)
+ std = 0.01
+
+ max_size = max(self.size)
+ std_x = (1 + max_size * std) / self.size[0]
+ std_y = (1 + max_size * std) / self.size[1]
+ grid = th.exp(-1 * ((self.grid_x - x)**2/(2 * std_x**2) + (self.grid_y - y)**2/(2 * std_y**2)))
+
+ # interpolating between start and end point
+ if vector is not None:
+ for length in np.linspace(0, 1, 11):
+ x_end = th.clip(x + x_vec * length, -1, 1)
+ y_end = th.clip(y + y_vec * length, -1, 1)
+
+ grid_point = th.exp(-1 * ((self.grid_x - x_end)**2/(0.5 * std_x**2) + (self.grid_y - y_end)**2/(0.5 * std_y**2)))
+ grid_point[:, 0:2, :, :] = 0
+ grid = th.max(grid, grid_point)
+
+ return grid
+
+class SharedObjectsToBatch(nn.Module):
+ def __init__(self, num_objects):
+ super(SharedObjectsToBatch, self).__init__()
+
+ self.num_objects = num_objects
+
+ def forward(self, input: th.Tensor):
+ return rearrange(input, 'b (o c) h w -> (b o) c h w', o=self.num_objects)
+
+class BatchToSharedObjects(nn.Module):
+ def __init__(self, num_objects):
+ super(BatchToSharedObjects, self).__init__()
+
+ self.num_objects = num_objects
+
+ def forward(self, input: th.Tensor):
+ return rearrange(input, '(b o) c h w -> b (o c) h w', o=self.num_objects)
+
+class LambdaModule(nn.Module):
+ def __init__(self, lambd):
+ super().__init__()
+ import types
+ assert type(lambd) is types.LambdaType
+ self.lambd = lambd
+
+ def forward(self, *x):
+ return self.lambd(*x)
+
+class PrintGradientFunction(Function):
+ @staticmethod
+ def forward(ctx, tensor, msg):
+ ctx.msg = msg
+ return tensor
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_input = grad_output.clone()
+ print(f"{ctx.msg}: {th.mean(grad_output).item()} +- {th.std(grad_output).item()}")
+ return grad_input, None
+
+class PrintGradient(nn.Module):
+ def __init__(self, msg = "PrintGradient"):
+ super(PrintGradient, self).__init__()
+
+ self.fcn = PrintGradientFunction.apply
+ self.msg = msg
+
+ def forward(self, input: th.Tensor):
+ return self.fcn(input, self.msg)
+
+class Prioritize(nn.Module):
+ def __init__(self, num_objects):
+ super(Prioritize, self).__init__()
+
+ self.num_objects = num_objects
+ self.to_batch = SharedObjectsToBatch(num_objects)
+
+ def forward(self, input: th.Tensor, priority: th.Tensor):
+
+ if priority is None:
+ return input
+
+ batch_size = input.shape[0]
+ weights = th.zeros((batch_size, self.num_objects, self.num_objects, 1, 1), device=input.device)
+
+ for o in range(self.num_objects):
+ weights[:,o,:,0,0] = th.sigmoid(priority[:,:] - priority[:,o:o+1])
+ weights[:,o,o,0,0] = weights[:,o,o,0,0] * 0
+
+ input = rearrange(input, 'b c h w -> 1 (b c) h w')
+ weights = rearrange(weights, 'b o i 1 1 -> (b o) i 1 1')
+
+ output = th.relu(input - nn.functional.conv2d(input, weights, groups=batch_size))
+ output = rearrange(output, '1 (b c) h w -> b c h w ', b=batch_size)
+
+ return output
+
+class MultiArgSequential(nn.Sequential):
+ def __init__(self, *args, **kwargs):
+ super(MultiArgSequential, self).__init__(*args, **kwargs)
+
+ def forward(self, *tensor):
+
+ for n in range(len(self)):
+ if isinstance(tensor, th.Tensor) or tensor == None:
+ tensor = self[n](tensor)
+ else:
+ tensor = self[n](*tensor)
+
+ return tensor
+
+def create_grid(size):
+ grid_x = th.arange(size[0])
+ grid_y = th.arange(size[1])
+
+ grid_x = (grid_x / (size[0]-1)) * 2 - 1
+ grid_y = (grid_y / (size[1]-1)) * 2 - 1
+
+ grid_x = grid_x.view(1, 1, -1, 1).expand(1, 1, *size).clone()
+ grid_y = grid_y.view(1, 1, 1, -1).expand(1, 1, *size).clone()
+
+ return th.cat((grid_y, grid_x), dim=1)
+
+class Warp(nn.Module):
+ def __init__(self, size, padding = 0.1):
+ super(Warp, self).__init__()
+
+ padding = int(max(size) * padding)
+ padded_size = (size[0] + 2 * padding, size[1] + 2 * padding)
+
+ self.register_buffer('grid', create_grid(size))
+ self.register_buffer('padded_grid', create_grid(padded_size))
+
+ self.replication_pad = nn.ReplicationPad2d(padding)
+ self.interpolate = nn.Sequential(
+ LambdaModule(lambda x:
+ th.nn.functional.interpolate(x, size=size, mode='bicubic', align_corners = True)
+ ),
+ LambdaModule(lambda x: x - self.grid),
+ nn.ConstantPad2d(padding, 0),
+ LambdaModule(lambda x: x + self.padded_grid),
+ LambdaModule(lambda x: rearrange(x, 'b c h w -> b h w c'))
+ )
+
+ self.warp = LambdaModule(lambda input, flow:
+ th.nn.functional.grid_sample(input, flow, mode='bicubic', align_corners=True)
+ )
+
+ self.un_pad = LambdaModule(lambda x: x[:,:,padding:-padding,padding:-padding])
+
+ def get_raw_flow(self, flow):
+ return flow - self.grid
+
+ def forward(self, input, flow):
+ input = self.replication_pad(input)
+ flow = self.interpolate(flow)
+ return self.un_pad(self.warp(input, flow))
+
+class Binarize(nn.Module):
+ def __init__(self):
+ super(Binarize, self).__init__()
+
+ def forward(self, input: th.Tensor):
+ input = th.sigmoid(input)
+ if not self.training:
+ return th.round(input)
+
+ return input + input * (1 - input) * th.randn_like(input)
+
+class TanhAlpha(nn.Module):
+ def __init__(self, start = 0, stepsize = 1e-4, max_value = 1):
+ super(TanhAlpha, self).__init__()
+
+ self.register_buffer('init', th.zeros(1) + start)
+ self.stepsize = stepsize
+ self.max_value = max_value
+
+ def get(self):
+ return (th.tanh(self.init) * self.max_value).item()
+
+ def forward(self):
+ self.init = self.init.detach() + self.stepsize
+ return self.get() \ No newline at end of file
diff --git a/model/utils/slot_utils.py b/model/utils/slot_utils.py
new file mode 100644
index 0000000..936248b
--- /dev/null
+++ b/model/utils/slot_utils.py
@@ -0,0 +1,326 @@
+import torch.nn as nn
+import torch as th
+import torchvision.transforms as transforms
+import torch.nn.functional as F
+from einops import rearrange, repeat, reduce
+from typing import Tuple, Union, List
+from model.utils.nn_utils import Gaus2D, LambdaModule, TanhAlpha
+
+class InitialLatentStates(nn.Module):
+ def __init__(
+ self,
+ gestalt_size: int,
+ num_objects: int,
+ bottleneck: str,
+ size: Tuple[int, int],
+ teacher_forcing: int
+ ):
+ super(InitialLatentStates, self).__init__()
+ self.bottleneck = bottleneck
+
+ self.num_objects = num_objects
+ self.gestalt_mean = nn.Parameter(th.zeros(1, gestalt_size))
+ self.gestalt_std = nn.Parameter(th.ones(1, gestalt_size))
+ self.std = nn.Parameter(th.zeros(1))
+ self.gestalt_strength = 2
+ self.teacher_forcing = teacher_forcing
+
+ self.init = TanhAlpha(start = -1)
+ self.register_buffer('priority', th.arange(num_objects).float() * 25, persistent=False)
+ self.register_buffer('threshold', th.ones(1) * 0.8)
+ self.last_mask = None
+ self.binarize_first = round(gestalt_size * 0.8)
+
+ self.gaus2d = nn.Sequential(
+ Gaus2D((size[0] // 16, size[1] // 16)),
+ Gaus2D((size[0] // 4, size[1] // 4)),
+ Gaus2D(size)
+ )
+
+ self.level = 1
+ self.t = 0
+
+ self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects))
+ self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o = num_objects))
+
+ self.blur = transforms.GaussianBlur(13)
+ self.size = size
+
+ def reset_state(self):
+ self.last_mask = None
+ self.t = 0
+ self.to_next_spawn = 0
+
+ def set_level(self, level):
+ self.level = level
+ factor = int(4 / (level ** 2))
+ self.to_position = ErrorToPosition((self.size[0] // factor, self.size[1] // factor))
+
+ def forward(
+ self,
+ error: th.Tensor,
+ mask: th.Tensor = None,
+ position: th.Tensor = None,
+ gestalt: th.Tensor = None,
+ priority: th.Tensor = None,
+ shuffleslots: bool = True,
+ slots_bounded_last: th.Tensor = None,
+ slots_occlusionfactor_last: th.Tensor = None,
+ allow_spawn: bool = True,
+ clean_slots: bool = False
+ ):
+
+ batch_size = error.shape[0]
+ device = error.device
+
+ if self.init.get() < 1:
+ self.gestalt_strength = self.init()
+
+ if self.last_mask is None:
+ self.last_mask = th.zeros((batch_size * self.num_objects, 1), device = device)
+ if shuffleslots:
+ self.slots_assigned = th.ones((batch_size * self.num_objects, 1), device = device)
+ else:
+ self.slots_assigned = th.zeros((batch_size * self.num_objects, 1), device = device)
+
+ if not allow_spawn:
+ unnassigned = self.slots_assigned - slots_bounded_last
+ self.slots_assigned = self.slots_assigned - unnassigned
+
+ if clean_slots and (slots_occlusionfactor_last is not None):
+ occluded = self.slots_assigned * (self.to_batch(slots_occlusionfactor_last) > 0.1).float()
+ self.slots_assigned = self.slots_assigned - occluded
+
+ if (slots_bounded_last is None) or (self.gestalt_strength < 1):
+
+ if mask is not None:
+ # maximum berechnung --> slot gebunden c=o
+ mask2 = reduce(mask[:,:-1], 'b c h w -> (b c) 1' , 'max').detach()
+
+ if self.gestalt_strength <= 0:
+ self.last_mask = mask2
+ elif self.gestalt_strength < 1:
+ self.last_mask = th.maximum(self.last_mask, mask2)
+ self.last_mask = self.last_mask - th.relu(-1 * (mask2 - self.threshold) * (1 - self.gestalt_strength))
+ else:
+ self.last_mask = th.maximum(self.last_mask, mask2)
+
+ slots_bounded = (self.last_mask > self.threshold).float().detach() * self.slots_assigned
+ else:
+ slots_bounded = slots_bounded_last * self.slots_assigned
+
+ if self.bottleneck == "binar":
+ gestalt_new = repeat(th.sigmoid(self.gestalt_mean), '1 c -> b c', b = batch_size * self.num_objects)
+ gestalt_new = gestalt_new + gestalt_new * (1 - gestalt_new) * th.randn_like(gestalt_new)
+ else:
+ gestalt_mean = repeat(self.gestalt_mean, '1 c -> b c', b = batch_size * self.num_objects)
+ gestalt_std = repeat(self.gestalt_std, '1 c -> b c', b = batch_size * self.num_objects)
+ gestalt_new = th.sigmoid(gestalt_mean + gestalt_std * th.randn_like(gestalt_std))
+
+ if gestalt is None:
+ gestalt = gestalt_new
+ else:
+ gestalt = self.to_batch(gestalt) * slots_bounded + gestalt_new * (1 - slots_bounded)
+
+ if priority is None:
+ priority = repeat(self.priority, 'o -> (b o) 1', b = batch_size)
+ else:
+ priority = self.to_batch(priority) * slots_bounded + repeat(self.priority, 'o -> (b o) 1', b = batch_size) * (1 - slots_bounded)
+
+
+ if shuffleslots:
+ self.slots_assigned = th.ones_like(self.slots_assigned)
+
+ xy_rand_new = th.rand((batch_size * self.num_objects * 10, 2), device = device) * 2 - 1
+ std_new = th.zeros((batch_size * self.num_objects * 10, 1), device = device)
+ position_new = th.cat((xy_rand_new, std_new), dim=1)
+
+ position2d = self.gaus2d[self.level](position_new)
+ position2d = rearrange(position2d, '(b o) 1 h w -> b o h w', b = batch_size)
+
+ rand_error = reduce(position2d * error, 'b o h w -> (b o) 1', 'sum')
+
+ xy_rand_new = rearrange(xy_rand_new, '(b r) c -> r b c', r = 10)
+ rand_error = rearrange(rand_error, '(b r) c -> r b c', r = 10)
+
+ max_error = th.argmax(rand_error, dim=0, keepdim=True)
+ x, y = th.chunk(xy_rand_new, 2, dim=2)
+ x = th.gather(x, dim=0, index=max_error).detach().squeeze(dim=0)
+ y = th.gather(y, dim=0, index=max_error).detach().squeeze(dim=0)
+ std = repeat(self.std, '1 -> (b o) 1', b = batch_size, o=self.num_objects)
+
+ if position is None:
+ position = th.cat((x, y, std), dim=1)
+ else:
+ position = self.to_batch(position) * slots_bounded + th.cat((x, y, std), dim=1) * (1 - slots_bounded)
+
+ else:
+
+ # set unassigned slots to empty position
+ empty_position = th.tensor([-1,-1,0]).to(device)
+ empty_position = repeat(empty_position, 'c -> (b o) c', b = batch_size, o=self.num_objects).detach()
+
+ if position is None:
+ position = empty_position
+ else:
+ position = self.to_batch(position) * self.slots_assigned + empty_position * (1 - self.slots_assigned)
+
+
+ # blur errror, and set masked areas to zero
+ error = self.blur(error)
+ if mask is not None:
+ mask2 = mask[:,:-1] * rearrange(slots_bounded, '(b o) 1 -> b o 1 1', b = batch_size)
+ mask2 = th.sum(mask2, dim=1, keepdim=True)
+ error = error * (1-mask2)
+ max_error = reduce(error, 'b o h w -> (b o) 1', 'max')
+
+ if self.to_next_spawn <= 0 and allow_spawn:
+
+ self.to_next_spawn = 2
+
+ # calculate the position with the highest error
+ new_pos = self.to_position(error)
+ std = repeat(self.std, '1 -> b 1', b = batch_size)
+ new_pos = repeat(th.cat((new_pos, std), dim=1), 'b c -> (b o) c', o = self.num_objects)
+
+ # calculate if an assigned slot is unbound (-->free)
+ n_slots_assigned = self.to_shared(self.slots_assigned).sum(dim=1, keepdim=True)
+ n_slots_bounded = self.to_shared(slots_bounded).sum(dim=1, keepdim=True)
+ free_slot_given = th.clip(n_slots_assigned - n_slots_bounded, 0, 1)
+
+ # either spawn a new slot or use the one that is free
+ slots_new_index = n_slots_assigned * (1-free_slot_given) + n_slots_bounded * free_slot_given # reset the free slot each timespawn
+
+ # new slot index
+ free_slot_required = (max_error > 0).float()
+ slots_new_index = F.one_hot(slots_new_index.long(), num_classes=self.num_objects+1).float().squeeze(dim=1)[:,:-1]
+ slots_new_index = self.to_batch(slots_new_index * free_slot_required)
+
+ # place new free slot
+ position = new_pos * slots_new_index + position * (1 - slots_new_index)
+ self.slots_assigned = th.clip(self.slots_assigned + slots_new_index, 0, 1)
+
+ self.to_next_spawn -= 1
+ return self.to_shared(position), self.to_shared(gestalt), self.to_shared(priority), error
+
+ def get_slots_unassigned(self):
+ return self.to_shared(1-self.slots_assigned)
+
+ def get_slots_assigned(self):
+ return self.to_shared(self.slots_assigned)
+
+
+class OcclusionTracker(nn.Module):
+ def __init__(self, batch_size, num_objects, device):
+ super(OcclusionTracker, self).__init__()
+ self.batch_size = batch_size
+ self.num_objects = num_objects
+ self.slots_bounded_all = th.zeros((batch_size * num_objects, 1)).to(device)
+ self.threshold = 0.8
+ self.device = device
+ self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o = num_objects))
+ self.slots_bounded_next_last = None
+
+ def forward(
+ self,
+ mask: th.Tensor = None,
+ rawmask: th.Tensor = None,
+ reset_mask: bool = False,
+ update: bool = True
+ ):
+
+ if mask is not None:
+
+ # compute bounding mask
+ slots_bounded_smooth_cur = reduce(mask[:,:-1], 'b o h w -> (b o) 1' , 'max').detach()
+ slots_bounded_cur = (slots_bounded_smooth_cur > self.threshold).float().detach()
+ if reset_mask:
+ self.slots_bounded_next_last = slots_bounded_cur # allow immediate spawn
+
+ if update:
+ slots_bounded_cur = slots_bounded_cur * th.clip(self.slots_bounded_next_last + self.slots_bounded_all, 0, 1)
+ else:
+ self.slots_bounded_next_last = slots_bounded_cur
+
+ if reset_mask:
+ self.slots_bounded_smooth_all = slots_bounded_smooth_cur
+ self.slots_bounded_all = slots_bounded_cur
+ elif update:
+ self.slots_bounded_all = th.maximum(self.slots_bounded_all, slots_bounded_cur)
+ self.slots_bounded_smooth_all = th.maximum(self.slots_bounded_smooth_all, slots_bounded_smooth_cur)
+
+ # compute occlusion mask
+ slots_occluded_cur = self.slots_bounded_all - slots_bounded_cur
+
+ # compute partially occluded mask
+ mask = (mask[:,:-1] > self.threshold).float().detach()
+ rawmask = (rawmask[:,:-1] > self.threshold).float().detach()
+ masked = rawmask - mask
+
+ masked = reduce(masked, 'b o h w -> (b o) 1' , 'sum')
+ rawmask = reduce(rawmask, 'b o h w -> (b o) 1' , 'sum')
+
+ slots_occlusionfactor_cur = (masked / (rawmask + 1)) * (1-slots_occluded_cur) + slots_occluded_cur
+ slots_partially_occluded = (slots_occlusionfactor_cur > 0.1).float() #* slots_bounded_cur
+ slots_fully_visible = (slots_occlusionfactor_cur <= 0.1).float() * slots_bounded_cur
+
+ if reset_mask:
+ self.slots_fully_visible_all = slots_fully_visible
+ elif update:
+ self.slots_fully_visible_all = th.maximum(self.slots_fully_visible_all, slots_fully_visible)
+
+ return self.to_shared(self.slots_bounded_all), self.to_shared(self.slots_bounded_smooth_all), self.to_shared(slots_occluded_cur), self.to_shared(slots_partially_occluded), self.to_shared(slots_fully_visible), self.to_shared(slots_occlusionfactor_cur)
+
+ def get_slots_fully_visible_all(self):
+ return self.to_shared(self.slots_fully_visible_all)
+
+class ErrorToPosition(nn.Module):
+ def __init__(self, size: Union[int, Tuple[int, int]]):
+ super(ErrorToPosition, self).__init__()
+
+ self.register_buffer("grid_x", th.arange(size[0]), persistent=False)
+ self.register_buffer("grid_y", th.arange(size[1]), persistent=False)
+
+ self.grid_x = (self.grid_x / (size[0]-1)) * 2 - 1
+ self.grid_y = (self.grid_y / (size[1]-1)) * 2 - 1
+
+ self.grid_x = self.grid_x.view(1, 1, -1, 1).expand(1, 1, *size).clone()
+ self.grid_y = self.grid_y.view(1, 1, 1, -1).expand(1, 1, *size).clone()
+
+ self.grid_x = self.grid_x.view(1, 1, -1)
+ self.grid_y = self.grid_y.view(1, 1, -1)
+
+ self.size = size
+
+ def forward(self, input: th.Tensor):
+ assert input.shape[1] == 1
+
+ input = rearrange(input, 'b c h w -> b c (h w)')
+ argmax = th.argmax(input, dim=2, keepdim=True)
+
+ x = self.grid_x[0,0,argmax].squeeze(dim=2)
+ y = self.grid_y[0,0,argmax].squeeze(dim=2)
+
+ return th.cat((x,y),dim=1)
+
+
+def compute_rawmask(mask, bg_mask):
+
+ num_objects = mask.shape[1]
+
+ # d is a diagonal matrix which defines what to take the softmax over
+ d_mask = th.diag(th.ones(num_objects+1)).to(mask.device)
+ d_mask[:,-1] = 1
+ d_mask[-1,-1] = 0
+
+ # take subset of rawmask with the diagonal matrix
+ rawmask = th.cat((mask, bg_mask), dim=1)
+ rawmask = repeat(rawmask, 'b o h w -> b r o h w', r = num_objects+1)
+ rawmask = rawmask[:,d_mask.bool()]
+ rawmask = rearrange(rawmask, 'b (o r) h w -> b o r h w', o = num_objects)
+
+ # take softmax between each object mask and the background mask
+ rawmask = th.squeeze(th.softmax(rawmask, dim=2)[:,:,0], dim=2)
+ rawmask = th.cat((rawmask, bg_mask), dim=1) # add background mask
+
+ return rawmask \ No newline at end of file