aboutsummaryrefslogtreecommitdiff
path: root/model/nn/eprop_transformer.py
diff options
context:
space:
mode:
Diffstat (limited to 'model/nn/eprop_transformer.py')
-rw-r--r--model/nn/eprop_transformer.py14
1 files changed, 11 insertions, 3 deletions
diff --git a/model/nn/eprop_transformer.py b/model/nn/eprop_transformer.py
index 4d89ec4..d2341dd 100644
--- a/model/nn/eprop_transformer.py
+++ b/model/nn/eprop_transformer.py
@@ -35,13 +35,21 @@ class EpropGateL0rdTransformer(nn.Module):
_layers.append(OutputEmbeding(num_hidden, num_outputs))
self.layers = nn.Sequential(*_layers)
+ self.attention = []
+ self.l0rds = []
+ for l in self.layers:
+ if 'AlphaAttention' in type(l).__name__:
+ self.attention.append(l)
+ elif 'EpropAlphaGateL0rd' in type(l).__name__:
+ self.l0rds.append(l)
def get_openings(self):
- openings = 0
+ openings = []
for i in range(self.depth):
- openings += self.layers[2 * (i + 1)].l0rd.openings.item()
+ openings.append(self.l0rds[i].l0rd.openings_perslot)
- return openings / self.depth
+ openings = th.mean(th.stack(openings, dim=0), dim=0)
+ return openings
def get_hidden(self):
states = []