1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
|
import torch.nn as nn
import torch as th
from model.nn.residual import ResidualBlock, SkipConnection, LinearResidual
from model.nn.encoder import PatchDownConv
from model.nn.encoder import AggressiveConvToGestalt
from model.nn.decoder import PatchUpscale
from model.utils.nn_utils import LambdaModule, Binarize
from einops import rearrange, repeat, reduce
from typing import Tuple
__author__ = "Manuel Traub"
class BackgroundEnhancer(nn.Module):
def __init__(
self,
input_size: Tuple[int, int],
batch_size,
):
super(BackgroundEnhancer, self).__init__()
self.batch_size = batch_size
self.height = input_size[0]
self.width = input_size[1]
self.mask = nn.Parameter(th.ones(1, 1, *input_size) * 10)
self.register_buffer('init', th.zeros(1).long())
def get_init(self):
return self.init.item()
def forward(self, input: th.Tensor):
mask = reduce(self.mask, '1 1 (h h2) (w w2) -> 1 1 h w', 'mean', h = input.shape[2], w = input.shape[3])
mask = repeat(mask, '1 1 h w -> b 1 h w', b = self.batch_size) * 0.1
return mask
|