diff options
Diffstat (limited to 'model/nn/decoder.py')
-rw-r--r-- | model/nn/decoder.py | 169 |
1 files changed, 169 insertions, 0 deletions
diff --git a/model/nn/decoder.py b/model/nn/decoder.py new file mode 100644 index 0000000..3e56640 --- /dev/null +++ b/model/nn/decoder.py @@ -0,0 +1,169 @@ +import torch.nn as nn +import torch as th +from model.nn.residual import SkipConnection, ResidualBlock +from model.utils.nn_utils import Gaus2D, SharedObjectsToBatch, BatchToSharedObjects, Prioritize +from einops import rearrange, repeat, reduce +from typing import Tuple, Union, List + +__author__ = "Manuel Traub" + +class PriorityEncoder(nn.Module): + def __init__(self, num_objects, batch_size): + super(PriorityEncoder, self).__init__() + + self.num_objects = num_objects + self.register_buffer("indices", repeat(th.arange(num_objects), 'a -> b a', b=batch_size), persistent=False) + + self.index_factor = nn.Parameter(th.ones(1)) + self.priority_factor = nn.Parameter(th.ones(1)) + + def forward(self, priority: th.Tensor) -> th.Tensor: + + if priority is None: + return None + + priority = priority * self.num_objects + th.randn_like(priority) * 0.1 + priority = priority * self.priority_factor + priority = priority + self.indices * self.index_factor + + return priority * 25 + + +class GestaltPositionMerge(nn.Module): + def __init__( + self, + latent_size: Union[int, Tuple[int, int]], + num_objects: int, + batch_size: int + ): + + super(GestaltPositionMerge, self).__init__() + self.num_objects = num_objects + + self.gaus2d = Gaus2D(size=latent_size) + + self.to_batch = SharedObjectsToBatch(num_objects) + self.to_shared = BatchToSharedObjects(num_objects) + + self.prioritize = Prioritize(num_objects) + + self.priority_encoder = PriorityEncoder(num_objects, batch_size) + + def forward(self, position, gestalt, priority): + + position = rearrange(position, 'b (o c) -> (b o) c', o = self.num_objects) + gestalt = rearrange(gestalt, 'b (o c) -> (b o) c 1 1', o = self.num_objects) + priority = self.priority_encoder(priority) + + position = self.gaus2d(position) + position = self.to_batch(self.prioritize(self.to_shared(position), priority)) + + return position * gestalt + +class PatchUpscale(nn.Module): + def __init__(self, in_channels, out_channels, scale_factor = 4, alpha = 1): + super(PatchUpscale, self).__init__() + assert in_channels % out_channels == 0 + + self.skip = SkipConnection(in_channels, out_channels, scale_factor=scale_factor) + + self.residual = nn.Sequential( + nn.ReLU(), + nn.Conv2d( + in_channels = in_channels, + out_channels = in_channels, + kernel_size = 3, + padding = 1 + ), + nn.ReLU(), + nn.ConvTranspose2d( + in_channels = in_channels, + out_channels = out_channels, + kernel_size = scale_factor, + stride = scale_factor, + ), + ) + + self.alpha = nn.Parameter(th.zeros(1) + alpha) + + def forward(self, input): + return self.skip(input) + self.alpha * self.residual(input) + + +class LociDecoder(nn.Module): + def __init__( + self, + latent_size: Union[int, Tuple[int, int]], + gestalt_size: int, + num_objects: int, + img_channels: int, + hidden_channels: int, + level1_channels: int, + num_layers: int, + batch_size: int + ): + + super(LociDecoder, self).__init__() + self.to_batch = SharedObjectsToBatch(num_objects) + self.to_shared = BatchToSharedObjects(num_objects) + self.level = 1 + + assert(level1_channels % img_channels == 0) + level1_factor = level1_channels // img_channels + print(f"Level1 channels: {level1_channels}") + + self.merge = GestaltPositionMerge( + latent_size = latent_size, + num_objects = num_objects, + batch_size = batch_size + ) + + self.layer0 = nn.Sequential( + ResidualBlock(gestalt_size, hidden_channels, input_nonlinearity = False), + *[ResidualBlock(hidden_channels, hidden_channels) for _ in range(num_layers-1)], + ) + + self.to_mask_level0 = ResidualBlock(hidden_channels, hidden_channels) + self.to_mask_level1 = PatchUpscale(hidden_channels, 1) + + self.to_mask_level2 = nn.Sequential( + ResidualBlock(hidden_channels, hidden_channels), + ResidualBlock(hidden_channels, hidden_channels), + PatchUpscale(hidden_channels, level1_factor, alpha = 1), + PatchUpscale(level1_factor, 1, alpha = 1) + ) + + self.to_object_level0 = ResidualBlock(hidden_channels, hidden_channels) + self.to_object_level1 = PatchUpscale(hidden_channels, img_channels) + + self.to_object_level2 = nn.Sequential( + ResidualBlock(hidden_channels, hidden_channels), + ResidualBlock(hidden_channels, hidden_channels), + PatchUpscale(hidden_channels, level1_channels, alpha = 1), + PatchUpscale(level1_channels, img_channels, alpha = 1) + ) + + self.mask_alpha = nn.Parameter(th.zeros(1)+1e-16) + self.object_alpha = nn.Parameter(th.zeros(1)+1e-16) + + + def set_level(self, level): + self.level = level + + def forward(self, position, gestalt, priority = None): + + maps = self.layer0(self.merge(position, gestalt, priority)) + mask0 = self.to_mask_level0(maps) + object0 = self.to_object_level0(maps) + + mask = self.to_mask_level1(mask0) + object = self.to_object_level1(object0) + + if self.level > 1: + mask = repeat(mask, 'b c h w -> b c (h h2) (w w2)', h2 = 4, w2 = 4) + object = repeat(object, 'b c h w -> b c (h h2) (w w2)', h2 = 4, w2 = 4) + + mask = mask + self.to_mask_level2(mask0) * self.mask_alpha + object = object + self.to_object_level2(object0) * self.object_alpha + + return self.to_shared(mask), self.to_shared(object) |