aboutsummaryrefslogtreecommitdiff
path: root/model/nn/eprop_transformer_utils.py
diff options
context:
space:
mode:
authorfredeee2023-11-02 10:47:21 +0100
committerfredeee2023-11-02 10:47:21 +0100
commitf8302ee886ef9b631f11a52900dac964a61350e1 (patch)
tree87288be6f851ab69405e524b81940c501c52789a /model/nn/eprop_transformer_utils.py
parentf16fef1ab9371e1c81a2e0b2fbea59dee285a9f8 (diff)
initiaƶ commit
Diffstat (limited to 'model/nn/eprop_transformer_utils.py')
-rw-r--r--model/nn/eprop_transformer_utils.py66
1 files changed, 66 insertions, 0 deletions
diff --git a/model/nn/eprop_transformer_utils.py b/model/nn/eprop_transformer_utils.py
new file mode 100644
index 0000000..9e5e874
--- /dev/null
+++ b/model/nn/eprop_transformer_utils.py
@@ -0,0 +1,66 @@
+import torch.nn as nn
+import torch as th
+from model.utils.nn_utils import LambdaModule
+from einops import rearrange, repeat, reduce
+
+class AlphaAttention(nn.Module):
+ def __init__(
+ self,
+ num_hidden,
+ num_objects,
+ heads,
+ dropout = 0.0
+ ):
+ super(AlphaAttention, self).__init__()
+
+ self.to_sequence = LambdaModule(lambda x: rearrange(x, '(b o) c -> b o c', o = num_objects))
+ self.to_batch = LambdaModule(lambda x: rearrange(x, 'b o c -> (b o) c', o = num_objects))
+
+ self.alpha = nn.Parameter(th.zeros(1)+1e-12)
+ self.attention = nn.MultiheadAttention(
+ num_hidden,
+ heads,
+ dropout = dropout,
+ batch_first = True
+ )
+
+ def forward(self, x: th.Tensor):
+ x = self.to_sequence(x)
+ x = x + self.alpha * self.attention(x, x, x, need_weights=False)[0]
+ return self.to_batch(x)
+
+class InputEmbeding(nn.Module):
+ def __init__(self, num_inputs, num_hidden):
+ super(InputEmbeding, self).__init__()
+
+ self.embeding = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(num_inputs, num_hidden),
+ nn.ReLU(),
+ nn.Linear(num_hidden, num_hidden),
+ )
+ self.skip = LambdaModule(
+ lambda x: repeat(x, 'b c -> b (n c)', n = num_hidden // num_inputs)
+ )
+ self.alpha = nn.Parameter(th.zeros(1)+1e-12)
+
+ def forward(self, input: th.Tensor):
+ return self.skip(input) + self.alpha * self.embeding(input)
+
+class OutputEmbeding(nn.Module):
+ def __init__(self, num_hidden, num_outputs):
+ super(OutputEmbeding, self).__init__()
+
+ self.embeding = nn.Sequential(
+ nn.ReLU(),
+ nn.Linear(num_hidden, num_outputs),
+ nn.ReLU(),
+ nn.Linear(num_outputs, num_outputs),
+ )
+ self.skip = LambdaModule(
+ lambda x: reduce(x, 'b (n c) -> b c', 'mean', n = num_hidden // num_outputs)
+ )
+ self.alpha = nn.Parameter(th.zeros(1)+1e-12)
+
+ def forward(self, input: th.Tensor):
+ return self.skip(input) + self.alpha * self.embeding(input) \ No newline at end of file