diff options
author | fredeee | 2024-03-23 13:27:00 +0100 |
---|---|---|
committer | fredeee | 2024-03-23 13:27:00 +0100 |
commit | 6bcf6b8306ce4903734fb31824799a50281cea69 (patch) | |
tree | 0545ff1b8beb051993c2d75fd81306db1a22274d /model/nn/eprop_transformer.py | |
parent | ad0b64a7f0140406151d18b19ab2ed5d19b6c511 (diff) |
add bouncingball experiment and ablation studies
Diffstat (limited to 'model/nn/eprop_transformer.py')
-rw-r--r-- | model/nn/eprop_transformer.py | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/model/nn/eprop_transformer.py b/model/nn/eprop_transformer.py index 4d89ec4..d2341dd 100644 --- a/model/nn/eprop_transformer.py +++ b/model/nn/eprop_transformer.py @@ -35,13 +35,21 @@ class EpropGateL0rdTransformer(nn.Module): _layers.append(OutputEmbeding(num_hidden, num_outputs)) self.layers = nn.Sequential(*_layers) + self.attention = [] + self.l0rds = [] + for l in self.layers: + if 'AlphaAttention' in type(l).__name__: + self.attention.append(l) + elif 'EpropAlphaGateL0rd' in type(l).__name__: + self.l0rds.append(l) def get_openings(self): - openings = 0 + openings = [] for i in range(self.depth): - openings += self.layers[2 * (i + 1)].l0rd.openings.item() + openings.append(self.l0rds[i].l0rd.openings_perslot) - return openings / self.depth + openings = th.mean(th.stack(openings, dim=0), dim=0) + return openings def get_hidden(self): states = [] |