diff options
Diffstat (limited to 'model/nn/background.py')
-rw-r--r-- | model/nn/background.py | 149 |
1 files changed, 8 insertions, 141 deletions
diff --git a/model/nn/background.py b/model/nn/background.py index 38ec325..ff14e02 100644 --- a/model/nn/background.py +++ b/model/nn/background.py @@ -14,154 +14,21 @@ class BackgroundEnhancer(nn.Module): def __init__( self, input_size: Tuple[int, int], - img_channels: int, - level1_channels, - latent_channels, - gestalt_size, batch_size, - depth ): super(BackgroundEnhancer, self).__init__() - latent_size = [input_size[0] // 16, input_size[1] // 16] - self.input_size = input_size - - self.register_buffer('init', th.zeros(1).long()) - self.alpha = nn.Parameter(th.zeros(1)+1e-16) - - self.level = 1 - self.down_level2 = nn.Sequential( - PatchDownConv(img_channels*2+2, level1_channels, alpha = 1e-16), - *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for i in range(depth)] - ) - - self.down_level1 = nn.Sequential( - PatchDownConv(level1_channels, latent_channels, alpha = 1), - *[ResidualBlock(latent_channels, latent_channels, alpha_residual = False) for i in range(depth)] - ) - - self.down_level0 = nn.Sequential( - *[ResidualBlock(latent_channels, latent_channels) for i in range(depth)], - AggressiveConvToGestalt(latent_channels, gestalt_size, latent_size), - LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')), - Binarize(), - ) - - self.bias = nn.Parameter(th.zeros((1, gestalt_size, *latent_size))) - - self.to_grid = nn.Sequential( - LinearResidual(gestalt_size, gestalt_size, input_relu = False), - LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')), - LambdaModule(lambda x: x + self.bias), - *[ResidualBlock(gestalt_size, gestalt_size) for i in range(depth)], - ) - - - self.up_level0 = nn.Sequential( - ResidualBlock(gestalt_size, latent_channels), - *[ResidualBlock(latent_channels, latent_channels) for i in range(depth)], - ) - - self.up_level1 = nn.Sequential( - *[ResidualBlock(latent_channels, latent_channels, alpha_residual = False) for i in range(depth)], - PatchUpscale(latent_channels, level1_channels, alpha = 1), - ) - - self.up_level2 = nn.Sequential( - *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for i in range(depth)], - PatchUpscale(level1_channels, img_channels, alpha = 1e-16), - ) - - self.to_channels = nn.ModuleList([ - SkipConnection(img_channels*2+2, latent_channels), - SkipConnection(img_channels*2+2, level1_channels), - SkipConnection(img_channels*2+2, img_channels*2+2), - ]) - - self.to_img = nn.ModuleList([ - SkipConnection(latent_channels, img_channels), - SkipConnection(level1_channels, img_channels), - SkipConnection(img_channels, img_channels), - ]) - + self.batch_size = batch_size + self.height = input_size[0] + self.width = input_size[1] self.mask = nn.Parameter(th.ones(1, 1, *input_size) * 10) - self.object = nn.Parameter(th.ones(1, img_channels, *input_size)) - - self.register_buffer('latent', th.zeros((batch_size, gestalt_size)), persistent=False) + self.register_buffer('init', th.zeros(1).long()) def get_init(self): return self.init.item() - - def step_init(self): - self.init = self.init + 1 - - def detach(self): - self.latent = self.latent.detach() - - def reset_state(self): - self.latent = th.zeros_like(self.latent) - - def set_level(self, level): - self.level = level - - def encoder(self, input): - latent = self.to_channels[self.level](input) - - if self.level >= 2: - latent = self.down_level2(latent) - - if self.level >= 1: - latent = self.down_level1(latent) - - return self.down_level0(latent) - - def get_last_latent_gird(self): - return self.to_grid(self.latent) * self.alpha - - def decoder(self, latent, input): - grid = self.to_grid(latent) - latent = self.up_level0(grid) - - if self.level >= 1: - latent = self.up_level1(latent) - - if self.level >= 2: - latent = self.up_level2(latent) - - object = reduce(self.object, '1 c (h h2) (w w2) -> 1 c h w', 'mean', h = input.shape[2], w = input.shape[3]) - object = repeat(object, '1 c h w -> b c h w', b = input.shape[0]) - - return th.sigmoid(object + self.to_img[self.level](latent)), grid - - def forward(self, input: th.Tensor, error: th.Tensor = None, mask: th.Tensor = None, only_mask: bool = False): + + def forward(self, input: th.Tensor): - if only_mask: - mask = reduce(self.mask, '1 1 (h h2) (w w2) -> 1 1 h w', 'mean', h = input.shape[2], w = input.shape[3]) - mask = repeat(mask, '1 1 h w -> b 1 h w', b = input.shape[0]) * 0.1 - return mask - - last_bg = self.decoder(self.latent, input)[0] - - bg_error = th.sqrt(reduce((input - last_bg)**2, 'b c h w -> b 1 h w', 'mean')).detach() - bg_mask = (bg_error < th.mean(bg_error) + th.std(bg_error)).float().detach() - - if error is None or self.get_init() < 2: - error = bg_error - - if mask is None or self.get_init() < 2: - mask = bg_mask - - self.latent = self.encoder(th.cat((input, last_bg, error, mask), dim=1)) - mask = reduce(self.mask, '1 1 (h h2) (w w2) -> 1 1 h w', 'mean', h = input.shape[2], w = input.shape[3]) - mask = repeat(mask, '1 1 h w -> b 1 h w', b = input.shape[0]) * 0.1 - - background, grid = self.decoder(self.latent, input) - - if self.get_init() < 1: - return mask, background - - if self.get_init() < 2: - return mask, th.zeros_like(background), th.zeros_like(grid), background - - return mask, background, grid * self.alpha, background + mask = repeat(mask, '1 1 h w -> b 1 h w', b = self.batch_size) * 0.1 + return mask |