aboutsummaryrefslogtreecommitdiff
path: root/model/nn
diff options
context:
space:
mode:
Diffstat (limited to 'model/nn')
-rw-r--r--model/nn/background.py167
-rw-r--r--model/nn/decoder.py169
-rw-r--r--model/nn/encoder.py269
-rw-r--r--model/nn/eprop_gate_l0rd.py329
-rw-r--r--model/nn/eprop_transformer.py76
-rw-r--r--model/nn/eprop_transformer_shared.py92
-rw-r--r--model/nn/eprop_transformer_utils.py66
-rw-r--r--model/nn/percept_gate_controller.py59
-rw-r--r--model/nn/predictor.py99
-rw-r--r--model/nn/residual.py396
10 files changed, 1722 insertions, 0 deletions
diff --git a/model/nn/background.py b/model/nn/background.py
new file mode 100644
index 0000000..38ec325
--- /dev/null
+++ b/model/nn/background.py
@@ -0,0 +1,167 @@
+import torch.nn as nn
+import torch as th
+from model.nn.residual import ResidualBlock, SkipConnection, LinearResidual
+from model.nn.encoder import PatchDownConv
+from model.nn.encoder import AggressiveConvToGestalt
+from model.nn.decoder import PatchUpscale
+from model.utils.nn_utils import LambdaModule, Binarize
+from einops import rearrange, repeat, reduce
+from typing import Tuple
+
+__author__ = "Manuel Traub"
+
+class BackgroundEnhancer(nn.Module):
+ def __init__(
+ self,
+ input_size: Tuple[int, int],
+ img_channels: int,
+ level1_channels,
+ latent_channels,
+ gestalt_size,
+ batch_size,
+ depth
+ ):
+ super(BackgroundEnhancer, self).__init__()
+
+ latent_size = [input_size[0] // 16, input_size[1] // 16]
+ self.input_size = input_size
+
+ self.register_buffer('init', th.zeros(1).long())
+ self.alpha = nn.Parameter(th.zeros(1)+1e-16)
+
+ self.level = 1
+ self.down_level2 = nn.Sequential(
+ PatchDownConv(img_channels*2+2, level1_channels, alpha = 1e-16),
+ *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for i in range(depth)]
+ )
+
+ self.down_level1 = nn.Sequential(
+ PatchDownConv(level1_channels, latent_channels, alpha = 1),
+ *[ResidualBlock(latent_channels, latent_channels, alpha_residual = False) for i in range(depth)]
+ )
+
+ self.down_level0 = nn.Sequential(
+ *[ResidualBlock(latent_channels, latent_channels) for i in range(depth)],
+ AggressiveConvToGestalt(latent_channels, gestalt_size, latent_size),
+ LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')),
+ Binarize(),
+ )
+
+ self.bias = nn.Parameter(th.zeros((1, gestalt_size, *latent_size)))
+
+ self.to_grid = nn.Sequential(
+ LinearResidual(gestalt_size, gestalt_size, input_relu = False),
+ LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')),
+ LambdaModule(lambda x: x + self.bias),
+ *[ResidualBlock(gestalt_size, gestalt_size) for i in range(depth)],
+ )
+
+
+ self.up_level0 = nn.Sequential(
+ ResidualBlock(gestalt_size, latent_channels),
+ *[ResidualBlock(latent_channels, latent_channels) for i in range(depth)],
+ )
+
+ self.up_level1 = nn.Sequential(
+ *[ResidualBlock(latent_channels, latent_channels, alpha_residual = False) for i in range(depth)],
+ PatchUpscale(latent_channels, level1_channels, alpha = 1),
+ )
+
+ self.up_level2 = nn.Sequential(
+ *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for i in range(depth)],
+ PatchUpscale(level1_channels, img_channels, alpha = 1e-16),
+ )
+
+ self.to_channels = nn.ModuleList([
+ SkipConnection(img_channels*2+2, latent_channels),
+ SkipConnection(img_channels*2+2, level1_channels),
+ SkipConnection(img_channels*2+2, img_channels*2+2),
+ ])
+
+ self.to_img = nn.ModuleList([
+ SkipConnection(latent_channels, img_channels),
+ SkipConnection(level1_channels, img_channels),
+ SkipConnection(img_channels, img_channels),
+ ])
+
+ self.mask = nn.Parameter(th.ones(1, 1, *input_size) * 10)
+ self.object = nn.Parameter(th.ones(1, img_channels, *input_size))
+
+ self.register_buffer('latent', th.zeros((batch_size, gestalt_size)), persistent=False)
+
+ def get_init(self):
+ return self.init.item()
+
+ def step_init(self):
+ self.init = self.init + 1
+
+ def detach(self):
+ self.latent = self.latent.detach()
+
+ def reset_state(self):
+ self.latent = th.zeros_like(self.latent)
+
+ def set_level(self, level):
+ self.level = level
+
+ def encoder(self, input):
+ latent = self.to_channels[self.level](input)
+
+ if self.level >= 2:
+ latent = self.down_level2(latent)
+
+ if self.level >= 1:
+ latent = self.down_level1(latent)
+
+ return self.down_level0(latent)
+
+ def get_last_latent_gird(self):
+ return self.to_grid(self.latent) * self.alpha
+
+ def decoder(self, latent, input):
+ grid = self.to_grid(latent)
+ latent = self.up_level0(grid)
+
+ if self.level >= 1:
+ latent = self.up_level1(latent)
+
+ if self.level >= 2:
+ latent = self.up_level2(latent)
+
+ object = reduce(self.object, '1 c (h h2) (w w2) -> 1 c h w', 'mean', h = input.shape[2], w = input.shape[3])
+ object = repeat(object, '1 c h w -> b c h w', b = input.shape[0])
+
+ return th.sigmoid(object + self.to_img[self.level](latent)), grid
+
+ def forward(self, input: th.Tensor, error: th.Tensor = None, mask: th.Tensor = None, only_mask: bool = False):
+
+ if only_mask:
+ mask = reduce(self.mask, '1 1 (h h2) (w w2) -> 1 1 h w', 'mean', h = input.shape[2], w = input.shape[3])
+ mask = repeat(mask, '1 1 h w -> b 1 h w', b = input.shape[0]) * 0.1
+ return mask
+
+ last_bg = self.decoder(self.latent, input)[0]
+
+ bg_error = th.sqrt(reduce((input - last_bg)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+ bg_mask = (bg_error < th.mean(bg_error) + th.std(bg_error)).float().detach()
+
+ if error is None or self.get_init() < 2:
+ error = bg_error
+
+ if mask is None or self.get_init() < 2:
+ mask = bg_mask
+
+ self.latent = self.encoder(th.cat((input, last_bg, error, mask), dim=1))
+
+ mask = reduce(self.mask, '1 1 (h h2) (w w2) -> 1 1 h w', 'mean', h = input.shape[2], w = input.shape[3])
+ mask = repeat(mask, '1 1 h w -> b 1 h w', b = input.shape[0]) * 0.1
+
+ background, grid = self.decoder(self.latent, input)
+
+ if self.get_init() < 1:
+ return mask, background
+
+ if self.get_init() < 2:
+ return mask, th.zeros_like(background), th.zeros_like(grid), background
+
+ return mask, background, grid * self.alpha, background
diff --git a/model/nn/decoder.py b/model/nn/decoder.py
new file mode 100644
index 0000000..3e56640
--- /dev/null
+++ b/model/nn/decoder.py
@@ -0,0 +1,169 @@
+import torch.nn as nn
+import torch as th
+from model.nn.residual import SkipConnection, ResidualBlock
+from model.utils.nn_utils import Gaus2D, SharedObjectsToBatch, BatchToSharedObjects, Prioritize
+from einops import rearrange, repeat, reduce
+from typing import Tuple, Union, List
+
+__author__ = "Manuel Traub"
+
+class PriorityEncoder(nn.Module):
+ def __init__(self, num_objects, batch_size):
+ super(PriorityEncoder, self).__init__()
+
+ self.num_objects = num_objects
+ self.register_buffer("indices", repeat(th.arange(num_objects), 'a -> b a', b=batch_size), persistent=False)
+
+ self.index_factor = nn.Parameter(th.ones(1))
+ self.priority_factor = nn.Parameter(th.ones(1))
+
+ def forward(self, priority: th.Tensor) -> th.Tensor:
+
+ if priority is None:
+ return None
+
+ priority = priority * self.num_objects + th.randn_like(priority) * 0.1
+ priority = priority * self.priority_factor
+ priority = priority + self.indices * self.index_factor
+
+ return priority * 25
+
+
+class GestaltPositionMerge(nn.Module):
+ def __init__(
+ self,
+ latent_size: Union[int, Tuple[int, int]],
+ num_objects: int,
+ batch_size: int
+ ):
+
+ super(GestaltPositionMerge, self).__init__()
+ self.num_objects = num_objects
+
+ self.gaus2d = Gaus2D(size=latent_size)
+
+ self.to_batch = SharedObjectsToBatch(num_objects)
+ self.to_shared = BatchToSharedObjects(num_objects)
+
+ self.prioritize = Prioritize(num_objects)
+
+ self.priority_encoder = PriorityEncoder(num_objects, batch_size)
+
+ def forward(self, position, gestalt, priority):
+
+ position = rearrange(position, 'b (o c) -> (b o) c', o = self.num_objects)
+ gestalt = rearrange(gestalt, 'b (o c) -> (b o) c 1 1', o = self.num_objects)
+ priority = self.priority_encoder(priority)
+
+ position = self.gaus2d(position)
+ position = self.to_batch(self.prioritize(self.to_shared(position), priority))
+
+ return position * gestalt
+
+class PatchUpscale(nn.Module):
+ def __init__(self, in_channels, out_channels, scale_factor = 4, alpha = 1):
+ super(PatchUpscale, self).__init__()
+ assert in_channels % out_channels == 0
+
+ self.skip = SkipConnection(in_channels, out_channels, scale_factor=scale_factor)
+
+ self.residual = nn.Sequential(
+ nn.ReLU(),
+ nn.Conv2d(
+ in_channels = in_channels,
+ out_channels = in_channels,
+ kernel_size = 3,
+ padding = 1
+ ),
+ nn.ReLU(),
+ nn.ConvTranspose2d(
+ in_channels = in_channels,
+ out_channels = out_channels,
+ kernel_size = scale_factor,
+ stride = scale_factor,
+ ),
+ )
+
+ self.alpha = nn.Parameter(th.zeros(1) + alpha)
+
+ def forward(self, input):
+ return self.skip(input) + self.alpha * self.residual(input)
+
+
+class LociDecoder(nn.Module):
+ def __init__(
+ self,
+ latent_size: Union[int, Tuple[int, int]],
+ gestalt_size: int,
+ num_objects: int,
+ img_channels: int,
+ hidden_channels: int,
+ level1_channels: int,
+ num_layers: int,
+ batch_size: int
+ ):
+
+ super(LociDecoder, self).__init__()
+ self.to_batch = SharedObjectsToBatch(num_objects)
+ self.to_shared = BatchToSharedObjects(num_objects)
+ self.level = 1
+
+ assert(level1_channels % img_channels == 0)
+ level1_factor = level1_channels // img_channels
+ print(f"Level1 channels: {level1_channels}")
+
+ self.merge = GestaltPositionMerge(
+ latent_size = latent_size,
+ num_objects = num_objects,
+ batch_size = batch_size
+ )
+
+ self.layer0 = nn.Sequential(
+ ResidualBlock(gestalt_size, hidden_channels, input_nonlinearity = False),
+ *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers-1)],
+ )
+
+ self.to_mask_level0 = ResidualBlock(hidden_channels, hidden_channels)
+ self.to_mask_level1 = PatchUpscale(hidden_channels, 1)
+
+ self.to_mask_level2 = nn.Sequential(
+ ResidualBlock(hidden_channels, hidden_channels),
+ ResidualBlock(hidden_channels, hidden_channels),
+ PatchUpscale(hidden_channels, level1_factor, alpha = 1),
+ PatchUpscale(level1_factor, 1, alpha = 1)
+ )
+
+ self.to_object_level0 = ResidualBlock(hidden_channels, hidden_channels)
+ self.to_object_level1 = PatchUpscale(hidden_channels, img_channels)
+
+ self.to_object_level2 = nn.Sequential(
+ ResidualBlock(hidden_channels, hidden_channels),
+ ResidualBlock(hidden_channels, hidden_channels),
+ PatchUpscale(hidden_channels, level1_channels, alpha = 1),
+ PatchUpscale(level1_channels, img_channels, alpha = 1)
+ )
+
+ self.mask_alpha = nn.Parameter(th.zeros(1)+1e-16)
+ self.object_alpha = nn.Parameter(th.zeros(1)+1e-16)
+
+
+ def set_level(self, level):
+ self.level = level
+
+ def forward(self, position, gestalt, priority = None):
+
+ maps = self.layer0(self.merge(position, gestalt, priority))
+ mask0 = self.to_mask_level0(maps)
+ object0 = self.to_object_level0(maps)
+
+ mask = self.to_mask_level1(mask0)
+ object = self.to_object_level1(object0)
+
+ if self.level > 1:
+ mask = repeat(mask, 'b c h w -> b c (h h2) (w w2)', h2 = 4, w2 = 4)
+ object = repeat(object, 'b c h w -> b c (h h2) (w w2)', h2 = 4, w2 = 4)
+
+ mask = mask + self.to_mask_level2(mask0) * self.mask_alpha
+ object = object + self.to_object_level2(object0) * self.object_alpha
+
+ return self.to_shared(mask), self.to_shared(object)
diff --git a/model/nn/encoder.py b/model/nn/encoder.py
new file mode 100644
index 0000000..8eb2e7f
--- /dev/null
+++ b/model/nn/encoder.py
@@ -0,0 +1,269 @@
+import torch.nn as nn
+import torch as th
+from model.utils.nn_utils import Gaus2D, BatchToSharedObjects, LambdaModule, ForcedAlpha, Binarize
+from model.nn.residual import ResidualBlock, SkipConnection
+from einops import rearrange, repeat, reduce
+from typing import Tuple, Union, List
+
+__author__ = "Manuel Traub"
+
+class NeighbourChannels(nn.Module):
+ def __init__(self, channels):
+ super(NeighbourChannels, self).__init__()
+
+ self.register_buffer("weights", th.ones(channels, channels, 1, 1), persistent=False)
+
+ for i in range(channels):
+ self.weights[i,i,0,0] = 0
+
+ def forward(self, input: th.Tensor):
+ return nn.functional.conv2d(input, self.weights)
+
+class InputPreprocessing(nn.Module):
+ def __init__(self, num_objects: int, size: Union[int, Tuple[int, int]]):
+ super(InputPreprocessing, self).__init__()
+ self.num_objects = num_objects
+ self.neighbours = NeighbourChannels(num_objects)
+ self.gaus2d = Gaus2D(size)
+ self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects))
+ self.to_shared = BatchToSharedObjects(num_objects)
+
+ def forward(
+ self,
+ input: th.Tensor,
+ error: th.Tensor,
+ mask: th.Tensor,
+ object: th.Tensor,
+ position: th.Tensor,
+ rawmask: th.Tensor
+ ):
+ bg_mask = repeat(mask[:,-1:], 'b 1 h w -> b c h w', c = self.num_objects)
+ mask = mask[:,:-1]
+ mask_others = self.neighbours(mask)
+ rawmask = rawmask[:,:-1]
+
+ own_gaus2d = self.to_shared(self.gaus2d(self.to_batch(position)))
+
+ input = repeat(input, 'b c h w -> b o c h w', o = self.num_objects)
+ error = repeat(error, 'b 1 h w -> b o 1 h w', o = self.num_objects)
+ bg_mask = rearrange(bg_mask, 'b o h w -> b o 1 h w')
+ mask_others = rearrange(mask_others, 'b o h w -> b o 1 h w')
+ mask = rearrange(mask, 'b o h w -> b o 1 h w')
+ object = rearrange(object, 'b (o c) h w -> b o c h w', o = self.num_objects)
+ own_gaus2d = rearrange(own_gaus2d, 'b o h w -> b o 1 h w')
+ rawmask = rearrange(rawmask, 'b o h w -> b o 1 h w')
+
+ output = th.cat((input, error, mask, mask_others, bg_mask, object, own_gaus2d, rawmask), dim=2)
+ output = rearrange(output, 'b o c h w -> (b o) c h w')
+
+ return output
+
+class PatchDownConv(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size = 4, alpha = 1):
+ super(PatchDownConv, self).__init__()
+ assert out_channels % in_channels == 0
+
+ self.layers = nn.Conv2d(
+ in_channels = in_channels,
+ out_channels = out_channels,
+ kernel_size = kernel_size,
+ stride = kernel_size,
+ )
+
+ self.alpha = nn.Parameter(th.zeros(1) + alpha)
+ self.kernel_size = 4
+ self.channels_factor = out_channels // in_channels
+
+ def forward(self, input: th.Tensor):
+ k = self.kernel_size
+ c = self.channels_factor
+ skip = reduce(input, 'b c (h h2) (w w2) -> b c h w', 'mean', h2=k, w2=k)
+ skip = repeat(skip, 'b c h w -> b (c n) h w', n=c)
+ return skip + self.alpha * self.layers(input)
+
+class AggressiveConvToGestalt(nn.Module):
+ def __init__(self, channels, gestalt_size, size: Union[int, Tuple[int, int]]):
+ super(AggressiveConvToGestalt, self).__init__()
+
+ assert gestalt_size % channels == 0 or channels % gestalt_size == 0
+
+ self.layers = nn.Sequential(
+ nn.Conv2d(
+ in_channels = channels,
+ out_channels = gestalt_size,
+ kernel_size = 5,
+ stride = 3,
+ padding = 3
+ ),
+ nn.ReLU(),
+ nn.Conv2d(
+ in_channels = gestalt_size,
+ out_channels = gestalt_size,
+ kernel_size = ((size[0] + 1)//3 + 1, (size[1] + 1)//3 + 1)
+ )
+ )
+ if gestalt_size > channels:
+ self.skip = nn.Sequential(
+ LambdaModule(lambda x: reduce(x, 'b c h w -> b c 1 1', 'mean')),
+ LambdaModule(lambda x: repeat(x, 'b c 1 1 -> b (c n) 1 1', n = gestalt_size // channels))
+ )
+ else:
+ self.skip = LambdaModule(lambda x: reduce(x, 'b (c n) h w -> b c 1 1', 'mean', n = channels // gestalt_size))
+
+
+ def forward(self, input: th.Tensor):
+ return self.skip(input) + self.layers(input)
+
+class PixelToPosition(nn.Module):
+ def __init__(self, size: Union[int, Tuple[int, int]]):
+ super(PixelToPosition, self).__init__()
+
+ self.register_buffer("grid_x", th.arange(size[0]), persistent=False)
+ self.register_buffer("grid_y", th.arange(size[1]), persistent=False)
+
+ self.grid_x = (self.grid_x / (size[0]-1)) * 2 - 1
+ self.grid_y = (self.grid_y / (size[1]-1)) * 2 - 1
+
+ self.grid_x = self.grid_x.view(1, 1, -1, 1).expand(1, 1, *size).clone()
+ self.grid_y = self.grid_y.view(1, 1, 1, -1).expand(1, 1, *size).clone()
+
+ self.size = size
+
+ def forward(self, input: th.Tensor):
+ assert input.shape[1] == 1
+
+ input = rearrange(input, 'b c h w -> b c (h w)')
+ input = th.softmax(input, dim=2)
+ input = rearrange(input, 'b c (h w) -> b c h w', h = self.size[0], w = self.size[1])
+
+ x = th.sum(input * self.grid_x, dim=(2,3))
+ y = th.sum(input * self.grid_y, dim=(2,3))
+
+ return th.cat((x,y),dim=1)
+
+class PixelToSTD(nn.Module):
+ def __init__(self):
+ super(PixelToSTD, self).__init__()
+ self.alpha = ForcedAlpha()
+
+ def forward(self, input: th.Tensor):
+ assert input.shape[1] == 1
+ return self.alpha(reduce(th.sigmoid(input - 10), 'b c h w -> b c', 'mean'))
+
+class PixelToPriority(nn.Module):
+ def __init__(self):
+ super(PixelToPriority, self).__init__()
+
+ def forward(self, input: th.Tensor):
+ assert input.shape[1] == 1
+ return reduce(th.tanh(input), 'b c h w -> b c', 'mean')
+
+class LociEncoder(nn.Module):
+ def __init__(
+ self,
+ input_size: Union[int, Tuple[int, int]],
+ latent_size: Union[int, Tuple[int, int]],
+ num_objects: int,
+ img_channels: int,
+ hidden_channels: int,
+ level1_channels: int,
+ num_layers: int,
+ gestalt_size: int,
+ bottleneck: str
+ ):
+ super(LociEncoder, self).__init__()
+
+ self.num_objects = num_objects
+ self.latent_size = latent_size
+ self.level = 1
+
+ self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o = self.num_objects))
+
+ print(f"Level1 channels: {level1_channels}")
+
+ self.preprocess = nn.ModuleList([
+ InputPreprocessing(num_objects, (input_size[0] // 16, input_size[1] // 16)),
+ InputPreprocessing(num_objects, (input_size[0] // 4, input_size[1] // 4)),
+ InputPreprocessing(num_objects, (input_size[0], input_size[1]))
+ ])
+
+ self.to_channels = nn.ModuleList([
+ SkipConnection(img_channels, hidden_channels),
+ SkipConnection(img_channels, level1_channels),
+ SkipConnection(img_channels, img_channels)
+ ])
+
+ self.layers2 = nn.Sequential(
+ PatchDownConv(img_channels, level1_channels, alpha = 1e-16),
+ *[ResidualBlock(level1_channels, level1_channels, alpha_residual=True) for _ in range(num_layers)]
+ )
+
+ self.layers1 = PatchDownConv(level1_channels, hidden_channels)
+
+ self.layers0 = nn.Sequential(
+ *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)]
+ )
+
+ self.position_encoder = nn.Sequential(
+ *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)],
+ ResidualBlock(hidden_channels, 3),
+ )
+
+ self.xy_encoder = PixelToPosition(latent_size)
+ self.std_encoder = PixelToSTD()
+ self.priority_encoder = PixelToPriority()
+
+ if bottleneck == "binar":
+ print("Binary bottleneck")
+ self.gestalt_encoder = nn.Sequential(
+ *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)],
+ AggressiveConvToGestalt(hidden_channels, gestalt_size, latent_size),
+ LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')),
+ Binarize(),
+ )
+
+ else:
+ print("unrestricted bottleneck")
+ self.gestalt_encoder = nn.Sequential(
+ *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)],
+ AggressiveConvToGestalt(hidden_channels, gestalt_size, latent_size),
+ LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')),
+ nn.Sigmoid(),
+ )
+
+ def set_level(self, level):
+ self.level = level
+
+ def forward(
+ self,
+ input: th.Tensor,
+ error: th.Tensor,
+ mask: th.Tensor,
+ object: th.Tensor,
+ position: th.Tensor,
+ rawmask: th.Tensor
+ ):
+
+ latent = self.preprocess[self.level](input, error, mask, object, position, rawmask)
+ latent = self.to_channels[self.level](latent)
+
+ if self.level >= 2:
+ latent = self.layers2(latent)
+
+ if self.level >= 1:
+ latent = self.layers1(latent)
+
+ latent = self.layers0(latent)
+ gestalt = self.gestalt_encoder(latent)
+
+ latent = self.position_encoder(latent)
+ std = self.std_encoder(latent[:,0:1])
+ xy = self.xy_encoder(latent[:,1:2])
+ priority = self.priority_encoder(latent[:,2:3])
+
+ position = self.to_shared(th.cat((xy, std), dim=1))
+ gestalt = self.to_shared(gestalt)
+ priority = self.to_shared(priority)
+
+ return position, gestalt, priority
+
diff --git a/model/nn/eprop_gate_l0rd.py b/model/nn/eprop_gate_l0rd.py
new file mode 100644
index 0000000..82a9895
--- /dev/null
+++ b/model/nn/eprop_gate_l0rd.py
@@ -0,0 +1,329 @@
+import torch.nn as nn
+import torch as th
+import numpy as np
+from torch.autograd import Function
+from einops import rearrange, repeat, reduce
+
+__author__ = "Manuel Traub"
+
+class EpropGateL0rdFunction(Function):
+ @staticmethod
+ def forward(ctx, x, h_last, w_gx, w_gh, b_g, w_rx, w_rh, b_r, args):
+
+ e_w_gx, e_w_gh, e_b_g, e_w_rx, e_w_rh, e_b_r, reg, noise_level = args
+
+ noise = th.normal(mean=0, std=noise_level, size=b_g.shape, device=b_g.device)
+ g = th.relu(th.tanh(x.mm(w_gx.t()) + h_last.mm(w_gh.t()) + b_g + noise))
+ r = th.tanh(x.mm(w_rx.t()) + h_last.mm(w_rh.t()) + b_r)
+
+ h = g * r + (1 - g) * h_last
+
+ # Haevisite step function
+ H_g = th.ceil(g).clamp(0, 1)
+
+ dg = (1 - g**2) * H_g
+ dr = (1 - r**2)
+
+ delta_h = r - h_last
+
+ g_j = g.unsqueeze(dim=2)
+ dg_j = dg.unsqueeze(dim=2)
+ dr_j = dr.unsqueeze(dim=2)
+
+ x_i = x.unsqueeze(dim=1)
+ h_last_i = h_last.unsqueeze(dim=1)
+ delta_h_j = delta_h.unsqueeze(dim=2)
+
+ e_w_gh.copy_(e_w_gh * (1 - g_j) + dg_j * h_last_i * delta_h_j)
+ e_w_gx.copy_(e_w_gx * (1 - g_j) + dg_j * x_i * delta_h_j)
+ e_b_g.copy_( e_b_g * (1 - g) + dg * delta_h )
+
+ e_w_rh.copy_(e_w_rh * (1 - g_j) + dr_j * h_last_i * g_j)
+ e_w_rx.copy_(e_w_rx * (1 - g_j) + dr_j * x_i * g_j)
+ e_b_r.copy_( e_b_r * (1 - g) + dr * g )
+
+ ctx.save_for_backward(
+ g.clone(), dg.clone(), dg_j.clone(), dr.clone(), x_i.clone(), h_last_i.clone(),
+ reg.clone(), H_g.clone(), delta_h.clone(), w_gx.clone(), w_gh.clone(), w_rx.clone(), w_rh.clone(),
+ e_w_gx.clone(), e_w_gh.clone(), e_b_g.clone(),
+ e_w_rx.clone(), e_w_rh.clone(), e_b_r.clone(),
+ )
+
+ return h, th.mean(H_g)
+
+ @staticmethod
+ def backward(ctx, dh, _):
+
+ g, dg, dg_j, dr, x_i, h_last_i, reg, H_g, delta_h, w_gx, w_gh, w_rx, w_rh, \
+ e_w_gx, e_w_gh, e_b_g, e_w_rx, e_w_rh, e_b_r = ctx.saved_tensors
+
+ dh_j = dh.unsqueeze(dim=2)
+ H_g_reg = reg * H_g
+ H_g_reg_j = H_g_reg.unsqueeze(dim=2)
+
+ dw_gx = th.sum(dh_j * e_w_gx + H_g_reg_j * dg_j * x_i, dim=0)
+ dw_gh = th.sum(dh_j * e_w_gh + H_g_reg_j * dg_j * h_last_i, dim=0)
+ db_g = th.sum(dh * e_b_g + H_g_reg * dg, dim=0)
+
+ dw_rx = th.sum(dh_j * e_w_rx, dim=0)
+ dw_rh = th.sum(dh_j * e_w_rh, dim=0)
+ db_r = th.sum(dh * e_b_r , dim=0)
+
+ dh_dg = (dh * delta_h + H_g_reg) * dg
+ dh_dr = dh * g * dr
+
+ dx = dh_dg.mm(w_gx) + dh_dr.mm(w_rx)
+ dh = dh * (1 - g) + dh_dg.mm(w_gh) + dh_dr.mm(w_rh)
+
+ return dx, dh, dw_gx, dw_gh, db_g, dw_rx, dw_rh, db_r, None
+
+class ReTanhFunction(Function):
+ @staticmethod
+ def forward(ctx, x, reg):
+
+ g = th.relu(th.tanh(x))
+
+ # Haevisite step function
+ H_g = th.ceil(g).clamp(0, 1)
+
+ dg = (1 - g**2) * H_g
+
+ ctx.save_for_backward(g, dg, H_g, reg)
+ return g, th.mean(H_g)
+
+ @staticmethod
+ def backward(ctx, dh, _):
+
+ g, dg, H_g, reg = ctx.saved_tensors
+
+ dx = (dh + reg * H_g) * dg
+
+ return dx, None
+
+class ReTanh(nn.Module):
+ def __init__(self, reg_lambda):
+ super(ReTanh, self).__init__()
+
+ self.re_tanh = ReTanhFunction().apply
+ self.register_buffer("reg_lambda", th.tensor(reg_lambda), persistent=False)
+
+ def forward(self, input):
+ h, openings = self.re_tanh(input, self.reg_lambda)
+ self.openings = openings.item()
+
+ return h
+
+
+class EpropGateL0rd(nn.Module):
+ def __init__(
+ self,
+ num_inputs,
+ num_hidden,
+ num_outputs,
+ batch_size,
+ reg_lambda = 0,
+ gate_noise_level = 0,
+ ):
+ super(EpropGateL0rd, self).__init__()
+
+ self.register_buffer("reg", th.tensor(reg_lambda).view(1,1), persistent=False)
+ self.register_buffer("noise", th.tensor(gate_noise_level), persistent=False)
+ self.num_inputs = num_inputs
+ self.num_hidden = num_hidden
+ self.num_outputs = num_outputs
+
+ self.fcn = EpropGateL0rdFunction().apply
+ self.retanh = ReTanh(reg_lambda)
+
+ # gate weights and biases
+ self.w_gx = nn.Parameter(th.empty(num_hidden, num_inputs))
+ self.w_gh = nn.Parameter(th.empty(num_hidden, num_hidden))
+ self.b_g = nn.Parameter(th.zeros(num_hidden))
+
+ # candidate weights and biases
+ self.w_rx = nn.Parameter(th.empty(num_hidden, num_inputs))
+ self.w_rh = nn.Parameter(th.empty(num_hidden, num_hidden))
+ self.b_r = nn.Parameter(th.zeros(num_hidden))
+
+ # output projection weights and bias
+ self.w_px = nn.Parameter(th.empty(num_outputs, num_inputs))
+ self.w_ph = nn.Parameter(th.empty(num_outputs, num_hidden))
+ self.b_p = nn.Parameter(th.zeros(num_outputs))
+
+ # output gate weights and bias
+ self.w_ox = nn.Parameter(th.empty(num_outputs, num_inputs))
+ self.w_oh = nn.Parameter(th.empty(num_outputs, num_hidden))
+ self.b_o = nn.Parameter(th.zeros(num_outputs))
+
+ # input gate eligibilitiy traces
+ self.register_buffer("e_w_gx", th.zeros(batch_size, num_hidden, num_inputs), persistent=False)
+ self.register_buffer("e_w_gh", th.zeros(batch_size, num_hidden, num_hidden), persistent=False)
+ self.register_buffer("e_b_g", th.zeros(batch_size, num_hidden), persistent=False)
+
+ # forget gate eligibilitiy traces
+ self.register_buffer("e_w_rx", th.zeros(batch_size, num_hidden, num_inputs), persistent=False)
+ self.register_buffer("e_w_rh", th.zeros(batch_size, num_hidden, num_hidden), persistent=False)
+ self.register_buffer("e_b_r", th.zeros(batch_size, num_hidden), persistent=False)
+
+ # hidden state
+ self.register_buffer("h_last", th.zeros(batch_size, num_hidden), persistent=False)
+
+ self.register_buffer("openings", th.zeros(1), persistent=False)
+
+ # initialize weights
+ stdv_ih = np.sqrt(6/(self.num_inputs + self.num_hidden))
+ stdv_hh = np.sqrt(3/self.num_hidden)
+ stdv_io = np.sqrt(6/(self.num_inputs + self.num_outputs))
+ stdv_ho = np.sqrt(6/(self.num_hidden + self.num_outputs))
+
+ nn.init.uniform_(self.w_gx, -stdv_ih, stdv_ih)
+ nn.init.uniform_(self.w_gh, -stdv_hh, stdv_hh)
+
+ nn.init.uniform_(self.w_rx, -stdv_ih, stdv_ih)
+ nn.init.uniform_(self.w_rh, -stdv_hh, stdv_hh)
+
+ nn.init.uniform_(self.w_px, -stdv_io, stdv_io)
+ nn.init.uniform_(self.w_ph, -stdv_ho, stdv_ho)
+
+ nn.init.uniform_(self.w_ox, -stdv_io, stdv_io)
+ nn.init.uniform_(self.w_oh, -stdv_ho, stdv_ho)
+
+ self.backprop = False
+
+ def reset_state(self):
+ self.h_last.zero_()
+ self.e_w_gx.zero_()
+ self.e_w_gh.zero_()
+ self.e_b_g.zero_()
+ self.e_w_rx.zero_()
+ self.e_w_rh.zero_()
+ self.e_b_r.zero_()
+ self.openings.zero_()
+
+ def backprop_forward(self, x: th.Tensor):
+
+ noise = th.normal(mean=0, std=self.noise, size=self.b_g.shape, device=self.b_g.device)
+ g = self.retanh(x.mm(self.w_gx.t()) + self.h_last.mm(self.w_gh.t()) + self.b_g + noise)
+ r = th.tanh(x.mm(self.w_rx.t()) + self.h_last.mm(self.w_rh.t()) + self.b_r)
+
+ self.h_last = g * r + (1 - g) * self.h_last
+
+ # Haevisite step function
+ H_g = th.ceil(g).clamp(0, 1)
+
+ self.openings = th.mean(H_g)
+
+ p = th.tanh(x.mm(self.w_px.t()) + self.h_last.mm(self.w_ph.t()) + self.b_p)
+ o = th.sigmoid(x.mm(self.w_ox.t()) + self.h_last.mm(self.w_oh.t()) + self.b_o)
+ return o * p
+
+ def activate_backprop(self):
+ self.backprop = True
+
+ def deactivate_backprop(self):
+ self.backprop = False
+
+ def detach(self):
+ self.h_last.detach_()
+
+ def eprop_forward(self, x: th.Tensor):
+ h, openings = self.fcn(
+ x, self.h_last,
+ self.w_gx, self.w_gh, self.b_g,
+ self.w_rx, self.w_rh, self.b_r,
+ (
+ self.e_w_gx, self.e_w_gh, self.e_b_g,
+ self.e_w_rx, self.e_w_rh, self.e_b_r,
+ self.reg, self.noise
+ )
+ )
+
+ self.openings = openings
+ self.h_last = h
+
+ p = th.tanh(x.mm(self.w_px.t()) + h.mm(self.w_ph.t()) + self.b_p)
+ o = th.sigmoid(x.mm(self.w_ox.t()) + h.mm(self.w_oh.t()) + self.b_o)
+ return o * p
+
+ def save_hidden(self):
+ self.h_last_saved = self.h_last.detach()
+
+ def restore_hidden(self):
+ self.h_last = self.h_last_saved
+
+ def get_hidden(self):
+ return self.h_last
+
+ def set_hidden(self, h_last):
+ self.h_last = h_last
+
+ def forward(self, x: th.Tensor):
+ if self.backprop:
+ return self.backprop_forward(x)
+
+ return self.eprop_forward(x)
+
+
+class EpropGateL0rdShared(EpropGateL0rd):
+ def __init__(
+ self,
+ num_inputs,
+ num_hidden,
+ num_outputs,
+ batch_size,
+ reg_lambda = 0,
+ gate_noise_level = 0,
+ ):
+ super().__init__(num_inputs, num_hidden, num_outputs, batch_size, reg_lambda, gate_noise_level)
+
+ def backprop_forward(self, x: th.Tensor, h_last: th.Tensor):
+
+ noise = th.normal(mean=0, std=self.noise, size=self.b_g.shape, device=self.b_g.device)
+ g = self.retanh(x.mm(self.w_gx.t()) + h_last.mm(self.w_gh.t()) + self.b_g + noise)
+ r = th.tanh(x.mm(self.w_rx.t()) + h_last.mm(self.w_rh.t()) + self.b_r)
+
+ h_last = g * r + (1 - g) * h_last
+
+ # Haevisite step function
+ H_g = th.ceil(g).clamp(0, 1)
+
+ self.openings = th.mean(H_g)
+
+ p = th.tanh(x.mm(self.w_px.t()) + h_last.mm(self.w_ph.t()) + self.b_p)
+ o = th.sigmoid(x.mm(self.w_ox.t()) + h_last.mm(self.w_oh.t()) + self.b_o)
+ return o * p, h_last
+
+ def eprop_forward(self, x: th.Tensor, h_last: th.Tensor):
+ h, openings = self.fcn(
+ x, h_last,
+ self.w_gx, self.w_gh, self.b_g,
+ self.w_rx, self.w_rh, self.b_r,
+ (
+ self.e_w_gx, self.e_w_gh, self.e_b_g,
+ self.e_w_rx, self.e_w_rh, self.e_b_r,
+ self.reg, self.noise
+ )
+ )
+
+ self.openings = openings
+
+ p = th.tanh(x.mm(self.w_px.t()) + h.mm(self.w_ph.t()) + self.b_p)
+ o = th.sigmoid(x.mm(self.w_ox.t()) + h.mm(self.w_oh.t()) + self.b_o)
+ return o * p, h
+
+ def forward(self, x: th.Tensor, h_last: th.Tensor = None):
+
+ if h_last is not None:
+ if self.backprop:
+ return self.backprop_forward(x, h_last)
+
+ return self.eprop_forward(x, h_last)
+
+ # backward compatibility
+ if self.backprop:
+ x, h = self.backprop_forward(x, self.h_last)
+
+ x, h = self.eprop_forward(x, self.h_last)
+
+ self.h_last = h
+ return x \ No newline at end of file
diff --git a/model/nn/eprop_transformer.py b/model/nn/eprop_transformer.py
new file mode 100644
index 0000000..4d89ec4
--- /dev/null
+++ b/model/nn/eprop_transformer.py
@@ -0,0 +1,76 @@
+import torch.nn as nn
+import torch as th
+from model.nn.eprop_gate_l0rd import EpropGateL0rd
+from model.nn.eprop_transformer_utils import AlphaAttention, InputEmbeding, OutputEmbeding
+
+class EpropGateL0rdTransformer(nn.Module):
+ def __init__(
+ self,
+ channels,
+ multiplier,
+ num_objects,
+ batch_size,
+ heads,
+ depth,
+ reg_lambda,
+ dropout=0.0
+ ):
+ super(EpropGateL0rdTransformer, self).__init__()
+
+ num_inputs = channels
+ num_outputs = channels
+ num_hidden = channels
+ num_hidden = channels * multiplier
+
+ print(f"Predictor channels: {num_hidden}@({num_hidden // heads}x{heads})")
+
+
+ self.depth = depth
+ _layers = []
+ _layers.append(InputEmbeding(num_inputs, num_hidden))
+
+ for i in range(depth):
+ _layers.append(AlphaAttention(num_hidden, num_objects, heads, dropout))
+ _layers.append(EpropAlphaGateL0rd(num_hidden, batch_size * num_objects, reg_lambda))
+
+ _layers.append(OutputEmbeding(num_hidden, num_outputs))
+ self.layers = nn.Sequential(*_layers)
+
+ def get_openings(self):
+ openings = 0
+ for i in range(self.depth):
+ openings += self.layers[2 * (i + 1)].l0rd.openings.item()
+
+ return openings / self.depth
+
+ def get_hidden(self):
+ states = []
+ for i in range(self.depth):
+ states.append(self.layers[2 * (i + 1)].l0rd.get_hidden())
+
+ return th.cat(states, dim=1)
+
+ def set_hidden(self, hidden):
+ states = th.chunk(hidden, self.depth, dim=1)
+ for i in range(self.depth):
+ self.layers[2 * (i + 1)].l0rd.set_hidden(states[i])
+
+ def forward(self, input: th.Tensor) -> th.Tensor:
+ return self.layers(input)
+
+
+class EpropAlphaGateL0rd(nn.Module):
+ def __init__(self, num_hidden, batch_size, reg_lambda):
+ super(EpropAlphaGateL0rd, self).__init__()
+
+ self.alpha = nn.Parameter(th.zeros(1)+1e-12)
+ self.l0rd = EpropGateL0rd(
+ num_inputs = num_hidden,
+ num_hidden = num_hidden,
+ num_outputs = num_hidden,
+ reg_lambda = reg_lambda,
+ batch_size = batch_size
+ )
+
+ def forward(self, input):
+ return input + self.alpha * self.l0rd(input) \ No newline at end of file
diff --git a/model/nn/eprop_transformer_shared.py b/model/nn/eprop_transformer_shared.py
new file mode 100644
index 0000000..23a1b0f
--- /dev/null
+++ b/model/nn/eprop_transformer_shared.py
@@ -0,0 +1,92 @@
+import torch.nn as nn
+import torch as th
+from model.nn.eprop_gate_l0rd import EpropGateL0rdShared
+from model.nn.eprop_transformer_utils import AlphaAttention, InputEmbeding, OutputEmbeding
+
+class EpropGateL0rdTransformerShared(nn.Module):
+ def __init__(
+ self,
+ channels,
+ multiplier,
+ num_objects,
+ batch_size,
+ heads,
+ depth,
+ reg_lambda,
+ dropout=0.0,
+ exchange_length = 48,
+ ):
+ super(EpropGateL0rdTransformerShared, self).__init__()
+
+ num_inputs = channels
+ num_outputs = channels
+ num_hidden = channels * multiplier
+ num_hidden_gatelord = num_hidden + exchange_length
+ num_hidden_attention = num_hidden + exchange_length + num_hidden_gatelord
+
+ self.num_hidden = num_hidden
+ self.num_hidden_gatelord = num_hidden_gatelord
+
+ #print(f"Predictor channels: {num_hidden}@({num_hidden // heads}x{heads})")
+
+ self.register_buffer('hidden', th.zeros(batch_size * num_objects, num_hidden_gatelord), persistent=False)
+ self.register_buffer('exchange_code', th.zeros(batch_size * num_objects, exchange_length), persistent=False)
+
+ self.depth = depth
+ self.input_embeding = InputEmbeding(num_inputs, num_hidden)
+ self.attention = nn.Sequential(*[AlphaAttention(num_hidden_attention, num_objects, heads, dropout) for _ in range(depth)])
+ self.l0rds = nn.Sequential(*[EpropAlphaGateL0rdShared(num_hidden_gatelord, batch_size * num_objects, reg_lambda) for _ in range(depth)])
+ self.output_embeding = OutputEmbeding(num_hidden, num_outputs)
+
+ def get_openings(self):
+ openings = 0
+ for i in range(self.depth):
+ openings += self.l0rds[i].l0rd.openings.item()
+
+ return openings / self.depth
+
+ def get_hidden(self):
+ return self.hidden
+
+ def set_hidden(self, hidden):
+ self.hidden = hidden
+
+ def detach(self):
+ self.hidden = self.hidden.detach()
+
+ def reset_state(self):
+ self.hidden = th.zeros_like(self.hidden)
+
+ def forward(self, x: th.Tensor) -> th.Tensor:
+ x = self.input_embeding(x)
+ exchange_code = self.exchange_code.clone() * 0.0
+ x_ex = th.concat((x, exchange_code), dim=1)
+
+ for i in range(self.depth):
+ # attention layer
+ att = self.attention(th.concat((x_ex, self.hidden), dim=1))
+ x_ex = att[:, :self.num_hidden_gatelord]
+
+ # gatelord layer
+ x_ex, self.hidden = self.l0rds[i](x_ex, self.hidden)
+
+ # only yield x
+ x = x_ex[:, :self.num_hidden]
+ return self.output_embeding(x)
+
+class EpropAlphaGateL0rdShared(nn.Module):
+ def __init__(self, num_hidden, batch_size, reg_lambda):
+ super(EpropAlphaGateL0rdShared, self).__init__()
+
+ self.alpha = nn.Parameter(th.zeros(1)+1e-12)
+ self.l0rd = EpropGateL0rdShared(
+ num_inputs = num_hidden,
+ num_hidden = num_hidden,
+ num_outputs = num_hidden,
+ reg_lambda = reg_lambda,
+ batch_size = batch_size
+ )
+
+ def forward(self, input, hidden):
+ output, hidden = self.l0rd(input, hidden)
+ return input + self.alpha * output, hidden \ No newline at end of file
diff --git a/model/nn/eprop_transformer_utils.py b/model/nn/eprop_transformer_utils.py
new file mode 100644
index 0000000..9e5e874
--- /dev/null
+++ b/model/nn/eprop_transformer_utils.py
@@ -0,0 +1,66 @@
+import torch.nn as nn
+import torch as th
+from model.utils.nn_utils import LambdaModule
+from einops import rearrange, repeat, reduce
+
+class AlphaAttention(nn.Module):
+ def __init__(
+ self,
+ num_hidden,
+ num_objects,
+ heads,
+ dropout = 0.0
+ ):
+ super(AlphaAttention, self).__init__()
+
+ self.to_sequence = LambdaModule(lambda x: rearrange(x, '(b o) c -> b o c', o = num_objects))
+ self.to_batch = LambdaModule(lambda x: rearrange(x, 'b o c -> (b o) c', o = num_objects))
+
+ self.alpha = nn.Parameter(th.zeros(1)+1e-12)
+ self.attention = nn.MultiheadAttention(
+ num_hidden,
+ heads,
+ dropout = dropout,
+ batch_first = True
+ )
+
+ def forward(self, x: th.Tensor):
+ x = self.to_sequence(x)
+ x = x + self.alpha * self.attention(x, x, x, need_weights=False)[0]
+ return self.to_batch(x)
+
+class InputEmbeding(nn.Module):
+ def __init__(self, num_inputs, num_hidden):
+ super(InputEmbeding, self).__init__()
+
+ self.embeding = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(num_inputs, num_hidden),
+ nn.ReLU(),
+ nn.Linear(num_hidden, num_hidden),
+ )
+ self.skip = LambdaModule(
+ lambda x: repeat(x, 'b c -> b (n c)', n = num_hidden // num_inputs)
+ )
+ self.alpha = nn.Parameter(th.zeros(1)+1e-12)
+
+ def forward(self, input: th.Tensor):
+ return self.skip(input) + self.alpha * self.embeding(input)
+
+class OutputEmbeding(nn.Module):
+ def __init__(self, num_hidden, num_outputs):
+ super(OutputEmbeding, self).__init__()
+
+ self.embeding = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(num_hidden, num_outputs),
+ nn.ReLU(),
+ nn.Linear(num_outputs, num_outputs),
+ )
+ self.skip = LambdaModule(
+ lambda x: reduce(x, 'b (n c) -> b c', 'mean', n = num_hidden // num_outputs)
+ )
+ self.alpha = nn.Parameter(th.zeros(1)+1e-12)
+
+ def forward(self, input: th.Tensor):
+ return self.skip(input) + self.alpha * self.embeding(input) \ No newline at end of file
diff --git a/model/nn/percept_gate_controller.py b/model/nn/percept_gate_controller.py
new file mode 100644
index 0000000..a548c20
--- /dev/null
+++ b/model/nn/percept_gate_controller.py
@@ -0,0 +1,59 @@
+import torch.nn as nn
+import torch as th
+from model.utils.nn_utils import LambdaModule
+from einops import rearrange, repeat, reduce
+from model.nn.eprop_gate_l0rd import ReTanh
+
+class PerceptGateController(nn.Module):
+ def __init__(
+ self,
+ num_inputs: int,
+ num_hidden: list,
+ bias: bool,
+ num_objects: int,
+ gate_noise_level: float = 0.1,
+ reg_lambda: float = 0.000005
+ ):
+ super(PerceptGateController, self).__init__()
+
+ self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o=num_objects))
+ self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b o c', o=num_objects))
+
+ self.layers = nn.Sequential(
+ nn.Linear(num_inputs, num_hidden[0], bias = bias),
+ nn.Tanh(),
+ nn.Linear(num_hidden[0], num_hidden[1], bias = bias),
+ nn.Tanh(),
+ nn.Linear(num_hidden[1], 2, bias = bias)
+ )
+ self.output_function = ReTanh(reg_lambda)
+ self.register_buffer("noise", th.tensor(gate_noise_level), persistent=False)
+ self.init_weights()
+
+ def init_weights(self):
+ for layer in self.layers:
+ if isinstance(layer, nn.Linear):
+ nn.init.xavier_uniform(layer.weight)
+ layer.bias.data.fill_(3.00)
+
+ def forward(self, position_cur, gestalt_cur, priority_cur, slots_occlusionfactor_cur, position_last, gestalt_last, priority_last, slots_occlusionfactor_last, position_last2, evaluate=False):
+
+ position_cur = self.to_batch(position_cur)
+ gestalt_cur = self.to_batch(gestalt_cur)
+ priority_cur = self.to_batch(priority_cur)
+ position_last = self.to_batch(position_last)
+ gestalt_last = self.to_batch(gestalt_last)
+ priority_last = self.to_batch(priority_last)
+ slots_occlusionfactor_cur = self.to_batch(slots_occlusionfactor_cur).detach()
+ slots_occlusionfactor_last = self.to_batch(slots_occlusionfactor_last).detach()
+ position_last2 = self.to_batch(position_last2).detach()
+
+ input = th.cat((position_cur, gestalt_cur, priority_cur, slots_occlusionfactor_cur, position_last, gestalt_last, priority_last, slots_occlusionfactor_last, position_last2), dim=1)
+ output = self.layers(input)
+ if evaluate:
+ output = self.output_function(output)
+ else:
+ noise = th.normal(mean=0, std=self.noise, size=output.shape, device=output.device)
+ output = self.output_function(output + noise)
+
+ return self.to_shared(output)
diff --git a/model/nn/predictor.py b/model/nn/predictor.py
new file mode 100644
index 0000000..94f13b8
--- /dev/null
+++ b/model/nn/predictor.py
@@ -0,0 +1,99 @@
+import torch.nn as nn
+import torch as th
+from model.nn.eprop_transformer import EpropGateL0rdTransformer
+from model.nn.eprop_transformer_shared import EpropGateL0rdTransformerShared
+from model.utils.nn_utils import LambdaModule, Binarize
+from model.nn.residual import ResidualBlock
+from einops import rearrange, repeat, reduce
+
+__author__ = "Manuel Traub"
+
+class LociPredictor(nn.Module):
+ def __init__(
+ self,
+ heads: int,
+ layers: int,
+ channels_multiplier: int,
+ reg_lambda: float,
+ num_objects: int,
+ gestalt_size: int,
+ batch_size: int,
+ bottleneck: str,
+ transformer_type = 'standard',
+ ):
+ super(LociPredictor, self).__init__()
+ self.num_objects = num_objects
+ self.std_alpha = nn.Parameter(th.zeros(1)+1e-16)
+ self.bottleneck_type = bottleneck
+ self.gestalt_size = gestalt_size
+
+ self.reg_lambda = reg_lambda
+ Transformer = EpropGateL0rdTransformerShared if transformer_type == 'shared' else EpropGateL0rdTransformer
+ self.predictor = Transformer(
+ channels = gestalt_size + 3 + 1 + 2,
+ multiplier = channels_multiplier,
+ heads = heads,
+ depth = layers,
+ num_objects = num_objects,
+ reg_lambda = reg_lambda,
+ batch_size = batch_size,
+ )
+
+ if bottleneck == 'binar':
+ print("Binary bottleneck")
+ self.bottleneck = nn.Sequential(
+ LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')),
+ ResidualBlock(gestalt_size, gestalt_size, kernel_size=1),
+ Binarize(),
+ LambdaModule(lambda x: rearrange(x, '(b o) c 1 1 -> b (o c)', o=num_objects))
+ )
+
+ else:
+ print("unrestricted bottleneck")
+ self.bottleneck = nn.Sequential(
+ LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')),
+ ResidualBlock(gestalt_size, gestalt_size, kernel_size=1),
+ nn.Sigmoid(),
+ LambdaModule(lambda x: rearrange(x, '(b o) c 1 1 -> b (o c)', o=num_objects))
+ )
+
+ self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o=num_objects))
+ self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o=num_objects))
+
+ def get_openings(self):
+ return self.predictor.get_openings()
+
+ def get_hidden(self):
+ return self.predictor.get_hidden()
+
+ def set_hidden(self, hidden):
+ self.predictor.set_hidden(hidden)
+
+ def forward(
+ self,
+ gestalt: th.Tensor,
+ priority: th.Tensor,
+ position: th.Tensor,
+ slots_closed: th.Tensor,
+ ):
+
+ position = self.to_batch(position)
+ gestalt_cur = self.to_batch(gestalt)
+ priority = self.to_batch(priority)
+ slots_closed = rearrange(slots_closed, 'b o c -> (b o) c').detach()
+
+ input = th.cat((gestalt_cur, priority, position, slots_closed), dim=1)
+ output = self.predictor(input)
+
+ gestalt = output[:, :self.gestalt_size]
+ priority = output[:,self.gestalt_size:(self.gestalt_size+1)]
+ xy = output[:,(self.gestalt_size+1):(self.gestalt_size+3)]
+ std = output[:,(self.gestalt_size+3):(self.gestalt_size+4)]
+
+ position = th.cat((xy, std * self.std_alpha), dim=1)
+
+ position = self.to_shared(position)
+ gestalt = self.bottleneck(gestalt)
+ priority = self.to_shared(priority)
+
+ return position, gestalt, priority
diff --git a/model/nn/residual.py b/model/nn/residual.py
new file mode 100644
index 0000000..1602e16
--- /dev/null
+++ b/model/nn/residual.py
@@ -0,0 +1,396 @@
+import torch.nn as nn
+import torch as th
+import numpy as np
+from einops import rearrange, repeat, reduce
+from model.utils.nn_utils import LambdaModule
+
+from typing import Union, Tuple
+
+__author__ = "Manuel Traub"
+
+class DynamicLayerNorm(nn.Module):
+
+ def __init__(self, eps: float = 1e-5):
+ super(DynamicLayerNorm, self).__init__()
+ self.eps = eps
+
+ def forward(self, input: th.Tensor) -> th.Tensor:
+ return nn.functional.layer_norm(input, input.shape[2:], None, None, self.eps)
+
+
+class SkipConnection(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ scale_factor: float = 1.0
+ ):
+ super(SkipConnection, self).__init__()
+ assert scale_factor == 1 or int(scale_factor) > 1 or int(1 / scale_factor) > 1, f'invalid scale factor in SpikeFunction: {scale_factor}'
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.scale_factor = scale_factor
+
+ def channel_skip(self, input: th.Tensor):
+ in_channels = self.in_channels
+ out_channels = self.out_channels
+
+ if in_channels == out_channels:
+ return input
+
+ if in_channels % out_channels == 0 or out_channels % in_channels == 0:
+
+ if in_channels > out_channels:
+ return reduce(input, 'b (c n) h w -> b c h w', 'mean', n = in_channels // out_channels)
+
+ if out_channels > in_channels:
+ return repeat(input, 'b c h w -> b (c n) h w', n = out_channels // in_channels)
+
+ mean_channels = np.gcd(in_channels, out_channels)
+ input = reduce(input, 'b (c n) h w -> b c h w', 'mean', n = in_channels // mean_channels)
+ return repeat(input, 'b c h w -> b (c n) h w', n = out_channels // mean_channels)
+
+ def scale_skip(self, input: th.Tensor):
+ scale_factor = self.scale_factor
+
+ if scale_factor == 1:
+ return input
+
+ if scale_factor > 1:
+ return repeat(
+ input,
+ 'b c h w -> b c (h h2) (w w2)',
+ h2 = int(scale_factor),
+ w2 = int(scale_factor)
+ )
+
+ height = input.shape[2]
+ width = input.shape[3]
+
+ # scale factor < 1
+ scale_factor = int(1 / scale_factor)
+
+ if width % scale_factor == 0 and height % scale_factor == 0:
+ return reduce(
+ input,
+ 'b c (h h2) (w w2) -> b c h w',
+ 'mean',
+ h2 = scale_factor,
+ w2 = scale_factor
+ )
+
+ if width >= scale_factor and height >= scale_factor:
+ return nn.functional.avg_pool2d(
+ input,
+ kernel_size = scale_factor,
+ stride = scale_factor
+ )
+
+ assert width > 1 or height > 1
+ return reduce(input, 'b c h w -> b c 1 1', 'mean')
+
+
+ def forward(self, input: th.Tensor):
+
+ if self.scale_factor > 1:
+ return self.scale_skip(self.channel_skip(input))
+
+ return self.channel_skip(self.scale_skip(input))
+
+class DownScale(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ scale_factor: int,
+ groups: int = 1,
+ bias: bool = True
+ ):
+
+ super(DownScale, self).__init__()
+
+ assert(in_channels % groups == 0)
+ assert(out_channels % groups == 0)
+
+ self.groups = groups
+ self.scale_factor = scale_factor
+ self.weight = nn.Parameter(th.empty((out_channels, in_channels // groups, scale_factor, scale_factor)))
+ self.bias = nn.Parameter(th.empty((out_channels,))) if bias else None
+
+ nn.init.kaiming_uniform_(self.weight, a=np.sqrt(5))
+
+ if self.bias is not None:
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
+ bound = 1 / np.sqrt(fan_in)
+ nn.init.uniform_(self.bias, -bound, bound)
+
+ def forward(self, input: th.Tensor):
+ height = input.shape[2]
+ width = input.shape[3]
+ assert height > 1 or width > 1, "trying to dowscale 1x1"
+
+ scale_factor = self.scale_factor
+ padding = [0, 0]
+
+ if height < scale_factor:
+ padding[0] = scale_factor - height
+
+ if width < scale_factor:
+ padding[1] = scale_factor - width
+
+ return nn.functional.conv2d(
+ input,
+ self.weight,
+ bias=self.bias,
+ stride=scale_factor,
+ padding=padding,
+ groups=self.groups
+ )
+
+
+class ResidualBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]] = (3, 3),
+ scale_factor: int = 1,
+ groups: Union[int, Tuple[int, int]] = (1, 1),
+ bias: bool = True,
+ layer_norm: bool = False,
+ leaky_relu: bool = False,
+ residual: bool = True,
+ alpha_residual: bool = False,
+ input_nonlinearity = True
+ ):
+
+ super(ResidualBlock, self).__init__()
+ self.residual = residual
+ self.alpha_residual = alpha_residual
+ self.skip = False
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ if isinstance(kernel_size, int):
+ kernel_size = [kernel_size, kernel_size]
+
+ if isinstance(groups, int):
+ groups = [groups, groups]
+
+ padding = (kernel_size[0] // 2, kernel_size[1] // 2)
+
+ _layers = list()
+ if layer_norm:
+ _layers.append(DynamicLayerNorm())
+
+ if input_nonlinearity:
+ if leaky_relu:
+ _layers.append(nn.LeakyReLU())
+ else:
+ _layers.append(nn.ReLU())
+
+ if scale_factor > 1:
+ _layers.append(
+ nn.ConvTranspose2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=scale_factor,
+ stride=scale_factor,
+ groups=groups[0],
+ bias=bias
+ )
+ )
+ elif scale_factor < 1:
+ _layers.append(
+ DownScale(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ scale_factor=int(1.0/scale_factor),
+ groups=groups[0],
+ bias=bias
+ )
+ )
+ else:
+ _layers.append(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ padding=padding,
+ groups=groups[0],
+ bias=bias
+ )
+ )
+
+ if layer_norm:
+ _layers.append(DynamicLayerNorm())
+ if leaky_relu:
+ _layers.append(nn.LeakyReLU())
+ else:
+ _layers.append(nn.ReLU())
+ _layers.append(
+ nn.Conv2d(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ padding=padding,
+ groups=groups[1],
+ bias=bias
+ )
+ )
+ self.layers = nn.Sequential(*_layers)
+
+ if self.residual:
+ self.skip_connection = SkipConnection(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ scale_factor=scale_factor
+ )
+
+ if self.alpha_residual:
+ self.alpha = nn.Parameter(th.zeros(1) + 1e-12)
+
+ def set_mode(self, **kwargs):
+ if 'skip' in kwargs:
+ self.skip = kwargs['skip']
+
+ if 'residual' in kwargs:
+ self.residual = kwargs['residual']
+
+ def forward(self, input: th.Tensor) -> th.Tensor:
+ if self.skip:
+ return self.skip_connection(input)
+
+ if not self.residual:
+ return self.layers(input)
+
+ if self.alpha_residual:
+ return self.alpha * self.layers(input) + self.skip_connection(input)
+
+ return self.layers(input) + self.skip_connection(input)
+
+class LinearSkip(nn.Module):
+ def __init__(self, num_inputs: int, num_outputs: int):
+ super(LinearSkip, self).__init__()
+
+ self.num_inputs = num_inputs
+ self.num_outputs = num_outputs
+
+ if num_inputs % num_outputs != 0 and num_outputs % num_inputs != 0:
+ mean_channels = np.gcd(num_inputs, num_outputs)
+ print(f"[WW] gcd skip: {num_inputs} -> {mean_channels} -> {num_outputs}")
+ assert(False)
+
+ def forward(self, input: th.Tensor):
+ num_inputs = self.num_inputs
+ num_outputs = self.num_outputs
+
+ if num_inputs == num_outputs:
+ return input
+
+ if num_inputs % num_outputs == 0 or num_outputs % num_inputs == 0:
+
+ if num_inputs > num_outputs:
+ return reduce(input, 'b (c n) -> b c', 'mean', n = num_inputs // num_outputs)
+
+ if num_outputs > num_inputs:
+ return repeat(input, 'b c -> b (c n)', n = num_outputs // num_inputs)
+
+ mean_channels = np.gcd(num_inputs, num_outputs)
+ input = reduce(input, 'b (c n) -> b c', 'mean', n = num_inputs // mean_channels)
+ return repeat(input, 'b c -> b (c n)', n = num_outputs // mean_channels)
+
+class LinearResidual(nn.Module):
+ def __init__(
+ self,
+ num_inputs: int,
+ num_outputs: int,
+ num_hidden: int = None,
+ residual: bool = True,
+ alpha_residual: bool = False,
+ input_relu: bool = True
+ ):
+ super(LinearResidual, self).__init__()
+
+ self.residual = residual
+ self.alpha_residual = alpha_residual
+
+ if num_hidden is None:
+ num_hidden = num_outputs
+
+ _layers = []
+ if input_relu:
+ _layers.append(nn.ReLU())
+ _layers.append(nn.Linear(num_inputs, num_hidden))
+ _layers.append(nn.ReLU())
+ _layers.append(nn.Linear(num_hidden, num_outputs))
+
+ self.layers = nn.Sequential(*_layers)
+
+ if residual:
+ self.skip = LinearSkip(num_inputs, num_outputs)
+
+ if alpha_residual:
+ self.alpha = nn.Parameter(th.zeros(1)+1e-16)
+
+ def forward(self, input: th.Tensor):
+ if not self.residual:
+ return self.layers(input)
+
+ if not self.alpha_residual:
+ return self.skip(input) + self.layers(input)
+
+ return self.skip(input) + self.alpha * self.layers(input)
+
+
+class EntityAttention(nn.Module):
+ def __init__(self, channels, num_objects, size, channels_per_head = 12, dropout = 0.0):
+ super(EntityAttention, self).__init__()
+
+ assert channels % channels_per_head == 0
+ heads = channels // channels_per_head
+
+ self.alpha = nn.Parameter(th.zeros(1)+1e-16)
+ self.attention = nn.MultiheadAttention(
+ channels,
+ heads,
+ dropout = dropout,
+ batch_first = True
+ )
+
+ self.channel_attention = nn.Sequential(
+ LambdaModule(lambda x: rearrange(x, '(b o) c h w -> (b h w) o c', o = num_objects)),
+ LambdaModule(lambda x: self.attention(x, x, x, need_weights=False)[0]),
+ LambdaModule(lambda x: rearrange(x, '(b h w) o c-> (b o) c h w', h = size[0], w = size[1])),
+ )
+
+
+ def forward(self, input: th.Tensor) -> th.Tensor:
+ return input + self.channel_attention(input) * self.alpha
+
+class ImageAttention(nn.Module):
+ def __init__(self, channels, gestalt_size, num_objects, size, channels_per_head = 12, dropout = 0.0):
+ super(ImageAttention, self).__init__()
+
+ assert gestalt_size % channels_per_head == 0
+ heads = gestalt_size // channels_per_head
+
+ self.alpha = nn.Parameter(th.zeros(1)+1e-16)
+ self.attention = nn.MultiheadAttention(
+ gestalt_size,
+ heads,
+ dropout = dropout,
+ batch_first = True
+ )
+
+ self.image_attention = nn.Sequential(
+ nn.Conv2d(channels, gestalt_size, kernel_size=1),
+ LambdaModule(lambda x: rearrange(x, 'b c h w -> b (h w) c')),
+ LambdaModule(lambda x: self.attention(x, x, x, need_weights=False)[0]),
+ LambdaModule(lambda x: rearrange(x, 'b (h w) c -> b c h w', h = size[0], w = size[1])),
+ nn.Conv2d(gestalt_size, channels, kernel_size=1),
+ )
+
+ def forward(self, input: th.Tensor) -> th.Tensor:
+ return input + self.image_attention(input) * self.alpha \ No newline at end of file