diff options
author | fredeee | 2023-11-02 10:47:21 +0100 |
---|---|---|
committer | fredeee | 2023-11-02 10:47:21 +0100 |
commit | f8302ee886ef9b631f11a52900dac964a61350e1 (patch) | |
tree | 87288be6f851ab69405e524b81940c501c52789a /model | |
parent | f16fef1ab9371e1c81a2e0b2fbea59dee285a9f8 (diff) |
initiaƶ commit
Diffstat (limited to 'model')
-rw-r--r-- | model/loci.py | 331 | ||||
-rw-r--r-- | model/nn/background.py | 167 | ||||
-rw-r--r-- | model/nn/decoder.py | 169 | ||||
-rw-r--r-- | model/nn/encoder.py | 269 | ||||
-rw-r--r-- | model/nn/eprop_gate_l0rd.py | 329 | ||||
-rw-r--r-- | model/nn/eprop_transformer.py | 76 | ||||
-rw-r--r-- | model/nn/eprop_transformer_shared.py | 92 | ||||
-rw-r--r-- | model/nn/eprop_transformer_utils.py | 66 | ||||
-rw-r--r-- | model/nn/percept_gate_controller.py | 59 | ||||
-rw-r--r-- | model/nn/predictor.py | 99 | ||||
-rw-r--r-- | model/nn/residual.py | 396 | ||||
-rw-r--r-- | model/utils/loss.py | 97 | ||||
-rw-r--r-- | model/utils/nn_utils.py | 298 | ||||
-rw-r--r-- | model/utils/slot_utils.py | 326 |
14 files changed, 2774 insertions, 0 deletions
diff --git a/model/loci.py b/model/loci.py new file mode 100644 index 0000000..96088bd --- /dev/null +++ b/model/loci.py @@ -0,0 +1,331 @@ +import torch as th +import torch.nn as nn +from einops import rearrange, repeat, reduce +from model.nn.percept_gate_controller import PerceptGateController +from model.nn.decoder import LociDecoder +from model.nn.encoder import LociEncoder +from model.nn.predictor import LociPredictor +from model.nn.background import BackgroundEnhancer +from model.utils.nn_utils import LinearInterpolation +from model.utils.loss import ObjectModulator, TranslationInvariantObjectLoss, PositionLoss +from model.utils.slot_utils import OcclusionTracker, InitialLatentStates, compute_rawmask + +class Loci(nn.Module): + def __init__( + self, + cfg, + teacher_forcing=1, + ): + super(Loci, self).__init__() + + self.teacher_forcing = teacher_forcing + self.cfg = cfg + + self.encoder = LociEncoder( + input_size = cfg.input_size, + latent_size = cfg.latent_size, + num_objects = cfg.num_objects, + img_channels = cfg.img_channels * 2 + 6, + hidden_channels = cfg.encoder.channels, + level1_channels = cfg.encoder.level1_channels, + num_layers = cfg.encoder.num_layers, + gestalt_size = cfg.gestalt_size, + bottleneck = cfg.bottleneck, + ) + + self.percept_gate_controller = PerceptGateController( + num_inputs = 2*(cfg.gestalt_size + 3 + 1 + 1) + 3, + num_hidden = [32, 16], + bias = True, + num_objects = cfg.num_objects, + reg_lambda=cfg.update_module.reg_lambda, + ) + + self.predictor = LociPredictor( + num_objects = cfg.num_objects, + gestalt_size = cfg.gestalt_size, + bottleneck = cfg.bottleneck, + channels_multiplier = cfg.predictor.channels_multiplier, + heads = cfg.predictor.heads, + layers = cfg.predictor.layers, + reg_lambda = cfg.predictor.reg_lambda, + batch_size = cfg.batch_size, + transformer_type = cfg.predictor.transformer_type, + ) + + self.decoder = LociDecoder( + latent_size = cfg.latent_size, + num_objects = cfg.num_objects, + gestalt_size = cfg.gestalt_size, + img_channels = cfg.img_channels, + hidden_channels = cfg.decoder.channels, + level1_channels = cfg.decoder.level1_channels, + num_layers = cfg.decoder.num_layers, + batch_size = cfg.batch_size, + ) + + self.background = BackgroundEnhancer( + input_size = cfg.input_size, + gestalt_size = cfg.background.gestalt_size, + img_channels = cfg.img_channels, + depth = cfg.background.num_layers, + latent_channels = cfg.background.latent_channels, + level1_channels = cfg.background.level1_channels, + batch_size = cfg.batch_size, + ) + + self.initial_states = InitialLatentStates( + gestalt_size = cfg.gestalt_size, + bottleneck = cfg.bottleneck, + num_objects = cfg.num_objects, + size = cfg.input_size, + teacher_forcing = teacher_forcing + ) + + self.occlusion_tracker = OcclusionTracker( + batch_size=cfg.batch_size, + num_objects=cfg.num_objects, + device=cfg.device + ) + + self.translation_invariant_object_loss = TranslationInvariantObjectLoss(cfg.num_objects) + self.position_loss = PositionLoss(cfg.num_objects) + self.modulator = ObjectModulator(cfg.num_objects) + self.linear_gate = LinearInterpolation(cfg.num_objects) + + self.background.set_level(cfg.level) + self.encoder.set_level(cfg.level) + self.decoder.set_level(cfg.level) + self.initial_states.set_level(cfg.level) + + def get_init_status(self): + init = [] + for module in self.modules(): + if callable(getattr(module, "get_init", None)): + init.append(module.get_init()) + + assert len(set(init)) == 1 + return init[0] + + def inc_init_level(self): + for module in self.modules(): + if callable(getattr(module, "step_init", None)): + module.step_init() + + def get_openings(self): + return self.predictor.get_openings() + + def detach(self): + for module in self.modules(): + if module != self and callable(getattr(module, "detach", None)): + module.detach() + + def reset_state(self): + for module in self.modules(): + if module != self and callable(getattr(module, "reset_state", None)): + module.reset_state() + + def forward(self, *input, reset=True, detach=True, mode='end2end', evaluate=False, train_background=False, warmup=False, shuffleslots = True, reset_mask = False, allow_spawn = True, show_hidden = False, clean_slots = False): + + if detach: + self.detach() + + if reset: + self.reset_state() + + if train_background or self.get_init_status() < 1: + return self.background(*input) + + return self.run_end2end(*input, evaluate=evaluate, warmup=warmup, shuffleslots = shuffleslots, reset_mask = reset_mask, allow_spawn = allow_spawn, show_hidden = show_hidden, clean_slots = clean_slots) + + def run_decoder( + self, + position: th.Tensor, + gestalt: th.Tensor, + priority: th.Tensor, + bg_mask: th.Tensor, + background: th.Tensor, + only_mask: bool = False + ): + mask, object = self.decoder(position, gestalt, priority) + + rawmask = compute_rawmask(mask, bg_mask) + mask = th.softmax(th.cat((mask, bg_mask), dim=1), dim=1) + + if only_mask: + return mask, rawmask + + object = th.cat((th.sigmoid(object - 2.5), background), dim=1) + _mask = mask.unsqueeze(dim=2) + _object = object.view( + mask.shape[0], + self.cfg.num_objects + 1, + self.cfg.img_channels, + *mask.shape[2:] + ) + + output = th.sum(_mask * _object, dim=1) + return output, mask, object, rawmask + + def reset_unassigned_slots(self, position, gestalt, priority): + + position = self.linear_gate(position, th.zeros_like(position)-1, self.initial_states.get_slots_unassigned()) + gestalt = self.linear_gate(gestalt, th.zeros_like(gestalt), self.initial_states.get_slots_unassigned()) + priority = self.linear_gate(priority, th.zeros_like(priority), self.initial_states.get_slots_unassigned()) + + return position, gestalt, priority + + def run_end2end( + self, + input: th.Tensor, + error_last: th.Tensor = None, + mask_last: th.Tensor = None, + rawmask_last: th.Tensor = None, + position_last: th.Tensor = None, + gestalt_last: th.Tensor = None, + priority_last: th.Tensor = None, + background = None, + slots_occlusionfactor_last: th.Tensor = None, + evaluate = False, + warmup = False, + shuffleslots = True, + reset_mask = False, + allow_spawn = True, + show_hidden = False, + clean_slots = False + ): + position_loss = th.tensor(0, device=input.device) + time_loss = th.tensor(0, device=input.device) + bg_mask = None + position_encoder = None + + if error_last is None or mask_last is None: + bg_mask = self.background(input, only_mask=True) + error_last = th.sqrt(reduce((input - background)**2, 'b c h w -> b 1 h w', 'mean')).detach() + + position_last, gestalt_last, priority_last, error_cur = self.initial_states( + error_last, mask_last, position_last, gestalt_last, priority_last, shuffleslots, self.occlusion_tracker.slots_bounded_all, slots_occlusionfactor_last, allow_spawn=allow_spawn, clean_slots=clean_slots + ) + # only use assigned slots + position_last, gestalt_last, priority_last = self.reset_unassigned_slots(position_last, gestalt_last, priority_last) + + if mask_last is None: + mask_last, rawmask_last = self.run_decoder(position_last, gestalt_last, priority_last, bg_mask, background, only_mask=True) + object_last_unprioritized = self.decoder(position_last, gestalt_last)[-1] + + # background and bg_mask for the next time point + bg_mask = self.background(input, error_last, mask_last[:,-1:], only_mask=True) + + # position and gestalt for the current time point + position_cur, gestalt_cur, priority_cur = self.encoder(input, error_last, mask_last, object_last_unprioritized, position_last, rawmask_last) + if evaluate: + position_encoder = position_cur.clone().detach() + + # only use assigned slots + position_cur, gestalt_cur, priority_cur = self.reset_unassigned_slots(position_cur, gestalt_cur, priority_cur) + + output_cur, mask_cur, object_cur, rawmask_cur = self.run_decoder(position_cur, gestalt_cur, priority_cur, bg_mask, background) + slots_bounded, slots_bounded_smooth, slots_occluded_cur, slots_partially_occluded_cur, slots_fully_visible_cur, slots_occlusionfactor_cur = self.occlusion_tracker(mask_cur, rawmask_cur, reset_mask) + + # do not project into the future in the warmup phase + slots_closed = th.ones_like(repeat(slots_bounded, 'b o -> b o c', c=2)) + if warmup: + position_next = position_cur + gestalt_next = gestalt_cur + priority_next = priority_cur + else: + + if self.cfg.inner_loop_enabled: + + # update module + slots_closed = (1-self.percept_gate_controller(position_cur, gestalt_cur, priority_cur, slots_occlusionfactor_cur, position_last, gestalt_last, priority_last, slots_occlusionfactor_last, self.position_last2, evaluate=evaluate)) + + position_cur = self.linear_gate(position_cur, position_last, slots_closed[:, :, 1]) + priority_cur = self.linear_gate(priority_cur, priority_last, slots_closed[:, :, 1]) + gestalt_cur = self.linear_gate(gestalt_cur, gestalt_last, slots_closed[:, :, 0]) + + # position and gestalt for the next time point + position_next, gestalt_next, priority_next = self.predictor(gestalt_cur, priority_cur, position_cur, slots_closed) + + # combinded background and objects (masks) for next timepoint + self.position_last2 = position_last.clone().detach() + output_next, mask_next, object_next, rawmask_next = self.run_decoder(position_next, gestalt_next, priority_next, bg_mask, background) + slots_bounded_next, slots_bounded_smooth_next, slots_occluded_next, slots_partially_occluded_next, slots_fully_visible_next, slots_occlusionfactor_next = self.occlusion_tracker(mask_next, rawmask_next, update=False) + + if evaluate: + + if show_hidden: + pos_next = rearrange(position_next.clone(), '1 (o c) -> o c', c=3) + largest_object = th.argmax(pos_next[:, 2], dim=0) + pos_next[largest_object] = th.tensor([2, 2, 0.001]) + pos_next = rearrange(pos_next, 'o c -> 1 (o c)') + output_hidden, _, object_hidden, rawmask_hidden = self.run_decoder(pos_next, gestalt_next, priority_next, bg_mask, background) + else: + output_hidden = None + largest_object = None + rawmask_hidden = None + object_hidden = None + + return ( + output_next, + position_next, + gestalt_next, + priority_next, + mask_next, + rawmask_next, + object_next, + background, + slots_occlusionfactor_next, + output_cur, + position_cur, + gestalt_cur, + priority_cur, + mask_cur, + rawmask_cur, + object_cur, + position_encoder, + slots_bounded, + slots_partially_occluded_cur, + slots_occluded_cur, + slots_partially_occluded_next, + slots_occluded_next, + slots_closed, + output_hidden, + largest_object, + rawmask_hidden, + object_hidden + ) + + else: + + if not warmup: + + #regularize to small possition chananges over time + position_loss = position_loss + self.position_loss(position_next, position_last.detach(), slots_bounded_smooth) + + # regularize to produce consistent object codes over time + object_next_unprioritized = self.decoder(position_next, gestalt_next)[-1] + time_loss = time_loss + self.translation_invariant_object_loss( + slots_bounded_smooth, + object_last_unprioritized.detach(), + position_last.detach(), + object_next_unprioritized, + position_next.detach(), + ) + + return ( + output_next, + output_cur, + position_next, + gestalt_next, + priority_next, + mask_next, + rawmask_next, + object_next, + background, + slots_occlusionfactor_next, + position_loss, + time_loss, + slots_closed + ) + diff --git a/model/nn/background.py b/model/nn/background.py new file mode 100644 index 0000000..38ec325 --- /dev/null +++ b/model/nn/background.py @@ -0,0 +1,167 @@ +import torch.nn as nn +import torch as th +from model.nn.residual import ResidualBlock, SkipConnection, LinearResidual +from model.nn.encoder import PatchDownConv +from model.nn.encoder import AggressiveConvToGestalt +from model.nn.decoder import PatchUpscale +from model.utils.nn_utils import LambdaModule, Binarize +from einops import rearrange, repeat, reduce +from typing import Tuple + +__author__ = "Manuel Traub" + +class BackgroundEnhancer(nn.Module): + def __init__( + self, + input_size: Tuple[int, int], + img_channels: int, + level1_channels, + latent_channels, + gestalt_size, + batch_size, + depth + ): + super(BackgroundEnhancer, self).__init__() + + latent_size = [input_size[0] // 16, input_size[1] // 16] + self.input_size = input_size + + self.register_buffer('init', th.zeros(1).long()) + self.alpha = nn.Parameter(th.zeros(1)+1e-16) + + self.level = 1 + self.down_level2 = nn.Sequential( + PatchDownConv(img_channels*2+2, level1_channels, alpha = 1e-16), + *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for i in range(depth)] + ) + + self.down_level1 = nn.Sequential( + PatchDownConv(level1_channels, latent_channels, alpha = 1), + *[ResidualBlock(latent_channels, latent_channels, alpha_residual = False) for i in range(depth)] + ) + + self.down_level0 = nn.Sequential( + *[ResidualBlock(latent_channels, latent_channels) for i in range(depth)], + AggressiveConvToGestalt(latent_channels, gestalt_size, latent_size), + LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')), + Binarize(), + ) + + self.bias = nn.Parameter(th.zeros((1, gestalt_size, *latent_size))) + + self.to_grid = nn.Sequential( + LinearResidual(gestalt_size, gestalt_size, input_relu = False), + LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')), + LambdaModule(lambda x: x + self.bias), + *[ResidualBlock(gestalt_size, gestalt_size) for i in range(depth)], + ) + + + self.up_level0 = nn.Sequential( + ResidualBlock(gestalt_size, latent_channels), + *[ResidualBlock(latent_channels, latent_channels) for i in range(depth)], + ) + + self.up_level1 = nn.Sequential( + *[ResidualBlock(latent_channels, latent_channels, alpha_residual = False) for i in range(depth)], + PatchUpscale(latent_channels, level1_channels, alpha = 1), + ) + + self.up_level2 = nn.Sequential( + *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for i in range(depth)], + PatchUpscale(level1_channels, img_channels, alpha = 1e-16), + ) + + self.to_channels = nn.ModuleList([ + SkipConnection(img_channels*2+2, latent_channels), + SkipConnection(img_channels*2+2, level1_channels), + SkipConnection(img_channels*2+2, img_channels*2+2), + ]) + + self.to_img = nn.ModuleList([ + SkipConnection(latent_channels, img_channels), + SkipConnection(level1_channels, img_channels), + SkipConnection(img_channels, img_channels), + ]) + + self.mask = nn.Parameter(th.ones(1, 1, *input_size) * 10) + self.object = nn.Parameter(th.ones(1, img_channels, *input_size)) + + self.register_buffer('latent', th.zeros((batch_size, gestalt_size)), persistent=False) + + def get_init(self): + return self.init.item() + + def step_init(self): + self.init = self.init + 1 + + def detach(self): + self.latent = self.latent.detach() + + def reset_state(self): + self.latent = th.zeros_like(self.latent) + + def set_level(self, level): + self.level = level + + def encoder(self, input): + latent = self.to_channels[self.level](input) + + if self.level >= 2: + latent = self.down_level2(latent) + + if self.level >= 1: + latent = self.down_level1(latent) + + return self.down_level0(latent) + + def get_last_latent_gird(self): + return self.to_grid(self.latent) * self.alpha + + def decoder(self, latent, input): + grid = self.to_grid(latent) + latent = self.up_level0(grid) + + if self.level >= 1: + latent = self.up_level1(latent) + + if self.level >= 2: + latent = self.up_level2(latent) + + object = reduce(self.object, '1 c (h h2) (w w2) -> 1 c h w', 'mean', h = input.shape[2], w = input.shape[3]) + object = repeat(object, '1 c h w -> b c h w', b = input.shape[0]) + + return th.sigmoid(object + self.to_img[self.level](latent)), grid + + def forward(self, input: th.Tensor, error: th.Tensor = None, mask: th.Tensor = None, only_mask: bool = False): + + if only_mask: + mask = reduce(self.mask, '1 1 (h h2) (w w2) -> 1 1 h w', 'mean', h = input.shape[2], w = input.shape[3]) + mask = repeat(mask, '1 1 h w -> b 1 h w', b = input.shape[0]) * 0.1 + return mask + + last_bg = self.decoder(self.latent, input)[0] + + bg_error = th.sqrt(reduce((input - last_bg)**2, 'b c h w -> b 1 h w', 'mean')).detach() + bg_mask = (bg_error < th.mean(bg_error) + th.std(bg_error)).float().detach() + + if error is None or self.get_init() < 2: + error = bg_error + + if mask is None or self.get_init() < 2: + mask = bg_mask + + self.latent = self.encoder(th.cat((input, last_bg, error, mask), dim=1)) + + mask = reduce(self.mask, '1 1 (h h2) (w w2) -> 1 1 h w', 'mean', h = input.shape[2], w = input.shape[3]) + mask = repeat(mask, '1 1 h w -> b 1 h w', b = input.shape[0]) * 0.1 + + background, grid = self.decoder(self.latent, input) + + if self.get_init() < 1: + return mask, background + + if self.get_init() < 2: + return mask, th.zeros_like(background), th.zeros_like(grid), background + + return mask, background, grid * self.alpha, background diff --git a/model/nn/decoder.py b/model/nn/decoder.py new file mode 100644 index 0000000..3e56640 --- /dev/null +++ b/model/nn/decoder.py @@ -0,0 +1,169 @@ +import torch.nn as nn +import torch as th +from model.nn.residual import SkipConnection, ResidualBlock +from model.utils.nn_utils import Gaus2D, SharedObjectsToBatch, BatchToSharedObjects, Prioritize +from einops import rearrange, repeat, reduce +from typing import Tuple, Union, List + +__author__ = "Manuel Traub" + +class PriorityEncoder(nn.Module): + def __init__(self, num_objects, batch_size): + super(PriorityEncoder, self).__init__() + + self.num_objects = num_objects + self.register_buffer("indices", repeat(th.arange(num_objects), 'a -> b a', b=batch_size), persistent=False) + + self.index_factor = nn.Parameter(th.ones(1)) + self.priority_factor = nn.Parameter(th.ones(1)) + + def forward(self, priority: th.Tensor) -> th.Tensor: + + if priority is None: + return None + + priority = priority * self.num_objects + th.randn_like(priority) * 0.1 + priority = priority * self.priority_factor + priority = priority + self.indices * self.index_factor + + return priority * 25 + + +class GestaltPositionMerge(nn.Module): + def __init__( + self, + latent_size: Union[int, Tuple[int, int]], + num_objects: int, + batch_size: int + ): + + super(GestaltPositionMerge, self).__init__() + self.num_objects = num_objects + + self.gaus2d = Gaus2D(size=latent_size) + + self.to_batch = SharedObjectsToBatch(num_objects) + self.to_shared = BatchToSharedObjects(num_objects) + + self.prioritize = Prioritize(num_objects) + + self.priority_encoder = PriorityEncoder(num_objects, batch_size) + + def forward(self, position, gestalt, priority): + + position = rearrange(position, 'b (o c) -> (b o) c', o = self.num_objects) + gestalt = rearrange(gestalt, 'b (o c) -> (b o) c 1 1', o = self.num_objects) + priority = self.priority_encoder(priority) + + position = self.gaus2d(position) + position = self.to_batch(self.prioritize(self.to_shared(position), priority)) + + return position * gestalt + +class PatchUpscale(nn.Module): + def __init__(self, in_channels, out_channels, scale_factor = 4, alpha = 1): + super(PatchUpscale, self).__init__() + assert in_channels % out_channels == 0 + + self.skip = SkipConnection(in_channels, out_channels, scale_factor=scale_factor) + + self.residual = nn.Sequential( + nn.ReLU(), + nn.Conv2d( + in_channels = in_channels, + out_channels = in_channels, + kernel_size = 3, + padding = 1 + ), + nn.ReLU(), + nn.ConvTranspose2d( + in_channels = in_channels, + out_channels = out_channels, + kernel_size = scale_factor, + stride = scale_factor, + ), + ) + + self.alpha = nn.Parameter(th.zeros(1) + alpha) + + def forward(self, input): + return self.skip(input) + self.alpha * self.residual(input) + + +class LociDecoder(nn.Module): + def __init__( + self, + latent_size: Union[int, Tuple[int, int]], + gestalt_size: int, + num_objects: int, + img_channels: int, + hidden_channels: int, + level1_channels: int, + num_layers: int, + batch_size: int + ): + + super(LociDecoder, self).__init__() + self.to_batch = SharedObjectsToBatch(num_objects) + self.to_shared = BatchToSharedObjects(num_objects) + self.level = 1 + + assert(level1_channels % img_channels == 0) + level1_factor = level1_channels // img_channels + print(f"Level1 channels: {level1_channels}") + + self.merge = GestaltPositionMerge( + latent_size = latent_size, + num_objects = num_objects, + batch_size = batch_size + ) + + self.layer0 = nn.Sequential( + ResidualBlock(gestalt_size, hidden_channels, input_nonlinearity = False), + *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers-1)], + ) + + self.to_mask_level0 = ResidualBlock(hidden_channels, hidden_channels) + self.to_mask_level1 = PatchUpscale(hidden_channels, 1) + + self.to_mask_level2 = nn.Sequential( + ResidualBlock(hidden_channels, hidden_channels), + ResidualBlock(hidden_channels, hidden_channels), + PatchUpscale(hidden_channels, level1_factor, alpha = 1), + PatchUpscale(level1_factor, 1, alpha = 1) + ) + + self.to_object_level0 = ResidualBlock(hidden_channels, hidden_channels) + self.to_object_level1 = PatchUpscale(hidden_channels, img_channels) + + self.to_object_level2 = nn.Sequential( + ResidualBlock(hidden_channels, hidden_channels), + ResidualBlock(hidden_channels, hidden_channels), + PatchUpscale(hidden_channels, level1_channels, alpha = 1), + PatchUpscale(level1_channels, img_channels, alpha = 1) + ) + + self.mask_alpha = nn.Parameter(th.zeros(1)+1e-16) + self.object_alpha = nn.Parameter(th.zeros(1)+1e-16) + + + def set_level(self, level): + self.level = level + + def forward(self, position, gestalt, priority = None): + + maps = self.layer0(self.merge(position, gestalt, priority)) + mask0 = self.to_mask_level0(maps) + object0 = self.to_object_level0(maps) + + mask = self.to_mask_level1(mask0) + object = self.to_object_level1(object0) + + if self.level > 1: + mask = repeat(mask, 'b c h w -> b c (h h2) (w w2)', h2 = 4, w2 = 4) + object = repeat(object, 'b c h w -> b c (h h2) (w w2)', h2 = 4, w2 = 4) + + mask = mask + self.to_mask_level2(mask0) * self.mask_alpha + object = object + self.to_object_level2(object0) * self.object_alpha + + return self.to_shared(mask), self.to_shared(object) diff --git a/model/nn/encoder.py b/model/nn/encoder.py new file mode 100644 index 0000000..8eb2e7f --- /dev/null +++ b/model/nn/encoder.py @@ -0,0 +1,269 @@ +import torch.nn as nn +import torch as th +from model.utils.nn_utils import Gaus2D, BatchToSharedObjects, LambdaModule, ForcedAlpha, Binarize +from model.nn.residual import ResidualBlock, SkipConnection +from einops import rearrange, repeat, reduce +from typing import Tuple, Union, List + +__author__ = "Manuel Traub" + +class NeighbourChannels(nn.Module): + def __init__(self, channels): + super(NeighbourChannels, self).__init__() + + self.register_buffer("weights", th.ones(channels, channels, 1, 1), persistent=False) + + for i in range(channels): + self.weights[i,i,0,0] = 0 + + def forward(self, input: th.Tensor): + return nn.functional.conv2d(input, self.weights) + +class InputPreprocessing(nn.Module): + def __init__(self, num_objects: int, size: Union[int, Tuple[int, int]]): + super(InputPreprocessing, self).__init__() + self.num_objects = num_objects + self.neighbours = NeighbourChannels(num_objects) + self.gaus2d = Gaus2D(size) + self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects)) + self.to_shared = BatchToSharedObjects(num_objects) + + def forward( + self, + input: th.Tensor, + error: th.Tensor, + mask: th.Tensor, + object: th.Tensor, + position: th.Tensor, + rawmask: th.Tensor + ): + bg_mask = repeat(mask[:,-1:], 'b 1 h w -> b c h w', c = self.num_objects) + mask = mask[:,:-1] + mask_others = self.neighbours(mask) + rawmask = rawmask[:,:-1] + + own_gaus2d = self.to_shared(self.gaus2d(self.to_batch(position))) + + input = repeat(input, 'b c h w -> b o c h w', o = self.num_objects) + error = repeat(error, 'b 1 h w -> b o 1 h w', o = self.num_objects) + bg_mask = rearrange(bg_mask, 'b o h w -> b o 1 h w') + mask_others = rearrange(mask_others, 'b o h w -> b o 1 h w') + mask = rearrange(mask, 'b o h w -> b o 1 h w') + object = rearrange(object, 'b (o c) h w -> b o c h w', o = self.num_objects) + own_gaus2d = rearrange(own_gaus2d, 'b o h w -> b o 1 h w') + rawmask = rearrange(rawmask, 'b o h w -> b o 1 h w') + + output = th.cat((input, error, mask, mask_others, bg_mask, object, own_gaus2d, rawmask), dim=2) + output = rearrange(output, 'b o c h w -> (b o) c h w') + + return output + +class PatchDownConv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size = 4, alpha = 1): + super(PatchDownConv, self).__init__() + assert out_channels % in_channels == 0 + + self.layers = nn.Conv2d( + in_channels = in_channels, + out_channels = out_channels, + kernel_size = kernel_size, + stride = kernel_size, + ) + + self.alpha = nn.Parameter(th.zeros(1) + alpha) + self.kernel_size = 4 + self.channels_factor = out_channels // in_channels + + def forward(self, input: th.Tensor): + k = self.kernel_size + c = self.channels_factor + skip = reduce(input, 'b c (h h2) (w w2) -> b c h w', 'mean', h2=k, w2=k) + skip = repeat(skip, 'b c h w -> b (c n) h w', n=c) + return skip + self.alpha * self.layers(input) + +class AggressiveConvToGestalt(nn.Module): + def __init__(self, channels, gestalt_size, size: Union[int, Tuple[int, int]]): + super(AggressiveConvToGestalt, self).__init__() + + assert gestalt_size % channels == 0 or channels % gestalt_size == 0 + + self.layers = nn.Sequential( + nn.Conv2d( + in_channels = channels, + out_channels = gestalt_size, + kernel_size = 5, + stride = 3, + padding = 3 + ), + nn.ReLU(), + nn.Conv2d( + in_channels = gestalt_size, + out_channels = gestalt_size, + kernel_size = ((size[0] + 1)//3 + 1, (size[1] + 1)//3 + 1) + ) + ) + if gestalt_size > channels: + self.skip = nn.Sequential( + LambdaModule(lambda x: reduce(x, 'b c h w -> b c 1 1', 'mean')), + LambdaModule(lambda x: repeat(x, 'b c 1 1 -> b (c n) 1 1', n = gestalt_size // channels)) + ) + else: + self.skip = LambdaModule(lambda x: reduce(x, 'b (c n) h w -> b c 1 1', 'mean', n = channels // gestalt_size)) + + + def forward(self, input: th.Tensor): + return self.skip(input) + self.layers(input) + +class PixelToPosition(nn.Module): + def __init__(self, size: Union[int, Tuple[int, int]]): + super(PixelToPosition, self).__init__() + + self.register_buffer("grid_x", th.arange(size[0]), persistent=False) + self.register_buffer("grid_y", th.arange(size[1]), persistent=False) + + self.grid_x = (self.grid_x / (size[0]-1)) * 2 - 1 + self.grid_y = (self.grid_y / (size[1]-1)) * 2 - 1 + + self.grid_x = self.grid_x.view(1, 1, -1, 1).expand(1, 1, *size).clone() + self.grid_y = self.grid_y.view(1, 1, 1, -1).expand(1, 1, *size).clone() + + self.size = size + + def forward(self, input: th.Tensor): + assert input.shape[1] == 1 + + input = rearrange(input, 'b c h w -> b c (h w)') + input = th.softmax(input, dim=2) + input = rearrange(input, 'b c (h w) -> b c h w', h = self.size[0], w = self.size[1]) + + x = th.sum(input * self.grid_x, dim=(2,3)) + y = th.sum(input * self.grid_y, dim=(2,3)) + + return th.cat((x,y),dim=1) + +class PixelToSTD(nn.Module): + def __init__(self): + super(PixelToSTD, self).__init__() + self.alpha = ForcedAlpha() + + def forward(self, input: th.Tensor): + assert input.shape[1] == 1 + return self.alpha(reduce(th.sigmoid(input - 10), 'b c h w -> b c', 'mean')) + +class PixelToPriority(nn.Module): + def __init__(self): + super(PixelToPriority, self).__init__() + + def forward(self, input: th.Tensor): + assert input.shape[1] == 1 + return reduce(th.tanh(input), 'b c h w -> b c', 'mean') + +class LociEncoder(nn.Module): + def __init__( + self, + input_size: Union[int, Tuple[int, int]], + latent_size: Union[int, Tuple[int, int]], + num_objects: int, + img_channels: int, + hidden_channels: int, + level1_channels: int, + num_layers: int, + gestalt_size: int, + bottleneck: str + ): + super(LociEncoder, self).__init__() + + self.num_objects = num_objects + self.latent_size = latent_size + self.level = 1 + + self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o = self.num_objects)) + + print(f"Level1 channels: {level1_channels}") + + self.preprocess = nn.ModuleList([ + InputPreprocessing(num_objects, (input_size[0] // 16, input_size[1] // 16)), + InputPreprocessing(num_objects, (input_size[0] // 4, input_size[1] // 4)), + InputPreprocessing(num_objects, (input_size[0], input_size[1])) + ]) + + self.to_channels = nn.ModuleList([ + SkipConnection(img_channels, hidden_channels), + SkipConnection(img_channels, level1_channels), + SkipConnection(img_channels, img_channels) + ]) + + self.layers2 = nn.Sequential( + PatchDownConv(img_channels, level1_channels, alpha = 1e-16), + *[ResidualBlock(level1_channels, level1_channels, alpha_residual=True) for _ in range(num_layers)] + ) + + self.layers1 = PatchDownConv(level1_channels, hidden_channels) + + self.layers0 = nn.Sequential( + *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)] + ) + + self.position_encoder = nn.Sequential( + *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)], + ResidualBlock(hidden_channels, 3), + ) + + self.xy_encoder = PixelToPosition(latent_size) + self.std_encoder = PixelToSTD() + self.priority_encoder = PixelToPriority() + + if bottleneck == "binar": + print("Binary bottleneck") + self.gestalt_encoder = nn.Sequential( + *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)], + AggressiveConvToGestalt(hidden_channels, gestalt_size, latent_size), + LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')), + Binarize(), + ) + + else: + print("unrestricted bottleneck") + self.gestalt_encoder = nn.Sequential( + *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)], + AggressiveConvToGestalt(hidden_channels, gestalt_size, latent_size), + LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')), + nn.Sigmoid(), + ) + + def set_level(self, level): + self.level = level + + def forward( + self, + input: th.Tensor, + error: th.Tensor, + mask: th.Tensor, + object: th.Tensor, + position: th.Tensor, + rawmask: th.Tensor + ): + + latent = self.preprocess[self.level](input, error, mask, object, position, rawmask) + latent = self.to_channels[self.level](latent) + + if self.level >= 2: + latent = self.layers2(latent) + + if self.level >= 1: + latent = self.layers1(latent) + + latent = self.layers0(latent) + gestalt = self.gestalt_encoder(latent) + + latent = self.position_encoder(latent) + std = self.std_encoder(latent[:,0:1]) + xy = self.xy_encoder(latent[:,1:2]) + priority = self.priority_encoder(latent[:,2:3]) + + position = self.to_shared(th.cat((xy, std), dim=1)) + gestalt = self.to_shared(gestalt) + priority = self.to_shared(priority) + + return position, gestalt, priority + diff --git a/model/nn/eprop_gate_l0rd.py b/model/nn/eprop_gate_l0rd.py new file mode 100644 index 0000000..82a9895 --- /dev/null +++ b/model/nn/eprop_gate_l0rd.py @@ -0,0 +1,329 @@ +import torch.nn as nn +import torch as th +import numpy as np +from torch.autograd import Function +from einops import rearrange, repeat, reduce + +__author__ = "Manuel Traub" + +class EpropGateL0rdFunction(Function): + @staticmethod + def forward(ctx, x, h_last, w_gx, w_gh, b_g, w_rx, w_rh, b_r, args): + + e_w_gx, e_w_gh, e_b_g, e_w_rx, e_w_rh, e_b_r, reg, noise_level = args + + noise = th.normal(mean=0, std=noise_level, size=b_g.shape, device=b_g.device) + g = th.relu(th.tanh(x.mm(w_gx.t()) + h_last.mm(w_gh.t()) + b_g + noise)) + r = th.tanh(x.mm(w_rx.t()) + h_last.mm(w_rh.t()) + b_r) + + h = g * r + (1 - g) * h_last + + # Haevisite step function + H_g = th.ceil(g).clamp(0, 1) + + dg = (1 - g**2) * H_g + dr = (1 - r**2) + + delta_h = r - h_last + + g_j = g.unsqueeze(dim=2) + dg_j = dg.unsqueeze(dim=2) + dr_j = dr.unsqueeze(dim=2) + + x_i = x.unsqueeze(dim=1) + h_last_i = h_last.unsqueeze(dim=1) + delta_h_j = delta_h.unsqueeze(dim=2) + + e_w_gh.copy_(e_w_gh * (1 - g_j) + dg_j * h_last_i * delta_h_j) + e_w_gx.copy_(e_w_gx * (1 - g_j) + dg_j * x_i * delta_h_j) + e_b_g.copy_( e_b_g * (1 - g) + dg * delta_h ) + + e_w_rh.copy_(e_w_rh * (1 - g_j) + dr_j * h_last_i * g_j) + e_w_rx.copy_(e_w_rx * (1 - g_j) + dr_j * x_i * g_j) + e_b_r.copy_( e_b_r * (1 - g) + dr * g ) + + ctx.save_for_backward( + g.clone(), dg.clone(), dg_j.clone(), dr.clone(), x_i.clone(), h_last_i.clone(), + reg.clone(), H_g.clone(), delta_h.clone(), w_gx.clone(), w_gh.clone(), w_rx.clone(), w_rh.clone(), + e_w_gx.clone(), e_w_gh.clone(), e_b_g.clone(), + e_w_rx.clone(), e_w_rh.clone(), e_b_r.clone(), + ) + + return h, th.mean(H_g) + + @staticmethod + def backward(ctx, dh, _): + + g, dg, dg_j, dr, x_i, h_last_i, reg, H_g, delta_h, w_gx, w_gh, w_rx, w_rh, \ + e_w_gx, e_w_gh, e_b_g, e_w_rx, e_w_rh, e_b_r = ctx.saved_tensors + + dh_j = dh.unsqueeze(dim=2) + H_g_reg = reg * H_g + H_g_reg_j = H_g_reg.unsqueeze(dim=2) + + dw_gx = th.sum(dh_j * e_w_gx + H_g_reg_j * dg_j * x_i, dim=0) + dw_gh = th.sum(dh_j * e_w_gh + H_g_reg_j * dg_j * h_last_i, dim=0) + db_g = th.sum(dh * e_b_g + H_g_reg * dg, dim=0) + + dw_rx = th.sum(dh_j * e_w_rx, dim=0) + dw_rh = th.sum(dh_j * e_w_rh, dim=0) + db_r = th.sum(dh * e_b_r , dim=0) + + dh_dg = (dh * delta_h + H_g_reg) * dg + dh_dr = dh * g * dr + + dx = dh_dg.mm(w_gx) + dh_dr.mm(w_rx) + dh = dh * (1 - g) + dh_dg.mm(w_gh) + dh_dr.mm(w_rh) + + return dx, dh, dw_gx, dw_gh, db_g, dw_rx, dw_rh, db_r, None + +class ReTanhFunction(Function): + @staticmethod + def forward(ctx, x, reg): + + g = th.relu(th.tanh(x)) + + # Haevisite step function + H_g = th.ceil(g).clamp(0, 1) + + dg = (1 - g**2) * H_g + + ctx.save_for_backward(g, dg, H_g, reg) + return g, th.mean(H_g) + + @staticmethod + def backward(ctx, dh, _): + + g, dg, H_g, reg = ctx.saved_tensors + + dx = (dh + reg * H_g) * dg + + return dx, None + +class ReTanh(nn.Module): + def __init__(self, reg_lambda): + super(ReTanh, self).__init__() + + self.re_tanh = ReTanhFunction().apply + self.register_buffer("reg_lambda", th.tensor(reg_lambda), persistent=False) + + def forward(self, input): + h, openings = self.re_tanh(input, self.reg_lambda) + self.openings = openings.item() + + return h + + +class EpropGateL0rd(nn.Module): + def __init__( + self, + num_inputs, + num_hidden, + num_outputs, + batch_size, + reg_lambda = 0, + gate_noise_level = 0, + ): + super(EpropGateL0rd, self).__init__() + + self.register_buffer("reg", th.tensor(reg_lambda).view(1,1), persistent=False) + self.register_buffer("noise", th.tensor(gate_noise_level), persistent=False) + self.num_inputs = num_inputs + self.num_hidden = num_hidden + self.num_outputs = num_outputs + + self.fcn = EpropGateL0rdFunction().apply + self.retanh = ReTanh(reg_lambda) + + # gate weights and biases + self.w_gx = nn.Parameter(th.empty(num_hidden, num_inputs)) + self.w_gh = nn.Parameter(th.empty(num_hidden, num_hidden)) + self.b_g = nn.Parameter(th.zeros(num_hidden)) + + # candidate weights and biases + self.w_rx = nn.Parameter(th.empty(num_hidden, num_inputs)) + self.w_rh = nn.Parameter(th.empty(num_hidden, num_hidden)) + self.b_r = nn.Parameter(th.zeros(num_hidden)) + + # output projection weights and bias + self.w_px = nn.Parameter(th.empty(num_outputs, num_inputs)) + self.w_ph = nn.Parameter(th.empty(num_outputs, num_hidden)) + self.b_p = nn.Parameter(th.zeros(num_outputs)) + + # output gate weights and bias + self.w_ox = nn.Parameter(th.empty(num_outputs, num_inputs)) + self.w_oh = nn.Parameter(th.empty(num_outputs, num_hidden)) + self.b_o = nn.Parameter(th.zeros(num_outputs)) + + # input gate eligibilitiy traces + self.register_buffer("e_w_gx", th.zeros(batch_size, num_hidden, num_inputs), persistent=False) + self.register_buffer("e_w_gh", th.zeros(batch_size, num_hidden, num_hidden), persistent=False) + self.register_buffer("e_b_g", th.zeros(batch_size, num_hidden), persistent=False) + + # forget gate eligibilitiy traces + self.register_buffer("e_w_rx", th.zeros(batch_size, num_hidden, num_inputs), persistent=False) + self.register_buffer("e_w_rh", th.zeros(batch_size, num_hidden, num_hidden), persistent=False) + self.register_buffer("e_b_r", th.zeros(batch_size, num_hidden), persistent=False) + + # hidden state + self.register_buffer("h_last", th.zeros(batch_size, num_hidden), persistent=False) + + self.register_buffer("openings", th.zeros(1), persistent=False) + + # initialize weights + stdv_ih = np.sqrt(6/(self.num_inputs + self.num_hidden)) + stdv_hh = np.sqrt(3/self.num_hidden) + stdv_io = np.sqrt(6/(self.num_inputs + self.num_outputs)) + stdv_ho = np.sqrt(6/(self.num_hidden + self.num_outputs)) + + nn.init.uniform_(self.w_gx, -stdv_ih, stdv_ih) + nn.init.uniform_(self.w_gh, -stdv_hh, stdv_hh) + + nn.init.uniform_(self.w_rx, -stdv_ih, stdv_ih) + nn.init.uniform_(self.w_rh, -stdv_hh, stdv_hh) + + nn.init.uniform_(self.w_px, -stdv_io, stdv_io) + nn.init.uniform_(self.w_ph, -stdv_ho, stdv_ho) + + nn.init.uniform_(self.w_ox, -stdv_io, stdv_io) + nn.init.uniform_(self.w_oh, -stdv_ho, stdv_ho) + + self.backprop = False + + def reset_state(self): + self.h_last.zero_() + self.e_w_gx.zero_() + self.e_w_gh.zero_() + self.e_b_g.zero_() + self.e_w_rx.zero_() + self.e_w_rh.zero_() + self.e_b_r.zero_() + self.openings.zero_() + + def backprop_forward(self, x: th.Tensor): + + noise = th.normal(mean=0, std=self.noise, size=self.b_g.shape, device=self.b_g.device) + g = self.retanh(x.mm(self.w_gx.t()) + self.h_last.mm(self.w_gh.t()) + self.b_g + noise) + r = th.tanh(x.mm(self.w_rx.t()) + self.h_last.mm(self.w_rh.t()) + self.b_r) + + self.h_last = g * r + (1 - g) * self.h_last + + # Haevisite step function + H_g = th.ceil(g).clamp(0, 1) + + self.openings = th.mean(H_g) + + p = th.tanh(x.mm(self.w_px.t()) + self.h_last.mm(self.w_ph.t()) + self.b_p) + o = th.sigmoid(x.mm(self.w_ox.t()) + self.h_last.mm(self.w_oh.t()) + self.b_o) + return o * p + + def activate_backprop(self): + self.backprop = True + + def deactivate_backprop(self): + self.backprop = False + + def detach(self): + self.h_last.detach_() + + def eprop_forward(self, x: th.Tensor): + h, openings = self.fcn( + x, self.h_last, + self.w_gx, self.w_gh, self.b_g, + self.w_rx, self.w_rh, self.b_r, + ( + self.e_w_gx, self.e_w_gh, self.e_b_g, + self.e_w_rx, self.e_w_rh, self.e_b_r, + self.reg, self.noise + ) + ) + + self.openings = openings + self.h_last = h + + p = th.tanh(x.mm(self.w_px.t()) + h.mm(self.w_ph.t()) + self.b_p) + o = th.sigmoid(x.mm(self.w_ox.t()) + h.mm(self.w_oh.t()) + self.b_o) + return o * p + + def save_hidden(self): + self.h_last_saved = self.h_last.detach() + + def restore_hidden(self): + self.h_last = self.h_last_saved + + def get_hidden(self): + return self.h_last + + def set_hidden(self, h_last): + self.h_last = h_last + + def forward(self, x: th.Tensor): + if self.backprop: + return self.backprop_forward(x) + + return self.eprop_forward(x) + + +class EpropGateL0rdShared(EpropGateL0rd): + def __init__( + self, + num_inputs, + num_hidden, + num_outputs, + batch_size, + reg_lambda = 0, + gate_noise_level = 0, + ): + super().__init__(num_inputs, num_hidden, num_outputs, batch_size, reg_lambda, gate_noise_level) + + def backprop_forward(self, x: th.Tensor, h_last: th.Tensor): + + noise = th.normal(mean=0, std=self.noise, size=self.b_g.shape, device=self.b_g.device) + g = self.retanh(x.mm(self.w_gx.t()) + h_last.mm(self.w_gh.t()) + self.b_g + noise) + r = th.tanh(x.mm(self.w_rx.t()) + h_last.mm(self.w_rh.t()) + self.b_r) + + h_last = g * r + (1 - g) * h_last + + # Haevisite step function + H_g = th.ceil(g).clamp(0, 1) + + self.openings = th.mean(H_g) + + p = th.tanh(x.mm(self.w_px.t()) + h_last.mm(self.w_ph.t()) + self.b_p) + o = th.sigmoid(x.mm(self.w_ox.t()) + h_last.mm(self.w_oh.t()) + self.b_o) + return o * p, h_last + + def eprop_forward(self, x: th.Tensor, h_last: th.Tensor): + h, openings = self.fcn( + x, h_last, + self.w_gx, self.w_gh, self.b_g, + self.w_rx, self.w_rh, self.b_r, + ( + self.e_w_gx, self.e_w_gh, self.e_b_g, + self.e_w_rx, self.e_w_rh, self.e_b_r, + self.reg, self.noise + ) + ) + + self.openings = openings + + p = th.tanh(x.mm(self.w_px.t()) + h.mm(self.w_ph.t()) + self.b_p) + o = th.sigmoid(x.mm(self.w_ox.t()) + h.mm(self.w_oh.t()) + self.b_o) + return o * p, h + + def forward(self, x: th.Tensor, h_last: th.Tensor = None): + + if h_last is not None: + if self.backprop: + return self.backprop_forward(x, h_last) + + return self.eprop_forward(x, h_last) + + # backward compatibility + if self.backprop: + x, h = self.backprop_forward(x, self.h_last) + + x, h = self.eprop_forward(x, self.h_last) + + self.h_last = h + return x
\ No newline at end of file diff --git a/model/nn/eprop_transformer.py b/model/nn/eprop_transformer.py new file mode 100644 index 0000000..4d89ec4 --- /dev/null +++ b/model/nn/eprop_transformer.py @@ -0,0 +1,76 @@ +import torch.nn as nn +import torch as th +from model.nn.eprop_gate_l0rd import EpropGateL0rd +from model.nn.eprop_transformer_utils import AlphaAttention, InputEmbeding, OutputEmbeding + +class EpropGateL0rdTransformer(nn.Module): + def __init__( + self, + channels, + multiplier, + num_objects, + batch_size, + heads, + depth, + reg_lambda, + dropout=0.0 + ): + super(EpropGateL0rdTransformer, self).__init__() + + num_inputs = channels + num_outputs = channels + num_hidden = channels + num_hidden = channels * multiplier + + print(f"Predictor channels: {num_hidden}@({num_hidden // heads}x{heads})") + + + self.depth = depth + _layers = [] + _layers.append(InputEmbeding(num_inputs, num_hidden)) + + for i in range(depth): + _layers.append(AlphaAttention(num_hidden, num_objects, heads, dropout)) + _layers.append(EpropAlphaGateL0rd(num_hidden, batch_size * num_objects, reg_lambda)) + + _layers.append(OutputEmbeding(num_hidden, num_outputs)) + self.layers = nn.Sequential(*_layers) + + def get_openings(self): + openings = 0 + for i in range(self.depth): + openings += self.layers[2 * (i + 1)].l0rd.openings.item() + + return openings / self.depth + + def get_hidden(self): + states = [] + for i in range(self.depth): + states.append(self.layers[2 * (i + 1)].l0rd.get_hidden()) + + return th.cat(states, dim=1) + + def set_hidden(self, hidden): + states = th.chunk(hidden, self.depth, dim=1) + for i in range(self.depth): + self.layers[2 * (i + 1)].l0rd.set_hidden(states[i]) + + def forward(self, input: th.Tensor) -> th.Tensor: + return self.layers(input) + + +class EpropAlphaGateL0rd(nn.Module): + def __init__(self, num_hidden, batch_size, reg_lambda): + super(EpropAlphaGateL0rd, self).__init__() + + self.alpha = nn.Parameter(th.zeros(1)+1e-12) + self.l0rd = EpropGateL0rd( + num_inputs = num_hidden, + num_hidden = num_hidden, + num_outputs = num_hidden, + reg_lambda = reg_lambda, + batch_size = batch_size + ) + + def forward(self, input): + return input + self.alpha * self.l0rd(input)
\ No newline at end of file diff --git a/model/nn/eprop_transformer_shared.py b/model/nn/eprop_transformer_shared.py new file mode 100644 index 0000000..23a1b0f --- /dev/null +++ b/model/nn/eprop_transformer_shared.py @@ -0,0 +1,92 @@ +import torch.nn as nn +import torch as th +from model.nn.eprop_gate_l0rd import EpropGateL0rdShared +from model.nn.eprop_transformer_utils import AlphaAttention, InputEmbeding, OutputEmbeding + +class EpropGateL0rdTransformerShared(nn.Module): + def __init__( + self, + channels, + multiplier, + num_objects, + batch_size, + heads, + depth, + reg_lambda, + dropout=0.0, + exchange_length = 48, + ): + super(EpropGateL0rdTransformerShared, self).__init__() + + num_inputs = channels + num_outputs = channels + num_hidden = channels * multiplier + num_hidden_gatelord = num_hidden + exchange_length + num_hidden_attention = num_hidden + exchange_length + num_hidden_gatelord + + self.num_hidden = num_hidden + self.num_hidden_gatelord = num_hidden_gatelord + + #print(f"Predictor channels: {num_hidden}@({num_hidden // heads}x{heads})") + + self.register_buffer('hidden', th.zeros(batch_size * num_objects, num_hidden_gatelord), persistent=False) + self.register_buffer('exchange_code', th.zeros(batch_size * num_objects, exchange_length), persistent=False) + + self.depth = depth + self.input_embeding = InputEmbeding(num_inputs, num_hidden) + self.attention = nn.Sequential(*[AlphaAttention(num_hidden_attention, num_objects, heads, dropout) for _ in range(depth)]) + self.l0rds = nn.Sequential(*[EpropAlphaGateL0rdShared(num_hidden_gatelord, batch_size * num_objects, reg_lambda) for _ in range(depth)]) + self.output_embeding = OutputEmbeding(num_hidden, num_outputs) + + def get_openings(self): + openings = 0 + for i in range(self.depth): + openings += self.l0rds[i].l0rd.openings.item() + + return openings / self.depth + + def get_hidden(self): + return self.hidden + + def set_hidden(self, hidden): + self.hidden = hidden + + def detach(self): + self.hidden = self.hidden.detach() + + def reset_state(self): + self.hidden = th.zeros_like(self.hidden) + + def forward(self, x: th.Tensor) -> th.Tensor: + x = self.input_embeding(x) + exchange_code = self.exchange_code.clone() * 0.0 + x_ex = th.concat((x, exchange_code), dim=1) + + for i in range(self.depth): + # attention layer + att = self.attention(th.concat((x_ex, self.hidden), dim=1)) + x_ex = att[:, :self.num_hidden_gatelord] + + # gatelord layer + x_ex, self.hidden = self.l0rds[i](x_ex, self.hidden) + + # only yield x + x = x_ex[:, :self.num_hidden] + return self.output_embeding(x) + +class EpropAlphaGateL0rdShared(nn.Module): + def __init__(self, num_hidden, batch_size, reg_lambda): + super(EpropAlphaGateL0rdShared, self).__init__() + + self.alpha = nn.Parameter(th.zeros(1)+1e-12) + self.l0rd = EpropGateL0rdShared( + num_inputs = num_hidden, + num_hidden = num_hidden, + num_outputs = num_hidden, + reg_lambda = reg_lambda, + batch_size = batch_size + ) + + def forward(self, input, hidden): + output, hidden = self.l0rd(input, hidden) + return input + self.alpha * output, hidden
\ No newline at end of file diff --git a/model/nn/eprop_transformer_utils.py b/model/nn/eprop_transformer_utils.py new file mode 100644 index 0000000..9e5e874 --- /dev/null +++ b/model/nn/eprop_transformer_utils.py @@ -0,0 +1,66 @@ +import torch.nn as nn +import torch as th +from model.utils.nn_utils import LambdaModule +from einops import rearrange, repeat, reduce + +class AlphaAttention(nn.Module): + def __init__( + self, + num_hidden, + num_objects, + heads, + dropout = 0.0 + ): + super(AlphaAttention, self).__init__() + + self.to_sequence = LambdaModule(lambda x: rearrange(x, '(b o) c -> b o c', o = num_objects)) + self.to_batch = LambdaModule(lambda x: rearrange(x, 'b o c -> (b o) c', o = num_objects)) + + self.alpha = nn.Parameter(th.zeros(1)+1e-12) + self.attention = nn.MultiheadAttention( + num_hidden, + heads, + dropout = dropout, + batch_first = True + ) + + def forward(self, x: th.Tensor): + x = self.to_sequence(x) + x = x + self.alpha * self.attention(x, x, x, need_weights=False)[0] + return self.to_batch(x) + +class InputEmbeding(nn.Module): + def __init__(self, num_inputs, num_hidden): + super(InputEmbeding, self).__init__() + + self.embeding = nn.Sequential( + nn.ReLU(), + nn.Linear(num_inputs, num_hidden), + nn.ReLU(), + nn.Linear(num_hidden, num_hidden), + ) + self.skip = LambdaModule( + lambda x: repeat(x, 'b c -> b (n c)', n = num_hidden // num_inputs) + ) + self.alpha = nn.Parameter(th.zeros(1)+1e-12) + + def forward(self, input: th.Tensor): + return self.skip(input) + self.alpha * self.embeding(input) + +class OutputEmbeding(nn.Module): + def __init__(self, num_hidden, num_outputs): + super(OutputEmbeding, self).__init__() + + self.embeding = nn.Sequential( + nn.ReLU(), + nn.Linear(num_hidden, num_outputs), + nn.ReLU(), + nn.Linear(num_outputs, num_outputs), + ) + self.skip = LambdaModule( + lambda x: reduce(x, 'b (n c) -> b c', 'mean', n = num_hidden // num_outputs) + ) + self.alpha = nn.Parameter(th.zeros(1)+1e-12) + + def forward(self, input: th.Tensor): + return self.skip(input) + self.alpha * self.embeding(input)
\ No newline at end of file diff --git a/model/nn/percept_gate_controller.py b/model/nn/percept_gate_controller.py new file mode 100644 index 0000000..a548c20 --- /dev/null +++ b/model/nn/percept_gate_controller.py @@ -0,0 +1,59 @@ +import torch.nn as nn +import torch as th +from model.utils.nn_utils import LambdaModule +from einops import rearrange, repeat, reduce +from model.nn.eprop_gate_l0rd import ReTanh + +class PerceptGateController(nn.Module): + def __init__( + self, + num_inputs: int, + num_hidden: list, + bias: bool, + num_objects: int, + gate_noise_level: float = 0.1, + reg_lambda: float = 0.000005 + ): + super(PerceptGateController, self).__init__() + + self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o=num_objects)) + self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b o c', o=num_objects)) + + self.layers = nn.Sequential( + nn.Linear(num_inputs, num_hidden[0], bias = bias), + nn.Tanh(), + nn.Linear(num_hidden[0], num_hidden[1], bias = bias), + nn.Tanh(), + nn.Linear(num_hidden[1], 2, bias = bias) + ) + self.output_function = ReTanh(reg_lambda) + self.register_buffer("noise", th.tensor(gate_noise_level), persistent=False) + self.init_weights() + + def init_weights(self): + for layer in self.layers: + if isinstance(layer, nn.Linear): + nn.init.xavier_uniform(layer.weight) + layer.bias.data.fill_(3.00) + + def forward(self, position_cur, gestalt_cur, priority_cur, slots_occlusionfactor_cur, position_last, gestalt_last, priority_last, slots_occlusionfactor_last, position_last2, evaluate=False): + + position_cur = self.to_batch(position_cur) + gestalt_cur = self.to_batch(gestalt_cur) + priority_cur = self.to_batch(priority_cur) + position_last = self.to_batch(position_last) + gestalt_last = self.to_batch(gestalt_last) + priority_last = self.to_batch(priority_last) + slots_occlusionfactor_cur = self.to_batch(slots_occlusionfactor_cur).detach() + slots_occlusionfactor_last = self.to_batch(slots_occlusionfactor_last).detach() + position_last2 = self.to_batch(position_last2).detach() + + input = th.cat((position_cur, gestalt_cur, priority_cur, slots_occlusionfactor_cur, position_last, gestalt_last, priority_last, slots_occlusionfactor_last, position_last2), dim=1) + output = self.layers(input) + if evaluate: + output = self.output_function(output) + else: + noise = th.normal(mean=0, std=self.noise, size=output.shape, device=output.device) + output = self.output_function(output + noise) + + return self.to_shared(output) diff --git a/model/nn/predictor.py b/model/nn/predictor.py new file mode 100644 index 0000000..94f13b8 --- /dev/null +++ b/model/nn/predictor.py @@ -0,0 +1,99 @@ +import torch.nn as nn +import torch as th +from model.nn.eprop_transformer import EpropGateL0rdTransformer +from model.nn.eprop_transformer_shared import EpropGateL0rdTransformerShared +from model.utils.nn_utils import LambdaModule, Binarize +from model.nn.residual import ResidualBlock +from einops import rearrange, repeat, reduce + +__author__ = "Manuel Traub" + +class LociPredictor(nn.Module): + def __init__( + self, + heads: int, + layers: int, + channels_multiplier: int, + reg_lambda: float, + num_objects: int, + gestalt_size: int, + batch_size: int, + bottleneck: str, + transformer_type = 'standard', + ): + super(LociPredictor, self).__init__() + self.num_objects = num_objects + self.std_alpha = nn.Parameter(th.zeros(1)+1e-16) + self.bottleneck_type = bottleneck + self.gestalt_size = gestalt_size + + self.reg_lambda = reg_lambda + Transformer = EpropGateL0rdTransformerShared if transformer_type == 'shared' else EpropGateL0rdTransformer + self.predictor = Transformer( + channels = gestalt_size + 3 + 1 + 2, + multiplier = channels_multiplier, + heads = heads, + depth = layers, + num_objects = num_objects, + reg_lambda = reg_lambda, + batch_size = batch_size, + ) + + if bottleneck == 'binar': + print("Binary bottleneck") + self.bottleneck = nn.Sequential( + LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')), + ResidualBlock(gestalt_size, gestalt_size, kernel_size=1), + Binarize(), + LambdaModule(lambda x: rearrange(x, '(b o) c 1 1 -> b (o c)', o=num_objects)) + ) + + else: + print("unrestricted bottleneck") + self.bottleneck = nn.Sequential( + LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')), + ResidualBlock(gestalt_size, gestalt_size, kernel_size=1), + nn.Sigmoid(), + LambdaModule(lambda x: rearrange(x, '(b o) c 1 1 -> b (o c)', o=num_objects)) + ) + + self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o=num_objects)) + self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o=num_objects)) + + def get_openings(self): + return self.predictor.get_openings() + + def get_hidden(self): + return self.predictor.get_hidden() + + def set_hidden(self, hidden): + self.predictor.set_hidden(hidden) + + def forward( + self, + gestalt: th.Tensor, + priority: th.Tensor, + position: th.Tensor, + slots_closed: th.Tensor, + ): + + position = self.to_batch(position) + gestalt_cur = self.to_batch(gestalt) + priority = self.to_batch(priority) + slots_closed = rearrange(slots_closed, 'b o c -> (b o) c').detach() + + input = th.cat((gestalt_cur, priority, position, slots_closed), dim=1) + output = self.predictor(input) + + gestalt = output[:, :self.gestalt_size] + priority = output[:,self.gestalt_size:(self.gestalt_size+1)] + xy = output[:,(self.gestalt_size+1):(self.gestalt_size+3)] + std = output[:,(self.gestalt_size+3):(self.gestalt_size+4)] + + position = th.cat((xy, std * self.std_alpha), dim=1) + + position = self.to_shared(position) + gestalt = self.bottleneck(gestalt) + priority = self.to_shared(priority) + + return position, gestalt, priority diff --git a/model/nn/residual.py b/model/nn/residual.py new file mode 100644 index 0000000..1602e16 --- /dev/null +++ b/model/nn/residual.py @@ -0,0 +1,396 @@ +import torch.nn as nn +import torch as th +import numpy as np +from einops import rearrange, repeat, reduce +from model.utils.nn_utils import LambdaModule + +from typing import Union, Tuple + +__author__ = "Manuel Traub" + +class DynamicLayerNorm(nn.Module): + + def __init__(self, eps: float = 1e-5): + super(DynamicLayerNorm, self).__init__() + self.eps = eps + + def forward(self, input: th.Tensor) -> th.Tensor: + return nn.functional.layer_norm(input, input.shape[2:], None, None, self.eps) + + +class SkipConnection(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + scale_factor: float = 1.0 + ): + super(SkipConnection, self).__init__() + assert scale_factor == 1 or int(scale_factor) > 1 or int(1 / scale_factor) > 1, f'invalid scale factor in SpikeFunction: {scale_factor}' + + self.in_channels = in_channels + self.out_channels = out_channels + self.scale_factor = scale_factor + + def channel_skip(self, input: th.Tensor): + in_channels = self.in_channels + out_channels = self.out_channels + + if in_channels == out_channels: + return input + + if in_channels % out_channels == 0 or out_channels % in_channels == 0: + + if in_channels > out_channels: + return reduce(input, 'b (c n) h w -> b c h w', 'mean', n = in_channels // out_channels) + + if out_channels > in_channels: + return repeat(input, 'b c h w -> b (c n) h w', n = out_channels // in_channels) + + mean_channels = np.gcd(in_channels, out_channels) + input = reduce(input, 'b (c n) h w -> b c h w', 'mean', n = in_channels // mean_channels) + return repeat(input, 'b c h w -> b (c n) h w', n = out_channels // mean_channels) + + def scale_skip(self, input: th.Tensor): + scale_factor = self.scale_factor + + if scale_factor == 1: + return input + + if scale_factor > 1: + return repeat( + input, + 'b c h w -> b c (h h2) (w w2)', + h2 = int(scale_factor), + w2 = int(scale_factor) + ) + + height = input.shape[2] + width = input.shape[3] + + # scale factor < 1 + scale_factor = int(1 / scale_factor) + + if width % scale_factor == 0 and height % scale_factor == 0: + return reduce( + input, + 'b c (h h2) (w w2) -> b c h w', + 'mean', + h2 = scale_factor, + w2 = scale_factor + ) + + if width >= scale_factor and height >= scale_factor: + return nn.functional.avg_pool2d( + input, + kernel_size = scale_factor, + stride = scale_factor + ) + + assert width > 1 or height > 1 + return reduce(input, 'b c h w -> b c 1 1', 'mean') + + + def forward(self, input: th.Tensor): + + if self.scale_factor > 1: + return self.scale_skip(self.channel_skip(input)) + + return self.channel_skip(self.scale_skip(input)) + +class DownScale(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + scale_factor: int, + groups: int = 1, + bias: bool = True + ): + + super(DownScale, self).__init__() + + assert(in_channels % groups == 0) + assert(out_channels % groups == 0) + + self.groups = groups + self.scale_factor = scale_factor + self.weight = nn.Parameter(th.empty((out_channels, in_channels // groups, scale_factor, scale_factor))) + self.bias = nn.Parameter(th.empty((out_channels,))) if bias else None + + nn.init.kaiming_uniform_(self.weight, a=np.sqrt(5)) + + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / np.sqrt(fan_in) + nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, input: th.Tensor): + height = input.shape[2] + width = input.shape[3] + assert height > 1 or width > 1, "trying to dowscale 1x1" + + scale_factor = self.scale_factor + padding = [0, 0] + + if height < scale_factor: + padding[0] = scale_factor - height + + if width < scale_factor: + padding[1] = scale_factor - width + + return nn.functional.conv2d( + input, + self.weight, + bias=self.bias, + stride=scale_factor, + padding=padding, + groups=self.groups + ) + + +class ResidualBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]] = (3, 3), + scale_factor: int = 1, + groups: Union[int, Tuple[int, int]] = (1, 1), + bias: bool = True, + layer_norm: bool = False, + leaky_relu: bool = False, + residual: bool = True, + alpha_residual: bool = False, + input_nonlinearity = True + ): + + super(ResidualBlock, self).__init__() + self.residual = residual + self.alpha_residual = alpha_residual + self.skip = False + self.in_channels = in_channels + self.out_channels = out_channels + + if isinstance(kernel_size, int): + kernel_size = [kernel_size, kernel_size] + + if isinstance(groups, int): + groups = [groups, groups] + + padding = (kernel_size[0] // 2, kernel_size[1] // 2) + + _layers = list() + if layer_norm: + _layers.append(DynamicLayerNorm()) + + if input_nonlinearity: + if leaky_relu: + _layers.append(nn.LeakyReLU()) + else: + _layers.append(nn.ReLU()) + + if scale_factor > 1: + _layers.append( + nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=scale_factor, + stride=scale_factor, + groups=groups[0], + bias=bias + ) + ) + elif scale_factor < 1: + _layers.append( + DownScale( + in_channels=in_channels, + out_channels=out_channels, + scale_factor=int(1.0/scale_factor), + groups=groups[0], + bias=bias + ) + ) + else: + _layers.append( + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + groups=groups[0], + bias=bias + ) + ) + + if layer_norm: + _layers.append(DynamicLayerNorm()) + if leaky_relu: + _layers.append(nn.LeakyReLU()) + else: + _layers.append(nn.ReLU()) + _layers.append( + nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + groups=groups[1], + bias=bias + ) + ) + self.layers = nn.Sequential(*_layers) + + if self.residual: + self.skip_connection = SkipConnection( + in_channels=in_channels, + out_channels=out_channels, + scale_factor=scale_factor + ) + + if self.alpha_residual: + self.alpha = nn.Parameter(th.zeros(1) + 1e-12) + + def set_mode(self, **kwargs): + if 'skip' in kwargs: + self.skip = kwargs['skip'] + + if 'residual' in kwargs: + self.residual = kwargs['residual'] + + def forward(self, input: th.Tensor) -> th.Tensor: + if self.skip: + return self.skip_connection(input) + + if not self.residual: + return self.layers(input) + + if self.alpha_residual: + return self.alpha * self.layers(input) + self.skip_connection(input) + + return self.layers(input) + self.skip_connection(input) + +class LinearSkip(nn.Module): + def __init__(self, num_inputs: int, num_outputs: int): + super(LinearSkip, self).__init__() + + self.num_inputs = num_inputs + self.num_outputs = num_outputs + + if num_inputs % num_outputs != 0 and num_outputs % num_inputs != 0: + mean_channels = np.gcd(num_inputs, num_outputs) + print(f"[WW] gcd skip: {num_inputs} -> {mean_channels} -> {num_outputs}") + assert(False) + + def forward(self, input: th.Tensor): + num_inputs = self.num_inputs + num_outputs = self.num_outputs + + if num_inputs == num_outputs: + return input + + if num_inputs % num_outputs == 0 or num_outputs % num_inputs == 0: + + if num_inputs > num_outputs: + return reduce(input, 'b (c n) -> b c', 'mean', n = num_inputs // num_outputs) + + if num_outputs > num_inputs: + return repeat(input, 'b c -> b (c n)', n = num_outputs // num_inputs) + + mean_channels = np.gcd(num_inputs, num_outputs) + input = reduce(input, 'b (c n) -> b c', 'mean', n = num_inputs // mean_channels) + return repeat(input, 'b c -> b (c n)', n = num_outputs // mean_channels) + +class LinearResidual(nn.Module): + def __init__( + self, + num_inputs: int, + num_outputs: int, + num_hidden: int = None, + residual: bool = True, + alpha_residual: bool = False, + input_relu: bool = True + ): + super(LinearResidual, self).__init__() + + self.residual = residual + self.alpha_residual = alpha_residual + + if num_hidden is None: + num_hidden = num_outputs + + _layers = [] + if input_relu: + _layers.append(nn.ReLU()) + _layers.append(nn.Linear(num_inputs, num_hidden)) + _layers.append(nn.ReLU()) + _layers.append(nn.Linear(num_hidden, num_outputs)) + + self.layers = nn.Sequential(*_layers) + + if residual: + self.skip = LinearSkip(num_inputs, num_outputs) + + if alpha_residual: + self.alpha = nn.Parameter(th.zeros(1)+1e-16) + + def forward(self, input: th.Tensor): + if not self.residual: + return self.layers(input) + + if not self.alpha_residual: + return self.skip(input) + self.layers(input) + + return self.skip(input) + self.alpha * self.layers(input) + + +class EntityAttention(nn.Module): + def __init__(self, channels, num_objects, size, channels_per_head = 12, dropout = 0.0): + super(EntityAttention, self).__init__() + + assert channels % channels_per_head == 0 + heads = channels // channels_per_head + + self.alpha = nn.Parameter(th.zeros(1)+1e-16) + self.attention = nn.MultiheadAttention( + channels, + heads, + dropout = dropout, + batch_first = True + ) + + self.channel_attention = nn.Sequential( + LambdaModule(lambda x: rearrange(x, '(b o) c h w -> (b h w) o c', o = num_objects)), + LambdaModule(lambda x: self.attention(x, x, x, need_weights=False)[0]), + LambdaModule(lambda x: rearrange(x, '(b h w) o c-> (b o) c h w', h = size[0], w = size[1])), + ) + + + def forward(self, input: th.Tensor) -> th.Tensor: + return input + self.channel_attention(input) * self.alpha + +class ImageAttention(nn.Module): + def __init__(self, channels, gestalt_size, num_objects, size, channels_per_head = 12, dropout = 0.0): + super(ImageAttention, self).__init__() + + assert gestalt_size % channels_per_head == 0 + heads = gestalt_size // channels_per_head + + self.alpha = nn.Parameter(th.zeros(1)+1e-16) + self.attention = nn.MultiheadAttention( + gestalt_size, + heads, + dropout = dropout, + batch_first = True + ) + + self.image_attention = nn.Sequential( + nn.Conv2d(channels, gestalt_size, kernel_size=1), + LambdaModule(lambda x: rearrange(x, 'b c h w -> b (h w) c')), + LambdaModule(lambda x: self.attention(x, x, x, need_weights=False)[0]), + LambdaModule(lambda x: rearrange(x, 'b (h w) c -> b c h w', h = size[0], w = size[1])), + nn.Conv2d(gestalt_size, channels, kernel_size=1), + ) + + def forward(self, input: th.Tensor) -> th.Tensor: + return input + self.image_attention(input) * self.alpha
\ No newline at end of file diff --git a/model/utils/loss.py b/model/utils/loss.py new file mode 100644 index 0000000..829bfd3 --- /dev/null +++ b/model/utils/loss.py @@ -0,0 +1,97 @@ +import torch as th +from torch import nn +from model.utils.nn_utils import SharedObjectsToBatch, LambdaModule +from einops import rearrange, repeat, reduce + +__author__ = "Manuel Traub" +class PositionLoss(nn.Module): + def __init__(self, num_objects: int): + super(PositionLoss, self).__init__() + + self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects)) + + def forward(self, position, position_last, slot_mask): + + slot_mask = rearrange(slot_mask, 'b o -> (b o) 1 1 1') + position = self.to_batch(position) + position_last = self.to_batch(position_last).detach() + + return th.mean(slot_mask * (position - position_last)**2) + +class ObjectModulator(nn.Module): + def __init__(self, num_objects: int): + super(ObjectModulator, self).__init__() + self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects)) + self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o = num_objects)) + self.position = None + self.gestalt = None + + def reset_state(self): + self.position = None + self.gestalt = None + + def forward(self, position: th.Tensor, gestalt: th.Tensor, slot_mask: th.Tensor): + + position = self.to_batch(position) + gestalt = self.to_batch(gestalt) + slot_mask = self.to_batch(slot_mask) + + if self.position is None or self.gestalt is None: + self.position = position.detach() + self.gestalt = gestalt.detach() + return self.to_shared(position), self.to_shared(gestalt) + + _position = slot_mask * position + (1 - slot_mask) * self.position + position = th.cat((position[:,:-1], _position[:,-1:]), dim=1) # keep the position of the objects fixed + gestalt = slot_mask * gestalt + (1 - slot_mask) * self.gestalt + + self.gestalt = gestalt.detach() + self.position = position.detach() + + return self.to_shared(position), self.to_shared(gestalt) + +class MoveToCenter(nn.Module): + def __init__(self, num_objects: int): + super(MoveToCenter, self).__init__() + + self.to_batch2d = SharedObjectsToBatch(num_objects) + self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects)) + + def forward(self, input: th.Tensor, position: th.Tensor): + + input = self.to_batch2d(input) # b (o c) h w -> (b o) c h w + position = self.to_batch(position).detach() + position = th.stack((position[:,1], position[:,0]), dim=1) + + theta = th.tensor([1, 0, 0, 1], dtype=th.float, device=input.device).view(1,2,2) + theta = repeat(theta, '1 a b -> n a b', n=input.shape[0]) + + position = rearrange(position, 'b c -> b c 1') + theta = th.cat((theta, position), dim=2) + + grid = nn.functional.affine_grid(theta, input.shape, align_corners=False) + output = nn.functional.grid_sample(input, grid, align_corners=False) + + return output + +class TranslationInvariantObjectLoss(nn.Module): + def __init__(self, num_objects: int): + super(TranslationInvariantObjectLoss, self).__init__() + + self.move_to_center = MoveToCenter(num_objects) + self.to_batch = SharedObjectsToBatch(num_objects) + + def forward( + self, + slot_mask: th.Tensor, + object1: th.Tensor, + position1: th.Tensor, + object2: th.Tensor, + position2: th.Tensor, + ): + slot_mask = rearrange(slot_mask, 'b o -> (b o) 1 1 1') + object1 = self.move_to_center(th.sigmoid(object1 - 2.5), position1) + object2 = self.move_to_center(th.sigmoid(object2 - 2.5), position2) + + return th.mean(slot_mask * (object1 - object2)**2) + diff --git a/model/utils/nn_utils.py b/model/utils/nn_utils.py new file mode 100644 index 0000000..5116e14 --- /dev/null +++ b/model/utils/nn_utils.py @@ -0,0 +1,298 @@ +from typing import Tuple +import torch.nn as nn +import torch as th +import numpy as np +from torch.autograd import Function +from einops import rearrange, repeat, reduce + +class PushToInfFunction(Function): + @staticmethod + def forward(ctx, tensor): + ctx.save_for_backward(tensor) + return tensor.clone() + + @staticmethod + def backward(ctx, grad_output): + tensor = ctx.saved_tensors[0] + grad_input = -th.ones_like(grad_output) + return grad_input + +class PushToInf(nn.Module): + def __init__(self): + super(PushToInf, self).__init__() + + self.fcn = PushToInfFunction.apply + + def forward(self, input: th.Tensor): + return self.fcn(input) + +class ForcedAlpha(nn.Module): + def __init__(self, speed = 1): + super(ForcedAlpha, self).__init__() + + self.init = nn.Parameter(th.zeros(1)) + self.speed = speed + self.to_inf = PushToInf() + + def item(self): + return th.tanh(self.to_inf(self.init * self.speed)).item() + + def forward(self, input: th.Tensor): + return input * th.tanh(self.to_inf(self.init * self.speed)) + +class LinearInterpolation(nn.Module): + def __init__(self, num_objects): + super(LinearInterpolation, self).__init__() + self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects)) + self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o = num_objects)) + + def forward( + self, + tensor_cur: th.Tensor = None, + tensor_last: th.Tensor = None, + slot_interpolation_value: th.Tensor = None + ): + + slot_interpolation_value = rearrange(slot_interpolation_value, 'b o -> (b o) 1') + tensor_cur = slot_interpolation_value * self.to_batch(tensor_last) + (1 - slot_interpolation_value) * self.to_batch(tensor_cur) + + return self.to_shared(tensor_cur) + +class Gaus2D(nn.Module): + def __init__(self, size: Tuple[int, int]): + super(Gaus2D, self).__init__() + + self.size = size + + self.register_buffer("grid_x", th.arange(size[0]), persistent=False) + self.register_buffer("grid_y", th.arange(size[1]), persistent=False) + + self.grid_x = (self.grid_x / (size[0]-1)) * 2 - 1 + self.grid_y = (self.grid_y / (size[1]-1)) * 2 - 1 + + self.grid_x = self.grid_x.view(1, 1, -1, 1).expand(1, 1, *size).clone() + self.grid_y = self.grid_y.view(1, 1, 1, -1).expand(1, 1, *size).clone() + + def forward(self, input: th.Tensor): + + x = rearrange(input[:,0:1], 'b c -> b c 1 1') + y = rearrange(input[:,1:2], 'b c -> b c 1 1') + std = rearrange(input[:,2:3], 'b c -> b c 1 1') + + x = th.clip(x, -1, 1) + y = th.clip(y, -1, 1) + std = th.clip(std, 0, 1) + + max_size = max(self.size) + std_x = (1 + max_size * std) / self.size[0] + std_y = (1 + max_size * std) / self.size[1] + + return th.exp(-1 * ((self.grid_x - x)**2/(2 * std_x**2) + (self.grid_y - y)**2/(2 * std_y**2))) + +class Vector2D(nn.Module): + def __init__(self, size: Tuple[int, int]): + super(Vector2D, self).__init__() + + self.size = size + + self.register_buffer("grid_x", th.arange(size[0]), persistent=False) + self.register_buffer("grid_y", th.arange(size[1]), persistent=False) + + self.grid_x = (self.grid_x / (size[0]-1)) * 2 - 1 + self.grid_y = (self.grid_y / (size[1]-1)) * 2 - 1 + + self.grid_x = self.grid_x.view(1, 1, -1, 1).expand(1, 3, *size).clone() + self.grid_y = self.grid_y.view(1, 1, 1, -1).expand(1, 3, *size).clone() + + def forward(self, input: th.Tensor, vector: th.Tensor = None): + + x = rearrange(input[:,0:1], 'b c -> b c 1 1') + y = rearrange(input[:,1:2], 'b c -> b c 1 1') + if vector is not None: + x_vec = rearrange(vector[:,0:1], 'b c -> b c 1 1') + y_vec = rearrange(vector[:,1:2], 'b c -> b c 1 1') + + x = th.clip(x, -1, 1) + y = th.clip(y, -1, 1) + std = 0.01 + + max_size = max(self.size) + std_x = (1 + max_size * std) / self.size[0] + std_y = (1 + max_size * std) / self.size[1] + grid = th.exp(-1 * ((self.grid_x - x)**2/(2 * std_x**2) + (self.grid_y - y)**2/(2 * std_y**2))) + + # interpolating between start and end point + if vector is not None: + for length in np.linspace(0, 1, 11): + x_end = th.clip(x + x_vec * length, -1, 1) + y_end = th.clip(y + y_vec * length, -1, 1) + + grid_point = th.exp(-1 * ((self.grid_x - x_end)**2/(0.5 * std_x**2) + (self.grid_y - y_end)**2/(0.5 * std_y**2))) + grid_point[:, 0:2, :, :] = 0 + grid = th.max(grid, grid_point) + + return grid + +class SharedObjectsToBatch(nn.Module): + def __init__(self, num_objects): + super(SharedObjectsToBatch, self).__init__() + + self.num_objects = num_objects + + def forward(self, input: th.Tensor): + return rearrange(input, 'b (o c) h w -> (b o) c h w', o=self.num_objects) + +class BatchToSharedObjects(nn.Module): + def __init__(self, num_objects): + super(BatchToSharedObjects, self).__init__() + + self.num_objects = num_objects + + def forward(self, input: th.Tensor): + return rearrange(input, '(b o) c h w -> b (o c) h w', o=self.num_objects) + +class LambdaModule(nn.Module): + def __init__(self, lambd): + super().__init__() + import types + assert type(lambd) is types.LambdaType + self.lambd = lambd + + def forward(self, *x): + return self.lambd(*x) + +class PrintGradientFunction(Function): + @staticmethod + def forward(ctx, tensor, msg): + ctx.msg = msg + return tensor + + @staticmethod + def backward(ctx, grad_output): + grad_input = grad_output.clone() + print(f"{ctx.msg}: {th.mean(grad_output).item()} +- {th.std(grad_output).item()}") + return grad_input, None + +class PrintGradient(nn.Module): + def __init__(self, msg = "PrintGradient"): + super(PrintGradient, self).__init__() + + self.fcn = PrintGradientFunction.apply + self.msg = msg + + def forward(self, input: th.Tensor): + return self.fcn(input, self.msg) + +class Prioritize(nn.Module): + def __init__(self, num_objects): + super(Prioritize, self).__init__() + + self.num_objects = num_objects + self.to_batch = SharedObjectsToBatch(num_objects) + + def forward(self, input: th.Tensor, priority: th.Tensor): + + if priority is None: + return input + + batch_size = input.shape[0] + weights = th.zeros((batch_size, self.num_objects, self.num_objects, 1, 1), device=input.device) + + for o in range(self.num_objects): + weights[:,o,:,0,0] = th.sigmoid(priority[:,:] - priority[:,o:o+1]) + weights[:,o,o,0,0] = weights[:,o,o,0,0] * 0 + + input = rearrange(input, 'b c h w -> 1 (b c) h w') + weights = rearrange(weights, 'b o i 1 1 -> (b o) i 1 1') + + output = th.relu(input - nn.functional.conv2d(input, weights, groups=batch_size)) + output = rearrange(output, '1 (b c) h w -> b c h w ', b=batch_size) + + return output + +class MultiArgSequential(nn.Sequential): + def __init__(self, *args, **kwargs): + super(MultiArgSequential, self).__init__(*args, **kwargs) + + def forward(self, *tensor): + + for n in range(len(self)): + if isinstance(tensor, th.Tensor) or tensor == None: + tensor = self[n](tensor) + else: + tensor = self[n](*tensor) + + return tensor + +def create_grid(size): + grid_x = th.arange(size[0]) + grid_y = th.arange(size[1]) + + grid_x = (grid_x / (size[0]-1)) * 2 - 1 + grid_y = (grid_y / (size[1]-1)) * 2 - 1 + + grid_x = grid_x.view(1, 1, -1, 1).expand(1, 1, *size).clone() + grid_y = grid_y.view(1, 1, 1, -1).expand(1, 1, *size).clone() + + return th.cat((grid_y, grid_x), dim=1) + +class Warp(nn.Module): + def __init__(self, size, padding = 0.1): + super(Warp, self).__init__() + + padding = int(max(size) * padding) + padded_size = (size[0] + 2 * padding, size[1] + 2 * padding) + + self.register_buffer('grid', create_grid(size)) + self.register_buffer('padded_grid', create_grid(padded_size)) + + self.replication_pad = nn.ReplicationPad2d(padding) + self.interpolate = nn.Sequential( + LambdaModule(lambda x: + th.nn.functional.interpolate(x, size=size, mode='bicubic', align_corners = True) + ), + LambdaModule(lambda x: x - self.grid), + nn.ConstantPad2d(padding, 0), + LambdaModule(lambda x: x + self.padded_grid), + LambdaModule(lambda x: rearrange(x, 'b c h w -> b h w c')) + ) + + self.warp = LambdaModule(lambda input, flow: + th.nn.functional.grid_sample(input, flow, mode='bicubic', align_corners=True) + ) + + self.un_pad = LambdaModule(lambda x: x[:,:,padding:-padding,padding:-padding]) + + def get_raw_flow(self, flow): + return flow - self.grid + + def forward(self, input, flow): + input = self.replication_pad(input) + flow = self.interpolate(flow) + return self.un_pad(self.warp(input, flow)) + +class Binarize(nn.Module): + def __init__(self): + super(Binarize, self).__init__() + + def forward(self, input: th.Tensor): + input = th.sigmoid(input) + if not self.training: + return th.round(input) + + return input + input * (1 - input) * th.randn_like(input) + +class TanhAlpha(nn.Module): + def __init__(self, start = 0, stepsize = 1e-4, max_value = 1): + super(TanhAlpha, self).__init__() + + self.register_buffer('init', th.zeros(1) + start) + self.stepsize = stepsize + self.max_value = max_value + + def get(self): + return (th.tanh(self.init) * self.max_value).item() + + def forward(self): + self.init = self.init.detach() + self.stepsize + return self.get()
\ No newline at end of file diff --git a/model/utils/slot_utils.py b/model/utils/slot_utils.py new file mode 100644 index 0000000..936248b --- /dev/null +++ b/model/utils/slot_utils.py @@ -0,0 +1,326 @@ +import torch.nn as nn +import torch as th +import torchvision.transforms as transforms +import torch.nn.functional as F +from einops import rearrange, repeat, reduce +from typing import Tuple, Union, List +from model.utils.nn_utils import Gaus2D, LambdaModule, TanhAlpha + +class InitialLatentStates(nn.Module): + def __init__( + self, + gestalt_size: int, + num_objects: int, + bottleneck: str, + size: Tuple[int, int], + teacher_forcing: int + ): + super(InitialLatentStates, self).__init__() + self.bottleneck = bottleneck + + self.num_objects = num_objects + self.gestalt_mean = nn.Parameter(th.zeros(1, gestalt_size)) + self.gestalt_std = nn.Parameter(th.ones(1, gestalt_size)) + self.std = nn.Parameter(th.zeros(1)) + self.gestalt_strength = 2 + self.teacher_forcing = teacher_forcing + + self.init = TanhAlpha(start = -1) + self.register_buffer('priority', th.arange(num_objects).float() * 25, persistent=False) + self.register_buffer('threshold', th.ones(1) * 0.8) + self.last_mask = None + self.binarize_first = round(gestalt_size * 0.8) + + self.gaus2d = nn.Sequential( + Gaus2D((size[0] // 16, size[1] // 16)), + Gaus2D((size[0] // 4, size[1] // 4)), + Gaus2D(size) + ) + + self.level = 1 + self.t = 0 + + self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects)) + self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o = num_objects)) + + self.blur = transforms.GaussianBlur(13) + self.size = size + + def reset_state(self): + self.last_mask = None + self.t = 0 + self.to_next_spawn = 0 + + def set_level(self, level): + self.level = level + factor = int(4 / (level ** 2)) + self.to_position = ErrorToPosition((self.size[0] // factor, self.size[1] // factor)) + + def forward( + self, + error: th.Tensor, + mask: th.Tensor = None, + position: th.Tensor = None, + gestalt: th.Tensor = None, + priority: th.Tensor = None, + shuffleslots: bool = True, + slots_bounded_last: th.Tensor = None, + slots_occlusionfactor_last: th.Tensor = None, + allow_spawn: bool = True, + clean_slots: bool = False + ): + + batch_size = error.shape[0] + device = error.device + + if self.init.get() < 1: + self.gestalt_strength = self.init() + + if self.last_mask is None: + self.last_mask = th.zeros((batch_size * self.num_objects, 1), device = device) + if shuffleslots: + self.slots_assigned = th.ones((batch_size * self.num_objects, 1), device = device) + else: + self.slots_assigned = th.zeros((batch_size * self.num_objects, 1), device = device) + + if not allow_spawn: + unnassigned = self.slots_assigned - slots_bounded_last + self.slots_assigned = self.slots_assigned - unnassigned + + if clean_slots and (slots_occlusionfactor_last is not None): + occluded = self.slots_assigned * (self.to_batch(slots_occlusionfactor_last) > 0.1).float() + self.slots_assigned = self.slots_assigned - occluded + + if (slots_bounded_last is None) or (self.gestalt_strength < 1): + + if mask is not None: + # maximum berechnung --> slot gebunden c=o + mask2 = reduce(mask[:,:-1], 'b c h w -> (b c) 1' , 'max').detach() + + if self.gestalt_strength <= 0: + self.last_mask = mask2 + elif self.gestalt_strength < 1: + self.last_mask = th.maximum(self.last_mask, mask2) + self.last_mask = self.last_mask - th.relu(-1 * (mask2 - self.threshold) * (1 - self.gestalt_strength)) + else: + self.last_mask = th.maximum(self.last_mask, mask2) + + slots_bounded = (self.last_mask > self.threshold).float().detach() * self.slots_assigned + else: + slots_bounded = slots_bounded_last * self.slots_assigned + + if self.bottleneck == "binar": + gestalt_new = repeat(th.sigmoid(self.gestalt_mean), '1 c -> b c', b = batch_size * self.num_objects) + gestalt_new = gestalt_new + gestalt_new * (1 - gestalt_new) * th.randn_like(gestalt_new) + else: + gestalt_mean = repeat(self.gestalt_mean, '1 c -> b c', b = batch_size * self.num_objects) + gestalt_std = repeat(self.gestalt_std, '1 c -> b c', b = batch_size * self.num_objects) + gestalt_new = th.sigmoid(gestalt_mean + gestalt_std * th.randn_like(gestalt_std)) + + if gestalt is None: + gestalt = gestalt_new + else: + gestalt = self.to_batch(gestalt) * slots_bounded + gestalt_new * (1 - slots_bounded) + + if priority is None: + priority = repeat(self.priority, 'o -> (b o) 1', b = batch_size) + else: + priority = self.to_batch(priority) * slots_bounded + repeat(self.priority, 'o -> (b o) 1', b = batch_size) * (1 - slots_bounded) + + + if shuffleslots: + self.slots_assigned = th.ones_like(self.slots_assigned) + + xy_rand_new = th.rand((batch_size * self.num_objects * 10, 2), device = device) * 2 - 1 + std_new = th.zeros((batch_size * self.num_objects * 10, 1), device = device) + position_new = th.cat((xy_rand_new, std_new), dim=1) + + position2d = self.gaus2d[self.level](position_new) + position2d = rearrange(position2d, '(b o) 1 h w -> b o h w', b = batch_size) + + rand_error = reduce(position2d * error, 'b o h w -> (b o) 1', 'sum') + + xy_rand_new = rearrange(xy_rand_new, '(b r) c -> r b c', r = 10) + rand_error = rearrange(rand_error, '(b r) c -> r b c', r = 10) + + max_error = th.argmax(rand_error, dim=0, keepdim=True) + x, y = th.chunk(xy_rand_new, 2, dim=2) + x = th.gather(x, dim=0, index=max_error).detach().squeeze(dim=0) + y = th.gather(y, dim=0, index=max_error).detach().squeeze(dim=0) + std = repeat(self.std, '1 -> (b o) 1', b = batch_size, o=self.num_objects) + + if position is None: + position = th.cat((x, y, std), dim=1) + else: + position = self.to_batch(position) * slots_bounded + th.cat((x, y, std), dim=1) * (1 - slots_bounded) + + else: + + # set unassigned slots to empty position + empty_position = th.tensor([-1,-1,0]).to(device) + empty_position = repeat(empty_position, 'c -> (b o) c', b = batch_size, o=self.num_objects).detach() + + if position is None: + position = empty_position + else: + position = self.to_batch(position) * self.slots_assigned + empty_position * (1 - self.slots_assigned) + + + # blur errror, and set masked areas to zero + error = self.blur(error) + if mask is not None: + mask2 = mask[:,:-1] * rearrange(slots_bounded, '(b o) 1 -> b o 1 1', b = batch_size) + mask2 = th.sum(mask2, dim=1, keepdim=True) + error = error * (1-mask2) + max_error = reduce(error, 'b o h w -> (b o) 1', 'max') + + if self.to_next_spawn <= 0 and allow_spawn: + + self.to_next_spawn = 2 + + # calculate the position with the highest error + new_pos = self.to_position(error) + std = repeat(self.std, '1 -> b 1', b = batch_size) + new_pos = repeat(th.cat((new_pos, std), dim=1), 'b c -> (b o) c', o = self.num_objects) + + # calculate if an assigned slot is unbound (-->free) + n_slots_assigned = self.to_shared(self.slots_assigned).sum(dim=1, keepdim=True) + n_slots_bounded = self.to_shared(slots_bounded).sum(dim=1, keepdim=True) + free_slot_given = th.clip(n_slots_assigned - n_slots_bounded, 0, 1) + + # either spawn a new slot or use the one that is free + slots_new_index = n_slots_assigned * (1-free_slot_given) + n_slots_bounded * free_slot_given # reset the free slot each timespawn + + # new slot index + free_slot_required = (max_error > 0).float() + slots_new_index = F.one_hot(slots_new_index.long(), num_classes=self.num_objects+1).float().squeeze(dim=1)[:,:-1] + slots_new_index = self.to_batch(slots_new_index * free_slot_required) + + # place new free slot + position = new_pos * slots_new_index + position * (1 - slots_new_index) + self.slots_assigned = th.clip(self.slots_assigned + slots_new_index, 0, 1) + + self.to_next_spawn -= 1 + return self.to_shared(position), self.to_shared(gestalt), self.to_shared(priority), error + + def get_slots_unassigned(self): + return self.to_shared(1-self.slots_assigned) + + def get_slots_assigned(self): + return self.to_shared(self.slots_assigned) + + +class OcclusionTracker(nn.Module): + def __init__(self, batch_size, num_objects, device): + super(OcclusionTracker, self).__init__() + self.batch_size = batch_size + self.num_objects = num_objects + self.slots_bounded_all = th.zeros((batch_size * num_objects, 1)).to(device) + self.threshold = 0.8 + self.device = device + self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o = num_objects)) + self.slots_bounded_next_last = None + + def forward( + self, + mask: th.Tensor = None, + rawmask: th.Tensor = None, + reset_mask: bool = False, + update: bool = True + ): + + if mask is not None: + + # compute bounding mask + slots_bounded_smooth_cur = reduce(mask[:,:-1], 'b o h w -> (b o) 1' , 'max').detach() + slots_bounded_cur = (slots_bounded_smooth_cur > self.threshold).float().detach() + if reset_mask: + self.slots_bounded_next_last = slots_bounded_cur # allow immediate spawn + + if update: + slots_bounded_cur = slots_bounded_cur * th.clip(self.slots_bounded_next_last + self.slots_bounded_all, 0, 1) + else: + self.slots_bounded_next_last = slots_bounded_cur + + if reset_mask: + self.slots_bounded_smooth_all = slots_bounded_smooth_cur + self.slots_bounded_all = slots_bounded_cur + elif update: + self.slots_bounded_all = th.maximum(self.slots_bounded_all, slots_bounded_cur) + self.slots_bounded_smooth_all = th.maximum(self.slots_bounded_smooth_all, slots_bounded_smooth_cur) + + # compute occlusion mask + slots_occluded_cur = self.slots_bounded_all - slots_bounded_cur + + # compute partially occluded mask + mask = (mask[:,:-1] > self.threshold).float().detach() + rawmask = (rawmask[:,:-1] > self.threshold).float().detach() + masked = rawmask - mask + + masked = reduce(masked, 'b o h w -> (b o) 1' , 'sum') + rawmask = reduce(rawmask, 'b o h w -> (b o) 1' , 'sum') + + slots_occlusionfactor_cur = (masked / (rawmask + 1)) * (1-slots_occluded_cur) + slots_occluded_cur + slots_partially_occluded = (slots_occlusionfactor_cur > 0.1).float() #* slots_bounded_cur + slots_fully_visible = (slots_occlusionfactor_cur <= 0.1).float() * slots_bounded_cur + + if reset_mask: + self.slots_fully_visible_all = slots_fully_visible + elif update: + self.slots_fully_visible_all = th.maximum(self.slots_fully_visible_all, slots_fully_visible) + + return self.to_shared(self.slots_bounded_all), self.to_shared(self.slots_bounded_smooth_all), self.to_shared(slots_occluded_cur), self.to_shared(slots_partially_occluded), self.to_shared(slots_fully_visible), self.to_shared(slots_occlusionfactor_cur) + + def get_slots_fully_visible_all(self): + return self.to_shared(self.slots_fully_visible_all) + +class ErrorToPosition(nn.Module): + def __init__(self, size: Union[int, Tuple[int, int]]): + super(ErrorToPosition, self).__init__() + + self.register_buffer("grid_x", th.arange(size[0]), persistent=False) + self.register_buffer("grid_y", th.arange(size[1]), persistent=False) + + self.grid_x = (self.grid_x / (size[0]-1)) * 2 - 1 + self.grid_y = (self.grid_y / (size[1]-1)) * 2 - 1 + + self.grid_x = self.grid_x.view(1, 1, -1, 1).expand(1, 1, *size).clone() + self.grid_y = self.grid_y.view(1, 1, 1, -1).expand(1, 1, *size).clone() + + self.grid_x = self.grid_x.view(1, 1, -1) + self.grid_y = self.grid_y.view(1, 1, -1) + + self.size = size + + def forward(self, input: th.Tensor): + assert input.shape[1] == 1 + + input = rearrange(input, 'b c h w -> b c (h w)') + argmax = th.argmax(input, dim=2, keepdim=True) + + x = self.grid_x[0,0,argmax].squeeze(dim=2) + y = self.grid_y[0,0,argmax].squeeze(dim=2) + + return th.cat((x,y),dim=1) + + +def compute_rawmask(mask, bg_mask): + + num_objects = mask.shape[1] + + # d is a diagonal matrix which defines what to take the softmax over + d_mask = th.diag(th.ones(num_objects+1)).to(mask.device) + d_mask[:,-1] = 1 + d_mask[-1,-1] = 0 + + # take subset of rawmask with the diagonal matrix + rawmask = th.cat((mask, bg_mask), dim=1) + rawmask = repeat(rawmask, 'b o h w -> b r o h w', r = num_objects+1) + rawmask = rawmask[:,d_mask.bool()] + rawmask = rearrange(rawmask, 'b (o r) h w -> b o r h w', o = num_objects) + + # take softmax between each object mask and the background mask + rawmask = th.squeeze(th.softmax(rawmask, dim=2)[:,:,0], dim=2) + rawmask = th.cat((rawmask, bg_mask), dim=1) # add background mask + + return rawmask
\ No newline at end of file |