diff options
author | fredeee | 2024-03-23 13:27:00 +0100 |
---|---|---|
committer | fredeee | 2024-03-23 13:27:00 +0100 |
commit | 6bcf6b8306ce4903734fb31824799a50281cea69 (patch) | |
tree | 0545ff1b8beb051993c2d75fd81306db1a22274d /model/nn/eprop_transformer_utils.py | |
parent | ad0b64a7f0140406151d18b19ab2ed5d19b6c511 (diff) |
add bouncingball experiment and ablation studies
Diffstat (limited to 'model/nn/eprop_transformer_utils.py')
-rw-r--r-- | model/nn/eprop_transformer_utils.py | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/model/nn/eprop_transformer_utils.py b/model/nn/eprop_transformer_utils.py index 9e5e874..3219cd0 100644 --- a/model/nn/eprop_transformer_utils.py +++ b/model/nn/eprop_transformer_utils.py @@ -9,7 +9,8 @@ class AlphaAttention(nn.Module): num_hidden, num_objects, heads, - dropout = 0.0 + dropout = 0.0, + need_weights = False ): super(AlphaAttention, self).__init__() @@ -23,10 +24,13 @@ class AlphaAttention(nn.Module): dropout = dropout, batch_first = True ) + self.need_weights = need_weights + self.att_weights = None def forward(self, x: th.Tensor): x = self.to_sequence(x) - x = x + self.alpha * self.attention(x, x, x, need_weights=False)[0] + att, self.att_weights = self.attention(x, x, x, need_weights=self.need_weights) + x = x + self.alpha * att return self.to_batch(x) class InputEmbeding(nn.Module): |