diff options
Diffstat (limited to 'model/nn/eprop_gate_l0rd.py')
-rw-r--r-- | model/nn/eprop_gate_l0rd.py | 329 |
1 files changed, 329 insertions, 0 deletions
diff --git a/model/nn/eprop_gate_l0rd.py b/model/nn/eprop_gate_l0rd.py new file mode 100644 index 0000000..82a9895 --- /dev/null +++ b/model/nn/eprop_gate_l0rd.py @@ -0,0 +1,329 @@ +import torch.nn as nn +import torch as th +import numpy as np +from torch.autograd import Function +from einops import rearrange, repeat, reduce + +__author__ = "Manuel Traub" + +class EpropGateL0rdFunction(Function): + @staticmethod + def forward(ctx, x, h_last, w_gx, w_gh, b_g, w_rx, w_rh, b_r, args): + + e_w_gx, e_w_gh, e_b_g, e_w_rx, e_w_rh, e_b_r, reg, noise_level = args + + noise = th.normal(mean=0, std=noise_level, size=b_g.shape, device=b_g.device) + g = th.relu(th.tanh(x.mm(w_gx.t()) + h_last.mm(w_gh.t()) + b_g + noise)) + r = th.tanh(x.mm(w_rx.t()) + h_last.mm(w_rh.t()) + b_r) + + h = g * r + (1 - g) * h_last + + # Haevisite step function + H_g = th.ceil(g).clamp(0, 1) + + dg = (1 - g**2) * H_g + dr = (1 - r**2) + + delta_h = r - h_last + + g_j = g.unsqueeze(dim=2) + dg_j = dg.unsqueeze(dim=2) + dr_j = dr.unsqueeze(dim=2) + + x_i = x.unsqueeze(dim=1) + h_last_i = h_last.unsqueeze(dim=1) + delta_h_j = delta_h.unsqueeze(dim=2) + + e_w_gh.copy_(e_w_gh * (1 - g_j) + dg_j * h_last_i * delta_h_j) + e_w_gx.copy_(e_w_gx * (1 - g_j) + dg_j * x_i * delta_h_j) + e_b_g.copy_( e_b_g * (1 - g) + dg * delta_h ) + + e_w_rh.copy_(e_w_rh * (1 - g_j) + dr_j * h_last_i * g_j) + e_w_rx.copy_(e_w_rx * (1 - g_j) + dr_j * x_i * g_j) + e_b_r.copy_( e_b_r * (1 - g) + dr * g ) + + ctx.save_for_backward( + g.clone(), dg.clone(), dg_j.clone(), dr.clone(), x_i.clone(), h_last_i.clone(), + reg.clone(), H_g.clone(), delta_h.clone(), w_gx.clone(), w_gh.clone(), w_rx.clone(), w_rh.clone(), + e_w_gx.clone(), e_w_gh.clone(), e_b_g.clone(), + e_w_rx.clone(), e_w_rh.clone(), e_b_r.clone(), + ) + + return h, th.mean(H_g) + + @staticmethod + def backward(ctx, dh, _): + + g, dg, dg_j, dr, x_i, h_last_i, reg, H_g, delta_h, w_gx, w_gh, w_rx, w_rh, \ + e_w_gx, e_w_gh, e_b_g, e_w_rx, e_w_rh, e_b_r = ctx.saved_tensors + + dh_j = dh.unsqueeze(dim=2) + H_g_reg = reg * H_g + H_g_reg_j = H_g_reg.unsqueeze(dim=2) + + dw_gx = th.sum(dh_j * e_w_gx + H_g_reg_j * dg_j * x_i, dim=0) + dw_gh = th.sum(dh_j * e_w_gh + H_g_reg_j * dg_j * h_last_i, dim=0) + db_g = th.sum(dh * e_b_g + H_g_reg * dg, dim=0) + + dw_rx = th.sum(dh_j * e_w_rx, dim=0) + dw_rh = th.sum(dh_j * e_w_rh, dim=0) + db_r = th.sum(dh * e_b_r , dim=0) + + dh_dg = (dh * delta_h + H_g_reg) * dg + dh_dr = dh * g * dr + + dx = dh_dg.mm(w_gx) + dh_dr.mm(w_rx) + dh = dh * (1 - g) + dh_dg.mm(w_gh) + dh_dr.mm(w_rh) + + return dx, dh, dw_gx, dw_gh, db_g, dw_rx, dw_rh, db_r, None + +class ReTanhFunction(Function): + @staticmethod + def forward(ctx, x, reg): + + g = th.relu(th.tanh(x)) + + # Haevisite step function + H_g = th.ceil(g).clamp(0, 1) + + dg = (1 - g**2) * H_g + + ctx.save_for_backward(g, dg, H_g, reg) + return g, th.mean(H_g) + + @staticmethod + def backward(ctx, dh, _): + + g, dg, H_g, reg = ctx.saved_tensors + + dx = (dh + reg * H_g) * dg + + return dx, None + +class ReTanh(nn.Module): + def __init__(self, reg_lambda): + super(ReTanh, self).__init__() + + self.re_tanh = ReTanhFunction().apply + self.register_buffer("reg_lambda", th.tensor(reg_lambda), persistent=False) + + def forward(self, input): + h, openings = self.re_tanh(input, self.reg_lambda) + self.openings = openings.item() + + return h + + +class EpropGateL0rd(nn.Module): + def __init__( + self, + num_inputs, + num_hidden, + num_outputs, + batch_size, + reg_lambda = 0, + gate_noise_level = 0, + ): + super(EpropGateL0rd, self).__init__() + + self.register_buffer("reg", th.tensor(reg_lambda).view(1,1), persistent=False) + self.register_buffer("noise", th.tensor(gate_noise_level), persistent=False) + self.num_inputs = num_inputs + self.num_hidden = num_hidden + self.num_outputs = num_outputs + + self.fcn = EpropGateL0rdFunction().apply + self.retanh = ReTanh(reg_lambda) + + # gate weights and biases + self.w_gx = nn.Parameter(th.empty(num_hidden, num_inputs)) + self.w_gh = nn.Parameter(th.empty(num_hidden, num_hidden)) + self.b_g = nn.Parameter(th.zeros(num_hidden)) + + # candidate weights and biases + self.w_rx = nn.Parameter(th.empty(num_hidden, num_inputs)) + self.w_rh = nn.Parameter(th.empty(num_hidden, num_hidden)) + self.b_r = nn.Parameter(th.zeros(num_hidden)) + + # output projection weights and bias + self.w_px = nn.Parameter(th.empty(num_outputs, num_inputs)) + self.w_ph = nn.Parameter(th.empty(num_outputs, num_hidden)) + self.b_p = nn.Parameter(th.zeros(num_outputs)) + + # output gate weights and bias + self.w_ox = nn.Parameter(th.empty(num_outputs, num_inputs)) + self.w_oh = nn.Parameter(th.empty(num_outputs, num_hidden)) + self.b_o = nn.Parameter(th.zeros(num_outputs)) + + # input gate eligibilitiy traces + self.register_buffer("e_w_gx", th.zeros(batch_size, num_hidden, num_inputs), persistent=False) + self.register_buffer("e_w_gh", th.zeros(batch_size, num_hidden, num_hidden), persistent=False) + self.register_buffer("e_b_g", th.zeros(batch_size, num_hidden), persistent=False) + + # forget gate eligibilitiy traces + self.register_buffer("e_w_rx", th.zeros(batch_size, num_hidden, num_inputs), persistent=False) + self.register_buffer("e_w_rh", th.zeros(batch_size, num_hidden, num_hidden), persistent=False) + self.register_buffer("e_b_r", th.zeros(batch_size, num_hidden), persistent=False) + + # hidden state + self.register_buffer("h_last", th.zeros(batch_size, num_hidden), persistent=False) + + self.register_buffer("openings", th.zeros(1), persistent=False) + + # initialize weights + stdv_ih = np.sqrt(6/(self.num_inputs + self.num_hidden)) + stdv_hh = np.sqrt(3/self.num_hidden) + stdv_io = np.sqrt(6/(self.num_inputs + self.num_outputs)) + stdv_ho = np.sqrt(6/(self.num_hidden + self.num_outputs)) + + nn.init.uniform_(self.w_gx, -stdv_ih, stdv_ih) + nn.init.uniform_(self.w_gh, -stdv_hh, stdv_hh) + + nn.init.uniform_(self.w_rx, -stdv_ih, stdv_ih) + nn.init.uniform_(self.w_rh, -stdv_hh, stdv_hh) + + nn.init.uniform_(self.w_px, -stdv_io, stdv_io) + nn.init.uniform_(self.w_ph, -stdv_ho, stdv_ho) + + nn.init.uniform_(self.w_ox, -stdv_io, stdv_io) + nn.init.uniform_(self.w_oh, -stdv_ho, stdv_ho) + + self.backprop = False + + def reset_state(self): + self.h_last.zero_() + self.e_w_gx.zero_() + self.e_w_gh.zero_() + self.e_b_g.zero_() + self.e_w_rx.zero_() + self.e_w_rh.zero_() + self.e_b_r.zero_() + self.openings.zero_() + + def backprop_forward(self, x: th.Tensor): + + noise = th.normal(mean=0, std=self.noise, size=self.b_g.shape, device=self.b_g.device) + g = self.retanh(x.mm(self.w_gx.t()) + self.h_last.mm(self.w_gh.t()) + self.b_g + noise) + r = th.tanh(x.mm(self.w_rx.t()) + self.h_last.mm(self.w_rh.t()) + self.b_r) + + self.h_last = g * r + (1 - g) * self.h_last + + # Haevisite step function + H_g = th.ceil(g).clamp(0, 1) + + self.openings = th.mean(H_g) + + p = th.tanh(x.mm(self.w_px.t()) + self.h_last.mm(self.w_ph.t()) + self.b_p) + o = th.sigmoid(x.mm(self.w_ox.t()) + self.h_last.mm(self.w_oh.t()) + self.b_o) + return o * p + + def activate_backprop(self): + self.backprop = True + + def deactivate_backprop(self): + self.backprop = False + + def detach(self): + self.h_last.detach_() + + def eprop_forward(self, x: th.Tensor): + h, openings = self.fcn( + x, self.h_last, + self.w_gx, self.w_gh, self.b_g, + self.w_rx, self.w_rh, self.b_r, + ( + self.e_w_gx, self.e_w_gh, self.e_b_g, + self.e_w_rx, self.e_w_rh, self.e_b_r, + self.reg, self.noise + ) + ) + + self.openings = openings + self.h_last = h + + p = th.tanh(x.mm(self.w_px.t()) + h.mm(self.w_ph.t()) + self.b_p) + o = th.sigmoid(x.mm(self.w_ox.t()) + h.mm(self.w_oh.t()) + self.b_o) + return o * p + + def save_hidden(self): + self.h_last_saved = self.h_last.detach() + + def restore_hidden(self): + self.h_last = self.h_last_saved + + def get_hidden(self): + return self.h_last + + def set_hidden(self, h_last): + self.h_last = h_last + + def forward(self, x: th.Tensor): + if self.backprop: + return self.backprop_forward(x) + + return self.eprop_forward(x) + + +class EpropGateL0rdShared(EpropGateL0rd): + def __init__( + self, + num_inputs, + num_hidden, + num_outputs, + batch_size, + reg_lambda = 0, + gate_noise_level = 0, + ): + super().__init__(num_inputs, num_hidden, num_outputs, batch_size, reg_lambda, gate_noise_level) + + def backprop_forward(self, x: th.Tensor, h_last: th.Tensor): + + noise = th.normal(mean=0, std=self.noise, size=self.b_g.shape, device=self.b_g.device) + g = self.retanh(x.mm(self.w_gx.t()) + h_last.mm(self.w_gh.t()) + self.b_g + noise) + r = th.tanh(x.mm(self.w_rx.t()) + h_last.mm(self.w_rh.t()) + self.b_r) + + h_last = g * r + (1 - g) * h_last + + # Haevisite step function + H_g = th.ceil(g).clamp(0, 1) + + self.openings = th.mean(H_g) + + p = th.tanh(x.mm(self.w_px.t()) + h_last.mm(self.w_ph.t()) + self.b_p) + o = th.sigmoid(x.mm(self.w_ox.t()) + h_last.mm(self.w_oh.t()) + self.b_o) + return o * p, h_last + + def eprop_forward(self, x: th.Tensor, h_last: th.Tensor): + h, openings = self.fcn( + x, h_last, + self.w_gx, self.w_gh, self.b_g, + self.w_rx, self.w_rh, self.b_r, + ( + self.e_w_gx, self.e_w_gh, self.e_b_g, + self.e_w_rx, self.e_w_rh, self.e_b_r, + self.reg, self.noise + ) + ) + + self.openings = openings + + p = th.tanh(x.mm(self.w_px.t()) + h.mm(self.w_ph.t()) + self.b_p) + o = th.sigmoid(x.mm(self.w_ox.t()) + h.mm(self.w_oh.t()) + self.b_o) + return o * p, h + + def forward(self, x: th.Tensor, h_last: th.Tensor = None): + + if h_last is not None: + if self.backprop: + return self.backprop_forward(x, h_last) + + return self.eprop_forward(x, h_last) + + # backward compatibility + if self.backprop: + x, h = self.backprop_forward(x, self.h_last) + + x, h = self.eprop_forward(x, self.h_last) + + self.h_last = h + return x
\ No newline at end of file |