diff options
Diffstat (limited to 'model/nn')
-rw-r--r-- | model/nn/background.py | 167 | ||||
-rw-r--r-- | model/nn/decoder.py | 169 | ||||
-rw-r--r-- | model/nn/encoder.py | 269 | ||||
-rw-r--r-- | model/nn/eprop_gate_l0rd.py | 329 | ||||
-rw-r--r-- | model/nn/eprop_transformer.py | 76 | ||||
-rw-r--r-- | model/nn/eprop_transformer_shared.py | 92 | ||||
-rw-r--r-- | model/nn/eprop_transformer_utils.py | 66 | ||||
-rw-r--r-- | model/nn/percept_gate_controller.py | 59 | ||||
-rw-r--r-- | model/nn/predictor.py | 99 | ||||
-rw-r--r-- | model/nn/residual.py | 396 |
10 files changed, 1722 insertions, 0 deletions
diff --git a/model/nn/background.py b/model/nn/background.py new file mode 100644 index 0000000..38ec325 --- /dev/null +++ b/model/nn/background.py @@ -0,0 +1,167 @@ +import torch.nn as nn +import torch as th +from model.nn.residual import ResidualBlock, SkipConnection, LinearResidual +from model.nn.encoder import PatchDownConv +from model.nn.encoder import AggressiveConvToGestalt +from model.nn.decoder import PatchUpscale +from model.utils.nn_utils import LambdaModule, Binarize +from einops import rearrange, repeat, reduce +from typing import Tuple + +__author__ = "Manuel Traub" + +class BackgroundEnhancer(nn.Module): + def __init__( + self, + input_size: Tuple[int, int], + img_channels: int, + level1_channels, + latent_channels, + gestalt_size, + batch_size, + depth + ): + super(BackgroundEnhancer, self).__init__() + + latent_size = [input_size[0] // 16, input_size[1] // 16] + self.input_size = input_size + + self.register_buffer('init', th.zeros(1).long()) + self.alpha = nn.Parameter(th.zeros(1)+1e-16) + + self.level = 1 + self.down_level2 = nn.Sequential( + PatchDownConv(img_channels*2+2, level1_channels, alpha = 1e-16), + *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for i in range(depth)] + ) + + self.down_level1 = nn.Sequential( + PatchDownConv(level1_channels, latent_channels, alpha = 1), + *[ResidualBlock(latent_channels, latent_channels, alpha_residual = False) for i in range(depth)] + ) + + self.down_level0 = nn.Sequential( + *[ResidualBlock(latent_channels, latent_channels) for i in range(depth)], + AggressiveConvToGestalt(latent_channels, gestalt_size, latent_size), + LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')), + Binarize(), + ) + + self.bias = nn.Parameter(th.zeros((1, gestalt_size, *latent_size))) + + self.to_grid = nn.Sequential( + LinearResidual(gestalt_size, gestalt_size, input_relu = False), + LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')), + LambdaModule(lambda x: x + self.bias), + *[ResidualBlock(gestalt_size, gestalt_size) for i in range(depth)], + ) + + + self.up_level0 = nn.Sequential( + ResidualBlock(gestalt_size, latent_channels), + *[ResidualBlock(latent_channels, latent_channels) for i in range(depth)], + ) + + self.up_level1 = nn.Sequential( + *[ResidualBlock(latent_channels, latent_channels, alpha_residual = False) for i in range(depth)], + PatchUpscale(latent_channels, level1_channels, alpha = 1), + ) + + self.up_level2 = nn.Sequential( + *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for i in range(depth)], + PatchUpscale(level1_channels, img_channels, alpha = 1e-16), + ) + + self.to_channels = nn.ModuleList([ + SkipConnection(img_channels*2+2, latent_channels), + SkipConnection(img_channels*2+2, level1_channels), + SkipConnection(img_channels*2+2, img_channels*2+2), + ]) + + self.to_img = nn.ModuleList([ + SkipConnection(latent_channels, img_channels), + SkipConnection(level1_channels, img_channels), + SkipConnection(img_channels, img_channels), + ]) + + self.mask = nn.Parameter(th.ones(1, 1, *input_size) * 10) + self.object = nn.Parameter(th.ones(1, img_channels, *input_size)) + + self.register_buffer('latent', th.zeros((batch_size, gestalt_size)), persistent=False) + + def get_init(self): + return self.init.item() + + def step_init(self): + self.init = self.init + 1 + + def detach(self): + self.latent = self.latent.detach() + + def reset_state(self): + self.latent = th.zeros_like(self.latent) + + def set_level(self, level): + self.level = level + + def encoder(self, input): + latent = self.to_channels[self.level](input) + + if self.level >= 2: + latent = self.down_level2(latent) + + if self.level >= 1: + latent = self.down_level1(latent) + + return self.down_level0(latent) + + def get_last_latent_gird(self): + return self.to_grid(self.latent) * self.alpha + + def decoder(self, latent, input): + grid = self.to_grid(latent) + latent = self.up_level0(grid) + + if self.level >= 1: + latent = self.up_level1(latent) + + if self.level >= 2: + latent = self.up_level2(latent) + + object = reduce(self.object, '1 c (h h2) (w w2) -> 1 c h w', 'mean', h = input.shape[2], w = input.shape[3]) + object = repeat(object, '1 c h w -> b c h w', b = input.shape[0]) + + return th.sigmoid(object + self.to_img[self.level](latent)), grid + + def forward(self, input: th.Tensor, error: th.Tensor = None, mask: th.Tensor = None, only_mask: bool = False): + + if only_mask: + mask = reduce(self.mask, '1 1 (h h2) (w w2) -> 1 1 h w', 'mean', h = input.shape[2], w = input.shape[3]) + mask = repeat(mask, '1 1 h w -> b 1 h w', b = input.shape[0]) * 0.1 + return mask + + last_bg = self.decoder(self.latent, input)[0] + + bg_error = th.sqrt(reduce((input - last_bg)**2, 'b c h w -> b 1 h w', 'mean')).detach() + bg_mask = (bg_error < th.mean(bg_error) + th.std(bg_error)).float().detach() + + if error is None or self.get_init() < 2: + error = bg_error + + if mask is None or self.get_init() < 2: + mask = bg_mask + + self.latent = self.encoder(th.cat((input, last_bg, error, mask), dim=1)) + + mask = reduce(self.mask, '1 1 (h h2) (w w2) -> 1 1 h w', 'mean', h = input.shape[2], w = input.shape[3]) + mask = repeat(mask, '1 1 h w -> b 1 h w', b = input.shape[0]) * 0.1 + + background, grid = self.decoder(self.latent, input) + + if self.get_init() < 1: + return mask, background + + if self.get_init() < 2: + return mask, th.zeros_like(background), th.zeros_like(grid), background + + return mask, background, grid * self.alpha, background diff --git a/model/nn/decoder.py b/model/nn/decoder.py new file mode 100644 index 0000000..3e56640 --- /dev/null +++ b/model/nn/decoder.py @@ -0,0 +1,169 @@ +import torch.nn as nn +import torch as th +from model.nn.residual import SkipConnection, ResidualBlock +from model.utils.nn_utils import Gaus2D, SharedObjectsToBatch, BatchToSharedObjects, Prioritize +from einops import rearrange, repeat, reduce +from typing import Tuple, Union, List + +__author__ = "Manuel Traub" + +class PriorityEncoder(nn.Module): + def __init__(self, num_objects, batch_size): + super(PriorityEncoder, self).__init__() + + self.num_objects = num_objects + self.register_buffer("indices", repeat(th.arange(num_objects), 'a -> b a', b=batch_size), persistent=False) + + self.index_factor = nn.Parameter(th.ones(1)) + self.priority_factor = nn.Parameter(th.ones(1)) + + def forward(self, priority: th.Tensor) -> th.Tensor: + + if priority is None: + return None + + priority = priority * self.num_objects + th.randn_like(priority) * 0.1 + priority = priority * self.priority_factor + priority = priority + self.indices * self.index_factor + + return priority * 25 + + +class GestaltPositionMerge(nn.Module): + def __init__( + self, + latent_size: Union[int, Tuple[int, int]], + num_objects: int, + batch_size: int + ): + + super(GestaltPositionMerge, self).__init__() + self.num_objects = num_objects + + self.gaus2d = Gaus2D(size=latent_size) + + self.to_batch = SharedObjectsToBatch(num_objects) + self.to_shared = BatchToSharedObjects(num_objects) + + self.prioritize = Prioritize(num_objects) + + self.priority_encoder = PriorityEncoder(num_objects, batch_size) + + def forward(self, position, gestalt, priority): + + position = rearrange(position, 'b (o c) -> (b o) c', o = self.num_objects) + gestalt = rearrange(gestalt, 'b (o c) -> (b o) c 1 1', o = self.num_objects) + priority = self.priority_encoder(priority) + + position = self.gaus2d(position) + position = self.to_batch(self.prioritize(self.to_shared(position), priority)) + + return position * gestalt + +class PatchUpscale(nn.Module): + def __init__(self, in_channels, out_channels, scale_factor = 4, alpha = 1): + super(PatchUpscale, self).__init__() + assert in_channels % out_channels == 0 + + self.skip = SkipConnection(in_channels, out_channels, scale_factor=scale_factor) + + self.residual = nn.Sequential( + nn.ReLU(), + nn.Conv2d( + in_channels = in_channels, + out_channels = in_channels, + kernel_size = 3, + padding = 1 + ), + nn.ReLU(), + nn.ConvTranspose2d( + in_channels = in_channels, + out_channels = out_channels, + kernel_size = scale_factor, + stride = scale_factor, + ), + ) + + self.alpha = nn.Parameter(th.zeros(1) + alpha) + + def forward(self, input): + return self.skip(input) + self.alpha * self.residual(input) + + +class LociDecoder(nn.Module): + def __init__( + self, + latent_size: Union[int, Tuple[int, int]], + gestalt_size: int, + num_objects: int, + img_channels: int, + hidden_channels: int, + level1_channels: int, + num_layers: int, + batch_size: int + ): + + super(LociDecoder, self).__init__() + self.to_batch = SharedObjectsToBatch(num_objects) + self.to_shared = BatchToSharedObjects(num_objects) + self.level = 1 + + assert(level1_channels % img_channels == 0) + level1_factor = level1_channels // img_channels + print(f"Level1 channels: {level1_channels}") + + self.merge = GestaltPositionMerge( + latent_size = latent_size, + num_objects = num_objects, + batch_size = batch_size + ) + + self.layer0 = nn.Sequential( + ResidualBlock(gestalt_size, hidden_channels, input_nonlinearity = False), + *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers-1)], + ) + + self.to_mask_level0 = ResidualBlock(hidden_channels, hidden_channels) + self.to_mask_level1 = PatchUpscale(hidden_channels, 1) + + self.to_mask_level2 = nn.Sequential( + ResidualBlock(hidden_channels, hidden_channels), + ResidualBlock(hidden_channels, hidden_channels), + PatchUpscale(hidden_channels, level1_factor, alpha = 1), + PatchUpscale(level1_factor, 1, alpha = 1) + ) + + self.to_object_level0 = ResidualBlock(hidden_channels, hidden_channels) + self.to_object_level1 = PatchUpscale(hidden_channels, img_channels) + + self.to_object_level2 = nn.Sequential( + ResidualBlock(hidden_channels, hidden_channels), + ResidualBlock(hidden_channels, hidden_channels), + PatchUpscale(hidden_channels, level1_channels, alpha = 1), + PatchUpscale(level1_channels, img_channels, alpha = 1) + ) + + self.mask_alpha = nn.Parameter(th.zeros(1)+1e-16) + self.object_alpha = nn.Parameter(th.zeros(1)+1e-16) + + + def set_level(self, level): + self.level = level + + def forward(self, position, gestalt, priority = None): + + maps = self.layer0(self.merge(position, gestalt, priority)) + mask0 = self.to_mask_level0(maps) + object0 = self.to_object_level0(maps) + + mask = self.to_mask_level1(mask0) + object = self.to_object_level1(object0) + + if self.level > 1: + mask = repeat(mask, 'b c h w -> b c (h h2) (w w2)', h2 = 4, w2 = 4) + object = repeat(object, 'b c h w -> b c (h h2) (w w2)', h2 = 4, w2 = 4) + + mask = mask + self.to_mask_level2(mask0) * self.mask_alpha + object = object + self.to_object_level2(object0) * self.object_alpha + + return self.to_shared(mask), self.to_shared(object) diff --git a/model/nn/encoder.py b/model/nn/encoder.py new file mode 100644 index 0000000..8eb2e7f --- /dev/null +++ b/model/nn/encoder.py @@ -0,0 +1,269 @@ +import torch.nn as nn +import torch as th +from model.utils.nn_utils import Gaus2D, BatchToSharedObjects, LambdaModule, ForcedAlpha, Binarize +from model.nn.residual import ResidualBlock, SkipConnection +from einops import rearrange, repeat, reduce +from typing import Tuple, Union, List + +__author__ = "Manuel Traub" + +class NeighbourChannels(nn.Module): + def __init__(self, channels): + super(NeighbourChannels, self).__init__() + + self.register_buffer("weights", th.ones(channels, channels, 1, 1), persistent=False) + + for i in range(channels): + self.weights[i,i,0,0] = 0 + + def forward(self, input: th.Tensor): + return nn.functional.conv2d(input, self.weights) + +class InputPreprocessing(nn.Module): + def __init__(self, num_objects: int, size: Union[int, Tuple[int, int]]): + super(InputPreprocessing, self).__init__() + self.num_objects = num_objects + self.neighbours = NeighbourChannels(num_objects) + self.gaus2d = Gaus2D(size) + self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects)) + self.to_shared = BatchToSharedObjects(num_objects) + + def forward( + self, + input: th.Tensor, + error: th.Tensor, + mask: th.Tensor, + object: th.Tensor, + position: th.Tensor, + rawmask: th.Tensor + ): + bg_mask = repeat(mask[:,-1:], 'b 1 h w -> b c h w', c = self.num_objects) + mask = mask[:,:-1] + mask_others = self.neighbours(mask) + rawmask = rawmask[:,:-1] + + own_gaus2d = self.to_shared(self.gaus2d(self.to_batch(position))) + + input = repeat(input, 'b c h w -> b o c h w', o = self.num_objects) + error = repeat(error, 'b 1 h w -> b o 1 h w', o = self.num_objects) + bg_mask = rearrange(bg_mask, 'b o h w -> b o 1 h w') + mask_others = rearrange(mask_others, 'b o h w -> b o 1 h w') + mask = rearrange(mask, 'b o h w -> b o 1 h w') + object = rearrange(object, 'b (o c) h w -> b o c h w', o = self.num_objects) + own_gaus2d = rearrange(own_gaus2d, 'b o h w -> b o 1 h w') + rawmask = rearrange(rawmask, 'b o h w -> b o 1 h w') + + output = th.cat((input, error, mask, mask_others, bg_mask, object, own_gaus2d, rawmask), dim=2) + output = rearrange(output, 'b o c h w -> (b o) c h w') + + return output + +class PatchDownConv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size = 4, alpha = 1): + super(PatchDownConv, self).__init__() + assert out_channels % in_channels == 0 + + self.layers = nn.Conv2d( + in_channels = in_channels, + out_channels = out_channels, + kernel_size = kernel_size, + stride = kernel_size, + ) + + self.alpha = nn.Parameter(th.zeros(1) + alpha) + self.kernel_size = 4 + self.channels_factor = out_channels // in_channels + + def forward(self, input: th.Tensor): + k = self.kernel_size + c = self.channels_factor + skip = reduce(input, 'b c (h h2) (w w2) -> b c h w', 'mean', h2=k, w2=k) + skip = repeat(skip, 'b c h w -> b (c n) h w', n=c) + return skip + self.alpha * self.layers(input) + +class AggressiveConvToGestalt(nn.Module): + def __init__(self, channels, gestalt_size, size: Union[int, Tuple[int, int]]): + super(AggressiveConvToGestalt, self).__init__() + + assert gestalt_size % channels == 0 or channels % gestalt_size == 0 + + self.layers = nn.Sequential( + nn.Conv2d( + in_channels = channels, + out_channels = gestalt_size, + kernel_size = 5, + stride = 3, + padding = 3 + ), + nn.ReLU(), + nn.Conv2d( + in_channels = gestalt_size, + out_channels = gestalt_size, + kernel_size = ((size[0] + 1)//3 + 1, (size[1] + 1)//3 + 1) + ) + ) + if gestalt_size > channels: + self.skip = nn.Sequential( + LambdaModule(lambda x: reduce(x, 'b c h w -> b c 1 1', 'mean')), + LambdaModule(lambda x: repeat(x, 'b c 1 1 -> b (c n) 1 1', n = gestalt_size // channels)) + ) + else: + self.skip = LambdaModule(lambda x: reduce(x, 'b (c n) h w -> b c 1 1', 'mean', n = channels // gestalt_size)) + + + def forward(self, input: th.Tensor): + return self.skip(input) + self.layers(input) + +class PixelToPosition(nn.Module): + def __init__(self, size: Union[int, Tuple[int, int]]): + super(PixelToPosition, self).__init__() + + self.register_buffer("grid_x", th.arange(size[0]), persistent=False) + self.register_buffer("grid_y", th.arange(size[1]), persistent=False) + + self.grid_x = (self.grid_x / (size[0]-1)) * 2 - 1 + self.grid_y = (self.grid_y / (size[1]-1)) * 2 - 1 + + self.grid_x = self.grid_x.view(1, 1, -1, 1).expand(1, 1, *size).clone() + self.grid_y = self.grid_y.view(1, 1, 1, -1).expand(1, 1, *size).clone() + + self.size = size + + def forward(self, input: th.Tensor): + assert input.shape[1] == 1 + + input = rearrange(input, 'b c h w -> b c (h w)') + input = th.softmax(input, dim=2) + input = rearrange(input, 'b c (h w) -> b c h w', h = self.size[0], w = self.size[1]) + + x = th.sum(input * self.grid_x, dim=(2,3)) + y = th.sum(input * self.grid_y, dim=(2,3)) + + return th.cat((x,y),dim=1) + +class PixelToSTD(nn.Module): + def __init__(self): + super(PixelToSTD, self).__init__() + self.alpha = ForcedAlpha() + + def forward(self, input: th.Tensor): + assert input.shape[1] == 1 + return self.alpha(reduce(th.sigmoid(input - 10), 'b c h w -> b c', 'mean')) + +class PixelToPriority(nn.Module): + def __init__(self): + super(PixelToPriority, self).__init__() + + def forward(self, input: th.Tensor): + assert input.shape[1] == 1 + return reduce(th.tanh(input), 'b c h w -> b c', 'mean') + +class LociEncoder(nn.Module): + def __init__( + self, + input_size: Union[int, Tuple[int, int]], + latent_size: Union[int, Tuple[int, int]], + num_objects: int, + img_channels: int, + hidden_channels: int, + level1_channels: int, + num_layers: int, + gestalt_size: int, + bottleneck: str + ): + super(LociEncoder, self).__init__() + + self.num_objects = num_objects + self.latent_size = latent_size + self.level = 1 + + self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o = self.num_objects)) + + print(f"Level1 channels: {level1_channels}") + + self.preprocess = nn.ModuleList([ + InputPreprocessing(num_objects, (input_size[0] // 16, input_size[1] // 16)), + InputPreprocessing(num_objects, (input_size[0] // 4, input_size[1] // 4)), + InputPreprocessing(num_objects, (input_size[0], input_size[1])) + ]) + + self.to_channels = nn.ModuleList([ + SkipConnection(img_channels, hidden_channels), + SkipConnection(img_channels, level1_channels), + SkipConnection(img_channels, img_channels) + ]) + + self.layers2 = nn.Sequential( + PatchDownConv(img_channels, level1_channels, alpha = 1e-16), + *[ResidualBlock(level1_channels, level1_channels, alpha_residual=True) for _ in range(num_layers)] + ) + + self.layers1 = PatchDownConv(level1_channels, hidden_channels) + + self.layers0 = nn.Sequential( + *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)] + ) + + self.position_encoder = nn.Sequential( + *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)], + ResidualBlock(hidden_channels, 3), + ) + + self.xy_encoder = PixelToPosition(latent_size) + self.std_encoder = PixelToSTD() + self.priority_encoder = PixelToPriority() + + if bottleneck == "binar": + print("Binary bottleneck") + self.gestalt_encoder = nn.Sequential( + *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)], + AggressiveConvToGestalt(hidden_channels, gestalt_size, latent_size), + LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')), + Binarize(), + ) + + else: + print("unrestricted bottleneck") + self.gestalt_encoder = nn.Sequential( + *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)], + AggressiveConvToGestalt(hidden_channels, gestalt_size, latent_size), + LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')), + nn.Sigmoid(), + ) + + def set_level(self, level): + self.level = level + + def forward( + self, + input: th.Tensor, + error: th.Tensor, + mask: th.Tensor, + object: th.Tensor, + position: th.Tensor, + rawmask: th.Tensor + ): + + latent = self.preprocess[self.level](input, error, mask, object, position, rawmask) + latent = self.to_channels[self.level](latent) + + if self.level >= 2: + latent = self.layers2(latent) + + if self.level >= 1: + latent = self.layers1(latent) + + latent = self.layers0(latent) + gestalt = self.gestalt_encoder(latent) + + latent = self.position_encoder(latent) + std = self.std_encoder(latent[:,0:1]) + xy = self.xy_encoder(latent[:,1:2]) + priority = self.priority_encoder(latent[:,2:3]) + + position = self.to_shared(th.cat((xy, std), dim=1)) + gestalt = self.to_shared(gestalt) + priority = self.to_shared(priority) + + return position, gestalt, priority + diff --git a/model/nn/eprop_gate_l0rd.py b/model/nn/eprop_gate_l0rd.py new file mode 100644 index 0000000..82a9895 --- /dev/null +++ b/model/nn/eprop_gate_l0rd.py @@ -0,0 +1,329 @@ +import torch.nn as nn +import torch as th +import numpy as np +from torch.autograd import Function +from einops import rearrange, repeat, reduce + +__author__ = "Manuel Traub" + +class EpropGateL0rdFunction(Function): + @staticmethod + def forward(ctx, x, h_last, w_gx, w_gh, b_g, w_rx, w_rh, b_r, args): + + e_w_gx, e_w_gh, e_b_g, e_w_rx, e_w_rh, e_b_r, reg, noise_level = args + + noise = th.normal(mean=0, std=noise_level, size=b_g.shape, device=b_g.device) + g = th.relu(th.tanh(x.mm(w_gx.t()) + h_last.mm(w_gh.t()) + b_g + noise)) + r = th.tanh(x.mm(w_rx.t()) + h_last.mm(w_rh.t()) + b_r) + + h = g * r + (1 - g) * h_last + + # Haevisite step function + H_g = th.ceil(g).clamp(0, 1) + + dg = (1 - g**2) * H_g + dr = (1 - r**2) + + delta_h = r - h_last + + g_j = g.unsqueeze(dim=2) + dg_j = dg.unsqueeze(dim=2) + dr_j = dr.unsqueeze(dim=2) + + x_i = x.unsqueeze(dim=1) + h_last_i = h_last.unsqueeze(dim=1) + delta_h_j = delta_h.unsqueeze(dim=2) + + e_w_gh.copy_(e_w_gh * (1 - g_j) + dg_j * h_last_i * delta_h_j) + e_w_gx.copy_(e_w_gx * (1 - g_j) + dg_j * x_i * delta_h_j) + e_b_g.copy_( e_b_g * (1 - g) + dg * delta_h ) + + e_w_rh.copy_(e_w_rh * (1 - g_j) + dr_j * h_last_i * g_j) + e_w_rx.copy_(e_w_rx * (1 - g_j) + dr_j * x_i * g_j) + e_b_r.copy_( e_b_r * (1 - g) + dr * g ) + + ctx.save_for_backward( + g.clone(), dg.clone(), dg_j.clone(), dr.clone(), x_i.clone(), h_last_i.clone(), + reg.clone(), H_g.clone(), delta_h.clone(), w_gx.clone(), w_gh.clone(), w_rx.clone(), w_rh.clone(), + e_w_gx.clone(), e_w_gh.clone(), e_b_g.clone(), + e_w_rx.clone(), e_w_rh.clone(), e_b_r.clone(), + ) + + return h, th.mean(H_g) + + @staticmethod + def backward(ctx, dh, _): + + g, dg, dg_j, dr, x_i, h_last_i, reg, H_g, delta_h, w_gx, w_gh, w_rx, w_rh, \ + e_w_gx, e_w_gh, e_b_g, e_w_rx, e_w_rh, e_b_r = ctx.saved_tensors + + dh_j = dh.unsqueeze(dim=2) + H_g_reg = reg * H_g + H_g_reg_j = H_g_reg.unsqueeze(dim=2) + + dw_gx = th.sum(dh_j * e_w_gx + H_g_reg_j * dg_j * x_i, dim=0) + dw_gh = th.sum(dh_j * e_w_gh + H_g_reg_j * dg_j * h_last_i, dim=0) + db_g = th.sum(dh * e_b_g + H_g_reg * dg, dim=0) + + dw_rx = th.sum(dh_j * e_w_rx, dim=0) + dw_rh = th.sum(dh_j * e_w_rh, dim=0) + db_r = th.sum(dh * e_b_r , dim=0) + + dh_dg = (dh * delta_h + H_g_reg) * dg + dh_dr = dh * g * dr + + dx = dh_dg.mm(w_gx) + dh_dr.mm(w_rx) + dh = dh * (1 - g) + dh_dg.mm(w_gh) + dh_dr.mm(w_rh) + + return dx, dh, dw_gx, dw_gh, db_g, dw_rx, dw_rh, db_r, None + +class ReTanhFunction(Function): + @staticmethod + def forward(ctx, x, reg): + + g = th.relu(th.tanh(x)) + + # Haevisite step function + H_g = th.ceil(g).clamp(0, 1) + + dg = (1 - g**2) * H_g + + ctx.save_for_backward(g, dg, H_g, reg) + return g, th.mean(H_g) + + @staticmethod + def backward(ctx, dh, _): + + g, dg, H_g, reg = ctx.saved_tensors + + dx = (dh + reg * H_g) * dg + + return dx, None + +class ReTanh(nn.Module): + def __init__(self, reg_lambda): + super(ReTanh, self).__init__() + + self.re_tanh = ReTanhFunction().apply + self.register_buffer("reg_lambda", th.tensor(reg_lambda), persistent=False) + + def forward(self, input): + h, openings = self.re_tanh(input, self.reg_lambda) + self.openings = openings.item() + + return h + + +class EpropGateL0rd(nn.Module): + def __init__( + self, + num_inputs, + num_hidden, + num_outputs, + batch_size, + reg_lambda = 0, + gate_noise_level = 0, + ): + super(EpropGateL0rd, self).__init__() + + self.register_buffer("reg", th.tensor(reg_lambda).view(1,1), persistent=False) + self.register_buffer("noise", th.tensor(gate_noise_level), persistent=False) + self.num_inputs = num_inputs + self.num_hidden = num_hidden + self.num_outputs = num_outputs + + self.fcn = EpropGateL0rdFunction().apply + self.retanh = ReTanh(reg_lambda) + + # gate weights and biases + self.w_gx = nn.Parameter(th.empty(num_hidden, num_inputs)) + self.w_gh = nn.Parameter(th.empty(num_hidden, num_hidden)) + self.b_g = nn.Parameter(th.zeros(num_hidden)) + + # candidate weights and biases + self.w_rx = nn.Parameter(th.empty(num_hidden, num_inputs)) + self.w_rh = nn.Parameter(th.empty(num_hidden, num_hidden)) + self.b_r = nn.Parameter(th.zeros(num_hidden)) + + # output projection weights and bias + self.w_px = nn.Parameter(th.empty(num_outputs, num_inputs)) + self.w_ph = nn.Parameter(th.empty(num_outputs, num_hidden)) + self.b_p = nn.Parameter(th.zeros(num_outputs)) + + # output gate weights and bias + self.w_ox = nn.Parameter(th.empty(num_outputs, num_inputs)) + self.w_oh = nn.Parameter(th.empty(num_outputs, num_hidden)) + self.b_o = nn.Parameter(th.zeros(num_outputs)) + + # input gate eligibilitiy traces + self.register_buffer("e_w_gx", th.zeros(batch_size, num_hidden, num_inputs), persistent=False) + self.register_buffer("e_w_gh", th.zeros(batch_size, num_hidden, num_hidden), persistent=False) + self.register_buffer("e_b_g", th.zeros(batch_size, num_hidden), persistent=False) + + # forget gate eligibilitiy traces + self.register_buffer("e_w_rx", th.zeros(batch_size, num_hidden, num_inputs), persistent=False) + self.register_buffer("e_w_rh", th.zeros(batch_size, num_hidden, num_hidden), persistent=False) + self.register_buffer("e_b_r", th.zeros(batch_size, num_hidden), persistent=False) + + # hidden state + self.register_buffer("h_last", th.zeros(batch_size, num_hidden), persistent=False) + + self.register_buffer("openings", th.zeros(1), persistent=False) + + # initialize weights + stdv_ih = np.sqrt(6/(self.num_inputs + self.num_hidden)) + stdv_hh = np.sqrt(3/self.num_hidden) + stdv_io = np.sqrt(6/(self.num_inputs + self.num_outputs)) + stdv_ho = np.sqrt(6/(self.num_hidden + self.num_outputs)) + + nn.init.uniform_(self.w_gx, -stdv_ih, stdv_ih) + nn.init.uniform_(self.w_gh, -stdv_hh, stdv_hh) + + nn.init.uniform_(self.w_rx, -stdv_ih, stdv_ih) + nn.init.uniform_(self.w_rh, -stdv_hh, stdv_hh) + + nn.init.uniform_(self.w_px, -stdv_io, stdv_io) + nn.init.uniform_(self.w_ph, -stdv_ho, stdv_ho) + + nn.init.uniform_(self.w_ox, -stdv_io, stdv_io) + nn.init.uniform_(self.w_oh, -stdv_ho, stdv_ho) + + self.backprop = False + + def reset_state(self): + self.h_last.zero_() + self.e_w_gx.zero_() + self.e_w_gh.zero_() + self.e_b_g.zero_() + self.e_w_rx.zero_() + self.e_w_rh.zero_() + self.e_b_r.zero_() + self.openings.zero_() + + def backprop_forward(self, x: th.Tensor): + + noise = th.normal(mean=0, std=self.noise, size=self.b_g.shape, device=self.b_g.device) + g = self.retanh(x.mm(self.w_gx.t()) + self.h_last.mm(self.w_gh.t()) + self.b_g + noise) + r = th.tanh(x.mm(self.w_rx.t()) + self.h_last.mm(self.w_rh.t()) + self.b_r) + + self.h_last = g * r + (1 - g) * self.h_last + + # Haevisite step function + H_g = th.ceil(g).clamp(0, 1) + + self.openings = th.mean(H_g) + + p = th.tanh(x.mm(self.w_px.t()) + self.h_last.mm(self.w_ph.t()) + self.b_p) + o = th.sigmoid(x.mm(self.w_ox.t()) + self.h_last.mm(self.w_oh.t()) + self.b_o) + return o * p + + def activate_backprop(self): + self.backprop = True + + def deactivate_backprop(self): + self.backprop = False + + def detach(self): + self.h_last.detach_() + + def eprop_forward(self, x: th.Tensor): + h, openings = self.fcn( + x, self.h_last, + self.w_gx, self.w_gh, self.b_g, + self.w_rx, self.w_rh, self.b_r, + ( + self.e_w_gx, self.e_w_gh, self.e_b_g, + self.e_w_rx, self.e_w_rh, self.e_b_r, + self.reg, self.noise + ) + ) + + self.openings = openings + self.h_last = h + + p = th.tanh(x.mm(self.w_px.t()) + h.mm(self.w_ph.t()) + self.b_p) + o = th.sigmoid(x.mm(self.w_ox.t()) + h.mm(self.w_oh.t()) + self.b_o) + return o * p + + def save_hidden(self): + self.h_last_saved = self.h_last.detach() + + def restore_hidden(self): + self.h_last = self.h_last_saved + + def get_hidden(self): + return self.h_last + + def set_hidden(self, h_last): + self.h_last = h_last + + def forward(self, x: th.Tensor): + if self.backprop: + return self.backprop_forward(x) + + return self.eprop_forward(x) + + +class EpropGateL0rdShared(EpropGateL0rd): + def __init__( + self, + num_inputs, + num_hidden, + num_outputs, + batch_size, + reg_lambda = 0, + gate_noise_level = 0, + ): + super().__init__(num_inputs, num_hidden, num_outputs, batch_size, reg_lambda, gate_noise_level) + + def backprop_forward(self, x: th.Tensor, h_last: th.Tensor): + + noise = th.normal(mean=0, std=self.noise, size=self.b_g.shape, device=self.b_g.device) + g = self.retanh(x.mm(self.w_gx.t()) + h_last.mm(self.w_gh.t()) + self.b_g + noise) + r = th.tanh(x.mm(self.w_rx.t()) + h_last.mm(self.w_rh.t()) + self.b_r) + + h_last = g * r + (1 - g) * h_last + + # Haevisite step function + H_g = th.ceil(g).clamp(0, 1) + + self.openings = th.mean(H_g) + + p = th.tanh(x.mm(self.w_px.t()) + h_last.mm(self.w_ph.t()) + self.b_p) + o = th.sigmoid(x.mm(self.w_ox.t()) + h_last.mm(self.w_oh.t()) + self.b_o) + return o * p, h_last + + def eprop_forward(self, x: th.Tensor, h_last: th.Tensor): + h, openings = self.fcn( + x, h_last, + self.w_gx, self.w_gh, self.b_g, + self.w_rx, self.w_rh, self.b_r, + ( + self.e_w_gx, self.e_w_gh, self.e_b_g, + self.e_w_rx, self.e_w_rh, self.e_b_r, + self.reg, self.noise + ) + ) + + self.openings = openings + + p = th.tanh(x.mm(self.w_px.t()) + h.mm(self.w_ph.t()) + self.b_p) + o = th.sigmoid(x.mm(self.w_ox.t()) + h.mm(self.w_oh.t()) + self.b_o) + return o * p, h + + def forward(self, x: th.Tensor, h_last: th.Tensor = None): + + if h_last is not None: + if self.backprop: + return self.backprop_forward(x, h_last) + + return self.eprop_forward(x, h_last) + + # backward compatibility + if self.backprop: + x, h = self.backprop_forward(x, self.h_last) + + x, h = self.eprop_forward(x, self.h_last) + + self.h_last = h + return x
\ No newline at end of file diff --git a/model/nn/eprop_transformer.py b/model/nn/eprop_transformer.py new file mode 100644 index 0000000..4d89ec4 --- /dev/null +++ b/model/nn/eprop_transformer.py @@ -0,0 +1,76 @@ +import torch.nn as nn +import torch as th +from model.nn.eprop_gate_l0rd import EpropGateL0rd +from model.nn.eprop_transformer_utils import AlphaAttention, InputEmbeding, OutputEmbeding + +class EpropGateL0rdTransformer(nn.Module): + def __init__( + self, + channels, + multiplier, + num_objects, + batch_size, + heads, + depth, + reg_lambda, + dropout=0.0 + ): + super(EpropGateL0rdTransformer, self).__init__() + + num_inputs = channels + num_outputs = channels + num_hidden = channels + num_hidden = channels * multiplier + + print(f"Predictor channels: {num_hidden}@({num_hidden // heads}x{heads})") + + + self.depth = depth + _layers = [] + _layers.append(InputEmbeding(num_inputs, num_hidden)) + + for i in range(depth): + _layers.append(AlphaAttention(num_hidden, num_objects, heads, dropout)) + _layers.append(EpropAlphaGateL0rd(num_hidden, batch_size * num_objects, reg_lambda)) + + _layers.append(OutputEmbeding(num_hidden, num_outputs)) + self.layers = nn.Sequential(*_layers) + + def get_openings(self): + openings = 0 + for i in range(self.depth): + openings += self.layers[2 * (i + 1)].l0rd.openings.item() + + return openings / self.depth + + def get_hidden(self): + states = [] + for i in range(self.depth): + states.append(self.layers[2 * (i + 1)].l0rd.get_hidden()) + + return th.cat(states, dim=1) + + def set_hidden(self, hidden): + states = th.chunk(hidden, self.depth, dim=1) + for i in range(self.depth): + self.layers[2 * (i + 1)].l0rd.set_hidden(states[i]) + + def forward(self, input: th.Tensor) -> th.Tensor: + return self.layers(input) + + +class EpropAlphaGateL0rd(nn.Module): + def __init__(self, num_hidden, batch_size, reg_lambda): + super(EpropAlphaGateL0rd, self).__init__() + + self.alpha = nn.Parameter(th.zeros(1)+1e-12) + self.l0rd = EpropGateL0rd( + num_inputs = num_hidden, + num_hidden = num_hidden, + num_outputs = num_hidden, + reg_lambda = reg_lambda, + batch_size = batch_size + ) + + def forward(self, input): + return input + self.alpha * self.l0rd(input)
\ No newline at end of file diff --git a/model/nn/eprop_transformer_shared.py b/model/nn/eprop_transformer_shared.py new file mode 100644 index 0000000..23a1b0f --- /dev/null +++ b/model/nn/eprop_transformer_shared.py @@ -0,0 +1,92 @@ +import torch.nn as nn +import torch as th +from model.nn.eprop_gate_l0rd import EpropGateL0rdShared +from model.nn.eprop_transformer_utils import AlphaAttention, InputEmbeding, OutputEmbeding + +class EpropGateL0rdTransformerShared(nn.Module): + def __init__( + self, + channels, + multiplier, + num_objects, + batch_size, + heads, + depth, + reg_lambda, + dropout=0.0, + exchange_length = 48, + ): + super(EpropGateL0rdTransformerShared, self).__init__() + + num_inputs = channels + num_outputs = channels + num_hidden = channels * multiplier + num_hidden_gatelord = num_hidden + exchange_length + num_hidden_attention = num_hidden + exchange_length + num_hidden_gatelord + + self.num_hidden = num_hidden + self.num_hidden_gatelord = num_hidden_gatelord + + #print(f"Predictor channels: {num_hidden}@({num_hidden // heads}x{heads})") + + self.register_buffer('hidden', th.zeros(batch_size * num_objects, num_hidden_gatelord), persistent=False) + self.register_buffer('exchange_code', th.zeros(batch_size * num_objects, exchange_length), persistent=False) + + self.depth = depth + self.input_embeding = InputEmbeding(num_inputs, num_hidden) + self.attention = nn.Sequential(*[AlphaAttention(num_hidden_attention, num_objects, heads, dropout) for _ in range(depth)]) + self.l0rds = nn.Sequential(*[EpropAlphaGateL0rdShared(num_hidden_gatelord, batch_size * num_objects, reg_lambda) for _ in range(depth)]) + self.output_embeding = OutputEmbeding(num_hidden, num_outputs) + + def get_openings(self): + openings = 0 + for i in range(self.depth): + openings += self.l0rds[i].l0rd.openings.item() + + return openings / self.depth + + def get_hidden(self): + return self.hidden + + def set_hidden(self, hidden): + self.hidden = hidden + + def detach(self): + self.hidden = self.hidden.detach() + + def reset_state(self): + self.hidden = th.zeros_like(self.hidden) + + def forward(self, x: th.Tensor) -> th.Tensor: + x = self.input_embeding(x) + exchange_code = self.exchange_code.clone() * 0.0 + x_ex = th.concat((x, exchange_code), dim=1) + + for i in range(self.depth): + # attention layer + att = self.attention(th.concat((x_ex, self.hidden), dim=1)) + x_ex = att[:, :self.num_hidden_gatelord] + + # gatelord layer + x_ex, self.hidden = self.l0rds[i](x_ex, self.hidden) + + # only yield x + x = x_ex[:, :self.num_hidden] + return self.output_embeding(x) + +class EpropAlphaGateL0rdShared(nn.Module): + def __init__(self, num_hidden, batch_size, reg_lambda): + super(EpropAlphaGateL0rdShared, self).__init__() + + self.alpha = nn.Parameter(th.zeros(1)+1e-12) + self.l0rd = EpropGateL0rdShared( + num_inputs = num_hidden, + num_hidden = num_hidden, + num_outputs = num_hidden, + reg_lambda = reg_lambda, + batch_size = batch_size + ) + + def forward(self, input, hidden): + output, hidden = self.l0rd(input, hidden) + return input + self.alpha * output, hidden
\ No newline at end of file diff --git a/model/nn/eprop_transformer_utils.py b/model/nn/eprop_transformer_utils.py new file mode 100644 index 0000000..9e5e874 --- /dev/null +++ b/model/nn/eprop_transformer_utils.py @@ -0,0 +1,66 @@ +import torch.nn as nn +import torch as th +from model.utils.nn_utils import LambdaModule +from einops import rearrange, repeat, reduce + +class AlphaAttention(nn.Module): + def __init__( + self, + num_hidden, + num_objects, + heads, + dropout = 0.0 + ): + super(AlphaAttention, self).__init__() + + self.to_sequence = LambdaModule(lambda x: rearrange(x, '(b o) c -> b o c', o = num_objects)) + self.to_batch = LambdaModule(lambda x: rearrange(x, 'b o c -> (b o) c', o = num_objects)) + + self.alpha = nn.Parameter(th.zeros(1)+1e-12) + self.attention = nn.MultiheadAttention( + num_hidden, + heads, + dropout = dropout, + batch_first = True + ) + + def forward(self, x: th.Tensor): + x = self.to_sequence(x) + x = x + self.alpha * self.attention(x, x, x, need_weights=False)[0] + return self.to_batch(x) + +class InputEmbeding(nn.Module): + def __init__(self, num_inputs, num_hidden): + super(InputEmbeding, self).__init__() + + self.embeding = nn.Sequential( + nn.ReLU(), + nn.Linear(num_inputs, num_hidden), + nn.ReLU(), + nn.Linear(num_hidden, num_hidden), + ) + self.skip = LambdaModule( + lambda x: repeat(x, 'b c -> b (n c)', n = num_hidden // num_inputs) + ) + self.alpha = nn.Parameter(th.zeros(1)+1e-12) + + def forward(self, input: th.Tensor): + return self.skip(input) + self.alpha * self.embeding(input) + +class OutputEmbeding(nn.Module): + def __init__(self, num_hidden, num_outputs): + super(OutputEmbeding, self).__init__() + + self.embeding = nn.Sequential( + nn.ReLU(), + nn.Linear(num_hidden, num_outputs), + nn.ReLU(), + nn.Linear(num_outputs, num_outputs), + ) + self.skip = LambdaModule( + lambda x: reduce(x, 'b (n c) -> b c', 'mean', n = num_hidden // num_outputs) + ) + self.alpha = nn.Parameter(th.zeros(1)+1e-12) + + def forward(self, input: th.Tensor): + return self.skip(input) + self.alpha * self.embeding(input)
\ No newline at end of file diff --git a/model/nn/percept_gate_controller.py b/model/nn/percept_gate_controller.py new file mode 100644 index 0000000..a548c20 --- /dev/null +++ b/model/nn/percept_gate_controller.py @@ -0,0 +1,59 @@ +import torch.nn as nn +import torch as th +from model.utils.nn_utils import LambdaModule +from einops import rearrange, repeat, reduce +from model.nn.eprop_gate_l0rd import ReTanh + +class PerceptGateController(nn.Module): + def __init__( + self, + num_inputs: int, + num_hidden: list, + bias: bool, + num_objects: int, + gate_noise_level: float = 0.1, + reg_lambda: float = 0.000005 + ): + super(PerceptGateController, self).__init__() + + self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o=num_objects)) + self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b o c', o=num_objects)) + + self.layers = nn.Sequential( + nn.Linear(num_inputs, num_hidden[0], bias = bias), + nn.Tanh(), + nn.Linear(num_hidden[0], num_hidden[1], bias = bias), + nn.Tanh(), + nn.Linear(num_hidden[1], 2, bias = bias) + ) + self.output_function = ReTanh(reg_lambda) + self.register_buffer("noise", th.tensor(gate_noise_level), persistent=False) + self.init_weights() + + def init_weights(self): + for layer in self.layers: + if isinstance(layer, nn.Linear): + nn.init.xavier_uniform(layer.weight) + layer.bias.data.fill_(3.00) + + def forward(self, position_cur, gestalt_cur, priority_cur, slots_occlusionfactor_cur, position_last, gestalt_last, priority_last, slots_occlusionfactor_last, position_last2, evaluate=False): + + position_cur = self.to_batch(position_cur) + gestalt_cur = self.to_batch(gestalt_cur) + priority_cur = self.to_batch(priority_cur) + position_last = self.to_batch(position_last) + gestalt_last = self.to_batch(gestalt_last) + priority_last = self.to_batch(priority_last) + slots_occlusionfactor_cur = self.to_batch(slots_occlusionfactor_cur).detach() + slots_occlusionfactor_last = self.to_batch(slots_occlusionfactor_last).detach() + position_last2 = self.to_batch(position_last2).detach() + + input = th.cat((position_cur, gestalt_cur, priority_cur, slots_occlusionfactor_cur, position_last, gestalt_last, priority_last, slots_occlusionfactor_last, position_last2), dim=1) + output = self.layers(input) + if evaluate: + output = self.output_function(output) + else: + noise = th.normal(mean=0, std=self.noise, size=output.shape, device=output.device) + output = self.output_function(output + noise) + + return self.to_shared(output) diff --git a/model/nn/predictor.py b/model/nn/predictor.py new file mode 100644 index 0000000..94f13b8 --- /dev/null +++ b/model/nn/predictor.py @@ -0,0 +1,99 @@ +import torch.nn as nn +import torch as th +from model.nn.eprop_transformer import EpropGateL0rdTransformer +from model.nn.eprop_transformer_shared import EpropGateL0rdTransformerShared +from model.utils.nn_utils import LambdaModule, Binarize +from model.nn.residual import ResidualBlock +from einops import rearrange, repeat, reduce + +__author__ = "Manuel Traub" + +class LociPredictor(nn.Module): + def __init__( + self, + heads: int, + layers: int, + channels_multiplier: int, + reg_lambda: float, + num_objects: int, + gestalt_size: int, + batch_size: int, + bottleneck: str, + transformer_type = 'standard', + ): + super(LociPredictor, self).__init__() + self.num_objects = num_objects + self.std_alpha = nn.Parameter(th.zeros(1)+1e-16) + self.bottleneck_type = bottleneck + self.gestalt_size = gestalt_size + + self.reg_lambda = reg_lambda + Transformer = EpropGateL0rdTransformerShared if transformer_type == 'shared' else EpropGateL0rdTransformer + self.predictor = Transformer( + channels = gestalt_size + 3 + 1 + 2, + multiplier = channels_multiplier, + heads = heads, + depth = layers, + num_objects = num_objects, + reg_lambda = reg_lambda, + batch_size = batch_size, + ) + + if bottleneck == 'binar': + print("Binary bottleneck") + self.bottleneck = nn.Sequential( + LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')), + ResidualBlock(gestalt_size, gestalt_size, kernel_size=1), + Binarize(), + LambdaModule(lambda x: rearrange(x, '(b o) c 1 1 -> b (o c)', o=num_objects)) + ) + + else: + print("unrestricted bottleneck") + self.bottleneck = nn.Sequential( + LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')), + ResidualBlock(gestalt_size, gestalt_size, kernel_size=1), + nn.Sigmoid(), + LambdaModule(lambda x: rearrange(x, '(b o) c 1 1 -> b (o c)', o=num_objects)) + ) + + self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o=num_objects)) + self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o=num_objects)) + + def get_openings(self): + return self.predictor.get_openings() + + def get_hidden(self): + return self.predictor.get_hidden() + + def set_hidden(self, hidden): + self.predictor.set_hidden(hidden) + + def forward( + self, + gestalt: th.Tensor, + priority: th.Tensor, + position: th.Tensor, + slots_closed: th.Tensor, + ): + + position = self.to_batch(position) + gestalt_cur = self.to_batch(gestalt) + priority = self.to_batch(priority) + slots_closed = rearrange(slots_closed, 'b o c -> (b o) c').detach() + + input = th.cat((gestalt_cur, priority, position, slots_closed), dim=1) + output = self.predictor(input) + + gestalt = output[:, :self.gestalt_size] + priority = output[:,self.gestalt_size:(self.gestalt_size+1)] + xy = output[:,(self.gestalt_size+1):(self.gestalt_size+3)] + std = output[:,(self.gestalt_size+3):(self.gestalt_size+4)] + + position = th.cat((xy, std * self.std_alpha), dim=1) + + position = self.to_shared(position) + gestalt = self.bottleneck(gestalt) + priority = self.to_shared(priority) + + return position, gestalt, priority diff --git a/model/nn/residual.py b/model/nn/residual.py new file mode 100644 index 0000000..1602e16 --- /dev/null +++ b/model/nn/residual.py @@ -0,0 +1,396 @@ +import torch.nn as nn +import torch as th +import numpy as np +from einops import rearrange, repeat, reduce +from model.utils.nn_utils import LambdaModule + +from typing import Union, Tuple + +__author__ = "Manuel Traub" + +class DynamicLayerNorm(nn.Module): + + def __init__(self, eps: float = 1e-5): + super(DynamicLayerNorm, self).__init__() + self.eps = eps + + def forward(self, input: th.Tensor) -> th.Tensor: + return nn.functional.layer_norm(input, input.shape[2:], None, None, self.eps) + + +class SkipConnection(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + scale_factor: float = 1.0 + ): + super(SkipConnection, self).__init__() + assert scale_factor == 1 or int(scale_factor) > 1 or int(1 / scale_factor) > 1, f'invalid scale factor in SpikeFunction: {scale_factor}' + + self.in_channels = in_channels + self.out_channels = out_channels + self.scale_factor = scale_factor + + def channel_skip(self, input: th.Tensor): + in_channels = self.in_channels + out_channels = self.out_channels + + if in_channels == out_channels: + return input + + if in_channels % out_channels == 0 or out_channels % in_channels == 0: + + if in_channels > out_channels: + return reduce(input, 'b (c n) h w -> b c h w', 'mean', n = in_channels // out_channels) + + if out_channels > in_channels: + return repeat(input, 'b c h w -> b (c n) h w', n = out_channels // in_channels) + + mean_channels = np.gcd(in_channels, out_channels) + input = reduce(input, 'b (c n) h w -> b c h w', 'mean', n = in_channels // mean_channels) + return repeat(input, 'b c h w -> b (c n) h w', n = out_channels // mean_channels) + + def scale_skip(self, input: th.Tensor): + scale_factor = self.scale_factor + + if scale_factor == 1: + return input + + if scale_factor > 1: + return repeat( + input, + 'b c h w -> b c (h h2) (w w2)', + h2 = int(scale_factor), + w2 = int(scale_factor) + ) + + height = input.shape[2] + width = input.shape[3] + + # scale factor < 1 + scale_factor = int(1 / scale_factor) + + if width % scale_factor == 0 and height % scale_factor == 0: + return reduce( + input, + 'b c (h h2) (w w2) -> b c h w', + 'mean', + h2 = scale_factor, + w2 = scale_factor + ) + + if width >= scale_factor and height >= scale_factor: + return nn.functional.avg_pool2d( + input, + kernel_size = scale_factor, + stride = scale_factor + ) + + assert width > 1 or height > 1 + return reduce(input, 'b c h w -> b c 1 1', 'mean') + + + def forward(self, input: th.Tensor): + + if self.scale_factor > 1: + return self.scale_skip(self.channel_skip(input)) + + return self.channel_skip(self.scale_skip(input)) + +class DownScale(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + scale_factor: int, + groups: int = 1, + bias: bool = True + ): + + super(DownScale, self).__init__() + + assert(in_channels % groups == 0) + assert(out_channels % groups == 0) + + self.groups = groups + self.scale_factor = scale_factor + self.weight = nn.Parameter(th.empty((out_channels, in_channels // groups, scale_factor, scale_factor))) + self.bias = nn.Parameter(th.empty((out_channels,))) if bias else None + + nn.init.kaiming_uniform_(self.weight, a=np.sqrt(5)) + + if self.bias is not None: + fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / np.sqrt(fan_in) + nn.init.uniform_(self.bias, -bound, bound) + + def forward(self, input: th.Tensor): + height = input.shape[2] + width = input.shape[3] + assert height > 1 or width > 1, "trying to dowscale 1x1" + + scale_factor = self.scale_factor + padding = [0, 0] + + if height < scale_factor: + padding[0] = scale_factor - height + + if width < scale_factor: + padding[1] = scale_factor - width + + return nn.functional.conv2d( + input, + self.weight, + bias=self.bias, + stride=scale_factor, + padding=padding, + groups=self.groups + ) + + +class ResidualBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]] = (3, 3), + scale_factor: int = 1, + groups: Union[int, Tuple[int, int]] = (1, 1), + bias: bool = True, + layer_norm: bool = False, + leaky_relu: bool = False, + residual: bool = True, + alpha_residual: bool = False, + input_nonlinearity = True + ): + + super(ResidualBlock, self).__init__() + self.residual = residual + self.alpha_residual = alpha_residual + self.skip = False + self.in_channels = in_channels + self.out_channels = out_channels + + if isinstance(kernel_size, int): + kernel_size = [kernel_size, kernel_size] + + if isinstance(groups, int): + groups = [groups, groups] + + padding = (kernel_size[0] // 2, kernel_size[1] // 2) + + _layers = list() + if layer_norm: + _layers.append(DynamicLayerNorm()) + + if input_nonlinearity: + if leaky_relu: + _layers.append(nn.LeakyReLU()) + else: + _layers.append(nn.ReLU()) + + if scale_factor > 1: + _layers.append( + nn.ConvTranspose2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=scale_factor, + stride=scale_factor, + groups=groups[0], + bias=bias + ) + ) + elif scale_factor < 1: + _layers.append( + DownScale( + in_channels=in_channels, + out_channels=out_channels, + scale_factor=int(1.0/scale_factor), + groups=groups[0], + bias=bias + ) + ) + else: + _layers.append( + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + groups=groups[0], + bias=bias + ) + ) + + if layer_norm: + _layers.append(DynamicLayerNorm()) + if leaky_relu: + _layers.append(nn.LeakyReLU()) + else: + _layers.append(nn.ReLU()) + _layers.append( + nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + groups=groups[1], + bias=bias + ) + ) + self.layers = nn.Sequential(*_layers) + + if self.residual: + self.skip_connection = SkipConnection( + in_channels=in_channels, + out_channels=out_channels, + scale_factor=scale_factor + ) + + if self.alpha_residual: + self.alpha = nn.Parameter(th.zeros(1) + 1e-12) + + def set_mode(self, **kwargs): + if 'skip' in kwargs: + self.skip = kwargs['skip'] + + if 'residual' in kwargs: + self.residual = kwargs['residual'] + + def forward(self, input: th.Tensor) -> th.Tensor: + if self.skip: + return self.skip_connection(input) + + if not self.residual: + return self.layers(input) + + if self.alpha_residual: + return self.alpha * self.layers(input) + self.skip_connection(input) + + return self.layers(input) + self.skip_connection(input) + +class LinearSkip(nn.Module): + def __init__(self, num_inputs: int, num_outputs: int): + super(LinearSkip, self).__init__() + + self.num_inputs = num_inputs + self.num_outputs = num_outputs + + if num_inputs % num_outputs != 0 and num_outputs % num_inputs != 0: + mean_channels = np.gcd(num_inputs, num_outputs) + print(f"[WW] gcd skip: {num_inputs} -> {mean_channels} -> {num_outputs}") + assert(False) + + def forward(self, input: th.Tensor): + num_inputs = self.num_inputs + num_outputs = self.num_outputs + + if num_inputs == num_outputs: + return input + + if num_inputs % num_outputs == 0 or num_outputs % num_inputs == 0: + + if num_inputs > num_outputs: + return reduce(input, 'b (c n) -> b c', 'mean', n = num_inputs // num_outputs) + + if num_outputs > num_inputs: + return repeat(input, 'b c -> b (c n)', n = num_outputs // num_inputs) + + mean_channels = np.gcd(num_inputs, num_outputs) + input = reduce(input, 'b (c n) -> b c', 'mean', n = num_inputs // mean_channels) + return repeat(input, 'b c -> b (c n)', n = num_outputs // mean_channels) + +class LinearResidual(nn.Module): + def __init__( + self, + num_inputs: int, + num_outputs: int, + num_hidden: int = None, + residual: bool = True, + alpha_residual: bool = False, + input_relu: bool = True + ): + super(LinearResidual, self).__init__() + + self.residual = residual + self.alpha_residual = alpha_residual + + if num_hidden is None: + num_hidden = num_outputs + + _layers = [] + if input_relu: + _layers.append(nn.ReLU()) + _layers.append(nn.Linear(num_inputs, num_hidden)) + _layers.append(nn.ReLU()) + _layers.append(nn.Linear(num_hidden, num_outputs)) + + self.layers = nn.Sequential(*_layers) + + if residual: + self.skip = LinearSkip(num_inputs, num_outputs) + + if alpha_residual: + self.alpha = nn.Parameter(th.zeros(1)+1e-16) + + def forward(self, input: th.Tensor): + if not self.residual: + return self.layers(input) + + if not self.alpha_residual: + return self.skip(input) + self.layers(input) + + return self.skip(input) + self.alpha * self.layers(input) + + +class EntityAttention(nn.Module): + def __init__(self, channels, num_objects, size, channels_per_head = 12, dropout = 0.0): + super(EntityAttention, self).__init__() + + assert channels % channels_per_head == 0 + heads = channels // channels_per_head + + self.alpha = nn.Parameter(th.zeros(1)+1e-16) + self.attention = nn.MultiheadAttention( + channels, + heads, + dropout = dropout, + batch_first = True + ) + + self.channel_attention = nn.Sequential( + LambdaModule(lambda x: rearrange(x, '(b o) c h w -> (b h w) o c', o = num_objects)), + LambdaModule(lambda x: self.attention(x, x, x, need_weights=False)[0]), + LambdaModule(lambda x: rearrange(x, '(b h w) o c-> (b o) c h w', h = size[0], w = size[1])), + ) + + + def forward(self, input: th.Tensor) -> th.Tensor: + return input + self.channel_attention(input) * self.alpha + +class ImageAttention(nn.Module): + def __init__(self, channels, gestalt_size, num_objects, size, channels_per_head = 12, dropout = 0.0): + super(ImageAttention, self).__init__() + + assert gestalt_size % channels_per_head == 0 + heads = gestalt_size // channels_per_head + + self.alpha = nn.Parameter(th.zeros(1)+1e-16) + self.attention = nn.MultiheadAttention( + gestalt_size, + heads, + dropout = dropout, + batch_first = True + ) + + self.image_attention = nn.Sequential( + nn.Conv2d(channels, gestalt_size, kernel_size=1), + LambdaModule(lambda x: rearrange(x, 'b c h w -> b (h w) c')), + LambdaModule(lambda x: self.attention(x, x, x, need_weights=False)[0]), + LambdaModule(lambda x: rearrange(x, 'b (h w) c -> b c h w', h = size[0], w = size[1])), + nn.Conv2d(gestalt_size, channels, kernel_size=1), + ) + + def forward(self, input: th.Tensor) -> th.Tensor: + return input + self.image_attention(input) * self.alpha
\ No newline at end of file |