aboutsummaryrefslogtreecommitdiff
path: root/model/nn/eprop_transformer_utils.py
diff options
context:
space:
mode:
authorfredeee2024-03-23 13:27:00 +0100
committerfredeee2024-03-23 13:27:00 +0100
commit6bcf6b8306ce4903734fb31824799a50281cea69 (patch)
tree0545ff1b8beb051993c2d75fd81306db1a22274d /model/nn/eprop_transformer_utils.py
parentad0b64a7f0140406151d18b19ab2ed5d19b6c511 (diff)
add bouncingball experiment and ablation studies
Diffstat (limited to 'model/nn/eprop_transformer_utils.py')
-rw-r--r--model/nn/eprop_transformer_utils.py8
1 files changed, 6 insertions, 2 deletions
diff --git a/model/nn/eprop_transformer_utils.py b/model/nn/eprop_transformer_utils.py
index 9e5e874..3219cd0 100644
--- a/model/nn/eprop_transformer_utils.py
+++ b/model/nn/eprop_transformer_utils.py
@@ -9,7 +9,8 @@ class AlphaAttention(nn.Module):
num_hidden,
num_objects,
heads,
- dropout = 0.0
+ dropout = 0.0,
+ need_weights = False
):
super(AlphaAttention, self).__init__()
@@ -23,10 +24,13 @@ class AlphaAttention(nn.Module):
dropout = dropout,
batch_first = True
)
+ self.need_weights = need_weights
+ self.att_weights = None
def forward(self, x: th.Tensor):
x = self.to_sequence(x)
- x = x + self.alpha * self.attention(x, x, x, need_weights=False)[0]
+ att, self.att_weights = self.attention(x, x, x, need_weights=self.need_weights)
+ x = x + self.alpha * att
return self.to_batch(x)
class InputEmbeding(nn.Module):