aboutsummaryrefslogtreecommitdiff
path: root/model/nn/eprop_gate_l0rd.py
diff options
context:
space:
mode:
Diffstat (limited to 'model/nn/eprop_gate_l0rd.py')
-rw-r--r--model/nn/eprop_gate_l0rd.py8
1 files changed, 5 insertions, 3 deletions
diff --git a/model/nn/eprop_gate_l0rd.py b/model/nn/eprop_gate_l0rd.py
index 82a9895..c2473d1 100644
--- a/model/nn/eprop_gate_l0rd.py
+++ b/model/nn/eprop_gate_l0rd.py
@@ -49,7 +49,7 @@ class EpropGateL0rdFunction(Function):
e_w_rx.clone(), e_w_rh.clone(), e_b_r.clone(),
)
- return h, th.mean(H_g)
+ return h, H_g
@staticmethod
def backward(ctx, dh, _):
@@ -169,6 +169,7 @@ class EpropGateL0rd(nn.Module):
self.register_buffer("h_last", th.zeros(batch_size, num_hidden), persistent=False)
self.register_buffer("openings", th.zeros(1), persistent=False)
+ self.register_buffer("openings_perslot", th.zeros(batch_size), persistent=False)
# initialize weights
stdv_ih = np.sqrt(6/(self.num_inputs + self.num_hidden))
@@ -294,7 +295,7 @@ class EpropGateL0rdShared(EpropGateL0rd):
return o * p, h_last
def eprop_forward(self, x: th.Tensor, h_last: th.Tensor):
- h, openings = self.fcn(
+ h, H_g = self.fcn(
x, h_last,
self.w_gx, self.w_gh, self.b_g,
self.w_rx, self.w_rh, self.b_r,
@@ -305,7 +306,8 @@ class EpropGateL0rdShared(EpropGateL0rd):
)
)
- self.openings = openings
+ self.openings = th.mean(H_g)
+ self.openings_perslot = th.mean(H_g, dim=1)
p = th.tanh(x.mm(self.w_px.t()) + h.mm(self.w_ph.t()) + self.b_p)
o = th.sigmoid(x.mm(self.w_ox.t()) + h.mm(self.w_oh.t()) + self.b_o)