aboutsummaryrefslogtreecommitdiff
path: root/model/nn/background.py
diff options
context:
space:
mode:
authorfredeee2023-11-02 10:47:21 +0100
committerfredeee2023-11-02 10:47:21 +0100
commitf8302ee886ef9b631f11a52900dac964a61350e1 (patch)
tree87288be6f851ab69405e524b81940c501c52789a /model/nn/background.py
parentf16fef1ab9371e1c81a2e0b2fbea59dee285a9f8 (diff)
initiaƶ commit
Diffstat (limited to 'model/nn/background.py')
-rw-r--r--model/nn/background.py167
1 files changed, 167 insertions, 0 deletions
diff --git a/model/nn/background.py b/model/nn/background.py
new file mode 100644
index 0000000..38ec325
--- /dev/null
+++ b/model/nn/background.py
@@ -0,0 +1,167 @@
+import torch.nn as nn
+import torch as th
+from model.nn.residual import ResidualBlock, SkipConnection, LinearResidual
+from model.nn.encoder import PatchDownConv
+from model.nn.encoder import AggressiveConvToGestalt
+from model.nn.decoder import PatchUpscale
+from model.utils.nn_utils import LambdaModule, Binarize
+from einops import rearrange, repeat, reduce
+from typing import Tuple
+
+__author__ = "Manuel Traub"
+
+class BackgroundEnhancer(nn.Module):
+ def __init__(
+ self,
+ input_size: Tuple[int, int],
+ img_channels: int,
+ level1_channels,
+ latent_channels,
+ gestalt_size,
+ batch_size,
+ depth
+ ):
+ super(BackgroundEnhancer, self).__init__()
+
+ latent_size = [input_size[0] // 16, input_size[1] // 16]
+ self.input_size = input_size
+
+ self.register_buffer('init', th.zeros(1).long())
+ self.alpha = nn.Parameter(th.zeros(1)+1e-16)
+
+ self.level = 1
+ self.down_level2 = nn.Sequential(
+ PatchDownConv(img_channels*2+2, level1_channels, alpha = 1e-16),
+ *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for i in range(depth)]
+ )
+
+ self.down_level1 = nn.Sequential(
+ PatchDownConv(level1_channels, latent_channels, alpha = 1),
+ *[ResidualBlock(latent_channels, latent_channels, alpha_residual = False) for i in range(depth)]
+ )
+
+ self.down_level0 = nn.Sequential(
+ *[ResidualBlock(latent_channels, latent_channels) for i in range(depth)],
+ AggressiveConvToGestalt(latent_channels, gestalt_size, latent_size),
+ LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')),
+ Binarize(),
+ )
+
+ self.bias = nn.Parameter(th.zeros((1, gestalt_size, *latent_size)))
+
+ self.to_grid = nn.Sequential(
+ LinearResidual(gestalt_size, gestalt_size, input_relu = False),
+ LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')),
+ LambdaModule(lambda x: x + self.bias),
+ *[ResidualBlock(gestalt_size, gestalt_size) for i in range(depth)],
+ )
+
+
+ self.up_level0 = nn.Sequential(
+ ResidualBlock(gestalt_size, latent_channels),
+ *[ResidualBlock(latent_channels, latent_channels) for i in range(depth)],
+ )
+
+ self.up_level1 = nn.Sequential(
+ *[ResidualBlock(latent_channels, latent_channels, alpha_residual = False) for i in range(depth)],
+ PatchUpscale(latent_channels, level1_channels, alpha = 1),
+ )
+
+ self.up_level2 = nn.Sequential(
+ *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for i in range(depth)],
+ PatchUpscale(level1_channels, img_channels, alpha = 1e-16),
+ )
+
+ self.to_channels = nn.ModuleList([
+ SkipConnection(img_channels*2+2, latent_channels),
+ SkipConnection(img_channels*2+2, level1_channels),
+ SkipConnection(img_channels*2+2, img_channels*2+2),
+ ])
+
+ self.to_img = nn.ModuleList([
+ SkipConnection(latent_channels, img_channels),
+ SkipConnection(level1_channels, img_channels),
+ SkipConnection(img_channels, img_channels),
+ ])
+
+ self.mask = nn.Parameter(th.ones(1, 1, *input_size) * 10)
+ self.object = nn.Parameter(th.ones(1, img_channels, *input_size))
+
+ self.register_buffer('latent', th.zeros((batch_size, gestalt_size)), persistent=False)
+
+ def get_init(self):
+ return self.init.item()
+
+ def step_init(self):
+ self.init = self.init + 1
+
+ def detach(self):
+ self.latent = self.latent.detach()
+
+ def reset_state(self):
+ self.latent = th.zeros_like(self.latent)
+
+ def set_level(self, level):
+ self.level = level
+
+ def encoder(self, input):
+ latent = self.to_channels[self.level](input)
+
+ if self.level >= 2:
+ latent = self.down_level2(latent)
+
+ if self.level >= 1:
+ latent = self.down_level1(latent)
+
+ return self.down_level0(latent)
+
+ def get_last_latent_gird(self):
+ return self.to_grid(self.latent) * self.alpha
+
+ def decoder(self, latent, input):
+ grid = self.to_grid(latent)
+ latent = self.up_level0(grid)
+
+ if self.level >= 1:
+ latent = self.up_level1(latent)
+
+ if self.level >= 2:
+ latent = self.up_level2(latent)
+
+ object = reduce(self.object, '1 c (h h2) (w w2) -> 1 c h w', 'mean', h = input.shape[2], w = input.shape[3])
+ object = repeat(object, '1 c h w -> b c h w', b = input.shape[0])
+
+ return th.sigmoid(object + self.to_img[self.level](latent)), grid
+
+ def forward(self, input: th.Tensor, error: th.Tensor = None, mask: th.Tensor = None, only_mask: bool = False):
+
+ if only_mask:
+ mask = reduce(self.mask, '1 1 (h h2) (w w2) -> 1 1 h w', 'mean', h = input.shape[2], w = input.shape[3])
+ mask = repeat(mask, '1 1 h w -> b 1 h w', b = input.shape[0]) * 0.1
+ return mask
+
+ last_bg = self.decoder(self.latent, input)[0]
+
+ bg_error = th.sqrt(reduce((input - last_bg)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+ bg_mask = (bg_error < th.mean(bg_error) + th.std(bg_error)).float().detach()
+
+ if error is None or self.get_init() < 2:
+ error = bg_error
+
+ if mask is None or self.get_init() < 2:
+ mask = bg_mask
+
+ self.latent = self.encoder(th.cat((input, last_bg, error, mask), dim=1))
+
+ mask = reduce(self.mask, '1 1 (h h2) (w w2) -> 1 1 h w', 'mean', h = input.shape[2], w = input.shape[3])
+ mask = repeat(mask, '1 1 h w -> b 1 h w', b = input.shape[0]) * 0.1
+
+ background, grid = self.decoder(self.latent, input)
+
+ if self.get_init() < 1:
+ return mask, background
+
+ if self.get_init() < 2:
+ return mask, th.zeros_like(background), th.zeros_like(grid), background
+
+ return mask, background, grid * self.alpha, background