diff options
Diffstat (limited to 'model/nn/eprop_transformer_shared.py')
-rw-r--r-- | model/nn/eprop_transformer_shared.py | 7 |
1 files changed, 4 insertions, 3 deletions
diff --git a/model/nn/eprop_transformer_shared.py b/model/nn/eprop_transformer_shared.py index 23a1b0f..79223c1 100644 --- a/model/nn/eprop_transformer_shared.py +++ b/model/nn/eprop_transformer_shared.py @@ -39,11 +39,12 @@ class EpropGateL0rdTransformerShared(nn.Module): self.output_embeding = OutputEmbeding(num_hidden, num_outputs) def get_openings(self): - openings = 0 + openings = [] for i in range(self.depth): - openings += self.l0rds[i].l0rd.openings.item() + openings.append(self.l0rds[i].l0rd.openings_perslot) - return openings / self.depth + openings = th.mean(th.stack(openings, dim=0), dim=0) + return openings def get_hidden(self): return self.hidden |