aboutsummaryrefslogtreecommitdiff
path: root/model/utils/nn_utils.py
blob: 5116e141614dea89f75396b6fa1bcb9c97147b32 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
from typing import Tuple
import torch.nn as nn
import torch as th
import numpy as np
from torch.autograd import Function
from einops import rearrange, repeat, reduce

class PushToInfFunction(Function):
    @staticmethod
    def forward(ctx, tensor):
        ctx.save_for_backward(tensor)
        return tensor.clone()

    @staticmethod
    def backward(ctx, grad_output):
        tensor = ctx.saved_tensors[0]
        grad_input = -th.ones_like(grad_output)
        return grad_input

class PushToInf(nn.Module):
    def __init__(self):
        super(PushToInf, self).__init__()
        
        self.fcn = PushToInfFunction.apply

    def forward(self, input: th.Tensor):
        return self.fcn(input)

class ForcedAlpha(nn.Module):
    def __init__(self, speed = 1):
        super(ForcedAlpha, self).__init__()

        self.init   = nn.Parameter(th.zeros(1))
        self.speed  = speed
        self.to_inf = PushToInf()

    def item(self):
        return th.tanh(self.to_inf(self.init * self.speed)).item()

    def forward(self, input: th.Tensor):
        return input * th.tanh(self.to_inf(self.init * self.speed))

class LinearInterpolation(nn.Module):
    def __init__(self, num_objects):
        super(LinearInterpolation, self).__init__()
        self.to_batch  = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o = num_objects))
        self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o = num_objects))

    def forward(
        self, 
        tensor_cur: th.Tensor = None,
        tensor_last: th.Tensor = None,
        slot_interpolation_value: th.Tensor = None
    ):

        slot_interpolation_value = rearrange(slot_interpolation_value, 'b o -> (b o) 1')
        tensor_cur = slot_interpolation_value * self.to_batch(tensor_last) + (1 - slot_interpolation_value) * self.to_batch(tensor_cur)

        return self.to_shared(tensor_cur)

class Gaus2D(nn.Module):
    def __init__(self, size: Tuple[int, int]):
        super(Gaus2D, self).__init__()

        self.size = size

        self.register_buffer("grid_x", th.arange(size[0]), persistent=False)
        self.register_buffer("grid_y", th.arange(size[1]), persistent=False)

        self.grid_x = (self.grid_x / (size[0]-1)) * 2 - 1
        self.grid_y = (self.grid_y / (size[1]-1)) * 2 - 1

        self.grid_x = self.grid_x.view(1, 1, -1, 1).expand(1, 1, *size).clone()
        self.grid_y = self.grid_y.view(1, 1, 1, -1).expand(1, 1, *size).clone()

    def forward(self, input: th.Tensor):

        x   = rearrange(input[:,0:1], 'b c -> b c 1 1')
        y   = rearrange(input[:,1:2], 'b c -> b c 1 1')
        std = rearrange(input[:,2:3], 'b c -> b c 1 1')

        x   = th.clip(x, -1, 1)
        y   = th.clip(y, -1, 1)
        std = th.clip(std, 0, 1)
            
        max_size = max(self.size)
        std_x = (1 + max_size * std) / self.size[0]
        std_y = (1 + max_size * std) / self.size[1]

        return th.exp(-1 * ((self.grid_x - x)**2/(2 * std_x**2) + (self.grid_y - y)**2/(2 * std_y**2)))

class Vector2D(nn.Module):
    def __init__(self, size: Tuple[int, int]):
        super(Vector2D, self).__init__()

        self.size = size

        self.register_buffer("grid_x", th.arange(size[0]), persistent=False)
        self.register_buffer("grid_y", th.arange(size[1]), persistent=False)

        self.grid_x = (self.grid_x / (size[0]-1)) * 2 - 1
        self.grid_y = (self.grid_y / (size[1]-1)) * 2 - 1

        self.grid_x = self.grid_x.view(1, 1, -1, 1).expand(1, 3, *size).clone()
        self.grid_y = self.grid_y.view(1, 1, 1, -1).expand(1, 3, *size).clone()

    def forward(self, input: th.Tensor, vector: th.Tensor = None):

        x   = rearrange(input[:,0:1], 'b c -> b c 1 1')
        y   = rearrange(input[:,1:2], 'b c -> b c 1 1')
        if vector is not None:
            x_vec = rearrange(vector[:,0:1], 'b c -> b c 1 1')
            y_vec = rearrange(vector[:,1:2], 'b c -> b c 1 1')

        x   = th.clip(x, -1, 1)
        y   = th.clip(y, -1, 1)
        std = 0.01
            
        max_size = max(self.size)
        std_x = (1 + max_size * std) / self.size[0]
        std_y = (1 + max_size * std) / self.size[1]
        grid = th.exp(-1 * ((self.grid_x - x)**2/(2 * std_x**2) + (self.grid_y - y)**2/(2 * std_y**2)))

        # interpolating between start and end point
        if vector is not None:
            for length in np.linspace(0, 1, 11):
                x_end = th.clip(x + x_vec * length, -1, 1)
                y_end = th.clip(y + y_vec * length, -1, 1)

                grid_point = th.exp(-1 * ((self.grid_x - x_end)**2/(0.5 * std_x**2) + (self.grid_y - y_end)**2/(0.5 * std_y**2)))
                grid_point[:, 0:2, :, :] = 0
                grid = th.max(grid, grid_point)

        return grid

class SharedObjectsToBatch(nn.Module):
    def __init__(self, num_objects):
        super(SharedObjectsToBatch, self).__init__()

        self.num_objects = num_objects

    def forward(self, input: th.Tensor):
        return rearrange(input, 'b (o c) h w -> (b o) c h w', o=self.num_objects)

class BatchToSharedObjects(nn.Module):
    def __init__(self, num_objects):
        super(BatchToSharedObjects, self).__init__()

        self.num_objects = num_objects

    def forward(self, input: th.Tensor):
        return rearrange(input, '(b o) c h w -> b (o c) h w', o=self.num_objects)

class LambdaModule(nn.Module):
    def __init__(self, lambd):
        super().__init__()
        import types
        assert type(lambd) is types.LambdaType
        self.lambd = lambd

    def forward(self, *x):
        return self.lambd(*x)

class PrintGradientFunction(Function):
    @staticmethod
    def forward(ctx, tensor, msg):
        ctx.msg = msg
        return tensor

    @staticmethod
    def backward(ctx, grad_output):
        grad_input = grad_output.clone()
        print(f"{ctx.msg}: {th.mean(grad_output).item()} +- {th.std(grad_output).item()}")
        return grad_input, None

class PrintGradient(nn.Module):
    def __init__(self, msg = "PrintGradient"):
        super(PrintGradient, self).__init__()

        self.fcn = PrintGradientFunction.apply
        self.msg = msg

    def forward(self, input: th.Tensor):
        return self.fcn(input, self.msg)

class Prioritize(nn.Module):
    def __init__(self, num_objects):
        super(Prioritize, self).__init__()

        self.num_objects = num_objects
        self.to_batch    = SharedObjectsToBatch(num_objects)

    def forward(self, input: th.Tensor, priority: th.Tensor):
        
        if priority is None:
            return input

        batch_size = input.shape[0]
        weights    = th.zeros((batch_size, self.num_objects, self.num_objects, 1, 1), device=input.device)

        for o in range(self.num_objects):
            weights[:,o,:,0,0] = th.sigmoid(priority[:,:] - priority[:,o:o+1])
            weights[:,o,o,0,0] = weights[:,o,o,0,0] * 0

        input  = rearrange(input, 'b c h w -> 1 (b c) h w')
        weights = rearrange(weights, 'b o i 1 1 -> (b o) i 1 1')

        output = th.relu(input - nn.functional.conv2d(input, weights, groups=batch_size))
        output = rearrange(output, '1 (b c) h w -> b c h w ', b=batch_size)

        return output

class MultiArgSequential(nn.Sequential):
    def __init__(self, *args, **kwargs):
        super(MultiArgSequential, self).__init__(*args, **kwargs)

    def forward(self, *tensor):

        for n in range(len(self)):
            if isinstance(tensor, th.Tensor) or tensor == None:
                tensor = self[n](tensor)
            else:
                tensor = self[n](*tensor)

        return tensor

def create_grid(size):
    grid_x = th.arange(size[0])
    grid_y = th.arange(size[1])

    grid_x = (grid_x / (size[0]-1)) * 2 - 1
    grid_y = (grid_y / (size[1]-1)) * 2 - 1

    grid_x = grid_x.view(1, 1, -1, 1).expand(1, 1, *size).clone()
    grid_y = grid_y.view(1, 1, 1, -1).expand(1, 1, *size).clone()

    return th.cat((grid_y, grid_x), dim=1)

class Warp(nn.Module):
    def __init__(self, size, padding = 0.1):
        super(Warp, self).__init__()

        padding = int(max(size) * padding)
        padded_size = (size[0] + 2 * padding, size[1] + 2 * padding)

        self.register_buffer('grid', create_grid(size))
        self.register_buffer('padded_grid', create_grid(padded_size))

        self.replication_pad = nn.ReplicationPad2d(padding)
        self.interpolate = nn.Sequential(
            LambdaModule(lambda x:
                th.nn.functional.interpolate(x, size=size, mode='bicubic', align_corners = True)
            ),
            LambdaModule(lambda x: x - self.grid),
            nn.ConstantPad2d(padding, 0),
            LambdaModule(lambda x: x + self.padded_grid),
            LambdaModule(lambda x: rearrange(x, 'b c h w -> b h w c'))
        )

        self.warp = LambdaModule(lambda input, flow:
            th.nn.functional.grid_sample(input, flow, mode='bicubic', align_corners=True)
        )

        self.un_pad = LambdaModule(lambda x: x[:,:,padding:-padding,padding:-padding])
    
    def get_raw_flow(self, flow):
        return flow - self.grid

    def forward(self, input, flow):
        input = self.replication_pad(input)
        flow  = self.interpolate(flow)
        return self.un_pad(self.warp(input, flow))

class Binarize(nn.Module):
    def __init__(self):
        super(Binarize, self).__init__()

    def forward(self, input: th.Tensor):
        input = th.sigmoid(input)
        if not self.training:
            return th.round(input)

        return input + input * (1 - input) * th.randn_like(input)
    
class TanhAlpha(nn.Module):
    def __init__(self, start = 0, stepsize = 1e-4, max_value = 1):
        super(TanhAlpha, self).__init__()

        self.register_buffer('init', th.zeros(1) + start)
        self.stepsize  = stepsize
        self.max_value = max_value

    def get(self):
        return (th.tanh(self.init) * self.max_value).item()

    def forward(self):
        self.init = self.init.detach() + self.stepsize
        return self.get()