diff options
Diffstat (limited to 'model/nn/eprop_gate_l0rd.py')
-rw-r--r-- | model/nn/eprop_gate_l0rd.py | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/model/nn/eprop_gate_l0rd.py b/model/nn/eprop_gate_l0rd.py index 82a9895..c2473d1 100644 --- a/model/nn/eprop_gate_l0rd.py +++ b/model/nn/eprop_gate_l0rd.py @@ -49,7 +49,7 @@ class EpropGateL0rdFunction(Function): e_w_rx.clone(), e_w_rh.clone(), e_b_r.clone(), ) - return h, th.mean(H_g) + return h, H_g @staticmethod def backward(ctx, dh, _): @@ -169,6 +169,7 @@ class EpropGateL0rd(nn.Module): self.register_buffer("h_last", th.zeros(batch_size, num_hidden), persistent=False) self.register_buffer("openings", th.zeros(1), persistent=False) + self.register_buffer("openings_perslot", th.zeros(batch_size), persistent=False) # initialize weights stdv_ih = np.sqrt(6/(self.num_inputs + self.num_hidden)) @@ -294,7 +295,7 @@ class EpropGateL0rdShared(EpropGateL0rd): return o * p, h_last def eprop_forward(self, x: th.Tensor, h_last: th.Tensor): - h, openings = self.fcn( + h, H_g = self.fcn( x, h_last, self.w_gx, self.w_gh, self.b_g, self.w_rx, self.w_rh, self.b_r, @@ -305,7 +306,8 @@ class EpropGateL0rdShared(EpropGateL0rd): ) ) - self.openings = openings + self.openings = th.mean(H_g) + self.openings_perslot = th.mean(H_g, dim=1) p = th.tanh(x.mm(self.w_px.t()) + h.mm(self.w_ph.t()) + self.b_p) o = th.sigmoid(x.mm(self.w_ox.t()) + h.mm(self.w_oh.t()) + self.b_o) |