aboutsummaryrefslogtreecommitdiff
path: root/model/nn/background.py
blob: ff14e0260ba3d579c004059e59a2a0dc3bb83902 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch.nn as nn
import torch as th
from model.nn.residual import ResidualBlock, SkipConnection, LinearResidual
from model.nn.encoder import PatchDownConv
from model.nn.encoder import AggressiveConvToGestalt
from model.nn.decoder import PatchUpscale
from model.utils.nn_utils import LambdaModule, Binarize
from einops import rearrange, repeat, reduce
from typing import Tuple

__author__ = "Manuel Traub"

class BackgroundEnhancer(nn.Module):
    def __init__(
        self, 
        input_size: Tuple[int, int], 
        batch_size,
    ):
        super(BackgroundEnhancer, self).__init__()

        self.batch_size = batch_size
        self.height = input_size[0]
        self.width  = input_size[1]
        self.mask   = nn.Parameter(th.ones(1, 1, *input_size) * 10)
        self.register_buffer('init', th.zeros(1).long())

    def get_init(self):
        return self.init.item()
    
    def forward(self, input: th.Tensor):
        
        mask = reduce(self.mask, '1 1 (h h2) (w w2) -> 1 1 h w', 'mean', h = input.shape[2], w = input.shape[3])
        mask = repeat(mask,      '1 1 h w -> b 1 h w', b = self.batch_size) * 0.1
        return mask