diff options
author | fredeee | 2023-11-02 10:47:21 +0100 |
---|---|---|
committer | fredeee | 2023-11-02 10:47:21 +0100 |
commit | f8302ee886ef9b631f11a52900dac964a61350e1 (patch) | |
tree | 87288be6f851ab69405e524b81940c501c52789a /model/nn/eprop_transformer_utils.py | |
parent | f16fef1ab9371e1c81a2e0b2fbea59dee285a9f8 (diff) |
initiaƶ commit
Diffstat (limited to 'model/nn/eprop_transformer_utils.py')
-rw-r--r-- | model/nn/eprop_transformer_utils.py | 66 |
1 files changed, 66 insertions, 0 deletions
diff --git a/model/nn/eprop_transformer_utils.py b/model/nn/eprop_transformer_utils.py new file mode 100644 index 0000000..9e5e874 --- /dev/null +++ b/model/nn/eprop_transformer_utils.py @@ -0,0 +1,66 @@ +import torch.nn as nn +import torch as th +from model.utils.nn_utils import LambdaModule +from einops import rearrange, repeat, reduce + +class AlphaAttention(nn.Module): + def __init__( + self, + num_hidden, + num_objects, + heads, + dropout = 0.0 + ): + super(AlphaAttention, self).__init__() + + self.to_sequence = LambdaModule(lambda x: rearrange(x, '(b o) c -> b o c', o = num_objects)) + self.to_batch = LambdaModule(lambda x: rearrange(x, 'b o c -> (b o) c', o = num_objects)) + + self.alpha = nn.Parameter(th.zeros(1)+1e-12) + self.attention = nn.MultiheadAttention( + num_hidden, + heads, + dropout = dropout, + batch_first = True + ) + + def forward(self, x: th.Tensor): + x = self.to_sequence(x) + x = x + self.alpha * self.attention(x, x, x, need_weights=False)[0] + return self.to_batch(x) + +class InputEmbeding(nn.Module): + def __init__(self, num_inputs, num_hidden): + super(InputEmbeding, self).__init__() + + self.embeding = nn.Sequential( + nn.ReLU(), + nn.Linear(num_inputs, num_hidden), + nn.ReLU(), + nn.Linear(num_hidden, num_hidden), + ) + self.skip = LambdaModule( + lambda x: repeat(x, 'b c -> b (n c)', n = num_hidden // num_inputs) + ) + self.alpha = nn.Parameter(th.zeros(1)+1e-12) + + def forward(self, input: th.Tensor): + return self.skip(input) + self.alpha * self.embeding(input) + +class OutputEmbeding(nn.Module): + def __init__(self, num_hidden, num_outputs): + super(OutputEmbeding, self).__init__() + + self.embeding = nn.Sequential( + nn.ReLU(), + nn.Linear(num_hidden, num_outputs), + nn.ReLU(), + nn.Linear(num_outputs, num_outputs), + ) + self.skip = LambdaModule( + lambda x: reduce(x, 'b (n c) -> b c', 'mean', n = num_hidden // num_outputs) + ) + self.alpha = nn.Parameter(th.zeros(1)+1e-12) + + def forward(self, input: th.Tensor): + return self.skip(input) + self.alpha * self.embeding(input)
\ No newline at end of file |