diff options
Diffstat (limited to 'model/nn/percept_gate_controller.py')
-rw-r--r-- | model/nn/percept_gate_controller.py | 59 |
1 files changed, 59 insertions, 0 deletions
diff --git a/model/nn/percept_gate_controller.py b/model/nn/percept_gate_controller.py new file mode 100644 index 0000000..a548c20 --- /dev/null +++ b/model/nn/percept_gate_controller.py @@ -0,0 +1,59 @@ +import torch.nn as nn +import torch as th +from model.utils.nn_utils import LambdaModule +from einops import rearrange, repeat, reduce +from model.nn.eprop_gate_l0rd import ReTanh + +class PerceptGateController(nn.Module): + def __init__( + self, + num_inputs: int, + num_hidden: list, + bias: bool, + num_objects: int, + gate_noise_level: float = 0.1, + reg_lambda: float = 0.000005 + ): + super(PerceptGateController, self).__init__() + + self.to_batch = LambdaModule(lambda x: rearrange(x, 'b (o c) -> (b o) c', o=num_objects)) + self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b o c', o=num_objects)) + + self.layers = nn.Sequential( + nn.Linear(num_inputs, num_hidden[0], bias = bias), + nn.Tanh(), + nn.Linear(num_hidden[0], num_hidden[1], bias = bias), + nn.Tanh(), + nn.Linear(num_hidden[1], 2, bias = bias) + ) + self.output_function = ReTanh(reg_lambda) + self.register_buffer("noise", th.tensor(gate_noise_level), persistent=False) + self.init_weights() + + def init_weights(self): + for layer in self.layers: + if isinstance(layer, nn.Linear): + nn.init.xavier_uniform(layer.weight) + layer.bias.data.fill_(3.00) + + def forward(self, position_cur, gestalt_cur, priority_cur, slots_occlusionfactor_cur, position_last, gestalt_last, priority_last, slots_occlusionfactor_last, position_last2, evaluate=False): + + position_cur = self.to_batch(position_cur) + gestalt_cur = self.to_batch(gestalt_cur) + priority_cur = self.to_batch(priority_cur) + position_last = self.to_batch(position_last) + gestalt_last = self.to_batch(gestalt_last) + priority_last = self.to_batch(priority_last) + slots_occlusionfactor_cur = self.to_batch(slots_occlusionfactor_cur).detach() + slots_occlusionfactor_last = self.to_batch(slots_occlusionfactor_last).detach() + position_last2 = self.to_batch(position_last2).detach() + + input = th.cat((position_cur, gestalt_cur, priority_cur, slots_occlusionfactor_cur, position_last, gestalt_last, priority_last, slots_occlusionfactor_last, position_last2), dim=1) + output = self.layers(input) + if evaluate: + output = self.output_function(output) + else: + noise = th.normal(mean=0, std=self.noise, size=output.shape, device=output.device) + output = self.output_function(output + noise) + + return self.to_shared(output) |