aboutsummaryrefslogtreecommitdiff
path: root/model/nn/encoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'model/nn/encoder.py')
-rw-r--r--model/nn/encoder.py269
1 files changed, 269 insertions, 0 deletions
diff --git a/model/nn/encoder.py b/model/nn/encoder.py
new file mode 100644
index 0000000..8eb2e7f
--- /dev/null
+++ b/model/nn/encoder.py
@@ -0,0 +1,269 @@
+import torch.nn as nn
+import torch as th
+from model.utils.nn_utils import Gaus2D, BatchToSharedObjects, LambdaModule, ForcedAlpha, Binarize
+from model.nn.residual import ResidualBlock, SkipConnection
+from einops import rearrange, repeat, reduce
+from typing import Tuple, Union, List
+
+__author__ = "Manuel Traub"
+
+class NeighbourChannels(nn.Module):
+ def __init__(self, channels):
+ super(NeighbourChannels, self).__init__()
+
+ self.register_buffer("weights", th.ones(channels, channels, 1, 1), persistent=False)
+
+ for i in range(channels):
+ self.weights[i,i,0,0] = 0
+
+ def forward(self, input: th.Tensor):
+ return nn.functional.conv2d(input, self.weights)
+
+class InputPreprocessing(nn.Module):
+ def __init__(self, num_objects: int, size: Union[int, Tuple[int, int]]):
+ super(InputPreprocessing, self).__init__()
+ self.num_objects = num_objects
+ self.neighbours = NeighbourChannels(num_objects)
+ self.gaus2d = Gaus2D(size)
+ self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects))
+ self.to_shared = BatchToSharedObjects(num_objects)
+
+ def forward(
+ self,
+ input: th.Tensor,
+ error: th.Tensor,
+ mask: th.Tensor,
+ object: th.Tensor,
+ position: th.Tensor,
+ rawmask: th.Tensor
+ ):
+ bg_mask = repeat(mask[:,-1:], 'b 1 h w -> b c h w', c = self.num_objects)
+ mask = mask[:,:-1]
+ mask_others = self.neighbours(mask)
+ rawmask = rawmask[:,:-1]
+
+ own_gaus2d = self.to_shared(self.gaus2d(self.to_batch(position)))
+
+ input = repeat(input, 'b c h w -> b o c h w', o = self.num_objects)
+ error = repeat(error, 'b 1 h w -> b o 1 h w', o = self.num_objects)
+ bg_mask = rearrange(bg_mask, 'b o h w -> b o 1 h w')
+ mask_others = rearrange(mask_others, 'b o h w -> b o 1 h w')
+ mask = rearrange(mask, 'b o h w -> b o 1 h w')
+ object = rearrange(object, 'b (o c) h w -> b o c h w', o = self.num_objects)
+ own_gaus2d = rearrange(own_gaus2d, 'b o h w -> b o 1 h w')
+ rawmask = rearrange(rawmask, 'b o h w -> b o 1 h w')
+
+ output = th.cat((input, error, mask, mask_others, bg_mask, object, own_gaus2d, rawmask), dim=2)
+ output = rearrange(output, 'b o c h w -> (b o) c h w')
+
+ return output
+
+class PatchDownConv(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size = 4, alpha = 1):
+ super(PatchDownConv, self).__init__()
+ assert out_channels % in_channels == 0
+
+ self.layers = nn.Conv2d(
+ in_channels = in_channels,
+ out_channels = out_channels,
+ kernel_size = kernel_size,
+ stride = kernel_size,
+ )
+
+ self.alpha = nn.Parameter(th.zeros(1) + alpha)
+ self.kernel_size = 4
+ self.channels_factor = out_channels // in_channels
+
+ def forward(self, input: th.Tensor):
+ k = self.kernel_size
+ c = self.channels_factor
+ skip = reduce(input, 'b c (h h2) (w w2) -> b c h w', 'mean', h2=k, w2=k)
+ skip = repeat(skip, 'b c h w -> b (c n) h w', n=c)
+ return skip + self.alpha * self.layers(input)
+
+class AggressiveConvToGestalt(nn.Module):
+ def __init__(self, channels, gestalt_size, size: Union[int, Tuple[int, int]]):
+ super(AggressiveConvToGestalt, self).__init__()
+
+ assert gestalt_size % channels == 0 or channels % gestalt_size == 0
+
+ self.layers = nn.Sequential(
+ nn.Conv2d(
+ in_channels = channels,
+ out_channels = gestalt_size,
+ kernel_size = 5,
+ stride = 3,
+ padding = 3
+ ),
+ nn.ReLU(),
+ nn.Conv2d(
+ in_channels = gestalt_size,
+ out_channels = gestalt_size,
+ kernel_size = ((size[0] + 1)//3 + 1, (size[1] + 1)//3 + 1)
+ )
+ )
+ if gestalt_size > channels:
+ self.skip = nn.Sequential(
+ LambdaModule(lambda x: reduce(x, 'b c h w -> b c 1 1', 'mean')),
+ LambdaModule(lambda x: repeat(x, 'b c 1 1 -> b (c n) 1 1', n = gestalt_size // channels))
+ )
+ else:
+ self.skip = LambdaModule(lambda x: reduce(x, 'b (c n) h w -> b c 1 1', 'mean', n = channels // gestalt_size))
+
+
+ def forward(self, input: th.Tensor):
+ return self.skip(input) + self.layers(input)
+
+class PixelToPosition(nn.Module):
+ def __init__(self, size: Union[int, Tuple[int, int]]):
+ super(PixelToPosition, self).__init__()
+
+ self.register_buffer("grid_x", th.arange(size[0]), persistent=False)
+ self.register_buffer("grid_y", th.arange(size[1]), persistent=False)
+
+ self.grid_x = (self.grid_x / (size[0]-1)) * 2 - 1
+ self.grid_y = (self.grid_y / (size[1]-1)) * 2 - 1
+
+ self.grid_x = self.grid_x.view(1, 1, -1, 1).expand(1, 1, *size).clone()
+ self.grid_y = self.grid_y.view(1, 1, 1, -1).expand(1, 1, *size).clone()
+
+ self.size = size
+
+ def forward(self, input: th.Tensor):
+ assert input.shape[1] == 1
+
+ input = rearrange(input, 'b c h w -> b c (h w)')
+ input = th.softmax(input, dim=2)
+ input = rearrange(input, 'b c (h w) -> b c h w', h = self.size[0], w = self.size[1])
+
+ x = th.sum(input * self.grid_x, dim=(2,3))
+ y = th.sum(input * self.grid_y, dim=(2,3))
+
+ return th.cat((x,y),dim=1)
+
+class PixelToSTD(nn.Module):
+ def __init__(self):
+ super(PixelToSTD, self).__init__()
+ self.alpha = ForcedAlpha()
+
+ def forward(self, input: th.Tensor):
+ assert input.shape[1] == 1
+ return self.alpha(reduce(th.sigmoid(input - 10), 'b c h w -> b c', 'mean'))
+
+class PixelToPriority(nn.Module):
+ def __init__(self):
+ super(PixelToPriority, self).__init__()
+
+ def forward(self, input: th.Tensor):
+ assert input.shape[1] == 1
+ return reduce(th.tanh(input), 'b c h w -> b c', 'mean')
+
+class LociEncoder(nn.Module):
+ def __init__(
+ self,
+ input_size: Union[int, Tuple[int, int]],
+ latent_size: Union[int, Tuple[int, int]],
+ num_objects: int,
+ img_channels: int,
+ hidden_channels: int,
+ level1_channels: int,
+ num_layers: int,
+ gestalt_size: int,
+ bottleneck: str
+ ):
+ super(LociEncoder, self).__init__()
+
+ self.num_objects = num_objects
+ self.latent_size = latent_size
+ self.level = 1
+
+ self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o = self.num_objects))
+
+ print(f"Level1 channels: {level1_channels}")
+
+ self.preprocess = nn.ModuleList([
+ InputPreprocessing(num_objects, (input_size[0] // 16, input_size[1] // 16)),
+ InputPreprocessing(num_objects, (input_size[0] // 4, input_size[1] // 4)),
+ InputPreprocessing(num_objects, (input_size[0], input_size[1]))
+ ])
+
+ self.to_channels = nn.ModuleList([
+ SkipConnection(img_channels, hidden_channels),
+ SkipConnection(img_channels, level1_channels),
+ SkipConnection(img_channels, img_channels)
+ ])
+
+ self.layers2 = nn.Sequential(
+ PatchDownConv(img_channels, level1_channels, alpha = 1e-16),
+ *[ResidualBlock(level1_channels, level1_channels, alpha_residual=True) for _ in range(num_layers)]
+ )
+
+ self.layers1 = PatchDownConv(level1_channels, hidden_channels)
+
+ self.layers0 = nn.Sequential(
+ *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)]
+ )
+
+ self.position_encoder = nn.Sequential(
+ *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)],
+ ResidualBlock(hidden_channels, 3),
+ )
+
+ self.xy_encoder = PixelToPosition(latent_size)
+ self.std_encoder = PixelToSTD()
+ self.priority_encoder = PixelToPriority()
+
+ if bottleneck == "binar":
+ print("Binary bottleneck")
+ self.gestalt_encoder = nn.Sequential(
+ *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)],
+ AggressiveConvToGestalt(hidden_channels, gestalt_size, latent_size),
+ LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')),
+ Binarize(),
+ )
+
+ else:
+ print("unrestricted bottleneck")
+ self.gestalt_encoder = nn.Sequential(
+ *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers)],
+ AggressiveConvToGestalt(hidden_channels, gestalt_size, latent_size),
+ LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')),
+ nn.Sigmoid(),
+ )
+
+ def set_level(self, level):
+ self.level = level
+
+ def forward(
+ self,
+ input: th.Tensor,
+ error: th.Tensor,
+ mask: th.Tensor,
+ object: th.Tensor,
+ position: th.Tensor,
+ rawmask: th.Tensor
+ ):
+
+ latent = self.preprocess[self.level](input, error, mask, object, position, rawmask)
+ latent = self.to_channels[self.level](latent)
+
+ if self.level >= 2:
+ latent = self.layers2(latent)
+
+ if self.level >= 1:
+ latent = self.layers1(latent)
+
+ latent = self.layers0(latent)
+ gestalt = self.gestalt_encoder(latent)
+
+ latent = self.position_encoder(latent)
+ std = self.std_encoder(latent[:,0:1])
+ xy = self.xy_encoder(latent[:,1:2])
+ priority = self.priority_encoder(latent[:,2:3])
+
+ position = self.to_shared(th.cat((xy, std), dim=1))
+ gestalt = self.to_shared(gestalt)
+ priority = self.to_shared(priority)
+
+ return position, gestalt, priority
+