aboutsummaryrefslogtreecommitdiff
path: root/model/nn/residual.py
diff options
context:
space:
mode:
authorfredeee2023-11-02 10:47:21 +0100
committerfredeee2023-11-02 10:47:21 +0100
commitf8302ee886ef9b631f11a52900dac964a61350e1 (patch)
tree87288be6f851ab69405e524b81940c501c52789a /model/nn/residual.py
parentf16fef1ab9371e1c81a2e0b2fbea59dee285a9f8 (diff)
initiaƶ commit
Diffstat (limited to 'model/nn/residual.py')
-rw-r--r--model/nn/residual.py396
1 files changed, 396 insertions, 0 deletions
diff --git a/model/nn/residual.py b/model/nn/residual.py
new file mode 100644
index 0000000..1602e16
--- /dev/null
+++ b/model/nn/residual.py
@@ -0,0 +1,396 @@
+import torch.nn as nn
+import torch as th
+import numpy as np
+from einops import rearrange, repeat, reduce
+from model.utils.nn_utils import LambdaModule
+
+from typing import Union, Tuple
+
+__author__ = "Manuel Traub"
+
+class DynamicLayerNorm(nn.Module):
+
+ def __init__(self, eps: float = 1e-5):
+ super(DynamicLayerNorm, self).__init__()
+ self.eps = eps
+
+ def forward(self, input: th.Tensor) -> th.Tensor:
+ return nn.functional.layer_norm(input, input.shape[2:], None, None, self.eps)
+
+
+class SkipConnection(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ scale_factor: float = 1.0
+ ):
+ super(SkipConnection, self).__init__()
+ assert scale_factor == 1 or int(scale_factor) > 1 or int(1 / scale_factor) > 1, f'invalid scale factor in SpikeFunction: {scale_factor}'
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.scale_factor = scale_factor
+
+ def channel_skip(self, input: th.Tensor):
+ in_channels = self.in_channels
+ out_channels = self.out_channels
+
+ if in_channels == out_channels:
+ return input
+
+ if in_channels % out_channels == 0 or out_channels % in_channels == 0:
+
+ if in_channels > out_channels:
+ return reduce(input, 'b (c n) h w -> b c h w', 'mean', n = in_channels // out_channels)
+
+ if out_channels > in_channels:
+ return repeat(input, 'b c h w -> b (c n) h w', n = out_channels // in_channels)
+
+ mean_channels = np.gcd(in_channels, out_channels)
+ input = reduce(input, 'b (c n) h w -> b c h w', 'mean', n = in_channels // mean_channels)
+ return repeat(input, 'b c h w -> b (c n) h w', n = out_channels // mean_channels)
+
+ def scale_skip(self, input: th.Tensor):
+ scale_factor = self.scale_factor
+
+ if scale_factor == 1:
+ return input
+
+ if scale_factor > 1:
+ return repeat(
+ input,
+ 'b c h w -> b c (h h2) (w w2)',
+ h2 = int(scale_factor),
+ w2 = int(scale_factor)
+ )
+
+ height = input.shape[2]
+ width = input.shape[3]
+
+ # scale factor < 1
+ scale_factor = int(1 / scale_factor)
+
+ if width % scale_factor == 0 and height % scale_factor == 0:
+ return reduce(
+ input,
+ 'b c (h h2) (w w2) -> b c h w',
+ 'mean',
+ h2 = scale_factor,
+ w2 = scale_factor
+ )
+
+ if width >= scale_factor and height >= scale_factor:
+ return nn.functional.avg_pool2d(
+ input,
+ kernel_size = scale_factor,
+ stride = scale_factor
+ )
+
+ assert width > 1 or height > 1
+ return reduce(input, 'b c h w -> b c 1 1', 'mean')
+
+
+ def forward(self, input: th.Tensor):
+
+ if self.scale_factor > 1:
+ return self.scale_skip(self.channel_skip(input))
+
+ return self.channel_skip(self.scale_skip(input))
+
+class DownScale(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ scale_factor: int,
+ groups: int = 1,
+ bias: bool = True
+ ):
+
+ super(DownScale, self).__init__()
+
+ assert(in_channels % groups == 0)
+ assert(out_channels % groups == 0)
+
+ self.groups = groups
+ self.scale_factor = scale_factor
+ self.weight = nn.Parameter(th.empty((out_channels, in_channels // groups, scale_factor, scale_factor)))
+ self.bias = nn.Parameter(th.empty((out_channels,))) if bias else None
+
+ nn.init.kaiming_uniform_(self.weight, a=np.sqrt(5))
+
+ if self.bias is not None:
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
+ bound = 1 / np.sqrt(fan_in)
+ nn.init.uniform_(self.bias, -bound, bound)
+
+ def forward(self, input: th.Tensor):
+ height = input.shape[2]
+ width = input.shape[3]
+ assert height > 1 or width > 1, "trying to dowscale 1x1"
+
+ scale_factor = self.scale_factor
+ padding = [0, 0]
+
+ if height < scale_factor:
+ padding[0] = scale_factor - height
+
+ if width < scale_factor:
+ padding[1] = scale_factor - width
+
+ return nn.functional.conv2d(
+ input,
+ self.weight,
+ bias=self.bias,
+ stride=scale_factor,
+ padding=padding,
+ groups=self.groups
+ )
+
+
+class ResidualBlock(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]] = (3, 3),
+ scale_factor: int = 1,
+ groups: Union[int, Tuple[int, int]] = (1, 1),
+ bias: bool = True,
+ layer_norm: bool = False,
+ leaky_relu: bool = False,
+ residual: bool = True,
+ alpha_residual: bool = False,
+ input_nonlinearity = True
+ ):
+
+ super(ResidualBlock, self).__init__()
+ self.residual = residual
+ self.alpha_residual = alpha_residual
+ self.skip = False
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ if isinstance(kernel_size, int):
+ kernel_size = [kernel_size, kernel_size]
+
+ if isinstance(groups, int):
+ groups = [groups, groups]
+
+ padding = (kernel_size[0] // 2, kernel_size[1] // 2)
+
+ _layers = list()
+ if layer_norm:
+ _layers.append(DynamicLayerNorm())
+
+ if input_nonlinearity:
+ if leaky_relu:
+ _layers.append(nn.LeakyReLU())
+ else:
+ _layers.append(nn.ReLU())
+
+ if scale_factor > 1:
+ _layers.append(
+ nn.ConvTranspose2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=scale_factor,
+ stride=scale_factor,
+ groups=groups[0],
+ bias=bias
+ )
+ )
+ elif scale_factor < 1:
+ _layers.append(
+ DownScale(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ scale_factor=int(1.0/scale_factor),
+ groups=groups[0],
+ bias=bias
+ )
+ )
+ else:
+ _layers.append(
+ nn.Conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ padding=padding,
+ groups=groups[0],
+ bias=bias
+ )
+ )
+
+ if layer_norm:
+ _layers.append(DynamicLayerNorm())
+ if leaky_relu:
+ _layers.append(nn.LeakyReLU())
+ else:
+ _layers.append(nn.ReLU())
+ _layers.append(
+ nn.Conv2d(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ padding=padding,
+ groups=groups[1],
+ bias=bias
+ )
+ )
+ self.layers = nn.Sequential(*_layers)
+
+ if self.residual:
+ self.skip_connection = SkipConnection(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ scale_factor=scale_factor
+ )
+
+ if self.alpha_residual:
+ self.alpha = nn.Parameter(th.zeros(1) + 1e-12)
+
+ def set_mode(self, **kwargs):
+ if 'skip' in kwargs:
+ self.skip = kwargs['skip']
+
+ if 'residual' in kwargs:
+ self.residual = kwargs['residual']
+
+ def forward(self, input: th.Tensor) -> th.Tensor:
+ if self.skip:
+ return self.skip_connection(input)
+
+ if not self.residual:
+ return self.layers(input)
+
+ if self.alpha_residual:
+ return self.alpha * self.layers(input) + self.skip_connection(input)
+
+ return self.layers(input) + self.skip_connection(input)
+
+class LinearSkip(nn.Module):
+ def __init__(self, num_inputs: int, num_outputs: int):
+ super(LinearSkip, self).__init__()
+
+ self.num_inputs = num_inputs
+ self.num_outputs = num_outputs
+
+ if num_inputs % num_outputs != 0 and num_outputs % num_inputs != 0:
+ mean_channels = np.gcd(num_inputs, num_outputs)
+ print(f"[WW] gcd skip: {num_inputs} -> {mean_channels} -> {num_outputs}")
+ assert(False)
+
+ def forward(self, input: th.Tensor):
+ num_inputs = self.num_inputs
+ num_outputs = self.num_outputs
+
+ if num_inputs == num_outputs:
+ return input
+
+ if num_inputs % num_outputs == 0 or num_outputs % num_inputs == 0:
+
+ if num_inputs > num_outputs:
+ return reduce(input, 'b (c n) -> b c', 'mean', n = num_inputs // num_outputs)
+
+ if num_outputs > num_inputs:
+ return repeat(input, 'b c -> b (c n)', n = num_outputs // num_inputs)
+
+ mean_channels = np.gcd(num_inputs, num_outputs)
+ input = reduce(input, 'b (c n) -> b c', 'mean', n = num_inputs // mean_channels)
+ return repeat(input, 'b c -> b (c n)', n = num_outputs // mean_channels)
+
+class LinearResidual(nn.Module):
+ def __init__(
+ self,
+ num_inputs: int,
+ num_outputs: int,
+ num_hidden: int = None,
+ residual: bool = True,
+ alpha_residual: bool = False,
+ input_relu: bool = True
+ ):
+ super(LinearResidual, self).__init__()
+
+ self.residual = residual
+ self.alpha_residual = alpha_residual
+
+ if num_hidden is None:
+ num_hidden = num_outputs
+
+ _layers = []
+ if input_relu:
+ _layers.append(nn.ReLU())
+ _layers.append(nn.Linear(num_inputs, num_hidden))
+ _layers.append(nn.ReLU())
+ _layers.append(nn.Linear(num_hidden, num_outputs))
+
+ self.layers = nn.Sequential(*_layers)
+
+ if residual:
+ self.skip = LinearSkip(num_inputs, num_outputs)
+
+ if alpha_residual:
+ self.alpha = nn.Parameter(th.zeros(1)+1e-16)
+
+ def forward(self, input: th.Tensor):
+ if not self.residual:
+ return self.layers(input)
+
+ if not self.alpha_residual:
+ return self.skip(input) + self.layers(input)
+
+ return self.skip(input) + self.alpha * self.layers(input)
+
+
+class EntityAttention(nn.Module):
+ def __init__(self, channels, num_objects, size, channels_per_head = 12, dropout = 0.0):
+ super(EntityAttention, self).__init__()
+
+ assert channels % channels_per_head == 0
+ heads = channels // channels_per_head
+
+ self.alpha = nn.Parameter(th.zeros(1)+1e-16)
+ self.attention = nn.MultiheadAttention(
+ channels,
+ heads,
+ dropout = dropout,
+ batch_first = True
+ )
+
+ self.channel_attention = nn.Sequential(
+ LambdaModule(lambda x: rearrange(x, '(b o) c h w -> (b h w) o c', o = num_objects)),
+ LambdaModule(lambda x: self.attention(x, x, x, need_weights=False)[0]),
+ LambdaModule(lambda x: rearrange(x, '(b h w) o c-> (b o) c h w', h = size[0], w = size[1])),
+ )
+
+
+ def forward(self, input: th.Tensor) -> th.Tensor:
+ return input + self.channel_attention(input) * self.alpha
+
+class ImageAttention(nn.Module):
+ def __init__(self, channels, gestalt_size, num_objects, size, channels_per_head = 12, dropout = 0.0):
+ super(ImageAttention, self).__init__()
+
+ assert gestalt_size % channels_per_head == 0
+ heads = gestalt_size // channels_per_head
+
+ self.alpha = nn.Parameter(th.zeros(1)+1e-16)
+ self.attention = nn.MultiheadAttention(
+ gestalt_size,
+ heads,
+ dropout = dropout,
+ batch_first = True
+ )
+
+ self.image_attention = nn.Sequential(
+ nn.Conv2d(channels, gestalt_size, kernel_size=1),
+ LambdaModule(lambda x: rearrange(x, 'b c h w -> b (h w) c')),
+ LambdaModule(lambda x: self.attention(x, x, x, need_weights=False)[0]),
+ LambdaModule(lambda x: rearrange(x, 'b (h w) c -> b c h w', h = size[0], w = size[1])),
+ nn.Conv2d(gestalt_size, channels, kernel_size=1),
+ )
+
+ def forward(self, input: th.Tensor) -> th.Tensor:
+ return input + self.image_attention(input) * self.alpha \ No newline at end of file