aboutsummaryrefslogtreecommitdiff
path: root/model/nn/eprop_transformer_shared.py
diff options
context:
space:
mode:
authorfredeee2024-03-23 13:27:00 +0100
committerfredeee2024-03-23 13:27:00 +0100
commit6bcf6b8306ce4903734fb31824799a50281cea69 (patch)
tree0545ff1b8beb051993c2d75fd81306db1a22274d /model/nn/eprop_transformer_shared.py
parentad0b64a7f0140406151d18b19ab2ed5d19b6c511 (diff)
add bouncingball experiment and ablation studies
Diffstat (limited to 'model/nn/eprop_transformer_shared.py')
-rw-r--r--model/nn/eprop_transformer_shared.py7
1 files changed, 4 insertions, 3 deletions
diff --git a/model/nn/eprop_transformer_shared.py b/model/nn/eprop_transformer_shared.py
index 23a1b0f..79223c1 100644
--- a/model/nn/eprop_transformer_shared.py
+++ b/model/nn/eprop_transformer_shared.py
@@ -39,11 +39,12 @@ class EpropGateL0rdTransformerShared(nn.Module):
self.output_embeding = OutputEmbeding(num_hidden, num_outputs)
def get_openings(self):
- openings = 0
+ openings = []
for i in range(self.depth):
- openings += self.l0rds[i].l0rd.openings.item()
+ openings.append(self.l0rds[i].l0rd.openings_perslot)
- return openings / self.depth
+ openings = th.mean(th.stack(openings, dim=0), dim=0)
+ return openings
def get_hidden(self):
return self.hidden