import torch.nn as nn import torch as th import numpy as np from torch.autograd import Function from einops import rearrange, repeat, reduce __author__ = "Manuel Traub" class EpropGateL0rdFunction(Function): @staticmethod def forward(ctx, x, h_last, w_gx, w_gh, b_g, w_rx, w_rh, b_r, args): e_w_gx, e_w_gh, e_b_g, e_w_rx, e_w_rh, e_b_r, reg, noise_level = args noise = th.normal(mean=0, std=noise_level, size=b_g.shape, device=b_g.device) g = th.relu(th.tanh(x.mm(w_gx.t()) + h_last.mm(w_gh.t()) + b_g + noise)) r = th.tanh(x.mm(w_rx.t()) + h_last.mm(w_rh.t()) + b_r) h = g * r + (1 - g) * h_last # Haevisite step function H_g = th.ceil(g).clamp(0, 1) dg = (1 - g**2) * H_g dr = (1 - r**2) delta_h = r - h_last g_j = g.unsqueeze(dim=2) dg_j = dg.unsqueeze(dim=2) dr_j = dr.unsqueeze(dim=2) x_i = x.unsqueeze(dim=1) h_last_i = h_last.unsqueeze(dim=1) delta_h_j = delta_h.unsqueeze(dim=2) e_w_gh.copy_(e_w_gh * (1 - g_j) + dg_j * h_last_i * delta_h_j) e_w_gx.copy_(e_w_gx * (1 - g_j) + dg_j * x_i * delta_h_j) e_b_g.copy_( e_b_g * (1 - g) + dg * delta_h ) e_w_rh.copy_(e_w_rh * (1 - g_j) + dr_j * h_last_i * g_j) e_w_rx.copy_(e_w_rx * (1 - g_j) + dr_j * x_i * g_j) e_b_r.copy_( e_b_r * (1 - g) + dr * g ) ctx.save_for_backward( g.clone(), dg.clone(), dg_j.clone(), dr.clone(), x_i.clone(), h_last_i.clone(), reg.clone(), H_g.clone(), delta_h.clone(), w_gx.clone(), w_gh.clone(), w_rx.clone(), w_rh.clone(), e_w_gx.clone(), e_w_gh.clone(), e_b_g.clone(), e_w_rx.clone(), e_w_rh.clone(), e_b_r.clone(), ) return h, H_g @staticmethod def backward(ctx, dh, _): g, dg, dg_j, dr, x_i, h_last_i, reg, H_g, delta_h, w_gx, w_gh, w_rx, w_rh, \ e_w_gx, e_w_gh, e_b_g, e_w_rx, e_w_rh, e_b_r = ctx.saved_tensors dh_j = dh.unsqueeze(dim=2) H_g_reg = reg * H_g H_g_reg_j = H_g_reg.unsqueeze(dim=2) dw_gx = th.sum(dh_j * e_w_gx + H_g_reg_j * dg_j * x_i, dim=0) dw_gh = th.sum(dh_j * e_w_gh + H_g_reg_j * dg_j * h_last_i, dim=0) db_g = th.sum(dh * e_b_g + H_g_reg * dg, dim=0) dw_rx = th.sum(dh_j * e_w_rx, dim=0) dw_rh = th.sum(dh_j * e_w_rh, dim=0) db_r = th.sum(dh * e_b_r , dim=0) dh_dg = (dh * delta_h + H_g_reg) * dg dh_dr = dh * g * dr dx = dh_dg.mm(w_gx) + dh_dr.mm(w_rx) dh = dh * (1 - g) + dh_dg.mm(w_gh) + dh_dr.mm(w_rh) return dx, dh, dw_gx, dw_gh, db_g, dw_rx, dw_rh, db_r, None class ReTanhFunction(Function): @staticmethod def forward(ctx, x, reg): g = th.relu(th.tanh(x)) # Haevisite step function H_g = th.ceil(g).clamp(0, 1) dg = (1 - g**2) * H_g ctx.save_for_backward(g, dg, H_g, reg) return g, th.mean(H_g) @staticmethod def backward(ctx, dh, _): g, dg, H_g, reg = ctx.saved_tensors dx = (dh + reg * H_g) * dg return dx, None class ReTanh(nn.Module): def __init__(self, reg_lambda): super(ReTanh, self).__init__() self.re_tanh = ReTanhFunction().apply self.register_buffer("reg_lambda", th.tensor(reg_lambda), persistent=False) def forward(self, input): h, openings = self.re_tanh(input, self.reg_lambda) self.openings = openings.item() return h class EpropGateL0rd(nn.Module): def __init__( self, num_inputs, num_hidden, num_outputs, batch_size, reg_lambda = 0, gate_noise_level = 0, ): super(EpropGateL0rd, self).__init__() self.register_buffer("reg", th.tensor(reg_lambda).view(1,1), persistent=False) self.register_buffer("noise", th.tensor(gate_noise_level), persistent=False) self.num_inputs = num_inputs self.num_hidden = num_hidden self.num_outputs = num_outputs self.fcn = EpropGateL0rdFunction().apply self.retanh = ReTanh(reg_lambda) # gate weights and biases self.w_gx = nn.Parameter(th.empty(num_hidden, num_inputs)) self.w_gh = nn.Parameter(th.empty(num_hidden, num_hidden)) self.b_g = nn.Parameter(th.zeros(num_hidden)) # candidate weights and biases self.w_rx = nn.Parameter(th.empty(num_hidden, num_inputs)) self.w_rh = nn.Parameter(th.empty(num_hidden, num_hidden)) self.b_r = nn.Parameter(th.zeros(num_hidden)) # output projection weights and bias self.w_px = nn.Parameter(th.empty(num_outputs, num_inputs)) self.w_ph = nn.Parameter(th.empty(num_outputs, num_hidden)) self.b_p = nn.Parameter(th.zeros(num_outputs)) # output gate weights and bias self.w_ox = nn.Parameter(th.empty(num_outputs, num_inputs)) self.w_oh = nn.Parameter(th.empty(num_outputs, num_hidden)) self.b_o = nn.Parameter(th.zeros(num_outputs)) # input gate eligibilitiy traces self.register_buffer("e_w_gx", th.zeros(batch_size, num_hidden, num_inputs), persistent=False) self.register_buffer("e_w_gh", th.zeros(batch_size, num_hidden, num_hidden), persistent=False) self.register_buffer("e_b_g", th.zeros(batch_size, num_hidden), persistent=False) # forget gate eligibilitiy traces self.register_buffer("e_w_rx", th.zeros(batch_size, num_hidden, num_inputs), persistent=False) self.register_buffer("e_w_rh", th.zeros(batch_size, num_hidden, num_hidden), persistent=False) self.register_buffer("e_b_r", th.zeros(batch_size, num_hidden), persistent=False) # hidden state self.register_buffer("h_last", th.zeros(batch_size, num_hidden), persistent=False) self.register_buffer("openings", th.zeros(1), persistent=False) self.register_buffer("openings_perslot", th.zeros(batch_size), persistent=False) # initialize weights stdv_ih = np.sqrt(6/(self.num_inputs + self.num_hidden)) stdv_hh = np.sqrt(3/self.num_hidden) stdv_io = np.sqrt(6/(self.num_inputs + self.num_outputs)) stdv_ho = np.sqrt(6/(self.num_hidden + self.num_outputs)) nn.init.uniform_(self.w_gx, -stdv_ih, stdv_ih) nn.init.uniform_(self.w_gh, -stdv_hh, stdv_hh) nn.init.uniform_(self.w_rx, -stdv_ih, stdv_ih) nn.init.uniform_(self.w_rh, -stdv_hh, stdv_hh) nn.init.uniform_(self.w_px, -stdv_io, stdv_io) nn.init.uniform_(self.w_ph, -stdv_ho, stdv_ho) nn.init.uniform_(self.w_ox, -stdv_io, stdv_io) nn.init.uniform_(self.w_oh, -stdv_ho, stdv_ho) self.backprop = False def reset_state(self): self.h_last.zero_() self.e_w_gx.zero_() self.e_w_gh.zero_() self.e_b_g.zero_() self.e_w_rx.zero_() self.e_w_rh.zero_() self.e_b_r.zero_() self.openings.zero_() def backprop_forward(self, x: th.Tensor): noise = th.normal(mean=0, std=self.noise, size=self.b_g.shape, device=self.b_g.device) g = self.retanh(x.mm(self.w_gx.t()) + self.h_last.mm(self.w_gh.t()) + self.b_g + noise) r = th.tanh(x.mm(self.w_rx.t()) + self.h_last.mm(self.w_rh.t()) + self.b_r) self.h_last = g * r + (1 - g) * self.h_last # Haevisite step function H_g = th.ceil(g).clamp(0, 1) self.openings = th.mean(H_g) p = th.tanh(x.mm(self.w_px.t()) + self.h_last.mm(self.w_ph.t()) + self.b_p) o = th.sigmoid(x.mm(self.w_ox.t()) + self.h_last.mm(self.w_oh.t()) + self.b_o) return o * p def activate_backprop(self): self.backprop = True def deactivate_backprop(self): self.backprop = False def detach(self): self.h_last.detach_() def eprop_forward(self, x: th.Tensor): h, openings = self.fcn( x, self.h_last, self.w_gx, self.w_gh, self.b_g, self.w_rx, self.w_rh, self.b_r, ( self.e_w_gx, self.e_w_gh, self.e_b_g, self.e_w_rx, self.e_w_rh, self.e_b_r, self.reg, self.noise ) ) self.openings = openings self.h_last = h p = th.tanh(x.mm(self.w_px.t()) + h.mm(self.w_ph.t()) + self.b_p) o = th.sigmoid(x.mm(self.w_ox.t()) + h.mm(self.w_oh.t()) + self.b_o) return o * p def save_hidden(self): self.h_last_saved = self.h_last.detach() def restore_hidden(self): self.h_last = self.h_last_saved def get_hidden(self): return self.h_last def set_hidden(self, h_last): self.h_last = h_last def forward(self, x: th.Tensor): if self.backprop: return self.backprop_forward(x) return self.eprop_forward(x) class EpropGateL0rdShared(EpropGateL0rd): def __init__( self, num_inputs, num_hidden, num_outputs, batch_size, reg_lambda = 0, gate_noise_level = 0, ): super().__init__(num_inputs, num_hidden, num_outputs, batch_size, reg_lambda, gate_noise_level) def backprop_forward(self, x: th.Tensor, h_last: th.Tensor): noise = th.normal(mean=0, std=self.noise, size=self.b_g.shape, device=self.b_g.device) g = self.retanh(x.mm(self.w_gx.t()) + h_last.mm(self.w_gh.t()) + self.b_g + noise) r = th.tanh(x.mm(self.w_rx.t()) + h_last.mm(self.w_rh.t()) + self.b_r) h_last = g * r + (1 - g) * h_last # Haevisite step function H_g = th.ceil(g).clamp(0, 1) self.openings = th.mean(H_g) p = th.tanh(x.mm(self.w_px.t()) + h_last.mm(self.w_ph.t()) + self.b_p) o = th.sigmoid(x.mm(self.w_ox.t()) + h_last.mm(self.w_oh.t()) + self.b_o) return o * p, h_last def eprop_forward(self, x: th.Tensor, h_last: th.Tensor): h, H_g = self.fcn( x, h_last, self.w_gx, self.w_gh, self.b_g, self.w_rx, self.w_rh, self.b_r, ( self.e_w_gx, self.e_w_gh, self.e_b_g, self.e_w_rx, self.e_w_rh, self.e_b_r, self.reg, self.noise ) ) self.openings = th.mean(H_g) self.openings_perslot = th.mean(H_g, dim=1) p = th.tanh(x.mm(self.w_px.t()) + h.mm(self.w_ph.t()) + self.b_p) o = th.sigmoid(x.mm(self.w_ox.t()) + h.mm(self.w_oh.t()) + self.b_o) return o * p, h def forward(self, x: th.Tensor, h_last: th.Tensor = None): if h_last is not None: if self.backprop: return self.backprop_forward(x, h_last) return self.eprop_forward(x, h_last) # backward compatibility if self.backprop: x, h = self.backprop_forward(x, self.h_last) x, h = self.eprop_forward(x, self.h_last) self.h_last = h return x