aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorfredeee2024-03-23 13:27:00 +0100
committerfredeee2024-03-23 13:27:00 +0100
commit6bcf6b8306ce4903734fb31824799a50281cea69 (patch)
tree0545ff1b8beb051993c2d75fd81306db1a22274d
parentad0b64a7f0140406151d18b19ab2ed5d19b6c511 (diff)
add bouncingball experiment and ablation studies
-rw-r--r--.gitignore2
-rw-r--r--cfg/bouncingballs/bb-level1-run1.json87
-rw-r--r--cfg/bouncingballs/bb-level1-run2.json87
-rw-r--r--cfg/bouncingballs/bb-level1-run3.json87
-rw-r--r--data/datasets/BOUNCINGBALLS/dataset.py195
-rw-r--r--evaluation/adept/evaluation_gswm.ipynb428
-rw-r--r--evaluation/adept/evaluation_loci_looped.ipynb607
-rw-r--r--evaluation/adept/evaluation_loci_looped_visibility.ipynb475
-rw-r--r--evaluation/adept/plots/loci_looped_complete.pdfbin0 -> 45683 bytes
-rw-r--r--evaluation/adept/plots/loci_looped_gates.pdfbin0 -> 26325 bytes
-rw-r--r--evaluation/adept/plots/loci_looped_gates2.pdfbin0 -> 26037 bytes
-rw-r--r--evaluation/adept/plots/loci_looped_surprise.pdfbin0 -> 12746 bytes
-rw-r--r--evaluation/adept/plots/loci_looped_voe.pdfbin0 -> 29604 bytes
-rw-r--r--evaluation/adept/plots/loci_looped_voe2.pdfbin0 -> 26244 bytes
-rw-r--r--evaluation/adept/plots/loci_looped_voe3.pdfbin0 -> 25852 bytes
-rw-r--r--evaluation/adept_ablation/lambda/evaluation_loci_looped_abl1.ipynb479
-rw-r--r--evaluation/adept_ablation/lambda/evaluation_loci_looped_abl2.ipynb469
-rw-r--r--evaluation/adept_ablation/lambda/evaluation_loci_looped_abl3.ipynb469
-rw-r--r--evaluation/adept_ablation/recon/evaluation_loci_looped_norecon.ipynb464
-rw-r--r--evaluation/adept_ablation/recon/evaluation_loci_looped_recon.ipynb466
-rw-r--r--model/loci.py31
-rw-r--r--model/nn/background.py149
-rw-r--r--model/nn/eprop_gate_l0rd.py8
-rw-r--r--model/nn/eprop_transformer.py14
-rw-r--r--model/nn/eprop_transformer_shared.py7
-rw-r--r--model/nn/eprop_transformer_utils.py8
-rw-r--r--model/nn/predictor.py22
-rw-r--r--scripts/evaluation_adept.py11
-rw-r--r--scripts/evaluation_adept_baselines.py (renamed from scripts/evaluation_adept_savi.py)143
-rw-r--r--scripts/evaluation_bb.py385
-rw-r--r--scripts/evaluation_clevrer.py45
-rw-r--r--scripts/exec/eval.py32
-rw-r--r--scripts/exec/eval_baselines.py (renamed from scripts/exec/eval_savi.py)5
-rw-r--r--scripts/exec/train.py15
-rw-r--r--scripts/training.py324
-rw-r--r--scripts/utils/eval_adept.py169
-rw-r--r--scripts/utils/eval_metrics.py51
-rw-r--r--scripts/utils/eval_utils.py92
-rw-r--r--scripts/utils/io.py72
-rw-r--r--scripts/utils/plot_utils.py162
-rw-r--r--scripts/validation.py270
41 files changed, 5859 insertions, 471 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..15495d3
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,2 @@
+*.DS_Store
+*.pyc \ No newline at end of file
diff --git a/cfg/bouncingballs/bb-level1-run1.json b/cfg/bouncingballs/bb-level1-run1.json
new file mode 100644
index 0000000..1e6f21b
--- /dev/null
+++ b/cfg/bouncingballs/bb-level1-run1.json
@@ -0,0 +1,87 @@
+{
+ "model_path": "bb_baseline",
+ "datatype": "bouncingballs",
+ "dataset": "BOUNCINGBALLS_OCCLUSION",
+ "scenario": "occlusion",
+ "num_updates": 0,
+ "max_epochs": 1000,
+ "max_updates": 400000,
+ "learning_rate": {
+ "lr": 0.0004,
+ "deacrease_lr_every": 100000,
+ "deacrease_lr_factor": 0.75
+ },
+ "blackout": {
+ "blackout_start_timestep": 10,
+ "blackout_rate": 0.1,
+ "blackout_rate_max": 0.45,
+ "blackout_increase_every": 10000,
+ "blackout_increase_rate": 0.01
+ },
+ "phases": {
+ "start_inner_loop": 40000,
+ "shufleslots_end": 3000000,
+ "entity_pretraining_phase2_end": 30000,
+ "entity_pretraining_phase1_end": 15000,
+ "background_pretraining_end": 0
+ },
+ "defaults": {
+ "num_workers": 2,
+ "prefetch_factor": 2,
+ "statistics_offset": 10,
+ "load_optimizers": false,
+ "teacher_forcing": 5,
+ "skip_frames": 1,
+ "error_dropout": 0.1,
+ "seed": 67
+ },
+ "bptt": {
+ "bptt_start_timestep": 0,
+ "bptt_steps": 1,
+ "bptt_steps_max": 3,
+ "increase_bptt_steps_every": 100000
+ },
+ "model": {
+ "level": 2,
+ "batch_size": 128,
+ "num_objects": 3,
+ "img_channels": 3,
+ "input_size": [
+ 64,
+ 64
+ ],
+ "latent_size": [
+ 4,
+ 4
+ ],
+ "gestalt_size": 12,
+ "bottleneck": "binar",
+ "position_regularizer": 0.01,
+ "time_regularizer": 0.25,
+ "encoder_regularizer": 0.333333,
+ "latent_regularizer": 0.0,
+ "inner_loop_enabled": false,
+ "latent_loss_enabled": false,
+ "encoder": {
+ "channels": 24,
+ "level1_channels": 12,
+ "num_layers": 2,
+ "reg_lambda": 1e-10
+ },
+ "predictor": {
+ "heads": 2,
+ "layers": 3,
+ "channels_multiplier": 2,
+ "reg_lambda": 1e-10,
+ "transformer_type": "shared"
+ },
+ "decoder": {
+ "channels": 24,
+ "level1_channels": 12,
+ "num_layers": 2
+ },
+ "update_module": {
+ "reg_lambda": 1e-05
+ }
+ }
+ } \ No newline at end of file
diff --git a/cfg/bouncingballs/bb-level1-run2.json b/cfg/bouncingballs/bb-level1-run2.json
new file mode 100644
index 0000000..a257408
--- /dev/null
+++ b/cfg/bouncingballs/bb-level1-run2.json
@@ -0,0 +1,87 @@
+{
+ "model_path": "bb_baseline",
+ "datatype": "bouncingballs",
+ "dataset": "BOUNCINGBALLS_OCCLUSION",
+ "scenario": "occlusion",
+ "num_updates": 400000,
+ "max_epochs": 1000,
+ "max_updates": 800000,
+ "learning_rate": {
+ "lr": 0.0001,
+ "deacrease_lr_every": 100000,
+ "deacrease_lr_factor": 0.75
+ },
+ "blackout": {
+ "blackout_start_timestep": 10,
+ "blackout_rate": 0.45,
+ "blackout_rate_max": 0.45,
+ "blackout_increase_every": 10000,
+ "blackout_increase_rate": 0.01
+ },
+ "phases": {
+ "start_inner_loop": 40000,
+ "shufleslots_end": 3000000,
+ "entity_pretraining_phase2_end": 30000,
+ "entity_pretraining_phase1_end": 15000,
+ "background_pretraining_end": 0
+ },
+ "defaults": {
+ "num_workers": 2,
+ "prefetch_factor": 2,
+ "statistics_offset": 10,
+ "load_optimizers": false,
+ "teacher_forcing": 5,
+ "skip_frames": 1,
+ "error_dropout": 0.1,
+ "seed": 67
+ },
+ "bptt": {
+ "bptt_start_timestep": 0,
+ "bptt_steps": 3,
+ "bptt_steps_max": 3,
+ "increase_bptt_steps_every": 100000
+ },
+ "model": {
+ "level": 2,
+ "batch_size": 128,
+ "num_objects": 3,
+ "img_channels": 3,
+ "input_size": [
+ 64,
+ 64
+ ],
+ "latent_size": [
+ 4,
+ 4
+ ],
+ "gestalt_size": 12,
+ "bottleneck": "binar",
+ "position_regularizer": 0.01,
+ "time_regularizer": 0.25,
+ "encoder_regularizer": 0.333333,
+ "latent_regularizer": 0.0,
+ "inner_loop_enabled": false,
+ "latent_loss_enabled": false,
+ "encoder": {
+ "channels": 24,
+ "level1_channels": 12,
+ "num_layers": 2,
+ "reg_lambda": 1e-10
+ },
+ "predictor": {
+ "heads": 2,
+ "layers": 3,
+ "channels_multiplier": 2,
+ "reg_lambda": 1e-10,
+ "transformer_type": "shared"
+ },
+ "decoder": {
+ "channels": 24,
+ "level1_channels": 12,
+ "num_layers": 2
+ },
+ "update_module": {
+ "reg_lambda": 1e-05
+ }
+ }
+ } \ No newline at end of file
diff --git a/cfg/bouncingballs/bb-level1-run3.json b/cfg/bouncingballs/bb-level1-run3.json
new file mode 100644
index 0000000..cb4be46
--- /dev/null
+++ b/cfg/bouncingballs/bb-level1-run3.json
@@ -0,0 +1,87 @@
+{
+ "model_path": "bb_baseline",
+ "datatype": "bouncingballs",
+ "dataset": "BOUNCINGBALLS_OCCLUSION",
+ "scenario": "occlusion",
+ "num_updates": 800000,
+ "max_epochs": 1000,
+ "max_updates": 1200000,
+ "learning_rate": {
+ "lr": 5e-05,
+ "deacrease_lr_every": 100000,
+ "deacrease_lr_factor": 0.75
+ },
+ "blackout": {
+ "blackout_start_timestep": 10,
+ "blackout_rate": 0.45,
+ "blackout_rate_max": 0.45,
+ "blackout_increase_every": 10000,
+ "blackout_increase_rate": 0.01
+ },
+ "phases": {
+ "start_inner_loop": 40000,
+ "shufleslots_end": 3000000,
+ "entity_pretraining_phase2_end": 30000,
+ "entity_pretraining_phase1_end": 15000,
+ "background_pretraining_end": 0
+ },
+ "defaults": {
+ "num_workers": 2,
+ "prefetch_factor": 2,
+ "statistics_offset": 10,
+ "load_optimizers": false,
+ "teacher_forcing": 5,
+ "skip_frames": 1,
+ "error_dropout": 0.1,
+ "seed": 5
+ },
+ "bptt": {
+ "bptt_start_timestep": 0,
+ "bptt_steps": 3,
+ "bptt_steps_max": 3,
+ "increase_bptt_steps_every": 100000
+ },
+ "model": {
+ "level": 2,
+ "batch_size": 128,
+ "num_objects": 3,
+ "img_channels": 3,
+ "input_size": [
+ 64,
+ 64
+ ],
+ "latent_size": [
+ 4,
+ 4
+ ],
+ "gestalt_size": 12,
+ "bottleneck": "binar",
+ "position_regularizer": 0.01,
+ "time_regularizer": 0.25,
+ "encoder_regularizer": 0.333333,
+ "latent_regularizer": 0.0,
+ "inner_loop_enabled": false,
+ "latent_loss_enabled": false,
+ "encoder": {
+ "channels": 24,
+ "level1_channels": 12,
+ "num_layers": 2,
+ "reg_lambda": 1e-10
+ },
+ "predictor": {
+ "heads": 2,
+ "layers": 3,
+ "channels_multiplier": 2,
+ "reg_lambda": 1e-10,
+ "transformer_type": "shared"
+ },
+ "decoder": {
+ "channels": 24,
+ "level1_channels": 12,
+ "num_layers": 2
+ },
+ "update_module": {
+ "reg_lambda": 1e-05
+ }
+ }
+ } \ No newline at end of file
diff --git a/data/datasets/BOUNCINGBALLS/dataset.py b/data/datasets/BOUNCINGBALLS/dataset.py
new file mode 100644
index 0000000..3460d02
--- /dev/null
+++ b/data/datasets/BOUNCINGBALLS/dataset.py
@@ -0,0 +1,195 @@
+from pickletools import int4
+from torch.utils import data
+from typing import Tuple, Union, List
+import numpy as np
+import json
+import math
+import cv2
+import h5py
+import os
+import pickle
+import sys
+import yaml
+import warnings
+from PIL import Image
+from einops import reduce, rearrange, repeat
+import torch as th
+
+
+class BouncingBallDataset(data.Dataset):
+
+ def __init__(self, root_path: str, dataset_name: str, type: str, size: Tuple[int, int], type_name: str = None, full_size: Tuple[int, int] = None, create_dataset: bool = False):
+
+ assert type in ["train", "test", "val"]
+ assert type_name in ["interaction", "occlusion", "twolayer", "twolayerdense", "twolayer_ood", "threelayer_ood", "twolayer_ood_3balls"]
+
+ data_path = f'data/data/video/{dataset_name}'
+ data_path = os.path.join(root_path, data_path)
+ self.file = os.path.join(data_path, f'balls_{type_name}-{type}-{size[0]}x{size[1]}-v1.hdf5')
+ self.train = (type == "train")
+ self.samples = []
+
+ if os.path.exists(self.file):
+ self.hdf5_file = h5py.File(self.file, "r")
+
+ # load dataset
+ self.length = self.hdf5_file['sequence_indices'].shape[0]
+ self.background = np.zeros((3, size[0], size[1]), dtype=np.uint8)
+
+ # set number of objects
+ if (type_name == "twolayer") or (type_name == "threelayer_ood" and type != "test"):
+ self.num_objects = 6
+ elif (type_name == "twolayer_ood" and type == "test") or (type_name == "twolayer_ood_3balls" and type == "test"):
+ self.num_objects = 4
+ elif type_name == "twolayer_ood":
+ self.num_objects = 2
+ elif type_name == "twolayer_ood_3balls":
+ self.num_objects = 3
+ elif (type_name == "threelayer_ood" and type == "test"):
+ self.num_objects = 9
+ else:
+ self.num_objects = 3
+
+ if len(self) == 0:
+ raise FileNotFoundError(f'Found no dataset at {data_path}')
+
+ # loop trough own dataset by calling __getitem__
+ if False:
+ for i in range(len(self)):
+ self[i]
+
+ def add_one_timestep(self, x):
+ return np.concatenate((x, np.zeros_like(x[:1])), axis=0)
+
+ def __len__(self):
+ return self.length
+
+ def __getitem__(self, index: int):
+
+ index_start, length = self.hdf5_file['sequence_indices'][index]
+ rgb_images = self.hdf5_file["rgb_images"][index_start:index_start+length]
+
+ if rgb_images[0].dtype == np.uint8:
+ images = []
+ for i in range(len(rgb_images)):
+ img = cv2.imdecode(rgb_images[i], 1)
+ images.append(img.transpose(2, 0, 1).astype(np.float32) / 255.0)
+
+ rgb_images = np.stack(images)
+
+ rgb_images = th.from_numpy(rgb_images)
+
+ if self.train:
+ return (
+ rgb_images,
+ self.background
+ )
+
+ # EVALUATION
+ num_objects = self.num_objects
+ instance_positions = self.hdf5_file['instance_positions'][index_start*num_objects:(index_start+length)*num_objects]
+ instance_positions = rearrange(instance_positions, '(t o) c -> t o c', o=num_objects)
+ instance_positions = instance_positions[:, :, ::-1] # IMPORTANT: flip x and y axis
+
+ instance_pres = self.hdf5_file['instance_incamera'][index_start*num_objects:(index_start+length)*num_objects]
+ instance_pres = rearrange(instance_pres, '(t o) c -> t o c', o=num_objects).squeeze(-1)
+
+ instance_bounding_boxes = self.hdf5_file['instance_mask_bboxes'][index_start*num_objects:(index_start+length)*num_objects]
+ instance_bounding_boxes = rearrange(instance_bounding_boxes, '(t o) c -> t o c', o=num_objects)
+ instance_bounding_boxes = instance_bounding_boxes[:, :, [1, 0, 3, 2]]
+
+ foreground_mask = self.hdf5_file['foreground_mask'][index_start:(index_start+length)]
+ foreground_mask = rearrange(foreground_mask, 't 1 h w -> t h w')/255
+
+ instance_masks = self.hdf5_file['instance_masks'][index_start*num_objects:(index_start+length)*num_objects]
+ instance_masks = rearrange(instance_masks, '(t o) 1 h w -> t o 1 h w', o=num_objects).squeeze()/255
+
+ # CUSTOM
+ # use instance masks to to create hidden masks
+ hidden_mask = reduce(instance_masks, 't o h w -> t 1 h w', 'sum').squeeze()
+ hidden_mask = (hidden_mask > 1).astype(np.uint8)
+
+ # segmentation_masks: index gives which object is visible at that pixel
+ segmentation_mask = np.argmax(instance_masks[:, ::-1], axis=1) + 1
+ segmentation_mask = foreground_mask * segmentation_mask
+
+ # segmentation mask but only for hidden objects
+ segementation_mask_hidden = np.argmax(instance_masks[:, :3], axis=1) + 1 # TODO only works for 6 objects
+ segementation_mask_hidden = hidden_mask * segementation_mask_hidden
+
+ # add one dummy timestep at the end
+ instance_positions = self.add_one_timestep(instance_positions)
+ rgb_images = self.add_one_timestep(rgb_images)
+ foreground_mask = self.add_one_timestep(foreground_mask)
+ hidden_mask = self.add_one_timestep(hidden_mask)
+ instance_pres = self.add_one_timestep(instance_pres)
+ instance_bounding_boxes = self.add_one_timestep(instance_bounding_boxes)
+ instance_masks = self.add_one_timestep(instance_masks)
+ segmentation_mask = self.add_one_timestep(segmentation_mask)
+ segementation_mask_hidden = self.add_one_timestep(segementation_mask_hidden)
+
+ if False:
+ video = np.array(rgb_images)
+ locations = np.array(instance_positions)
+ fg_masks = np.array(foreground_mask)
+ bb = np.array(instance_bounding_boxes)
+ h_masks = np.array(hidden_mask)
+
+ # loop through video frames and show them using cv2
+ for t in range(video.shape[0]):
+ frame = rearrange(video[t], 'c h w -> h w c') * 255
+
+ for loc in locations[t]:
+ x, y = loc
+ #cv2.circle(frame, (int(x), int(y)), 2, (255, 0, 0), -1) # did not work properly
+ x_max = int(min(int(x + 2), frame.shape[1]))
+ x_min = int(max(int(x - 2), 0))
+ y_max = int(min(int(y + 2), frame.shape[0]))
+ y_min = int(max(int(y - 2), 0))
+ frame[x_min:x_max, y_min:y_max] = [255, 0, 0]
+
+ # draw the bounding boxes into the frame
+ for i, b in enumerate(bb[t]):
+ x_min, y_min, x_max, y_max = b
+
+ x_min = int(max(x_min, 0))
+ y_min = int(max(y_min, 0))
+ x_max = int(min(x_max, frame.shape[0]-1))
+ y_max = int(min(y_max, frame.shape[0]-1))
+
+ # dont't use c2 rectangeel function here but draw it manually
+ for pixel in range(x_min, x_max):
+ frame[pixel, y_min, 0] = 255
+ frame[pixel, y_max, 0] = 255
+ for pixel in range(y_min, y_max):
+ frame[x_min, pixel, 0] = 255
+ frame[x_max, pixel, 0] = 255
+
+ fg_mask = repeat(fg_masks[t], 'h w -> h w 3') * 255
+ h_mask = repeat(h_masks[t], 'h w -> h w 3') * 255
+ s_mask = repeat(segmentation_mask[t], 'h w -> h w 3') * (255/6)
+ s_mask_hidden = repeat(segementation_mask_hidden[t], 'h w -> h w 3') * (255/3)
+ frame = np.concatenate((frame, fg_mask, s_mask, h_mask, s_mask_hidden), axis=1)
+
+ # add instance masks to visualisation
+ for i, mask in enumerate(instance_masks[t]):
+ mask = repeat(mask, 'h w -> h w 3') * 255
+ # add border to the right side
+ mask[:, -1, 0] = 255
+ frame = np.concatenate((frame, mask), axis=1)
+
+ frame = frame.astype(np.uint8)
+ cv2.imshow('frame', frame)
+ cv2.waitKey(0)
+
+ return (
+ rgb_images,
+ self.background,
+ instance_positions,
+ segmentation_mask,
+ instance_pres,
+ segementation_mask_hidden
+ )
+
+
+#a = BouncingBallDataset("./", 'BOUNCINGBALLS', "train", (64,64)) \ No newline at end of file
diff --git a/evaluation/adept/evaluation_gswm.ipynb b/evaluation/adept/evaluation_gswm.ipynb
new file mode 100644
index 0000000..31abb91
--- /dev/null
+++ b/evaluation/adept/evaluation_gswm.ipynb
@@ -0,0 +1,428 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import seaborn as sns\n",
+ "import warnings\n",
+ "import scipy.stats as stats\n",
+ "import os\n",
+ "\n",
+ "warnings.simplefilter(action='ignore', category=FutureWarning)\n",
+ "pd.options.mode.chained_assignment = None \n",
+ "plt.style.use('ggplot')\n",
+ "sns.color_palette(\"Paired\");\n",
+ "sns.set_theme();"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Data Loading"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "FileNotFoundError",
+ "evalue": "[Errno 2] No such file or directory: '../../out/pretrained/adept/gswm/results/'",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
+ "Cell \u001b[0;32mIn[2], line 5\u001b[0m\n\u001b[1;32m 2\u001b[0m root_path \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124m../../out/pretrained/adept/gswm/results/\u001b[39m\u001b[38;5;124m'\u001b[39m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# list all folders in root path that don't stat with a dot\u001b[39;00m\n\u001b[0;32m----> 5\u001b[0m nets \u001b[38;5;241m=\u001b[39m [f \u001b[38;5;28;01mfor\u001b[39;00m f \u001b[38;5;129;01min\u001b[39;00m \u001b[43mos\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mlistdir\u001b[49m\u001b[43m(\u001b[49m\u001b[43mroot_path\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m f\u001b[38;5;241m.\u001b[39mstartswith(\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m.\u001b[39m\u001b[38;5;124m'\u001b[39m)]\n\u001b[1;32m 7\u001b[0m \u001b[38;5;66;03m# read pickle file\u001b[39;00m\n\u001b[1;32m 8\u001b[0m sf \u001b[38;5;241m=\u001b[39m pd\u001b[38;5;241m.\u001b[39mDataFrame()\n",
+ "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '../../out/pretrained/adept/gswm/results/'"
+ ]
+ }
+ ],
+ "source": [
+ "# setting path to results folder\n",
+ "root_path = '../../out/pretrained/adept/gswm/results/'\n",
+ "\n",
+ "# list all folders in root path that don't stat with a dot\n",
+ "nets = [f for f in os.listdir(root_path) if not f.startswith('.')]\n",
+ "\n",
+ "# read pickle file\n",
+ "sf = pd.DataFrame()\n",
+ "af = pd.DataFrame()\n",
+ "\n",
+ "# load statistics files from nets\n",
+ "for net in nets:\n",
+ " #path = os.path.join(root_path, net, 'control', 'statistics',)\n",
+ " path = os.path.join(root_path, 'statistics',)\n",
+ " with open(os.path.join(path, 'slotframe.csv'), 'rb') as f:\n",
+ " sf_temp = pd.read_csv(f, index_col=0)\n",
+ " sf_temp['net'] = net\n",
+ " sf = pd.concat([sf,sf_temp])\n",
+ "\n",
+ " with open(os.path.join(path, 'accframe.csv'), 'rb') as f:\n",
+ " af_temp = pd.read_csv(f, index_col=0)\n",
+ " af_temp['net'] = net\n",
+ " af = pd.concat([af,af_temp])\n",
+ "\n",
+ "# cast variables\n",
+ "sf['visible'] = sf['visible'].astype(bool)\n",
+ "sf['bound'] = sf['bound'].astype(bool)\n",
+ "sf['occluder'] = sf['occluder'].astype(bool)\n",
+ "sf['inimage'] = sf['inimage'].astype(bool)\n",
+ "sf['alpha_pos'] = 1-sf['alpha_pos']\n",
+ "sf['alpha_ges'] = 1-sf['alpha_ges']\n",
+ "\n",
+ "# scale to percentage\n",
+ "sf['TE'] = sf['TE'] * 100\n",
+ "\n",
+ "# add surprise as dummy code\n",
+ "sf['control'] = [('control' in set) for set in sf['set']]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Calculate Tracking Error (TE)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Tracking Error when visible: M: 30.3 , STD: 14.1, Count: 507\n",
+ "Tracking Error when occluded: M: 25.8 , STD: 16.1, Count: 15\n",
+ "Tracking Error Overall: M: 30.2 , STD: 14.1, Count: 522\n"
+ ]
+ }
+ ],
+ "source": [
+ "grouping = (sf.inimage & sf.bound & ~sf.occluder & sf.control)\n",
+ "\n",
+ "def get_stats(col):\n",
+ " return f' M: {col.mean():.3} , STD: {col.std():.3}, Count: {col.count()}'\n",
+ "\n",
+ "# When Visible\n",
+ "temp = sf[grouping & sf.visible]\n",
+ "print(f'Tracking Error when visible:' + get_stats(temp['TE']))\n",
+ "\n",
+ "# When Occluded\n",
+ "temp = sf[grouping & ~sf.visible]\n",
+ "print(f'Tracking Error when occluded:' + get_stats(temp['TE']))\n",
+ "\n",
+ "# When Overall\n",
+ "temp = sf[grouping]\n",
+ "print(f'Tracking Error Overall:' + get_stats(temp['TE']))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Calculate Succesfull Trackings (TE)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<div>\n",
+ "<style scoped>\n",
+ " .dataframe tbody tr th:only-of-type {\n",
+ " vertical-align: middle;\n",
+ " }\n",
+ "\n",
+ " .dataframe tbody tr th {\n",
+ " vertical-align: top;\n",
+ " }\n",
+ "\n",
+ " .dataframe thead th {\n",
+ " text-align: right;\n",
+ " }\n",
+ "</style>\n",
+ "<table border=\"1\" class=\"dataframe\">\n",
+ " <thead>\n",
+ " <tr style=\"text-align: right;\">\n",
+ " <th></th>\n",
+ " <th>set</th>\n",
+ " <th>evalmode</th>\n",
+ " <th>tracked_pos</th>\n",
+ " <th>tracked_neg</th>\n",
+ " <th>tracked_pos_pro</th>\n",
+ " <th>tracked_neg_pro</th>\n",
+ " </tr>\n",
+ " </thead>\n",
+ " <tbody>\n",
+ " <tr>\n",
+ " <th>0</th>\n",
+ " <td>control</td>\n",
+ " <td>control</td>\n",
+ " <td>1</td>\n",
+ " <td>19</td>\n",
+ " <td>0.05</td>\n",
+ " <td>0.95</td>\n",
+ " </tr>\n",
+ " </tbody>\n",
+ "</table>\n",
+ "</div>"
+ ],
+ "text/plain": [
+ " set evalmode tracked_pos tracked_neg tracked_pos_pro \\\n",
+ "0 control control 1 19 0.05 \n",
+ "\n",
+ " tracked_neg_pro \n",
+ "0 0.95 "
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# succesfull trackings: In the last visible moment of the target, the slot was less than 10% away from the target\n",
+ "# determine last visible frame numeric\n",
+ "grouping_factors = ['net','set','evalmode','scene','slot']\n",
+ "ff = sf[sf.visible].groupby(grouping_factors).max()\n",
+ "ff.rename(columns = {'frame':'last_visible'}, inplace = True)\n",
+ "ff = ff[['last_visible']]\n",
+ "\n",
+ "# add dummy variable to sf\n",
+ "sf = sf.merge(ff, on=grouping_factors, how='left')\n",
+ "sf['last_visible'] = (sf['last_visible'] == sf['frame'])\n",
+ "\n",
+ "# same for first bound frame\n",
+ "ff = sf[sf.visible & sf.bound & sf.inimage].groupby(grouping_factors).min()\n",
+ "ff.rename(columns = {'frame':'first_visible'}, inplace = True)\n",
+ "ff = ff[['first_visible']]\n",
+ "\n",
+ "# add dummy variable to sf\n",
+ "sf = sf.merge(ff, on=grouping_factors, how='left')\n",
+ "\n",
+ "# extract the trials where the target was last visible and threshold the TE\n",
+ "ff = sf[sf['last_visible']] \n",
+ "ff['tracked_pos'] = (ff['TE'] < 10)\n",
+ "ff['tracked_neg'] = (ff['TE'] >= 10)\n",
+ "\n",
+ "# fill NaN with 0\n",
+ "sf = sf.merge(ff[grouping_factors + ['tracked_pos', 'tracked_neg']], on=grouping_factors, how='left')\n",
+ "sf['tracked_pos'].fillna(False, inplace=True)\n",
+ "sf['tracked_neg'].fillna(False, inplace=True)\n",
+ "\n",
+ "# Aggreagte over all scenes\n",
+ "temp = sf[(sf['frame']== 15) & ~sf.occluder & sf.control & (sf.first_visible < 20)]\n",
+ "temp = temp.groupby(['set', 'evalmode']).sum()\n",
+ "temp = temp[['tracked_pos', 'tracked_neg']]\n",
+ "temp = temp.reset_index()\n",
+ "\n",
+ "temp['tracked_pos_pro'] = temp['tracked_pos'] / (temp['tracked_pos'] + temp['tracked_neg'])\n",
+ "temp['tracked_neg_pro'] = temp['tracked_neg'] / (temp['tracked_pos'] + temp['tracked_neg'])\n",
+ "\n",
+ "temp"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Mostly Tracked stats"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ "<Figure size 432x288 with 1 Axes>"
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "temp = af[af.index == 'OVERALL']\n",
+ "temp['mostly_tracked'] = temp['mostly_tracked'] / temp['num_unique_objects']\n",
+ "temp['partially_tracked'] = temp['partially_tracked'] / temp['num_unique_objects']\n",
+ "temp['mostly_lost'] = temp['mostly_lost'] / temp['num_unique_objects']\n",
+ "g = temp[['mostly_tracked', 'partially_tracked', 'mostly_lost','set']].set_index(['set']).groupby(['set']).mean().plot(kind='bar', stacked=True);"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# MOTA "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<div>\n",
+ "<style scoped>\n",
+ " .dataframe tbody tr th:only-of-type {\n",
+ " vertical-align: middle;\n",
+ " }\n",
+ "\n",
+ " .dataframe tbody tr th {\n",
+ " vertical-align: top;\n",
+ " }\n",
+ "\n",
+ " .dataframe thead th {\n",
+ " text-align: right;\n",
+ " }\n",
+ "</style>\n",
+ "<table border=\"1\" class=\"dataframe\">\n",
+ " <thead>\n",
+ " <tr style=\"text-align: right;\">\n",
+ " <th></th>\n",
+ " <th>idf1</th>\n",
+ " <th>idp</th>\n",
+ " <th>idr</th>\n",
+ " <th>recall</th>\n",
+ " <th>precision</th>\n",
+ " <th>num_unique_objects</th>\n",
+ " <th>mostly_tracked</th>\n",
+ " <th>partially_tracked</th>\n",
+ " <th>mostly_lost</th>\n",
+ " <th>num_false_positives</th>\n",
+ " <th>num_misses</th>\n",
+ " <th>num_switches</th>\n",
+ " <th>num_fragmentations</th>\n",
+ " <th>mota</th>\n",
+ " <th>motp</th>\n",
+ " <th>num_transfer</th>\n",
+ " <th>num_ascend</th>\n",
+ " <th>num_migrate</th>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>set</th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " </tr>\n",
+ " </thead>\n",
+ " <tbody>\n",
+ " <tr>\n",
+ " <th>control</th>\n",
+ " <td>0.54837</td>\n",
+ " <td>0.487418</td>\n",
+ " <td>0.626745</td>\n",
+ " <td>0.70247</td>\n",
+ " <td>0.546309</td>\n",
+ " <td>86.0</td>\n",
+ " <td>35.0</td>\n",
+ " <td>44.0</td>\n",
+ " <td>7.0</td>\n",
+ " <td>4345.0</td>\n",
+ " <td>2216.0</td>\n",
+ " <td>240.0</td>\n",
+ " <td>73.0</td>\n",
+ " <td>0.086869</td>\n",
+ " <td>0.07704</td>\n",
+ " <td>119.0</td>\n",
+ " <td>55.0</td>\n",
+ " <td>19.0</td>\n",
+ " </tr>\n",
+ " </tbody>\n",
+ "</table>\n",
+ "</div>"
+ ],
+ "text/plain": [
+ " idf1 idp idr recall precision num_unique_objects \\\n",
+ "set \n",
+ "control 0.54837 0.487418 0.626745 0.70247 0.546309 86.0 \n",
+ "\n",
+ " mostly_tracked partially_tracked mostly_lost num_false_positives \\\n",
+ "set \n",
+ "control 35.0 44.0 7.0 4345.0 \n",
+ "\n",
+ " num_misses num_switches num_fragmentations mota motp \\\n",
+ "set \n",
+ "control 2216.0 240.0 73.0 0.086869 0.07704 \n",
+ "\n",
+ " num_transfer num_ascend num_migrate \n",
+ "set \n",
+ "control 119.0 55.0 19.0 "
+ ]
+ },
+ "execution_count": 8,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "af[af.index == 'OVERALL'].groupby(['set']).mean()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "loci23",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.16"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/evaluation/adept/evaluation_loci_looped.ipynb b/evaluation/adept/evaluation_loci_looped.ipynb
index 1de5e2b..8b39ac1 100644
--- a/evaluation/adept/evaluation_loci_looped.ipynb
+++ b/evaluation/adept/evaluation_loci_looped.ipynb
@@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 47,
+ "execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@@ -30,7 +30,7 @@
},
{
"cell_type": "code",
- "execution_count": 48,
+ "execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
@@ -89,7 +89,7 @@
},
{
"cell_type": "code",
- "execution_count": 49,
+ "execution_count": 3,
"metadata": {},
"outputs": [
{
@@ -125,7 +125,7 @@
},
{
"cell_type": "code",
- "execution_count": 50,
+ "execution_count": 4,
"metadata": {},
"outputs": [
{
@@ -179,7 +179,7 @@
"0 0.033333 "
]
},
- "execution_count": 50,
+ "execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
@@ -495,9 +495,22 @@
},
{
"cell_type": "code",
- "execution_count": 55,
+ "execution_count": 17,
"metadata": {},
- "outputs": [],
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ "<Figure size 360x720 with 4 Axes>"
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
"source": [
"def get_plot_complete():\n",
"\n",
@@ -534,9 +547,9 @@
" vslot = vanish_slot\n",
" if error == 'TE':\n",
" vslot = vanish_slot[(vanish_slot['TE'] > 0) | (vanish_slot['frame'] < 10)]\n",
- " sns.lineplot(x=\"frame\", y=error, data=control_slot, ax=axs[row][col], label=control_label)\n",
+ " sns.lineplot(x=\"frame\", y=error, data=control_slot, ax=ax, label=control_label)\n",
" sns.lineplot(x=\"frame\", y=error, data=vslot, ax=axs[row][col], label=vanish_label)\n",
- " axs[row][col].set_xlabel('')\n",
+ " ax[col].set_xlabel('')\n",
" axs[row][col].set_ylabel(ylabel)\n",
" axs[row][col].legend()\n",
"\n",
@@ -682,7 +695,7 @@
" (sf['first_visible'] < 20)]\n",
" control_slot['visible'] = (control_slot['visible'] & control_slot['bound'])\n",
" control_slot['slot_error'] = control_slot['slot_error'] * 10000000\n",
- " control_label = 'Reappearing'\n",
+ " control_label = 'Reappearing objects'\n",
"\n",
" # line plot of mean slot error as a function of frames per jumping mode\n",
" vanish_slot = sf[(sf['vanishing'] == True) &\n",
@@ -693,62 +706,62 @@
" (sf['first_visible'] < 20)]\n",
" vanish_slot['visible'] = (vanish_slot['visible'] & vanish_slot['bound'])\n",
" vanish_slot['slot_error'] = vanish_slot['slot_error'] * 10000000\n",
- " vanish_label = 'Vanishing'\n",
+ " vanish_label = 'Vanishing objects'\n",
"\n",
" a = control_slot.groupby(['frame']).mean()['visible']\n",
" b = vanish_slot.groupby(['frame']).mean()['visible']\n",
" #b[28:] = 0\n",
"\n",
" # combined plot of slot error and tracking error\n",
- " fig, axs = plt.subplots(ncols=2, nrows=2, figsize=(10, 4), gridspec_kw={'height_ratios': [3, 1]})\n",
+ " fig, axs = plt.subplots(ncols=1, nrows=5, figsize=(5, 10), gridspec_kw={'height_ratios': [3, 1, 0.5, 3, 1]})\n",
" alphabet = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']\n",
"\n",
" for i, (error,ylabel) in enumerate(zip(['alpha_pos', 'alpha_ges'],['Position Updates', 'Gestalt Updates'])):\n",
"\n",
" # set row and column\n",
- " row = int(i/2) \n",
- " col = i % 2\n",
+ " ax = axs[3*i]\n",
"\n",
" # line plot of slot error as a function of frames\n",
" vslot = vanish_slot\n",
" if error == 'TE':\n",
" vslot = vanish_slot[(vanish_slot['TE'] > 0) | (vanish_slot['frame'] < 10)]\n",
- " sns.lineplot(x=\"frame\", y=error, data=control_slot, ax=axs[row][col], label=control_label)\n",
- " sns.lineplot(x=\"frame\", y=error, data=vslot, ax=axs[row][col], label=vanish_label)\n",
- " axs[row][col].set_xlabel('')\n",
- " axs[row][col].set_ylabel(ylabel, fontsize=12)\n",
- " axs[row][col].legend()\n",
+ " sns.lineplot(x=\"frame\", y=error, data=control_slot, ax=ax, label=control_label)\n",
+ " sns.lineplot(x=\"frame\", y=error, data=vslot, ax=ax, label=vanish_label)\n",
+ " ax.set_xlabel('')\n",
+ " ax.set_ylabel(ylabel, fontsize=12)\n",
+ " ax.legend()\n",
"\n",
" # set y lim from 0 to 0.6\n",
- " axs[row][col].set_ylim(0, 0.6)\n",
+ " ax.set_ylim(0, 0.6)\n",
"\n",
- " axs[row][col].set_yticks([0, 0.3, 0.6])\n",
- " axs[row][col].set_yticklabels(['0%', '30%', '60%'])\n",
+ " ax.set_yticks([0, 0.3, 0.6])\n",
+ " ax.set_yticklabels(['0%', '30%', '60%'])\n",
"\n",
" # add a,b,c ... labels to the top left conrner of each plot\n",
- " axs[row][col].text(-0.2, 1, alphabet[i], transform=axs[row][col].transAxes, size=15, weight='bold')\n",
+ " ax.text(-0.2, 1, alphabet[i], transform=ax.transAxes, size=15, weight='bold')\n",
"\n",
" # remove x axis ticks from top plot but keep grid lines in x axis\n",
- " if row == 0:\n",
- " axs[row][col].set_xticklabels([])\n",
- " axs[row][col].set_xlabel('')\n",
+ " ax.set_xticklabels([])\n",
+ " ax.set_xlabel('')\n",
"\n",
- " for col in [0,1]:\n",
+ " for row in [1,4]:\n",
+ "\n",
+ " ax = axs[row]\n",
"\n",
" # add occulusion plot\n",
- " axs[1][col].plot(range(len(a)),a.to_list(), label=control_label)\n",
- " axs[1][col].plot(range(len(a)),b.to_list(), label=vanish_label)\n",
- " axs[1][col].set_xlabel('Frame', fontsize=12)\n",
- " axs[1][col].set_ylabel('Visiblity', fontsize=12)\n",
+ " ax.plot(range(len(a)),a.to_list(), label=control_label)\n",
+ " ax.plot(range(len(a)),b.to_list(), label=vanish_label)\n",
+ " ax.set_xlabel('Frame', fontsize=12)\n",
+ " ax.set_ylabel('Visiblity', fontsize=12)\n",
"\n",
" # set y axis ticks as 0% to 100%\n",
- " axs[1][col].set_yticks([0, 1])\n",
+ " ax.set_yticks([0, 1])\n",
" #axs[3][col].set_yticks([])\n",
- " axs[1][col].set_yticklabels(['0%', '100%'])\n",
+ " ax.set_yticklabels(['0%', '100%'])\n",
"\n",
" # add same text as above but center the text\n",
- " axs[1][col].text(28, 0.5, 'Object\\nvanishes', fontsize=9, color='grey', rotation=90, ha='center', va='center')\n",
- " axs[1][col].text(60, 0.5, 'Occluder\\nrotates', fontsize=9, color='grey', rotation=90, ha='center', va='center')\n",
+ " ax.text(29, 0.5, 'Reappear', fontsize=9, color='grey', rotation=90, ha='center', va='center')\n",
+ " ax.text(60, 0.5, 'Occluder\\nrotates', fontsize=9, color='grey', rotation=90, ha='center', va='center')\n",
"\n",
" # Update Gates\n",
" #axs[2][col].set_yticks([0, 0.2,0.4])\n",
@@ -758,6 +771,12 @@
"\n",
" plt.tight_layout()\n",
"\n",
+ " # reduce space between row 0 and 1, and row 2 and 3\n",
+ " plt.subplots_adjust(hspace=0.1)\n",
+ " \n",
+ " # delete row 2\n",
+ " axs[2].remove()\n",
+ "\n",
" pass"
]
},
@@ -805,29 +824,537 @@
},
{
"cell_type": "code",
- "execution_count": 58,
+ "execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
- "image/png": "",
+ "image/png": "",
"text/plain": [
- "<Figure size 720x288 with 4 Axes>"
+ "<Figure size 360x720 with 4 Axes>"
]
},
- "metadata": {},
+ "metadata": {
+ "needs_background": "light"
+ },
"output_type": "display_data"
}
],
"source": [
"get_plot_gates()\n",
- "plt.savefig('plots/loci_looped_gates.pdf')"
+ "plt.savefig('plots/loci_looped_gates2.pdf')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Accuracy"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 77,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "[Text(0, 0, 'Expected moment of reappearance'), Text(1, 0, 'Occluder rotates')]"
+ ]
+ },
+ "execution_count": 77,
+ "metadata": {},
+ "output_type": "execute_result"
+ },
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ "<Figure size 576x432 with 1 Axes>"
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "control_slot = sf[(sf['vanishing'] == False) &\n",
+ " (sf['occluder'] == False) &\n",
+ " sf['set'].isin(['createdown_control']) &\n",
+ " (sf['tracked_pos']) &\n",
+ " (sf['first_visible'] < 20)]\n",
+ "control_slot['visible'] = (control_slot['visible'] & control_slot['bound'])\n",
+ "\n",
+ "# line plot of mean slot error as a function of frames per jumping mode\n",
+ "vanish_slot = sf[(sf['vanishing'] == True) &\n",
+ " (sf['occluder'] == False) &\n",
+ " sf['set'].isin(['createdown_surprise']) &\n",
+ " (sf['tracked_pos']) & \n",
+ " (sf['first_visible'] < 20)]\n",
+ "vanish_slot['visible'] = (vanish_slot['visible'] & vanish_slot['bound'])\n",
+ "\n",
+ "# concat the two dataframes\n",
+ "temp = pd.concat([control_slot, vanish_slot])\n",
+ "temp1 = temp[(temp['frame'] > 30) & (temp['frame'] < 50)]\n",
+ "temp2 = temp[(temp['frame'] > 50)]\n",
+ "\n",
+ "temp1 = pd.Series.to_frame(temp1.groupby(['set', 'scene', 'slot']).max()['slot_error'])\n",
+ "temp1['name'] = 30\n",
+ "\n",
+ "temp2 = pd.Series.to_frame(temp2.groupby(['set', 'scene', 'slot']).max()['slot_error'])\n",
+ "temp2['slot_error'] = temp2['slot_error'] \n",
+ "temp2['name'] = 50\n",
+ "\n",
+ "# concatenate the two dataframes\n",
+ "temp = pd.concat([temp1, temp2]).reset_index()\n",
+ "\n",
+ "# make box plot of the slot_error in one figure grouped by set and then by name\n",
+ "fig, ax = plt.subplots(figsize=(8, 6))\n",
+ "sns.boxplot(x='name', y='slot_error', hue='set', data=temp, ax=ax)\n",
+ "ax.set_xlabel('')\n",
+ "ax.set_ylabel('Surprise')\n",
+ "ax.legend(title='Frame', loc='upper right')\n",
+ "plt.tight_layout()\n",
+ "\n",
+ "# set legend labels to ['Reappearing', 'Vanishing']\n",
+ "handles, labels = ax.get_legend_handles_labels()\n",
+ "ax.legend(handles, ['Reappearing', 'Vanishing'], title='', loc='upper right')\n",
+ "\n",
+ "# set x-axis labes to [Reappearing, Occluder Rotates]\n",
+ "ax.set_xticklabels(['Expected moment of reappearance', 'Occluder rotates'])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 146,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "def get_plot_voe():\n",
+ "\n",
+ " ########################################\n",
+ " # line plot of mean slot error as a function of frames per jumping mode\n",
+ " control_slot = sf[(sf['vanishing'] == False) &\n",
+ " (sf['occluder'] == False) &\n",
+ " sf['set'].isin(['createdown_control']) &\n",
+ " #(sf['net'] == net) &\n",
+ " (sf['tracked_pos']) &\n",
+ " (sf['first_visible'] < 20)]\n",
+ " control_slot['visible'] = (control_slot['visible'] & control_slot['bound'])\n",
+ " control_slot['slot_error'] = control_slot['slot_error'] * 10000000\n",
+ " control_label = 'Reappearing'\n",
+ "\n",
+ " # line plot of mean slot error as a function of frames per jumping mode\n",
+ " vanish_slot = sf[(sf['vanishing'] == True) &\n",
+ " (sf['occluder'] == False) &\n",
+ " sf['set'].isin(['createdown_surprise']) &\n",
+ " #(sf['net'] == net) & \n",
+ " (sf['tracked_pos']) & \n",
+ " (sf['first_visible'] < 20)]\n",
+ " vanish_slot['visible'] = (vanish_slot['visible'] & vanish_slot['bound'])\n",
+ " vanish_slot['slot_error'] = vanish_slot['slot_error'] * 10000000\n",
+ " vanish_label = 'Vanishing'\n",
+ "\n",
+ " a = control_slot.groupby(['frame']).mean()['visible']\n",
+ " b = vanish_slot.groupby(['frame']).mean()['visible']\n",
+ " #b[28:] = 0\n",
+ "\n",
+ " # get 3 plots underneath each other the ratio of the plots is 3:1:3\n",
+ " fig, axs = plt.subplots(ncols=1, nrows=4, figsize=(5, 9), gridspec_kw={'height_ratios': [3, 0.5, 2, 1]})\n",
+ " alphabet = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']\n",
+ "\n",
+ " error = 'slot_error'\n",
+ " ylabel = 'Slot Error'\n",
+ " ax = axs[2]\n",
+ "\n",
+ " # line plot of slot error as a function of frames\n",
+ " vslot = vanish_slot\n",
+ " if error == 'TE':\n",
+ " vslot = vanish_slot[(vanish_slot['TE'] > 0) | (vanish_slot['frame'] < 10)]\n",
+ " sns.lineplot(x=\"frame\", y=error, data=control_slot, ax=ax, label=control_label)\n",
+ " sns.lineplot(x=\"frame\", y=error, data=vslot, ax=ax, label=vanish_label)\n",
+ " ax.set_xlabel('')\n",
+ " ax.set_ylabel(ylabel, fontsize=12)\n",
+ " ax.legend()\n",
+ "\n",
+ " # add a,b,c ... labels to the top left conrner of each plot\n",
+ " ax.text(-0.15, 1, alphabet[1], transform=ax.transAxes, size=15, weight='bold')\n",
+ " # add arrow to plot axs[0][0] at position (x,y) = (10, 0.5)\n",
+ " ax.annotate('', xy=(33, 0.45), xytext=(25, 0.6), arrowprops={'arrowstyle': '->', 'color': 'black'})\n",
+ " ax.text(21, 0.62, 'Expectation', fontsize=9, color='black', ha='center', va='center')\n",
+ "\n",
+ " ax.annotate('', xy=(63, 0.31), xytext=(56, 0.45), arrowprops={'arrowstyle': '->', 'color': 'black'})\n",
+ " ax.text(52, 0.47, 'Expectation', fontsize=9, color='black', ha='center', va='center')\n",
+ "\n",
+ " # remove x axis ticks from top plot but keep grid lines in x axis\n",
+ " ax.set_xticklabels([])\n",
+ " ax.set_xlabel('')\n",
+ "\n",
+ " # remove legend\n",
+ " ax.legend().remove()\n",
+ "\n",
+ " ########################################\n",
+ " ax = axs[3]\n",
+ "\n",
+ " # add occulusion plot\n",
+ " ax.plot(range(len(a)),a.to_list(), label=control_label)\n",
+ " ax.plot(range(len(a)),b.to_list(), label=vanish_label)\n",
+ " ax.set_xlabel('Frame', fontsize=12)\n",
+ " ax.set_ylabel('Visiblity', fontsize=12)\n",
+ "\n",
+ " # move ylbale a bit more to the right\n",
+ " ax.yaxis.set_label_coords(-0.09, 0.5)\n",
+ "\n",
+ " # set y axis ticks as 0% to 100%\n",
+ " ax.set_yticks([0, 1])\n",
+ " ax.set_yticklabels(['0%', '100%'])\n",
+ "\n",
+ " # add same text as above but center the text\n",
+ " ax.text(28, 0.5, 'Object\\nvanishes', fontsize=9, color='grey', rotation=90, ha='center', va='center')\n",
+ " ax.text(60, 0.5, 'Occluder\\nrotates', fontsize=9, color='grey', rotation=90, ha='center', va='center')\n",
+ "\n",
+ "\n",
+ " ########################################\n",
+ " axs[1].axis('off')\n",
+ "\n",
+ " ax = axs[0]\n",
+ " ax.text(-0.15, 1, alphabet[0], transform=ax.transAxes, size=15, weight='bold')\n",
+ "\n",
+ " # concat the two dataframes\n",
+ " temp = pd.concat([control_slot, vanish_slot])\n",
+ " temp1 = temp[(temp['frame'] > 30) & (temp['frame'] < 50)]\n",
+ " temp2 = temp[(temp['frame'] > 50)]\n",
+ "\n",
+ " temp1 = pd.Series.to_frame(temp1.groupby(['set', 'scene', 'slot']).max()['slot_error'])\n",
+ " temp1['name'] = 30\n",
+ "\n",
+ " temp2 = pd.Series.to_frame(temp2.groupby(['set', 'scene', 'slot']).max()['slot_error'])\n",
+ " temp2['name'] = 50\n",
+ "\n",
+ " # concatenate the two dataframes\n",
+ " temp = pd.concat([temp1, temp2]).reset_index()\n",
+ "\n",
+ " # make box plot of the slot_error in one figure grouped by set and then by name\n",
+ " sns.boxplot(x='name', y='slot_error', hue='set', data=temp, ax=ax)\n",
+ " ax.set_xlabel('')\n",
+ " ax.set_ylabel('Surprise')\n",
+ "\n",
+ " # remove legend\n",
+ " plt.tight_layout()\n",
+ "\n",
+ " # set x-axis labes to [Reappearing, Occluder Rotates]\n",
+ " ax.set_xticklabels(['Expected moment\\nof reappearance', 'Occluder rotates'])\n",
+ "\n",
+ " # set legend labels to ['Reappearing', 'Vanishing']\n",
+ " handles, labels = ax.get_legend_handles_labels()\n",
+ " ax.legend(handles, ['Reappearing', 'Vanishing'], title='', loc='upper center', bbox_to_anchor=(0.5, 1.105), ncol=2)\n",
+ "\n",
+ " # reduce the space between the second and the third plot\n",
+ " plt.subplots_adjust(hspace=0.05)\n",
+ " \n",
+ "\n",
+ " pass"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 148,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ "<Figure size 360x648 with 4 Axes>"
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "get_plot_voe()\n",
+ "plt.savefig('plots/loci_looped_voe2.pdf')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 138,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Ttest_indResult(statistic=1.691198428694689, pvalue=0.09495066166428383)\n",
+ "Ttest_indResult(statistic=3.6836011726202966, pvalue=0.00043158165864250897)\n"
+ ]
+ }
+ ],
+ "source": [
+ "control_slot = sf[(sf['vanishing'] == False) &\n",
+ " (sf['occluder'] == False) &\n",
+ " sf['set'].isin(['createdown_control']) &\n",
+ " (sf['tracked_pos']) &\n",
+ " (sf['first_visible'] < 20)]\n",
+ "control_slot['visible'] = (control_slot['visible'] & control_slot['bound'])\n",
+ "control_slot['slot_error'] = control_slot['slot_error'] * 10000000\n",
+ "\n",
+ "# line plot of mean slot error as a function of frames per jumping mode\n",
+ "vanish_slot = sf[(sf['vanishing'] == True) &\n",
+ " (sf['occluder'] == False) &\n",
+ " sf['set'].isin(['createdown_surprise']) &\n",
+ " (sf['tracked_pos']) & \n",
+ " (sf['first_visible'] < 20)]\n",
+ "vanish_slot['visible'] = (vanish_slot['visible'] & vanish_slot['bound'])\n",
+ "vanish_slot['slot_error'] = vanish_slot['slot_error'] * 10000000\n",
+ "\n",
+ "\n",
+ "temp = pd.concat([control_slot, vanish_slot])\n",
+ "temp1 = temp[(temp['frame'] > 30) & (temp['frame'] < 50)]\n",
+ "temp2 = temp[(temp['frame'] > 50)]\n",
+ "\n",
+ "temp1 = pd.Series.to_frame(temp1.groupby(['set', 'scene', 'slot']).max()['slot_error'])\n",
+ "temp2 = pd.Series.to_frame(temp2.groupby(['set', 'scene', 'slot']).max()['slot_error'])\n",
+ "\n",
+ "# \n",
+ "temp1 = temp1.groupby(['set', 'scene', 'slot']).max()['slot_error']\n",
+ "temp2 = temp2.groupby(['set', 'scene', 'slot']).max()['slot_error']\n",
+ "\n",
+ "# t-test between reappearing and vanishing\n",
+ "print(stats.ttest_ind(temp1['createdown_surprise'], temp1['createdown_control']))\n",
+ "print(stats.ttest_ind(temp2['createdown_surprise'], temp2['createdown_control']))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 141,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "0.047"
+ ]
+ },
+ "execution_count": 141,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "0.094/2"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 108,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# vertical aligned\n",
+ "def get_plot_voe():\n",
+ "\n",
+ " # increase font site of all plots\n",
+ " sns.set(font_scale=1.2)\n",
+ "\n",
+ "\n",
+ " ########################################\n",
+ " # line plot of mean slot error as a function of frames per jumping mode\n",
+ " control_slot = sf[(sf['vanishing'] == False) &\n",
+ " (sf['occluder'] == False) &\n",
+ " sf['set'].isin(['createdown_control']) &\n",
+ " #(sf['net'] == net) &\n",
+ " (sf['tracked_pos']) &\n",
+ " (sf['first_visible'] < 20)]\n",
+ " control_slot['visible'] = (control_slot['visible'] & control_slot['bound'])\n",
+ " control_slot['slot_error'] = control_slot['slot_error'] * 10000000\n",
+ " control_label = 'Reappearing'\n",
+ "\n",
+ " # line plot of mean slot error as a function of frames per jumping mode\n",
+ " vanish_slot = sf[(sf['vanishing'] == True) &\n",
+ " (sf['occluder'] == False) &\n",
+ " sf['set'].isin(['createdown_surprise']) &\n",
+ " #(sf['net'] == net) & \n",
+ " (sf['tracked_pos']) & \n",
+ " (sf['first_visible'] < 20)]\n",
+ " vanish_slot['visible'] = (vanish_slot['visible'] & vanish_slot['bound'])\n",
+ " vanish_slot['slot_error'] = vanish_slot['slot_error'] * 10000000\n",
+ " vanish_label = 'Vanishing'\n",
+ "\n",
+ " a = control_slot.groupby(['frame']).mean()['visible']\n",
+ " b = vanish_slot.groupby(['frame']).mean()['visible']\n",
+ " #b[28:] = 0\n",
+ "\n",
+ " # get 4 plots in 2 rows and 2 columns where the the first column has a ratio of 3:0.5 and the second column has a ratio of 2:1\n",
+ " #fig, axs = plt.subplots(ncols=2, nrows=2, figsize=(9, 5), gridspec_kw={'height_ratios': [3, 0.5, 2, 1]})\n",
+ " fig = plt.figure(constrained_layout=True, figsize=(9, 5))\n",
+ " gs = fig.add_gridspec(4, 8)\n",
+ "\n",
+ " alphabet = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h']\n",
+ "\n",
+ " error = 'slot_error'\n",
+ " ylabel = 'Slot Error'\n",
+ " ax = fig.add_subplot(gs[:3, 3:])\n",
+ "\n",
+ " # line plot of slot error as a function of frames\n",
+ " vslot = vanish_slot\n",
+ " if error == 'TE':\n",
+ " vslot = vanish_slot[(vanish_slot['TE'] > 0) | (vanish_slot['frame'] < 10)]\n",
+ " sns.lineplot(x=\"frame\", y=error, data=control_slot, ax=ax, label=control_label)\n",
+ " sns.lineplot(x=\"frame\", y=error, data=vslot, ax=ax, label=vanish_label)\n",
+ " ax.set_xlabel('')\n",
+ " ax.set_ylabel(ylabel)\n",
+ " ax.legend()\n",
+ "\n",
+ " # add a,b,c ... labels to the top left conrner of each plot\n",
+ " ax.text(-0.15, 1, alphabet[1], transform=ax.transAxes, size=15, weight='bold')\n",
+ " # add arrow to plot axs[0][0] at position (x,y) = (10, 0.5)\n",
+ " ax.annotate('', xy=(33, 0.45), xytext=(25, 0.6), arrowprops={'arrowstyle': '->', 'color': 'black'})\n",
+ " ax.text(21, 0.62, 'Expectation', fontsize=11, color='black', ha='center', va='center')\n",
+ "\n",
+ " ax.annotate('', xy=(63, 0.31), xytext=(56, 0.45), arrowprops={'arrowstyle': '->', 'color': 'black'})\n",
+ " ax.text(52, 0.47, 'Expectation', fontsize=11, color='black', ha='center', va='center')\n",
+ "\n",
+ " # remove x axis ticks from top plot but keep grid lines in x axis\n",
+ " ax.set_xticklabels([])\n",
+ " ax.set_xlabel('')\n",
+ "\n",
+ " # set legend labels to ['Reappearing', 'Vanishing']\n",
+ " handles, labels = ax.get_legend_handles_labels()\n",
+ " ax.legend(handles, ['Reappearing objects', 'Vanishing objects'], title='', loc='upper center', bbox_to_anchor=(0.5, 1.18), ncol=2)\n",
+ "\n",
+ " # remove legend\n",
+ " #ax.legend().remove()\n",
+ "\n",
+ "\n",
+ " ########################################\n",
+ " ax = fig.add_subplot(gs[3:, 3:])\n",
+ "\n",
+ " # add occulusion plot\n",
+ " ax.plot(range(len(a)),a.to_list(), label=control_label)\n",
+ " ax.plot(range(len(a)),b.to_list(), label=vanish_label)\n",
+ " ax.set_xlabel('Frame')\n",
+ " ax.set_ylabel('Visiblity', fontsize=12)\n",
+ "\n",
+ " # move ylbale a bit more to the right\n",
+ " ax.yaxis.set_label_coords(-0.09, 0.5)\n",
+ "\n",
+ " # set y axis ticks as 0% to 100%\n",
+ " ax.set_yticks([0, 1])\n",
+ " ax.set_yticklabels(['0%', '100%'])\n",
+ "\n",
+ " # add same text as above but center the text\n",
+ " ax.text(29, 0.5, 'Reappear', fontsize=11, color='grey', rotation=90, ha='center', va='center')\n",
+ " ax.text(60, 0.5, 'Occluder\\nrotates', fontsize=11, color='grey', rotation=90, ha='center', va='center')\n",
+ "\n",
+ "\n",
+ " ########################################\n",
+ " #axs[0][1].axis('off')\n",
+ "\n",
+ " ax = fig.add_subplot(gs[:, :3])\n",
+ " ax.text(-0.15, 1, alphabet[0], transform=ax.transAxes, size=17, weight='bold')\n",
+ "\n",
+ " # concat the two dataframes\n",
+ " temp = pd.concat([control_slot, vanish_slot])\n",
+ " temp1 = temp[(temp['frame'] > 30) & (temp['frame'] < 50)]\n",
+ " temp2 = temp[(temp['frame'] > 50)]\n",
+ "\n",
+ " temp1 = pd.Series.to_frame(temp1.groupby(['set', 'scene', 'slot']).max()['slot_error'])\n",
+ " temp1['name'] = 30\n",
+ "\n",
+ " temp2 = pd.Series.to_frame(temp2.groupby(['set', 'scene', 'slot']).max()['slot_error'])\n",
+ " temp2['name'] = 50\n",
+ "\n",
+ " # concatenate the two dataframes\n",
+ " temp = pd.concat([temp1, temp2]).reset_index()\n",
+ "\n",
+ " # make box plot of the slot_error in one figure grouped by set and then by name\n",
+ " sns.boxplot(x='name', y='slot_error', hue='set', data=temp, ax=ax)\n",
+ " ax.set_xlabel('')\n",
+ " ax.set_ylabel('Surprise')\n",
+ "\n",
+ " # remove legend\n",
+ " plt.tight_layout()\n",
+ "\n",
+ " # set x-axis labes to [Reappearing, Occluder Rotates]\n",
+ " ax.set_xticklabels(['Moment of \\nexpected \\nreappearance', 'Occluder \\nrotates'])\n",
+ "\n",
+ " ax.legend().remove() \n",
+ "\n",
+ " # reduce the space between the second and the third plot\n",
+ " plt.subplots_adjust(hspace=0.05)\n",
+ " \n",
+ "\n",
+ " pass"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 110,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "<ipython-input-108-43cbf80ac39b>:124: UserWarning: This figure was using constrained_layout, but that is incompatible with subplots_adjust and/or tight_layout; disabling constrained_layout.\n",
+ " plt.tight_layout()\n"
+ ]
+ },
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ "<Figure size 648x360 with 3 Axes>"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "get_plot_voe()\n",
+ "plt.savefig('plots/loci_looped_voe3.pdf')"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
}
],
"metadata": {
diff --git a/evaluation/adept/evaluation_loci_looped_visibility.ipynb b/evaluation/adept/evaluation_loci_looped_visibility.ipynb
new file mode 100644
index 0000000..cbb2b7c
--- /dev/null
+++ b/evaluation/adept/evaluation_loci_looped_visibility.ipynb
@@ -0,0 +1,475 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import seaborn as sns\n",
+ "import warnings\n",
+ "import scipy.stats as stats\n",
+ "import os\n",
+ "\n",
+ "warnings.simplefilter(action='ignore', category=FutureWarning)\n",
+ "pd.options.mode.chained_assignment = None \n",
+ "plt.style.use('ggplot')\n",
+ "sns.color_palette(\"Paired\");\n",
+ "sns.set_theme();"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Data Loading"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# setting path to results folder\n",
+ "root_path = '../../out/pretrained/adept/loci_looped/results_visibility'\n",
+ "\n",
+ "# list all folders in root path that don't stat with a dot\n",
+ "nets = [f for f in os.listdir(root_path) if not f.startswith('.')]\n",
+ "\n",
+ "# read pickle file\n",
+ "tf = pd.DataFrame()\n",
+ "sf = pd.DataFrame()\n",
+ "af = pd.DataFrame()\n",
+ "\n",
+ "# load statistics files from nets\n",
+ "for net in nets:\n",
+ " path = os.path.join(root_path, net, 'surprise', 'statistics',)\n",
+ " with open(os.path.join(path, 'trialframe.csv'), 'rb') as f:\n",
+ " tf_temp = pd.read_csv(f, index_col=0)\n",
+ " tf_temp['net'] = net\n",
+ " tf = pd.concat([tf,tf_temp])\n",
+ "\n",
+ " with open(os.path.join(path, 'slotframe.csv'), 'rb') as f:\n",
+ " sf_temp = pd.read_csv(f, index_col=0)\n",
+ " sf_temp['net'] = net\n",
+ " sf = pd.concat([sf,sf_temp])\n",
+ "\n",
+ " with open(os.path.join(path, 'accframe.csv'), 'rb') as f:\n",
+ " af_temp = pd.read_csv(f, index_col=0)\n",
+ " af_temp['net'] = net\n",
+ " af = pd.concat([af,af_temp])\n",
+ "\n",
+ "# cast variables\n",
+ "sf['visible'] = sf['visible'].astype(bool)\n",
+ "sf['bound'] = sf['bound'].astype(bool)\n",
+ "sf['occluder'] = sf['occluder'].astype(bool)\n",
+ "sf['inimage'] = sf['inimage'].astype(bool)\n",
+ "sf['vanishing'] = sf['vanishing'].astype(bool)\n",
+ "sf['alpha_pos'] = 1-sf['alpha_pos']\n",
+ "sf['alpha_ges'] = 1-sf['alpha_ges']\n",
+ "\n",
+ "# scale to percentage\n",
+ "sf['TE'] = sf['TE'] * 100\n",
+ "\n",
+ "# add surprise as dummy code\n",
+ "tf['control'] = [('control' in set) for set in tf['set']]\n",
+ "sf['control'] = [('control' in set) for set in sf['set']]"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Calculate Tracking Error (TE)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Tracking Error when visible: M: 7.65 , STD: 10.6, Count: 8266\n",
+ "Tracking Error when occluded: M: 6.72 , STD: 6.25, Count: 2236\n"
+ ]
+ }
+ ],
+ "source": [
+ "grouping = (sf.inimage & sf.bound & ~sf.occluder & sf.control)\n",
+ "\n",
+ "def get_stats(col):\n",
+ " return f' M: {col.mean():.3} , STD: {col.std():.3}, Count: {col.count()}'\n",
+ "\n",
+ "# When Visible\n",
+ "temp = sf[grouping & sf.visible]\n",
+ "print(f'Tracking Error when visible:' + get_stats(temp['TE']))\n",
+ "\n",
+ "# When Occluded\n",
+ "temp = sf[grouping & ~sf.visible]\n",
+ "print(f'Tracking Error when occluded:' + get_stats(temp['TE']))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Calculate Succesfull Trackings (TE)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<div>\n",
+ "<style scoped>\n",
+ " .dataframe tbody tr th:only-of-type {\n",
+ " vertical-align: middle;\n",
+ " }\n",
+ "\n",
+ " .dataframe tbody tr th {\n",
+ " vertical-align: top;\n",
+ " }\n",
+ "\n",
+ " .dataframe thead th {\n",
+ " text-align: right;\n",
+ " }\n",
+ "</style>\n",
+ "<table border=\"1\" class=\"dataframe\">\n",
+ " <thead>\n",
+ " <tr style=\"text-align: right;\">\n",
+ " <th></th>\n",
+ " <th>set</th>\n",
+ " <th>evalmode</th>\n",
+ " <th>tracked_pos</th>\n",
+ " <th>tracked_neg</th>\n",
+ " <th>tracked_pos_pro</th>\n",
+ " <th>tracked_neg_pro</th>\n",
+ " </tr>\n",
+ " </thead>\n",
+ " <tbody>\n",
+ " <tr>\n",
+ " <th>0</th>\n",
+ " <td>control</td>\n",
+ " <td>open</td>\n",
+ " <td>89</td>\n",
+ " <td>115</td>\n",
+ " <td>0.436275</td>\n",
+ " <td>0.563725</td>\n",
+ " </tr>\n",
+ " </tbody>\n",
+ "</table>\n",
+ "</div>"
+ ],
+ "text/plain": [
+ " set evalmode tracked_pos tracked_neg tracked_pos_pro \\\n",
+ "0 control open 89 115 0.436275 \n",
+ "\n",
+ " tracked_neg_pro \n",
+ "0 0.563725 "
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# succesfull trackings: In the last visible moment of the target, the slot was less than 10% away from the target\n",
+ "# determine last visible frame numeric\n",
+ "grouping_factors = ['net','set','evalmode','scene','slot']\n",
+ "ff = sf[sf.visible & sf.bound & sf.inimage].groupby(grouping_factors).max()\n",
+ "ff.rename(columns = {'frame':'last_visible'}, inplace = True)\n",
+ "sf = sf.merge(ff[['last_visible']], on=grouping_factors, how='left')\n",
+ "\n",
+ "# same for first bound frame\n",
+ "ff = sf[sf.visible & sf.bound & sf.inimage].groupby(grouping_factors).min()\n",
+ "ff.rename(columns = {'frame':'first_visible'}, inplace = True)\n",
+ "sf = sf.merge(ff[['first_visible']], on=grouping_factors, how='left')\n",
+ "\n",
+ "# add dummy variable to sf\n",
+ "sf['last_visible'] = (sf['last_visible'] == sf['frame'])\n",
+ "\n",
+ "# extract the trials where the target was last visible and threshold the TE\n",
+ "ff = sf[sf['last_visible']] \n",
+ "ff['tracked_pos'] = (ff['TE'] < 10)\n",
+ "ff['tracked_neg'] = (ff['TE'] >= 10)\n",
+ "\n",
+ "# fill NaN with 0\n",
+ "sf = sf.merge(ff[grouping_factors + ['tracked_pos', 'tracked_neg']], on=grouping_factors, how='left')\n",
+ "sf['tracked_pos'].fillna(False, inplace=True)\n",
+ "sf['tracked_neg'].fillna(False, inplace=True)\n",
+ "\n",
+ "# Aggreagte over all scenes\n",
+ "temp = sf[(sf['frame']== 1) & ~sf.occluder & sf.control & (sf.first_visible < 20)]\n",
+ "temp = temp.groupby(['set', 'evalmode']).sum()\n",
+ "temp = temp[['tracked_pos', 'tracked_neg']]\n",
+ "temp = temp.reset_index()\n",
+ "\n",
+ "temp['tracked_pos_pro'] = temp['tracked_pos'] / (temp['tracked_pos'] + temp['tracked_neg'])\n",
+ "temp['tracked_neg_pro'] = temp['tracked_neg'] / (temp['tracked_pos'] + temp['tracked_neg'])\n",
+ "\n",
+ "temp"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Mostly Tracked stats"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ "<Figure size 432x288 with 1 Axes>"
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "temp = af[af.index == 'OVERALL']\n",
+ "temp['mostly_tracked'] = temp['mostly_tracked'] / temp['num_unique_objects']\n",
+ "temp['partially_tracked'] = temp['partially_tracked'] / temp['num_unique_objects']\n",
+ "temp['mostly_lost'] = temp['mostly_lost'] / temp['num_unique_objects']\n",
+ "g = temp[['mostly_tracked', 'partially_tracked', 'mostly_lost','set']].set_index(['set']).groupby(['set']).mean().plot(kind='bar', stacked=True);"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# MOTA "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<div>\n",
+ "<style scoped>\n",
+ " .dataframe tbody tr th:only-of-type {\n",
+ " vertical-align: middle;\n",
+ " }\n",
+ "\n",
+ " .dataframe tbody tr th {\n",
+ " vertical-align: top;\n",
+ " }\n",
+ "\n",
+ " .dataframe thead th {\n",
+ " text-align: right;\n",
+ " }\n",
+ "</style>\n",
+ "<table border=\"1\" class=\"dataframe\">\n",
+ " <thead>\n",
+ " <tr style=\"text-align: right;\">\n",
+ " <th></th>\n",
+ " <th>idf1</th>\n",
+ " <th>idp</th>\n",
+ " <th>idr</th>\n",
+ " <th>recall</th>\n",
+ " <th>precision</th>\n",
+ " <th>num_unique_objects</th>\n",
+ " <th>mostly_tracked</th>\n",
+ " <th>partially_tracked</th>\n",
+ " <th>mostly_lost</th>\n",
+ " <th>num_false_positives</th>\n",
+ " <th>num_misses</th>\n",
+ " <th>num_switches</th>\n",
+ " <th>num_fragmentations</th>\n",
+ " <th>mota</th>\n",
+ " <th>motp</th>\n",
+ " <th>num_transfer</th>\n",
+ " <th>num_ascend</th>\n",
+ " <th>num_migrate</th>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>set</th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " </tr>\n",
+ " </thead>\n",
+ " <tbody>\n",
+ " <tr>\n",
+ " <th>control</th>\n",
+ " <td>0.797811</td>\n",
+ " <td>0.712093</td>\n",
+ " <td>0.907404</td>\n",
+ " <td>0.964629</td>\n",
+ " <td>0.756651</td>\n",
+ " <td>86.0</td>\n",
+ " <td>80.333333</td>\n",
+ " <td>5.666667</td>\n",
+ " <td>0.0</td>\n",
+ " <td>1544.333333</td>\n",
+ " <td>175.333333</td>\n",
+ " <td>48.333333</td>\n",
+ " <td>23.666667</td>\n",
+ " <td>0.643333</td>\n",
+ " <td>0.043260</td>\n",
+ " <td>3.0</td>\n",
+ " <td>43.000000</td>\n",
+ " <td>0.666667</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>surprise</th>\n",
+ " <td>0.775593</td>\n",
+ " <td>0.672416</td>\n",
+ " <td>0.917448</td>\n",
+ " <td>0.966854</td>\n",
+ " <td>0.708470</td>\n",
+ " <td>33.0</td>\n",
+ " <td>30.666667</td>\n",
+ " <td>2.333333</td>\n",
+ " <td>0.0</td>\n",
+ " <td>642.666667</td>\n",
+ " <td>53.000000</td>\n",
+ " <td>18.666667</td>\n",
+ " <td>5.000000</td>\n",
+ " <td>0.553262</td>\n",
+ " <td>0.043718</td>\n",
+ " <td>4.0</td>\n",
+ " <td>13.333333</td>\n",
+ " <td>0.000000</td>\n",
+ " </tr>\n",
+ " </tbody>\n",
+ "</table>\n",
+ "</div>"
+ ],
+ "text/plain": [
+ " idf1 idp idr recall precision \\\n",
+ "set \n",
+ "control 0.797811 0.712093 0.907404 0.964629 0.756651 \n",
+ "surprise 0.775593 0.672416 0.917448 0.966854 0.708470 \n",
+ "\n",
+ " num_unique_objects mostly_tracked partially_tracked mostly_lost \\\n",
+ "set \n",
+ "control 86.0 80.333333 5.666667 0.0 \n",
+ "surprise 33.0 30.666667 2.333333 0.0 \n",
+ "\n",
+ " num_false_positives num_misses num_switches num_fragmentations \\\n",
+ "set \n",
+ "control 1544.333333 175.333333 48.333333 23.666667 \n",
+ "surprise 642.666667 53.000000 18.666667 5.000000 \n",
+ "\n",
+ " mota motp num_transfer num_ascend num_migrate \n",
+ "set \n",
+ "control 0.643333 0.043260 3.0 43.000000 0.666667 \n",
+ "surprise 0.553262 0.043718 4.0 13.333333 0.000000 "
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "af[af.index == 'OVERALL'].groupby(['set']).mean()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Gate Openings"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Percept gate openings when visible: M: 1.01 , STD: 0.928, Count: 8266\n",
+ "Percept gate openings when occluded: M: 0.0118 , STD: 0.133, Count: 2236\n"
+ ]
+ }
+ ],
+ "source": [
+ "grouping = (sf.inimage & sf.bound & ~sf.occluder & sf.control)\n",
+ "temp = sf[grouping & sf.visible]\n",
+ "print(f'Percept gate openings when visible:' + get_stats(temp['alpha_pos'] + temp['alpha_ges']))\n",
+ "temp = sf[grouping & ~sf.visible]\n",
+ "print(f'Percept gate openings when occluded:' + get_stats(temp['alpha_pos'] + temp['alpha_ges']))"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "loci23",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.16"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/evaluation/adept/plots/loci_looped_complete.pdf b/evaluation/adept/plots/loci_looped_complete.pdf
new file mode 100644
index 0000000..306c3cf
--- /dev/null
+++ b/evaluation/adept/plots/loci_looped_complete.pdf
Binary files differ
diff --git a/evaluation/adept/plots/loci_looped_gates.pdf b/evaluation/adept/plots/loci_looped_gates.pdf
new file mode 100644
index 0000000..13d2630
--- /dev/null
+++ b/evaluation/adept/plots/loci_looped_gates.pdf
Binary files differ
diff --git a/evaluation/adept/plots/loci_looped_gates2.pdf b/evaluation/adept/plots/loci_looped_gates2.pdf
new file mode 100644
index 0000000..d4513e3
--- /dev/null
+++ b/evaluation/adept/plots/loci_looped_gates2.pdf
Binary files differ
diff --git a/evaluation/adept/plots/loci_looped_surprise.pdf b/evaluation/adept/plots/loci_looped_surprise.pdf
new file mode 100644
index 0000000..8faea0b
--- /dev/null
+++ b/evaluation/adept/plots/loci_looped_surprise.pdf
Binary files differ
diff --git a/evaluation/adept/plots/loci_looped_voe.pdf b/evaluation/adept/plots/loci_looped_voe.pdf
new file mode 100644
index 0000000..9ae7116
--- /dev/null
+++ b/evaluation/adept/plots/loci_looped_voe.pdf
Binary files differ
diff --git a/evaluation/adept/plots/loci_looped_voe2.pdf b/evaluation/adept/plots/loci_looped_voe2.pdf
new file mode 100644
index 0000000..c3190da
--- /dev/null
+++ b/evaluation/adept/plots/loci_looped_voe2.pdf
Binary files differ
diff --git a/evaluation/adept/plots/loci_looped_voe3.pdf b/evaluation/adept/plots/loci_looped_voe3.pdf
new file mode 100644
index 0000000..52bbff8
--- /dev/null
+++ b/evaluation/adept/plots/loci_looped_voe3.pdf
Binary files differ
diff --git a/evaluation/adept_ablation/lambda/evaluation_loci_looped_abl1.ipynb b/evaluation/adept_ablation/lambda/evaluation_loci_looped_abl1.ipynb
new file mode 100644
index 0000000..8fde028
--- /dev/null
+++ b/evaluation/adept_ablation/lambda/evaluation_loci_looped_abl1.ipynb
@@ -0,0 +1,479 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import seaborn as sns\n",
+ "import warnings\n",
+ "import scipy.stats as stats\n",
+ "import os\n",
+ "\n",
+ "warnings.simplefilter(action='ignore', category=FutureWarning)\n",
+ "pd.options.mode.chained_assignment = None \n",
+ "plt.style.use('ggplot')\n",
+ "sns.color_palette(\"Paired\");\n",
+ "sns.set_theme();"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Data Loading"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "ename": "FileNotFoundError",
+ "evalue": "[Errno 2] No such file or directory: '../../../out/pretrained/adept_ablations/lambda/results/adept_level1_ablation_lambda.run311/trialframe.csv'",
+ "output_type": "error",
+ "traceback": [
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+ "\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
+ "\u001b[1;32m/Users/fredericbecker/Documents/Master/Masterarbeit/code/local/loci_public/evaluation/adept_ablation/lambda/evaluation_loci_looped_abl1.ipynb Cell 3\u001b[0m line \u001b[0;36m1\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/fredericbecker/Documents/Master/Masterarbeit/code/local/loci_public/evaluation/adept_ablation/lambda/evaluation_loci_looped_abl1.ipynb#W2sZmlsZQ%3D%3D?line=12'>13</a>\u001b[0m \u001b[39mfor\u001b[39;00m net \u001b[39min\u001b[39;00m nets:\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/fredericbecker/Documents/Master/Masterarbeit/code/local/loci_public/evaluation/adept_ablation/lambda/evaluation_loci_looped_abl1.ipynb#W2sZmlsZQ%3D%3D?line=13'>14</a>\u001b[0m path \u001b[39m=\u001b[39m os\u001b[39m.\u001b[39mpath\u001b[39m.\u001b[39mjoin(root_path, \u001b[39m'\u001b[39m\u001b[39mresults\u001b[39m\u001b[39m'\u001b[39m, net)\n\u001b[0;32m---> <a href='vscode-notebook-cell:/Users/fredericbecker/Documents/Master/Masterarbeit/code/local/loci_public/evaluation/adept_ablation/lambda/evaluation_loci_looped_abl1.ipynb#W2sZmlsZQ%3D%3D?line=14'>15</a>\u001b[0m \u001b[39mwith\u001b[39;00m \u001b[39mopen\u001b[39;49m(os\u001b[39m.\u001b[39;49mpath\u001b[39m.\u001b[39;49mjoin(path, \u001b[39m'\u001b[39;49m\u001b[39mtrialframe.csv\u001b[39;49m\u001b[39m'\u001b[39;49m), \u001b[39m'\u001b[39;49m\u001b[39mrb\u001b[39;49m\u001b[39m'\u001b[39;49m) \u001b[39mas\u001b[39;00m f:\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/fredericbecker/Documents/Master/Masterarbeit/code/local/loci_public/evaluation/adept_ablation/lambda/evaluation_loci_looped_abl1.ipynb#W2sZmlsZQ%3D%3D?line=15'>16</a>\u001b[0m tf_temp \u001b[39m=\u001b[39m pd\u001b[39m.\u001b[39mread_csv(f, index_col\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m)\n\u001b[1;32m <a href='vscode-notebook-cell:/Users/fredericbecker/Documents/Master/Masterarbeit/code/local/loci_public/evaluation/adept_ablation/lambda/evaluation_loci_looped_abl1.ipynb#W2sZmlsZQ%3D%3D?line=16'>17</a>\u001b[0m tf_temp[\u001b[39m'\u001b[39m\u001b[39mnet\u001b[39m\u001b[39m'\u001b[39m] \u001b[39m=\u001b[39m net\n",
+ "File \u001b[0;32m~/miniconda3/envs/loci23/lib/python3.9/site-packages/IPython/core/interactiveshell.py:282\u001b[0m, in \u001b[0;36m_modified_open\u001b[0;34m(file, *args, **kwargs)\u001b[0m\n\u001b[1;32m 275\u001b[0m \u001b[39mif\u001b[39;00m file \u001b[39min\u001b[39;00m {\u001b[39m0\u001b[39m, \u001b[39m1\u001b[39m, \u001b[39m2\u001b[39m}:\n\u001b[1;32m 276\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 277\u001b[0m \u001b[39mf\u001b[39m\u001b[39m\"\u001b[39m\u001b[39mIPython won\u001b[39m\u001b[39m'\u001b[39m\u001b[39mt let you open fd=\u001b[39m\u001b[39m{\u001b[39;00mfile\u001b[39m}\u001b[39;00m\u001b[39m by default \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 278\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mas it is likely to crash IPython. If you know what you are doing, \u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 279\u001b[0m \u001b[39m\"\u001b[39m\u001b[39myou can use builtins\u001b[39m\u001b[39m'\u001b[39m\u001b[39m open.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 280\u001b[0m )\n\u001b[0;32m--> 282\u001b[0m \u001b[39mreturn\u001b[39;00m io_open(file, \u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwargs)\n",
+ "\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '../../../out/pretrained/adept_ablations/lambda/results/adept_level1_ablation_lambda.run311/trialframe.csv'"
+ ]
+ }
+ ],
+ "source": [
+ "# setting path to results folder\n",
+ "root_path = '../../../out/pretrained/adept_ablations/lambda'\n",
+ "\n",
+ "# list all folders in root path that don't stat with a dot\n",
+ "nets = ['adept_level1_ablation_lambda.run311']\n",
+ "\n",
+ "# read pickle file\n",
+ "tf = pd.DataFrame()\n",
+ "sf = pd.DataFrame()\n",
+ "af = pd.DataFrame()\n",
+ "\n",
+ "# load statistics files from nets\n",
+ "for net in nets:\n",
+ " path = os.path.join(root_path, net, 'results')\n",
+ " with open(os.path.join(path, 'trialframe.csv'), 'rb') as f:\n",
+ " tf_temp = pd.read_csv(f, index_col=0)\n",
+ " tf_temp['net'] = net\n",
+ " tf = pd.concat([tf,tf_temp])\n",
+ "\n",
+ " with open(os.path.join(path, 'slotframe.csv'), 'rb') as f:\n",
+ " sf_temp = pd.read_csv(f, index_col=0)\n",
+ " sf_temp['net'] = net\n",
+ " sf = pd.concat([sf,sf_temp])\n",
+ "\n",
+ " with open(os.path.join(path, 'accframe.csv'), 'rb') as f:\n",
+ " af_temp = pd.read_csv(f, index_col=0)\n",
+ " af_temp['net'] = net\n",
+ " af = pd.concat([af,af_temp])\n",
+ "\n",
+ "# cast variables\n",
+ "sf['visible'] = sf['visible'].astype(bool)\n",
+ "sf['bound'] = sf['bound'].astype(bool)\n",
+ "sf['occluder'] = sf['occluder'].astype(bool)\n",
+ "sf['inimage'] = sf['inimage'].astype(bool)\n",
+ "sf['vanishing'] = sf['vanishing'].astype(bool)\n",
+ "sf['alpha_pos'] = 1-sf['alpha_pos']\n",
+ "sf['alpha_ges'] = 1-sf['alpha_ges']\n",
+ "\n",
+ "# scale to percentage\n",
+ "sf['TE'] = sf['TE'] * 100\n",
+ "\n",
+ "# add surprise as dummy code\n",
+ "tf['control'] = [('control' in set) for set in tf['set']]\n",
+ "sf['control'] = [('control' in set) for set in sf['set']]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Calculate Tracking Error (TE)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Tracking Error when visible: M: 11.6 , STD: 13.5, Count: 2281\n",
+ "Tracking Error when occluded: M: 7.1 , STD: 4.02, Count: 564\n"
+ ]
+ }
+ ],
+ "source": [
+ "grouping = (sf.inimage & sf.bound & ~sf.occluder & sf.control)\n",
+ "\n",
+ "def get_stats(col):\n",
+ " return f' M: {col.mean():.3} , STD: {col.std():.3}, Count: {col.count()}'\n",
+ "\n",
+ "# When Visible\n",
+ "temp = sf[grouping & sf.visible]\n",
+ "print(f'Tracking Error when visible:' + get_stats(temp['TE']))\n",
+ "\n",
+ "# When Occluded\n",
+ "temp = sf[grouping & ~sf.visible]\n",
+ "print(f'Tracking Error when occluded:' + get_stats(temp['TE']))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Calculate Succesfull Trackings (TE)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<div>\n",
+ "<style scoped>\n",
+ " .dataframe tbody tr th:only-of-type {\n",
+ " vertical-align: middle;\n",
+ " }\n",
+ "\n",
+ " .dataframe tbody tr th {\n",
+ " vertical-align: top;\n",
+ " }\n",
+ "\n",
+ " .dataframe thead th {\n",
+ " text-align: right;\n",
+ " }\n",
+ "</style>\n",
+ "<table border=\"1\" class=\"dataframe\">\n",
+ " <thead>\n",
+ " <tr style=\"text-align: right;\">\n",
+ " <th></th>\n",
+ " <th>set</th>\n",
+ " <th>evalmode</th>\n",
+ " <th>tracked_pos</th>\n",
+ " <th>tracked_neg</th>\n",
+ " <th>tracked_pos_pro</th>\n",
+ " <th>tracked_neg_pro</th>\n",
+ " </tr>\n",
+ " </thead>\n",
+ " <tbody>\n",
+ " <tr>\n",
+ " <th>0</th>\n",
+ " <td>control</td>\n",
+ " <td>open</td>\n",
+ " <td>0</td>\n",
+ " <td>54</td>\n",
+ " <td>0.0</td>\n",
+ " <td>1.0</td>\n",
+ " </tr>\n",
+ " </tbody>\n",
+ "</table>\n",
+ "</div>"
+ ],
+ "text/plain": [
+ " set evalmode tracked_pos tracked_neg tracked_pos_pro \\\n",
+ "0 control open 0 54 0.0 \n",
+ "\n",
+ " tracked_neg_pro \n",
+ "0 1.0 "
+ ]
+ },
+ "execution_count": 5,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# succesfull trackings: In the last visible moment of the target, the slot was less than 10% away from the target\n",
+ "# determine last visible frame numeric\n",
+ "grouping_factors = ['net','set','evalmode','scene','slot']\n",
+ "ff = sf[sf.visible & sf.bound & sf.inimage].groupby(grouping_factors).max()\n",
+ "ff.rename(columns = {'frame':'last_visible'}, inplace = True)\n",
+ "sf = sf.merge(ff[['last_visible']], on=grouping_factors, how='left')\n",
+ "\n",
+ "# same for first bound frame\n",
+ "ff = sf[sf.visible & sf.bound & sf.inimage].groupby(grouping_factors).min()\n",
+ "ff.rename(columns = {'frame':'first_visible'}, inplace = True)\n",
+ "sf = sf.merge(ff[['first_visible']], on=grouping_factors, how='left')\n",
+ "\n",
+ "# add dummy variable to sf\n",
+ "sf['last_visible'] = (sf['last_visible'] == sf['frame'])\n",
+ "\n",
+ "# extract the trials where the target was last visible and threshold the TE\n",
+ "ff = sf[sf['last_visible']] \n",
+ "ff['tracked_pos'] = (ff['TE'] < 10)\n",
+ "ff['tracked_neg'] = (ff['TE'] >= 10)\n",
+ "\n",
+ "# fill NaN with 0\n",
+ "sf = sf.merge(ff[grouping_factors + ['tracked_pos', 'tracked_neg']], on=grouping_factors, how='left')\n",
+ "sf['tracked_pos'].fillna(False, inplace=True)\n",
+ "sf['tracked_neg'].fillna(False, inplace=True)\n",
+ "\n",
+ "# Aggreagte over all scenes\n",
+ "temp = sf[(sf['frame']== 1) & ~sf.occluder & sf.control & (sf.first_visible < 20)]\n",
+ "temp = temp.groupby(['set', 'evalmode']).sum()\n",
+ "temp = temp[['tracked_pos', 'tracked_neg']]\n",
+ "temp = temp.reset_index()\n",
+ "\n",
+ "temp['tracked_pos_pro'] = temp['tracked_pos'] / (temp['tracked_pos'] + temp['tracked_neg'])\n",
+ "temp['tracked_neg_pro'] = temp['tracked_neg'] / (temp['tracked_pos'] + temp['tracked_neg'])\n",
+ "\n",
+ "temp"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Mostly Tracked stats"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ "<Figure size 432x288 with 1 Axes>"
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "temp = af[af.index == 'OVERALL']\n",
+ "temp['mostly_tracked'] = temp['mostly_tracked'] / temp['num_unique_objects']\n",
+ "temp['partially_tracked'] = temp['partially_tracked'] / temp['num_unique_objects']\n",
+ "temp['mostly_lost'] = temp['mostly_lost'] / temp['num_unique_objects']\n",
+ "g = temp[['mostly_tracked', 'partially_tracked', 'mostly_lost','set']].set_index(['set']).groupby(['set']).mean().plot(kind='bar', stacked=True);"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# MOTA "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<div>\n",
+ "<style scoped>\n",
+ " .dataframe tbody tr th:only-of-type {\n",
+ " vertical-align: middle;\n",
+ " }\n",
+ "\n",
+ " .dataframe tbody tr th {\n",
+ " vertical-align: top;\n",
+ " }\n",
+ "\n",
+ " .dataframe thead th {\n",
+ " text-align: right;\n",
+ " }\n",
+ "</style>\n",
+ "<table border=\"1\" class=\"dataframe\">\n",
+ " <thead>\n",
+ " <tr style=\"text-align: right;\">\n",
+ " <th></th>\n",
+ " <th>idf1</th>\n",
+ " <th>idp</th>\n",
+ " <th>idr</th>\n",
+ " <th>recall</th>\n",
+ " <th>precision</th>\n",
+ " <th>num_unique_objects</th>\n",
+ " <th>mostly_tracked</th>\n",
+ " <th>partially_tracked</th>\n",
+ " <th>mostly_lost</th>\n",
+ " <th>num_false_positives</th>\n",
+ " <th>num_misses</th>\n",
+ " <th>num_switches</th>\n",
+ " <th>num_fragmentations</th>\n",
+ " <th>mota</th>\n",
+ " <th>motp</th>\n",
+ " <th>num_transfer</th>\n",
+ " <th>num_ascend</th>\n",
+ " <th>num_migrate</th>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>set</th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " </tr>\n",
+ " </thead>\n",
+ " <tbody>\n",
+ " <tr>\n",
+ " <th>control</th>\n",
+ " <td>0.722042</td>\n",
+ " <td>0.879472</td>\n",
+ " <td>0.612416</td>\n",
+ " <td>0.658158</td>\n",
+ " <td>0.945161</td>\n",
+ " <td>86.0</td>\n",
+ " <td>35.0</td>\n",
+ " <td>22.0</td>\n",
+ " <td>29.0</td>\n",
+ " <td>187.0</td>\n",
+ " <td>1674.0</td>\n",
+ " <td>26.0</td>\n",
+ " <td>38.0</td>\n",
+ " <td>0.614662</td>\n",
+ " <td>0.033957</td>\n",
+ " <td>5.0</td>\n",
+ " <td>21.0</td>\n",
+ " <td>0.0</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>surprise</th>\n",
+ " <td>0.738122</td>\n",
+ " <td>0.923004</td>\n",
+ " <td>0.614946</td>\n",
+ " <td>0.646612</td>\n",
+ " <td>0.970532</td>\n",
+ " <td>33.0</td>\n",
+ " <td>11.0</td>\n",
+ " <td>9.0</td>\n",
+ " <td>13.0</td>\n",
+ " <td>31.0</td>\n",
+ " <td>558.0</td>\n",
+ " <td>5.0</td>\n",
+ " <td>6.0</td>\n",
+ " <td>0.623813</td>\n",
+ " <td>0.031067</td>\n",
+ " <td>0.0</td>\n",
+ " <td>5.0</td>\n",
+ " <td>0.0</td>\n",
+ " </tr>\n",
+ " </tbody>\n",
+ "</table>\n",
+ "</div>"
+ ],
+ "text/plain": [
+ " idf1 idp idr recall precision \\\n",
+ "set \n",
+ "control 0.722042 0.879472 0.612416 0.658158 0.945161 \n",
+ "surprise 0.738122 0.923004 0.614946 0.646612 0.970532 \n",
+ "\n",
+ " num_unique_objects mostly_tracked partially_tracked mostly_lost \\\n",
+ "set \n",
+ "control 86.0 35.0 22.0 29.0 \n",
+ "surprise 33.0 11.0 9.0 13.0 \n",
+ "\n",
+ " num_false_positives num_misses num_switches num_fragmentations \\\n",
+ "set \n",
+ "control 187.0 1674.0 26.0 38.0 \n",
+ "surprise 31.0 558.0 5.0 6.0 \n",
+ "\n",
+ " mota motp num_transfer num_ascend num_migrate \n",
+ "set \n",
+ "control 0.614662 0.033957 5.0 21.0 0.0 \n",
+ "surprise 0.623813 0.031067 0.0 5.0 0.0 "
+ ]
+ },
+ "execution_count": 7,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "af[af.index == 'OVERALL'].groupby(['set']).mean()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Gate Openings"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Percept gate openings when visible: M: 1.02 , STD: 0.681, Count: 2281\n",
+ "Percept gate openings when occluded: M: 0.216 , STD: 0.48, Count: 564\n"
+ ]
+ }
+ ],
+ "source": [
+ "grouping = (sf.inimage & sf.bound & ~sf.occluder & sf.control)\n",
+ "temp = sf[grouping & sf.visible]\n",
+ "print(f'Percept gate openings when visible:' + get_stats(temp['alpha_pos'] + temp['alpha_ges']))\n",
+ "temp = sf[grouping & ~sf.visible]\n",
+ "print(f'Percept gate openings when occluded:' + get_stats(temp['alpha_pos'] + temp['alpha_ges']))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "loci23",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.16"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/evaluation/adept_ablation/lambda/evaluation_loci_looped_abl2.ipynb b/evaluation/adept_ablation/lambda/evaluation_loci_looped_abl2.ipynb
new file mode 100644
index 0000000..5b8e275
--- /dev/null
+++ b/evaluation/adept_ablation/lambda/evaluation_loci_looped_abl2.ipynb
@@ -0,0 +1,469 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import seaborn as sns\n",
+ "import warnings\n",
+ "import scipy.stats as stats\n",
+ "import os\n",
+ "\n",
+ "warnings.simplefilter(action='ignore', category=FutureWarning)\n",
+ "pd.options.mode.chained_assignment = None \n",
+ "plt.style.use('ggplot')\n",
+ "sns.color_palette(\"Paired\");\n",
+ "sns.set_theme();"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Data Loading"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# setting path to results folder\n",
+ "root_path = '../../../out/pretrained/adept_ablations/lambda'\n",
+ "\n",
+ "# list all folders in root path that don't stat with a dot\n",
+ "nets = ['adept_level1_ablation_lambda.run411']\n",
+ "\n",
+ "# read pickle file\n",
+ "tf = pd.DataFrame()\n",
+ "sf = pd.DataFrame()\n",
+ "af = pd.DataFrame()\n",
+ "\n",
+ "# load statistics files from nets\n",
+ "for net in nets:\n",
+ " path = os.path.join(root_path, net, 'results')\n",
+ " with open(os.path.join(path, 'trialframe.csv'), 'rb') as f:\n",
+ " tf_temp = pd.read_csv(f, index_col=0)\n",
+ " tf_temp['net'] = net\n",
+ " tf = pd.concat([tf,tf_temp])\n",
+ "\n",
+ " with open(os.path.join(path, 'slotframe.csv'), 'rb') as f:\n",
+ " sf_temp = pd.read_csv(f, index_col=0)\n",
+ " sf_temp['net'] = net\n",
+ " sf = pd.concat([sf,sf_temp])\n",
+ "\n",
+ " with open(os.path.join(path, 'accframe.csv'), 'rb') as f:\n",
+ " af_temp = pd.read_csv(f, index_col=0)\n",
+ " af_temp['net'] = net\n",
+ " af = pd.concat([af,af_temp])\n",
+ "\n",
+ "# cast variables\n",
+ "sf['visible'] = sf['visible'].astype(bool)\n",
+ "sf['bound'] = sf['bound'].astype(bool)\n",
+ "sf['occluder'] = sf['occluder'].astype(bool)\n",
+ "sf['inimage'] = sf['inimage'].astype(bool)\n",
+ "sf['vanishing'] = sf['vanishing'].astype(bool)\n",
+ "sf['alpha_pos'] = 1-sf['alpha_pos']\n",
+ "sf['alpha_ges'] = 1-sf['alpha_ges']\n",
+ "\n",
+ "# scale to percentage\n",
+ "sf['TE'] = sf['TE'] * 100\n",
+ "\n",
+ "# add surprise as dummy code\n",
+ "tf['control'] = [('control' in set) for set in tf['set']]\n",
+ "sf['control'] = [('control' in set) for set in sf['set']]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Calculate Tracking Error (TE)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Tracking Error when visible: M: 5.25 , STD: 7.52, Count: 1246\n",
+ "Tracking Error when occluded: M: 5.16 , STD: 4.93, Count: 342\n"
+ ]
+ }
+ ],
+ "source": [
+ "grouping = (sf.inimage & sf.bound & ~sf.occluder & sf.control)\n",
+ "\n",
+ "def get_stats(col):\n",
+ " return f' M: {col.mean():.3} , STD: {col.std():.3}, Count: {col.count()}'\n",
+ "\n",
+ "# When Visible\n",
+ "temp = sf[grouping & sf.visible]\n",
+ "print(f'Tracking Error when visible:' + get_stats(temp['TE']))\n",
+ "\n",
+ "# When Occluded\n",
+ "temp = sf[grouping & ~sf.visible]\n",
+ "print(f'Tracking Error when occluded:' + get_stats(temp['TE']))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Calculate Succesfull Trackings (TE)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<div>\n",
+ "<style scoped>\n",
+ " .dataframe tbody tr th:only-of-type {\n",
+ " vertical-align: middle;\n",
+ " }\n",
+ "\n",
+ " .dataframe tbody tr th {\n",
+ " vertical-align: top;\n",
+ " }\n",
+ "\n",
+ " .dataframe thead th {\n",
+ " text-align: right;\n",
+ " }\n",
+ "</style>\n",
+ "<table border=\"1\" class=\"dataframe\">\n",
+ " <thead>\n",
+ " <tr style=\"text-align: right;\">\n",
+ " <th></th>\n",
+ " <th>set</th>\n",
+ " <th>evalmode</th>\n",
+ " <th>tracked_pos</th>\n",
+ " <th>tracked_neg</th>\n",
+ " <th>tracked_pos_pro</th>\n",
+ " <th>tracked_neg_pro</th>\n",
+ " </tr>\n",
+ " </thead>\n",
+ " <tbody>\n",
+ " <tr>\n",
+ " <th>0</th>\n",
+ " <td>control</td>\n",
+ " <td>open</td>\n",
+ " <td>18</td>\n",
+ " <td>15</td>\n",
+ " <td>0.545455</td>\n",
+ " <td>0.454545</td>\n",
+ " </tr>\n",
+ " </tbody>\n",
+ "</table>\n",
+ "</div>"
+ ],
+ "text/plain": [
+ " set evalmode tracked_pos tracked_neg tracked_pos_pro \\\n",
+ "0 control open 18 15 0.545455 \n",
+ "\n",
+ " tracked_neg_pro \n",
+ "0 0.454545 "
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# succesfull trackings: In the last visible moment of the target, the slot was less than 10% away from the target\n",
+ "# determine last visible frame numeric\n",
+ "grouping_factors = ['net','set','evalmode','scene','slot']\n",
+ "ff = sf[sf.visible & sf.bound & sf.inimage].groupby(grouping_factors).max()\n",
+ "ff.rename(columns = {'frame':'last_visible'}, inplace = True)\n",
+ "sf = sf.merge(ff[['last_visible']], on=grouping_factors, how='left')\n",
+ "\n",
+ "# same for first bound frame\n",
+ "ff = sf[sf.visible & sf.bound & sf.inimage].groupby(grouping_factors).min()\n",
+ "ff.rename(columns = {'frame':'first_visible'}, inplace = True)\n",
+ "sf = sf.merge(ff[['first_visible']], on=grouping_factors, how='left')\n",
+ "\n",
+ "# add dummy variable to sf\n",
+ "sf['last_visible'] = (sf['last_visible'] == sf['frame'])\n",
+ "\n",
+ "# extract the trials where the target was last visible and threshold the TE\n",
+ "ff = sf[sf['last_visible']] \n",
+ "ff['tracked_pos'] = (ff['TE'] < 10)\n",
+ "ff['tracked_neg'] = (ff['TE'] >= 10)\n",
+ "\n",
+ "# fill NaN with 0\n",
+ "sf = sf.merge(ff[grouping_factors + ['tracked_pos', 'tracked_neg']], on=grouping_factors, how='left')\n",
+ "sf['tracked_pos'].fillna(False, inplace=True)\n",
+ "sf['tracked_neg'].fillna(False, inplace=True)\n",
+ "\n",
+ "# Aggreagte over all scenes\n",
+ "temp = sf[(sf['frame']== 1) & ~sf.occluder & sf.control & (sf.first_visible < 20)]\n",
+ "temp = temp.groupby(['set', 'evalmode']).sum()\n",
+ "temp = temp[['tracked_pos', 'tracked_neg']]\n",
+ "temp = temp.reset_index()\n",
+ "\n",
+ "temp['tracked_pos_pro'] = temp['tracked_pos'] / (temp['tracked_pos'] + temp['tracked_neg'])\n",
+ "temp['tracked_neg_pro'] = temp['tracked_neg'] / (temp['tracked_pos'] + temp['tracked_neg'])\n",
+ "\n",
+ "temp"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Mostly Tracked stats"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ "<Figure size 432x288 with 1 Axes>"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "temp = af[af.index == 'OVERALL']\n",
+ "temp['mostly_tracked'] = temp['mostly_tracked'] / temp['num_unique_objects']\n",
+ "temp['partially_tracked'] = temp['partially_tracked'] / temp['num_unique_objects']\n",
+ "temp['mostly_lost'] = temp['mostly_lost'] / temp['num_unique_objects']\n",
+ "g = temp[['mostly_tracked', 'partially_tracked', 'mostly_lost','set']].set_index(['set']).groupby(['set']).mean().plot(kind='bar', stacked=True);"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# MOTA "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<div>\n",
+ "<style scoped>\n",
+ " .dataframe tbody tr th:only-of-type {\n",
+ " vertical-align: middle;\n",
+ " }\n",
+ "\n",
+ " .dataframe tbody tr th {\n",
+ " vertical-align: top;\n",
+ " }\n",
+ "\n",
+ " .dataframe thead th {\n",
+ " text-align: right;\n",
+ " }\n",
+ "</style>\n",
+ "<table border=\"1\" class=\"dataframe\">\n",
+ " <thead>\n",
+ " <tr style=\"text-align: right;\">\n",
+ " <th></th>\n",
+ " <th>idf1</th>\n",
+ " <th>idp</th>\n",
+ " <th>idr</th>\n",
+ " <th>recall</th>\n",
+ " <th>precision</th>\n",
+ " <th>num_unique_objects</th>\n",
+ " <th>mostly_tracked</th>\n",
+ " <th>partially_tracked</th>\n",
+ " <th>mostly_lost</th>\n",
+ " <th>num_false_positives</th>\n",
+ " <th>num_misses</th>\n",
+ " <th>num_switches</th>\n",
+ " <th>num_fragmentations</th>\n",
+ " <th>mota</th>\n",
+ " <th>motp</th>\n",
+ " <th>num_transfer</th>\n",
+ " <th>num_ascend</th>\n",
+ " <th>num_migrate</th>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>set</th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " </tr>\n",
+ " </thead>\n",
+ " <tbody>\n",
+ " <tr>\n",
+ " <th>control</th>\n",
+ " <td>0.753959</td>\n",
+ " <td>0.934199</td>\n",
+ " <td>0.632020</td>\n",
+ " <td>0.639984</td>\n",
+ " <td>0.945970</td>\n",
+ " <td>86.0</td>\n",
+ " <td>43.0</td>\n",
+ " <td>5.0</td>\n",
+ " <td>38.0</td>\n",
+ " <td>179.0</td>\n",
+ " <td>1763.0</td>\n",
+ " <td>7.0</td>\n",
+ " <td>22.0</td>\n",
+ " <td>0.602001</td>\n",
+ " <td>0.059856</td>\n",
+ " <td>0.0</td>\n",
+ " <td>7.0</td>\n",
+ " <td>0.0</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>surprise</th>\n",
+ " <td>0.751982</td>\n",
+ " <td>0.930841</td>\n",
+ " <td>0.630779</td>\n",
+ " <td>0.644079</td>\n",
+ " <td>0.950467</td>\n",
+ " <td>33.0</td>\n",
+ " <td>13.0</td>\n",
+ " <td>6.0</td>\n",
+ " <td>14.0</td>\n",
+ " <td>53.0</td>\n",
+ " <td>562.0</td>\n",
+ " <td>3.0</td>\n",
+ " <td>11.0</td>\n",
+ " <td>0.608613</td>\n",
+ " <td>0.058177</td>\n",
+ " <td>0.0</td>\n",
+ " <td>3.0</td>\n",
+ " <td>0.0</td>\n",
+ " </tr>\n",
+ " </tbody>\n",
+ "</table>\n",
+ "</div>"
+ ],
+ "text/plain": [
+ " idf1 idp idr recall precision \\\n",
+ "set \n",
+ "control 0.753959 0.934199 0.632020 0.639984 0.945970 \n",
+ "surprise 0.751982 0.930841 0.630779 0.644079 0.950467 \n",
+ "\n",
+ " num_unique_objects mostly_tracked partially_tracked mostly_lost \\\n",
+ "set \n",
+ "control 86.0 43.0 5.0 38.0 \n",
+ "surprise 33.0 13.0 6.0 14.0 \n",
+ "\n",
+ " num_false_positives num_misses num_switches num_fragmentations \\\n",
+ "set \n",
+ "control 179.0 1763.0 7.0 22.0 \n",
+ "surprise 53.0 562.0 3.0 11.0 \n",
+ "\n",
+ " mota motp num_transfer num_ascend num_migrate \n",
+ "set \n",
+ "control 0.602001 0.059856 0.0 7.0 0.0 \n",
+ "surprise 0.608613 0.058177 0.0 3.0 0.0 "
+ ]
+ },
+ "execution_count": 14,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "af[af.index == 'OVERALL'].groupby(['set']).mean()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Gate Openings"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Percept gate openings when visible: M: 0.061 , STD: 0.108, Count: 1246\n",
+ "Percept gate openings when occluded: M: 0.00026 , STD: 0.00481, Count: 342\n"
+ ]
+ }
+ ],
+ "source": [
+ "grouping = (sf.inimage & sf.bound & ~sf.occluder & sf.control)\n",
+ "temp = sf[grouping & sf.visible]\n",
+ "print(f'Percept gate openings when visible:' + get_stats(temp['alpha_pos'] + temp['alpha_ges']))\n",
+ "temp = sf[grouping & ~sf.visible]\n",
+ "print(f'Percept gate openings when occluded:' + get_stats(temp['alpha_pos'] + temp['alpha_ges']))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "loci23",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.16"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/evaluation/adept_ablation/lambda/evaluation_loci_looped_abl3.ipynb b/evaluation/adept_ablation/lambda/evaluation_loci_looped_abl3.ipynb
new file mode 100644
index 0000000..f137a5f
--- /dev/null
+++ b/evaluation/adept_ablation/lambda/evaluation_loci_looped_abl3.ipynb
@@ -0,0 +1,469 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import seaborn as sns\n",
+ "import warnings\n",
+ "import scipy.stats as stats\n",
+ "import os\n",
+ "\n",
+ "warnings.simplefilter(action='ignore', category=FutureWarning)\n",
+ "pd.options.mode.chained_assignment = None \n",
+ "plt.style.use('ggplot')\n",
+ "sns.color_palette(\"Paired\");\n",
+ "sns.set_theme();"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Data Loading"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# setting path to results folder\n",
+ "root_path = '../../../out/pretrained/adept_ablations/lambda'\n",
+ "\n",
+ "# list all folders in root path that don't stat with a dot\n",
+ "nets = ['adept_level1_ablation_lambda.run511']\n",
+ "\n",
+ "# read pickle file\n",
+ "tf = pd.DataFrame()\n",
+ "sf = pd.DataFrame()\n",
+ "af = pd.DataFrame()\n",
+ "\n",
+ "# load statistics files from nets\n",
+ "for net in nets:\n",
+ " path = os.path.join(root_path, net, 'results')\n",
+ " with open(os.path.join(path, 'trialframe.csv'), 'rb') as f:\n",
+ " tf_temp = pd.read_csv(f, index_col=0)\n",
+ " tf_temp['net'] = net\n",
+ " tf = pd.concat([tf,tf_temp])\n",
+ "\n",
+ " with open(os.path.join(path, 'slotframe.csv'), 'rb') as f:\n",
+ " sf_temp = pd.read_csv(f, index_col=0)\n",
+ " sf_temp['net'] = net\n",
+ " sf = pd.concat([sf,sf_temp])\n",
+ "\n",
+ " with open(os.path.join(path, 'accframe.csv'), 'rb') as f:\n",
+ " af_temp = pd.read_csv(f, index_col=0)\n",
+ " af_temp['net'] = net\n",
+ " af = pd.concat([af,af_temp])\n",
+ "\n",
+ "# cast variables\n",
+ "sf['visible'] = sf['visible'].astype(bool)\n",
+ "sf['bound'] = sf['bound'].astype(bool)\n",
+ "sf['occluder'] = sf['occluder'].astype(bool)\n",
+ "sf['inimage'] = sf['inimage'].astype(bool)\n",
+ "sf['vanishing'] = sf['vanishing'].astype(bool)\n",
+ "sf['alpha_pos'] = 1-sf['alpha_pos']\n",
+ "sf['alpha_ges'] = 1-sf['alpha_ges']\n",
+ "\n",
+ "# scale to percentage\n",
+ "sf['TE'] = sf['TE'] * 100\n",
+ "\n",
+ "# add surprise as dummy code\n",
+ "tf['control'] = [('control' in set) for set in tf['set']]\n",
+ "sf['control'] = [('control' in set) for set in sf['set']]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Calculate Tracking Error (TE)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Tracking Error when visible: M: 3.34 , STD: 4.22, Count: 1873\n",
+ "Tracking Error when occluded: M: 2.78 , STD: 1.66, Count: 492\n"
+ ]
+ }
+ ],
+ "source": [
+ "grouping = (sf.inimage & sf.bound & ~sf.occluder & sf.control)\n",
+ "\n",
+ "def get_stats(col):\n",
+ " return f' M: {col.mean():.3} , STD: {col.std():.3}, Count: {col.count()}'\n",
+ "\n",
+ "# When Visible\n",
+ "temp = sf[grouping & sf.visible]\n",
+ "print(f'Tracking Error when visible:' + get_stats(temp['TE']))\n",
+ "\n",
+ "# When Occluded\n",
+ "temp = sf[grouping & ~sf.visible]\n",
+ "print(f'Tracking Error when occluded:' + get_stats(temp['TE']))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Calculate Succesfull Trackings (TE)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<div>\n",
+ "<style scoped>\n",
+ " .dataframe tbody tr th:only-of-type {\n",
+ " vertical-align: middle;\n",
+ " }\n",
+ "\n",
+ " .dataframe tbody tr th {\n",
+ " vertical-align: top;\n",
+ " }\n",
+ "\n",
+ " .dataframe thead th {\n",
+ " text-align: right;\n",
+ " }\n",
+ "</style>\n",
+ "<table border=\"1\" class=\"dataframe\">\n",
+ " <thead>\n",
+ " <tr style=\"text-align: right;\">\n",
+ " <th></th>\n",
+ " <th>set</th>\n",
+ " <th>evalmode</th>\n",
+ " <th>tracked_pos</th>\n",
+ " <th>tracked_neg</th>\n",
+ " <th>tracked_pos_pro</th>\n",
+ " <th>tracked_neg_pro</th>\n",
+ " </tr>\n",
+ " </thead>\n",
+ " <tbody>\n",
+ " <tr>\n",
+ " <th>0</th>\n",
+ " <td>control</td>\n",
+ " <td>open</td>\n",
+ " <td>37</td>\n",
+ " <td>11</td>\n",
+ " <td>0.770833</td>\n",
+ " <td>0.229167</td>\n",
+ " </tr>\n",
+ " </tbody>\n",
+ "</table>\n",
+ "</div>"
+ ],
+ "text/plain": [
+ " set evalmode tracked_pos tracked_neg tracked_pos_pro \\\n",
+ "0 control open 37 11 0.770833 \n",
+ "\n",
+ " tracked_neg_pro \n",
+ "0 0.229167 "
+ ]
+ },
+ "execution_count": 13,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# succesfull trackings: In the last visible moment of the target, the slot was less than 10% away from the target\n",
+ "# determine last visible frame numeric\n",
+ "grouping_factors = ['net','set','evalmode','scene','slot']\n",
+ "ff = sf[sf.visible & sf.bound & sf.inimage].groupby(grouping_factors).max()\n",
+ "ff.rename(columns = {'frame':'last_visible'}, inplace = True)\n",
+ "sf = sf.merge(ff[['last_visible']], on=grouping_factors, how='left')\n",
+ "\n",
+ "# same for first bound frame\n",
+ "ff = sf[sf.visible & sf.bound & sf.inimage].groupby(grouping_factors).min()\n",
+ "ff.rename(columns = {'frame':'first_visible'}, inplace = True)\n",
+ "sf = sf.merge(ff[['first_visible']], on=grouping_factors, how='left')\n",
+ "\n",
+ "# add dummy variable to sf\n",
+ "sf['last_visible'] = (sf['last_visible'] == sf['frame'])\n",
+ "\n",
+ "# extract the trials where the target was last visible and threshold the TE\n",
+ "ff = sf[sf['last_visible']] \n",
+ "ff['tracked_pos'] = (ff['TE'] < 10)\n",
+ "ff['tracked_neg'] = (ff['TE'] >= 10)\n",
+ "\n",
+ "# fill NaN with 0\n",
+ "sf = sf.merge(ff[grouping_factors + ['tracked_pos', 'tracked_neg']], on=grouping_factors, how='left')\n",
+ "sf['tracked_pos'].fillna(False, inplace=True)\n",
+ "sf['tracked_neg'].fillna(False, inplace=True)\n",
+ "\n",
+ "# Aggreagte over all scenes\n",
+ "temp = sf[(sf['frame']== 1) & ~sf.occluder & sf.control & (sf.first_visible < 20)]\n",
+ "temp = temp.groupby(['set', 'evalmode']).sum()\n",
+ "temp = temp[['tracked_pos', 'tracked_neg']]\n",
+ "temp = temp.reset_index()\n",
+ "\n",
+ "temp['tracked_pos_pro'] = temp['tracked_pos'] / (temp['tracked_pos'] + temp['tracked_neg'])\n",
+ "temp['tracked_neg_pro'] = temp['tracked_neg'] / (temp['tracked_pos'] + temp['tracked_neg'])\n",
+ "\n",
+ "temp"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Mostly Tracked stats"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ "<Figure size 432x288 with 1 Axes>"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "temp = af[af.index == 'OVERALL']\n",
+ "temp['mostly_tracked'] = temp['mostly_tracked'] / temp['num_unique_objects']\n",
+ "temp['partially_tracked'] = temp['partially_tracked'] / temp['num_unique_objects']\n",
+ "temp['mostly_lost'] = temp['mostly_lost'] / temp['num_unique_objects']\n",
+ "g = temp[['mostly_tracked', 'partially_tracked', 'mostly_lost','set']].set_index(['set']).groupby(['set']).mean().plot(kind='bar', stacked=True);"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# MOTA "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<div>\n",
+ "<style scoped>\n",
+ " .dataframe tbody tr th:only-of-type {\n",
+ " vertical-align: middle;\n",
+ " }\n",
+ "\n",
+ " .dataframe tbody tr th {\n",
+ " vertical-align: top;\n",
+ " }\n",
+ "\n",
+ " .dataframe thead th {\n",
+ " text-align: right;\n",
+ " }\n",
+ "</style>\n",
+ "<table border=\"1\" class=\"dataframe\">\n",
+ " <thead>\n",
+ " <tr style=\"text-align: right;\">\n",
+ " <th></th>\n",
+ " <th>idf1</th>\n",
+ " <th>idp</th>\n",
+ " <th>idr</th>\n",
+ " <th>recall</th>\n",
+ " <th>precision</th>\n",
+ " <th>num_unique_objects</th>\n",
+ " <th>mostly_tracked</th>\n",
+ " <th>partially_tracked</th>\n",
+ " <th>mostly_lost</th>\n",
+ " <th>num_false_positives</th>\n",
+ " <th>num_misses</th>\n",
+ " <th>num_switches</th>\n",
+ " <th>num_fragmentations</th>\n",
+ " <th>mota</th>\n",
+ " <th>motp</th>\n",
+ " <th>num_transfer</th>\n",
+ " <th>num_ascend</th>\n",
+ " <th>num_migrate</th>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>set</th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " </tr>\n",
+ " </thead>\n",
+ " <tbody>\n",
+ " <tr>\n",
+ " <th>control</th>\n",
+ " <td>0.773828</td>\n",
+ " <td>0.943857</td>\n",
+ " <td>0.655708</td>\n",
+ " <td>0.673474</td>\n",
+ " <td>0.969430</td>\n",
+ " <td>86.0</td>\n",
+ " <td>41.0</td>\n",
+ " <td>15.0</td>\n",
+ " <td>30.0</td>\n",
+ " <td>104.0</td>\n",
+ " <td>1599.0</td>\n",
+ " <td>9.0</td>\n",
+ " <td>31.0</td>\n",
+ " <td>0.650398</td>\n",
+ " <td>0.039370</td>\n",
+ " <td>0.0</td>\n",
+ " <td>9.0</td>\n",
+ " <td>0.0</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>surprise</th>\n",
+ " <td>0.776620</td>\n",
+ " <td>0.942186</td>\n",
+ " <td>0.660545</td>\n",
+ " <td>0.678911</td>\n",
+ " <td>0.968383</td>\n",
+ " <td>33.0</td>\n",
+ " <td>16.0</td>\n",
+ " <td>4.0</td>\n",
+ " <td>13.0</td>\n",
+ " <td>35.0</td>\n",
+ " <td>507.0</td>\n",
+ " <td>4.0</td>\n",
+ " <td>12.0</td>\n",
+ " <td>0.654212</td>\n",
+ " <td>0.039468</td>\n",
+ " <td>2.0</td>\n",
+ " <td>2.0</td>\n",
+ " <td>0.0</td>\n",
+ " </tr>\n",
+ " </tbody>\n",
+ "</table>\n",
+ "</div>"
+ ],
+ "text/plain": [
+ " idf1 idp idr recall precision \\\n",
+ "set \n",
+ "control 0.773828 0.943857 0.655708 0.673474 0.969430 \n",
+ "surprise 0.776620 0.942186 0.660545 0.678911 0.968383 \n",
+ "\n",
+ " num_unique_objects mostly_tracked partially_tracked mostly_lost \\\n",
+ "set \n",
+ "control 86.0 41.0 15.0 30.0 \n",
+ "surprise 33.0 16.0 4.0 13.0 \n",
+ "\n",
+ " num_false_positives num_misses num_switches num_fragmentations \\\n",
+ "set \n",
+ "control 104.0 1599.0 9.0 31.0 \n",
+ "surprise 35.0 507.0 4.0 12.0 \n",
+ "\n",
+ " mota motp num_transfer num_ascend num_migrate \n",
+ "set \n",
+ "control 0.650398 0.039370 0.0 9.0 0.0 \n",
+ "surprise 0.654212 0.039468 2.0 2.0 0.0 "
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "af[af.index == 'OVERALL'].groupby(['set']).mean()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Gate Openings"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Percept gate openings when visible: M: 0.132 , STD: 0.145, Count: 1873\n",
+ "Percept gate openings when occluded: M: 0.00809 , STD: 0.0444, Count: 492\n"
+ ]
+ }
+ ],
+ "source": [
+ "grouping = (sf.inimage & sf.bound & ~sf.occluder & sf.control)\n",
+ "temp = sf[grouping & sf.visible]\n",
+ "print(f'Percept gate openings when visible:' + get_stats(temp['alpha_pos'] + temp['alpha_ges']))\n",
+ "temp = sf[grouping & ~sf.visible]\n",
+ "print(f'Percept gate openings when occluded:' + get_stats(temp['alpha_pos'] + temp['alpha_ges']))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": []
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "loci23",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.16"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/evaluation/adept_ablation/recon/evaluation_loci_looped_norecon.ipynb b/evaluation/adept_ablation/recon/evaluation_loci_looped_norecon.ipynb
new file mode 100644
index 0000000..a1bbc32
--- /dev/null
+++ b/evaluation/adept_ablation/recon/evaluation_loci_looped_norecon.ipynb
@@ -0,0 +1,464 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import seaborn as sns\n",
+ "import warnings\n",
+ "import scipy.stats as stats\n",
+ "import os\n",
+ "\n",
+ "warnings.simplefilter(action='ignore', category=FutureWarning)\n",
+ "pd.options.mode.chained_assignment = None \n",
+ "plt.style.use('ggplot')\n",
+ "sns.color_palette(\"Paired\");\n",
+ "sns.set_theme();"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Data Loading"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# setting path to results folder\n",
+ "root_path = '../../../out/pretrained/adept_ablations/recon'\n",
+ "\n",
+ "# list all folders in root path that don't stat with a dot\n",
+ "nets = ['norecon']\n",
+ "\n",
+ "# read pickle file\n",
+ "tf = pd.DataFrame()\n",
+ "sf = pd.DataFrame()\n",
+ "af = pd.DataFrame()\n",
+ "\n",
+ "# load statistics files from nets\n",
+ "for net in nets:\n",
+ " path = os.path.join(root_path, net, 'results')\n",
+ " with open(os.path.join(path, 'trialframe.csv'), 'rb') as f:\n",
+ " tf_temp = pd.read_csv(f, index_col=0)\n",
+ " tf_temp['net'] = net\n",
+ " tf = pd.concat([tf,tf_temp])\n",
+ "\n",
+ " with open(os.path.join(path, 'slotframe.csv'), 'rb') as f:\n",
+ " sf_temp = pd.read_csv(f, index_col=0)\n",
+ " sf_temp['net'] = net\n",
+ " sf = pd.concat([sf,sf_temp])\n",
+ "\n",
+ " with open(os.path.join(path, 'accframe.csv'), 'rb') as f:\n",
+ " af_temp = pd.read_csv(f, index_col=0)\n",
+ " af_temp['net'] = net\n",
+ " af = pd.concat([af,af_temp])\n",
+ "\n",
+ "# cast variables\n",
+ "sf['visible'] = sf['visible'].astype(bool)\n",
+ "sf['bound'] = sf['bound'].astype(bool)\n",
+ "sf['occluder'] = sf['occluder'].astype(bool)\n",
+ "sf['inimage'] = sf['inimage'].astype(bool)\n",
+ "sf['vanishing'] = sf['vanishing'].astype(bool)\n",
+ "sf['alpha_pos'] = 1-sf['alpha_pos']\n",
+ "sf['alpha_ges'] = 1-sf['alpha_ges']\n",
+ "\n",
+ "# scale to percentage\n",
+ "sf['TE'] = sf['TE'] * 100\n",
+ "\n",
+ "# add surprise as dummy code\n",
+ "tf['control'] = [('control' in set) for set in tf['set']]\n",
+ "sf['control'] = [('control' in set) for set in sf['set']]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Calculate Tracking Error (TE)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 18,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Tracking Error when visible: M: 1.75 , STD: 1.23, Count: 1563\n",
+ "Tracking Error when occluded: M: 2.09 , STD: 1.39, Count: 486\n"
+ ]
+ }
+ ],
+ "source": [
+ "grouping = (sf.inimage & sf.bound & ~sf.occluder & sf.control)\n",
+ "\n",
+ "def get_stats(col):\n",
+ " return f' M: {col.mean():.3} , STD: {col.std():.3}, Count: {col.count()}'\n",
+ "\n",
+ "# When Visible\n",
+ "temp = sf[grouping & sf.visible]\n",
+ "print(f'Tracking Error when visible:' + get_stats(temp['TE']))\n",
+ "\n",
+ "# When Occluded\n",
+ "temp = sf[grouping & ~sf.visible]\n",
+ "print(f'Tracking Error when occluded:' + get_stats(temp['TE']))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Calculate Succesfull Trackings (TE)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 19,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<div>\n",
+ "<style scoped>\n",
+ " .dataframe tbody tr th:only-of-type {\n",
+ " vertical-align: middle;\n",
+ " }\n",
+ "\n",
+ " .dataframe tbody tr th {\n",
+ " vertical-align: top;\n",
+ " }\n",
+ "\n",
+ " .dataframe thead th {\n",
+ " text-align: right;\n",
+ " }\n",
+ "</style>\n",
+ "<table border=\"1\" class=\"dataframe\">\n",
+ " <thead>\n",
+ " <tr style=\"text-align: right;\">\n",
+ " <th></th>\n",
+ " <th>set</th>\n",
+ " <th>evalmode</th>\n",
+ " <th>tracked_pos</th>\n",
+ " <th>tracked_neg</th>\n",
+ " <th>tracked_pos_pro</th>\n",
+ " <th>tracked_neg_pro</th>\n",
+ " </tr>\n",
+ " </thead>\n",
+ " <tbody>\n",
+ " <tr>\n",
+ " <th>0</th>\n",
+ " <td>control</td>\n",
+ " <td>open</td>\n",
+ " <td>47</td>\n",
+ " <td>1</td>\n",
+ " <td>0.979167</td>\n",
+ " <td>0.020833</td>\n",
+ " </tr>\n",
+ " </tbody>\n",
+ "</table>\n",
+ "</div>"
+ ],
+ "text/plain": [
+ " set evalmode tracked_pos tracked_neg tracked_pos_pro \\\n",
+ "0 control open 47 1 0.979167 \n",
+ "\n",
+ " tracked_neg_pro \n",
+ "0 0.020833 "
+ ]
+ },
+ "execution_count": 19,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# succesfull trackings: In the last visible moment of the target, the slot was less than 10% away from the target\n",
+ "# determine last visible frame numeric\n",
+ "grouping_factors = ['net','set','evalmode','scene','slot']\n",
+ "ff = sf[sf.visible & sf.bound & sf.inimage].groupby(grouping_factors).max()\n",
+ "ff.rename(columns = {'frame':'last_visible'}, inplace = True)\n",
+ "sf = sf.merge(ff[['last_visible']], on=grouping_factors, how='left')\n",
+ "\n",
+ "# same for first bound frame\n",
+ "ff = sf[sf.visible & sf.bound & sf.inimage].groupby(grouping_factors).min()\n",
+ "ff.rename(columns = {'frame':'first_visible'}, inplace = True)\n",
+ "sf = sf.merge(ff[['first_visible']], on=grouping_factors, how='left')\n",
+ "\n",
+ "# add dummy variable to sf\n",
+ "sf['last_visible'] = (sf['last_visible'] == sf['frame'])\n",
+ "\n",
+ "# extract the trials where the target was last visible and threshold the TE\n",
+ "ff = sf[sf['last_visible']] \n",
+ "ff['tracked_pos'] = (ff['TE'] < 10)\n",
+ "ff['tracked_neg'] = (ff['TE'] >= 10)\n",
+ "\n",
+ "# fill NaN with 0\n",
+ "sf = sf.merge(ff[grouping_factors + ['tracked_pos', 'tracked_neg']], on=grouping_factors, how='left')\n",
+ "sf['tracked_pos'].fillna(False, inplace=True)\n",
+ "sf['tracked_neg'].fillna(False, inplace=True)\n",
+ "\n",
+ "# Aggreagte over all scenes\n",
+ "temp = sf[(sf['frame']== 1) & ~sf.occluder & sf.control & (sf.first_visible < 20)]\n",
+ "temp = temp.groupby(['set', 'evalmode']).sum()\n",
+ "temp = temp[['tracked_pos', 'tracked_neg']]\n",
+ "temp = temp.reset_index()\n",
+ "\n",
+ "temp['tracked_pos_pro'] = temp['tracked_pos'] / (temp['tracked_pos'] + temp['tracked_neg'])\n",
+ "temp['tracked_neg_pro'] = temp['tracked_neg'] / (temp['tracked_pos'] + temp['tracked_neg'])\n",
+ "\n",
+ "temp"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Mostly Tracked stats"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 20,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ "<Figure size 432x288 with 1 Axes>"
+ ]
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "temp = af[af.index == 'OVERALL']\n",
+ "temp['mostly_tracked'] = temp['mostly_tracked'] / temp['num_unique_objects']\n",
+ "temp['partially_tracked'] = temp['partially_tracked'] / temp['num_unique_objects']\n",
+ "temp['mostly_lost'] = temp['mostly_lost'] / temp['num_unique_objects']\n",
+ "g = temp[['mostly_tracked', 'partially_tracked', 'mostly_lost','set']].set_index(['set']).groupby(['set']).mean().plot(kind='bar', stacked=True);"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# MOTA "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 21,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<div>\n",
+ "<style scoped>\n",
+ " .dataframe tbody tr th:only-of-type {\n",
+ " vertical-align: middle;\n",
+ " }\n",
+ "\n",
+ " .dataframe tbody tr th {\n",
+ " vertical-align: top;\n",
+ " }\n",
+ "\n",
+ " .dataframe thead th {\n",
+ " text-align: right;\n",
+ " }\n",
+ "</style>\n",
+ "<table border=\"1\" class=\"dataframe\">\n",
+ " <thead>\n",
+ " <tr style=\"text-align: right;\">\n",
+ " <th></th>\n",
+ " <th>idf1</th>\n",
+ " <th>idp</th>\n",
+ " <th>idr</th>\n",
+ " <th>recall</th>\n",
+ " <th>precision</th>\n",
+ " <th>num_unique_objects</th>\n",
+ " <th>mostly_tracked</th>\n",
+ " <th>partially_tracked</th>\n",
+ " <th>mostly_lost</th>\n",
+ " <th>num_false_positives</th>\n",
+ " <th>num_misses</th>\n",
+ " <th>num_switches</th>\n",
+ " <th>num_fragmentations</th>\n",
+ " <th>mota</th>\n",
+ " <th>motp</th>\n",
+ " <th>num_transfer</th>\n",
+ " <th>num_ascend</th>\n",
+ " <th>num_migrate</th>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>set</th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " </tr>\n",
+ " </thead>\n",
+ " <tbody>\n",
+ " <tr>\n",
+ " <th>control</th>\n",
+ " <td>0.793020</td>\n",
+ " <td>0.965426</td>\n",
+ " <td>0.672861</td>\n",
+ " <td>0.675311</td>\n",
+ " <td>0.968942</td>\n",
+ " <td>86.0</td>\n",
+ " <td>44.0</td>\n",
+ " <td>12.0</td>\n",
+ " <td>30.0</td>\n",
+ " <td>106.0</td>\n",
+ " <td>1590.0</td>\n",
+ " <td>2.0</td>\n",
+ " <td>37.0</td>\n",
+ " <td>0.653257</td>\n",
+ " <td>0.035403</td>\n",
+ " <td>0.0</td>\n",
+ " <td>2.0</td>\n",
+ " <td>0.0</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>surprise</th>\n",
+ " <td>0.781989</td>\n",
+ " <td>0.959484</td>\n",
+ " <td>0.659911</td>\n",
+ " <td>0.666244</td>\n",
+ " <td>0.968692</td>\n",
+ " <td>33.0</td>\n",
+ " <td>13.0</td>\n",
+ " <td>7.0</td>\n",
+ " <td>13.0</td>\n",
+ " <td>34.0</td>\n",
+ " <td>527.0</td>\n",
+ " <td>2.0</td>\n",
+ " <td>11.0</td>\n",
+ " <td>0.643445</td>\n",
+ " <td>0.034990</td>\n",
+ " <td>0.0</td>\n",
+ " <td>2.0</td>\n",
+ " <td>0.0</td>\n",
+ " </tr>\n",
+ " </tbody>\n",
+ "</table>\n",
+ "</div>"
+ ],
+ "text/plain": [
+ " idf1 idp idr recall precision \\\n",
+ "set \n",
+ "control 0.793020 0.965426 0.672861 0.675311 0.968942 \n",
+ "surprise 0.781989 0.959484 0.659911 0.666244 0.968692 \n",
+ "\n",
+ " num_unique_objects mostly_tracked partially_tracked mostly_lost \\\n",
+ "set \n",
+ "control 86.0 44.0 12.0 30.0 \n",
+ "surprise 33.0 13.0 7.0 13.0 \n",
+ "\n",
+ " num_false_positives num_misses num_switches num_fragmentations \\\n",
+ "set \n",
+ "control 106.0 1590.0 2.0 37.0 \n",
+ "surprise 34.0 527.0 2.0 11.0 \n",
+ "\n",
+ " mota motp num_transfer num_ascend num_migrate \n",
+ "set \n",
+ "control 0.653257 0.035403 0.0 2.0 0.0 \n",
+ "surprise 0.643445 0.034990 0.0 2.0 0.0 "
+ ]
+ },
+ "execution_count": 21,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "af[af.index == 'OVERALL'].groupby(['set']).mean()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Gate Openings"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 22,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Percept gate openings when visible: M: 0.257 , STD: 0.237, Count: 1563\n",
+ "Percept gate openings when occluded: M: 0.00921 , STD: 0.0524, Count: 486\n"
+ ]
+ }
+ ],
+ "source": [
+ "grouping = (sf.inimage & sf.bound & ~sf.occluder & sf.control)\n",
+ "temp = sf[grouping & sf.visible]\n",
+ "print(f'Percept gate openings when visible:' + get_stats(temp['alpha_pos'] + temp['alpha_ges']))\n",
+ "temp = sf[grouping & ~sf.visible]\n",
+ "print(f'Percept gate openings when occluded:' + get_stats(temp['alpha_pos'] + temp['alpha_ges']))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "loci23",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.16"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/evaluation/adept_ablation/recon/evaluation_loci_looped_recon.ipynb b/evaluation/adept_ablation/recon/evaluation_loci_looped_recon.ipynb
new file mode 100644
index 0000000..a6422e0
--- /dev/null
+++ b/evaluation/adept_ablation/recon/evaluation_loci_looped_recon.ipynb
@@ -0,0 +1,466 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import pandas as pd\n",
+ "import matplotlib.pyplot as plt\n",
+ "import numpy as np\n",
+ "import seaborn as sns\n",
+ "import warnings\n",
+ "import scipy.stats as stats\n",
+ "import os\n",
+ "\n",
+ "warnings.simplefilter(action='ignore', category=FutureWarning)\n",
+ "pd.options.mode.chained_assignment = None \n",
+ "plt.style.use('ggplot')\n",
+ "sns.color_palette(\"Paired\");\n",
+ "sns.set_theme();"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Data Loading"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# setting path to results folder\n",
+ "root_path = '../../../out/pretrained/adept_ablations/recon'\n",
+ "\n",
+ "# list all folders in root path that don't stat with a dot\n",
+ "nets = ['recon']\n",
+ "\n",
+ "# read pickle file\n",
+ "tf = pd.DataFrame()\n",
+ "sf = pd.DataFrame()\n",
+ "af = pd.DataFrame()\n",
+ "\n",
+ "# load statistics files from nets\n",
+ "for net in nets:\n",
+ " path = os.path.join(root_path, net, 'results')\n",
+ " with open(os.path.join(path, 'trialframe.csv'), 'rb') as f:\n",
+ " tf_temp = pd.read_csv(f, index_col=0)\n",
+ " tf_temp['net'] = net\n",
+ " tf = pd.concat([tf,tf_temp])\n",
+ "\n",
+ " with open(os.path.join(path, 'slotframe.csv'), 'rb') as f:\n",
+ " sf_temp = pd.read_csv(f, index_col=0)\n",
+ " sf_temp['net'] = net\n",
+ " sf = pd.concat([sf,sf_temp])\n",
+ "\n",
+ " with open(os.path.join(path, 'accframe.csv'), 'rb') as f:\n",
+ " af_temp = pd.read_csv(f, index_col=0)\n",
+ " af_temp['net'] = net\n",
+ " af = pd.concat([af,af_temp])\n",
+ "\n",
+ "# cast variables\n",
+ "sf['visible'] = sf['visible'].astype(bool)\n",
+ "sf['bound'] = sf['bound'].astype(bool)\n",
+ "sf['occluder'] = sf['occluder'].astype(bool)\n",
+ "sf['inimage'] = sf['inimage'].astype(bool)\n",
+ "sf['vanishing'] = sf['vanishing'].astype(bool)\n",
+ "sf['alpha_pos'] = 1-sf['alpha_pos']\n",
+ "sf['alpha_ges'] = 1-sf['alpha_ges']\n",
+ "\n",
+ "# scale to percentage\n",
+ "sf['TE'] = sf['TE'] * 100\n",
+ "\n",
+ "# add surprise as dummy code\n",
+ "tf['control'] = [('control' in set) for set in tf['set']]\n",
+ "sf['control'] = [('control' in set) for set in sf['set']]"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Calculate Tracking Error (TE)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Tracking Error when visible: M: 2.95 , STD: 4.8, Count: 1686\n",
+ "Tracking Error when occluded: M: 2.54 , STD: 2.52, Count: 477\n"
+ ]
+ }
+ ],
+ "source": [
+ "grouping = (sf.inimage & sf.bound & ~sf.occluder & sf.control)\n",
+ "\n",
+ "def get_stats(col):\n",
+ " return f' M: {col.mean():.3} , STD: {col.std():.3}, Count: {col.count()}'\n",
+ "\n",
+ "# When Visible\n",
+ "temp = sf[grouping & sf.visible]\n",
+ "print(f'Tracking Error when visible:' + get_stats(temp['TE']))\n",
+ "\n",
+ "# When Occluded\n",
+ "temp = sf[grouping & ~sf.visible]\n",
+ "print(f'Tracking Error when occluded:' + get_stats(temp['TE']))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Calculate Succesfull Trackings (TE)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<div>\n",
+ "<style scoped>\n",
+ " .dataframe tbody tr th:only-of-type {\n",
+ " vertical-align: middle;\n",
+ " }\n",
+ "\n",
+ " .dataframe tbody tr th {\n",
+ " vertical-align: top;\n",
+ " }\n",
+ "\n",
+ " .dataframe thead th {\n",
+ " text-align: right;\n",
+ " }\n",
+ "</style>\n",
+ "<table border=\"1\" class=\"dataframe\">\n",
+ " <thead>\n",
+ " <tr style=\"text-align: right;\">\n",
+ " <th></th>\n",
+ " <th>set</th>\n",
+ " <th>evalmode</th>\n",
+ " <th>tracked_pos</th>\n",
+ " <th>tracked_neg</th>\n",
+ " <th>tracked_pos_pro</th>\n",
+ " <th>tracked_neg_pro</th>\n",
+ " </tr>\n",
+ " </thead>\n",
+ " <tbody>\n",
+ " <tr>\n",
+ " <th>0</th>\n",
+ " <td>control</td>\n",
+ " <td>open</td>\n",
+ " <td>37</td>\n",
+ " <td>9</td>\n",
+ " <td>0.804348</td>\n",
+ " <td>0.195652</td>\n",
+ " </tr>\n",
+ " </tbody>\n",
+ "</table>\n",
+ "</div>"
+ ],
+ "text/plain": [
+ " set evalmode tracked_pos tracked_neg tracked_pos_pro \\\n",
+ "0 control open 37 9 0.804348 \n",
+ "\n",
+ " tracked_neg_pro \n",
+ "0 0.195652 "
+ ]
+ },
+ "execution_count": 4,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "# succesfull trackings: In the last visible moment of the target, the slot was less than 10% away from the target\n",
+ "# determine last visible frame numeric\n",
+ "grouping_factors = ['net','set','evalmode','scene','slot']\n",
+ "ff = sf[sf.visible & sf.bound & sf.inimage].groupby(grouping_factors).max()\n",
+ "ff.rename(columns = {'frame':'last_visible'}, inplace = True)\n",
+ "sf = sf.merge(ff[['last_visible']], on=grouping_factors, how='left')\n",
+ "\n",
+ "# same for first bound frame\n",
+ "ff = sf[sf.visible & sf.bound & sf.inimage].groupby(grouping_factors).min()\n",
+ "ff.rename(columns = {'frame':'first_visible'}, inplace = True)\n",
+ "sf = sf.merge(ff[['first_visible']], on=grouping_factors, how='left')\n",
+ "\n",
+ "# add dummy variable to sf\n",
+ "sf['last_visible'] = (sf['last_visible'] == sf['frame'])\n",
+ "\n",
+ "# extract the trials where the target was last visible and threshold the TE\n",
+ "ff = sf[sf['last_visible']] \n",
+ "ff['tracked_pos'] = (ff['TE'] < 10)\n",
+ "ff['tracked_neg'] = (ff['TE'] >= 10)\n",
+ "\n",
+ "# fill NaN with 0\n",
+ "sf = sf.merge(ff[grouping_factors + ['tracked_pos', 'tracked_neg']], on=grouping_factors, how='left')\n",
+ "sf['tracked_pos'].fillna(False, inplace=True)\n",
+ "sf['tracked_neg'].fillna(False, inplace=True)\n",
+ "\n",
+ "# Aggreagte over all scenes\n",
+ "temp = sf[(sf['frame']== 1) & ~sf.occluder & sf.control & (sf.first_visible < 20)]\n",
+ "temp = temp.groupby(['set', 'evalmode']).sum()\n",
+ "temp = temp[['tracked_pos', 'tracked_neg']]\n",
+ "temp = temp.reset_index()\n",
+ "\n",
+ "temp['tracked_pos_pro'] = temp['tracked_pos'] / (temp['tracked_pos'] + temp['tracked_neg'])\n",
+ "temp['tracked_neg_pro'] = temp['tracked_neg'] / (temp['tracked_pos'] + temp['tracked_neg'])\n",
+ "\n",
+ "temp"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Mostly Tracked stats"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "image/png": "",
+ "text/plain": [
+ "<Figure size 432x288 with 1 Axes>"
+ ]
+ },
+ "metadata": {
+ "needs_background": "light"
+ },
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "temp = af[af.index == 'OVERALL']\n",
+ "temp['mostly_tracked'] = temp['mostly_tracked'] / temp['num_unique_objects']\n",
+ "temp['partially_tracked'] = temp['partially_tracked'] / temp['num_unique_objects']\n",
+ "temp['mostly_lost'] = temp['mostly_lost'] / temp['num_unique_objects']\n",
+ "g = temp[['mostly_tracked', 'partially_tracked', 'mostly_lost','set']].set_index(['set']).groupby(['set']).mean().plot(kind='bar', stacked=True);"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# MOTA "
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/html": [
+ "<div>\n",
+ "<style scoped>\n",
+ " .dataframe tbody tr th:only-of-type {\n",
+ " vertical-align: middle;\n",
+ " }\n",
+ "\n",
+ " .dataframe tbody tr th {\n",
+ " vertical-align: top;\n",
+ " }\n",
+ "\n",
+ " .dataframe thead th {\n",
+ " text-align: right;\n",
+ " }\n",
+ "</style>\n",
+ "<table border=\"1\" class=\"dataframe\">\n",
+ " <thead>\n",
+ " <tr style=\"text-align: right;\">\n",
+ " <th></th>\n",
+ " <th>idf1</th>\n",
+ " <th>idp</th>\n",
+ " <th>idr</th>\n",
+ " <th>recall</th>\n",
+ " <th>precision</th>\n",
+ " <th>num_unique_objects</th>\n",
+ " <th>mostly_tracked</th>\n",
+ " <th>partially_tracked</th>\n",
+ " <th>mostly_lost</th>\n",
+ " <th>num_false_positives</th>\n",
+ " <th>num_misses</th>\n",
+ " <th>num_switches</th>\n",
+ " <th>num_fragmentations</th>\n",
+ " <th>mota</th>\n",
+ " <th>motp</th>\n",
+ " <th>num_transfer</th>\n",
+ " <th>num_ascend</th>\n",
+ " <th>num_migrate</th>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>set</th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " <th></th>\n",
+ " </tr>\n",
+ " </thead>\n",
+ " <tbody>\n",
+ " <tr>\n",
+ " <th>control</th>\n",
+ " <td>0.754576</td>\n",
+ " <td>0.919577</td>\n",
+ " <td>0.639779</td>\n",
+ " <td>0.662855</td>\n",
+ " <td>0.952744</td>\n",
+ " <td>86.0</td>\n",
+ " <td>38.0</td>\n",
+ " <td>18.0</td>\n",
+ " <td>30.0</td>\n",
+ " <td>161.0</td>\n",
+ " <td>1651.0</td>\n",
+ " <td>13.0</td>\n",
+ " <td>28.0</td>\n",
+ " <td>0.627323</td>\n",
+ " <td>0.037094</td>\n",
+ " <td>0.0</td>\n",
+ " <td>12.0</td>\n",
+ " <td>0.0</td>\n",
+ " </tr>\n",
+ " <tr>\n",
+ " <th>surprise</th>\n",
+ " <td>0.790401</td>\n",
+ " <td>0.968750</td>\n",
+ " <td>0.667511</td>\n",
+ " <td>0.667511</td>\n",
+ " <td>0.968750</td>\n",
+ " <td>33.0</td>\n",
+ " <td>13.0</td>\n",
+ " <td>8.0</td>\n",
+ " <td>12.0</td>\n",
+ " <td>34.0</td>\n",
+ " <td>525.0</td>\n",
+ " <td>0.0</td>\n",
+ " <td>11.0</td>\n",
+ " <td>0.645978</td>\n",
+ " <td>0.034979</td>\n",
+ " <td>0.0</td>\n",
+ " <td>0.0</td>\n",
+ " <td>0.0</td>\n",
+ " </tr>\n",
+ " </tbody>\n",
+ "</table>\n",
+ "</div>"
+ ],
+ "text/plain": [
+ " idf1 idp idr recall precision \\\n",
+ "set \n",
+ "control 0.754576 0.919577 0.639779 0.662855 0.952744 \n",
+ "surprise 0.790401 0.968750 0.667511 0.667511 0.968750 \n",
+ "\n",
+ " num_unique_objects mostly_tracked partially_tracked mostly_lost \\\n",
+ "set \n",
+ "control 86.0 38.0 18.0 30.0 \n",
+ "surprise 33.0 13.0 8.0 12.0 \n",
+ "\n",
+ " num_false_positives num_misses num_switches num_fragmentations \\\n",
+ "set \n",
+ "control 161.0 1651.0 13.0 28.0 \n",
+ "surprise 34.0 525.0 0.0 11.0 \n",
+ "\n",
+ " mota motp num_transfer num_ascend num_migrate \n",
+ "set \n",
+ "control 0.627323 0.037094 0.0 12.0 0.0 \n",
+ "surprise 0.645978 0.034979 0.0 0.0 0.0 "
+ ]
+ },
+ "execution_count": 6,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "af[af.index == 'OVERALL'].groupby(['set']).mean()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# Gate Openings"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Percept gate openings when visible: M: 0.223 , STD: 0.303, Count: 1686\n",
+ "Percept gate openings when occluded: M: 0.00995 , STD: 0.0471, Count: 477\n"
+ ]
+ }
+ ],
+ "source": [
+ "grouping = (sf.inimage & sf.bound & ~sf.occluder & sf.control)\n",
+ "temp = sf[grouping & sf.visible]\n",
+ "print(f'Percept gate openings when visible:' + get_stats(temp['alpha_pos'] + temp['alpha_ges']))\n",
+ "temp = sf[grouping & ~sf.visible]\n",
+ "print(f'Percept gate openings when occluded:' + get_stats(temp['alpha_pos'] + temp['alpha_ges']))"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "loci23",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.9.16"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/model/loci.py b/model/loci.py
index 96088bd..46f89f7 100644
--- a/model/loci.py
+++ b/model/loci.py
@@ -66,11 +66,6 @@ class Loci(nn.Module):
self.background = BackgroundEnhancer(
input_size = cfg.input_size,
- gestalt_size = cfg.background.gestalt_size,
- img_channels = cfg.img_channels,
- depth = cfg.background.num_layers,
- latent_channels = cfg.background.latent_channels,
- level1_channels = cfg.background.level1_channels,
batch_size = cfg.batch_size,
)
@@ -93,11 +88,16 @@ class Loci(nn.Module):
self.modulator = ObjectModulator(cfg.num_objects)
self.linear_gate = LinearInterpolation(cfg.num_objects)
- self.background.set_level(cfg.level)
self.encoder.set_level(cfg.level)
self.decoder.set_level(cfg.level)
self.initial_states.set_level(cfg.level)
+ # add flag option to enable/disable latent loss
+ if 'latent_loss_enabled' in cfg:
+ self.latent_loss_enabled = cfg.latent_loss_enabled
+ else:
+ self.latent_loss_enabled = False
+
def get_init_status(self):
init = []
for module in self.modules():
@@ -133,9 +133,6 @@ class Loci(nn.Module):
if reset:
self.reset_state()
- if train_background or self.get_init_status() < 1:
- return self.background(*input)
-
return self.run_end2end(*input, evaluate=evaluate, warmup=warmup, shuffleslots = shuffleslots, reset_mask = reset_mask, allow_spawn = allow_spawn, show_hidden = show_hidden, clean_slots = clean_slots)
def run_decoder(
@@ -196,11 +193,12 @@ class Loci(nn.Module):
):
position_loss = th.tensor(0, device=input.device)
time_loss = th.tensor(0, device=input.device)
+ latent_loss = th.tensor(0, device=input.device)
bg_mask = None
position_encoder = None
if error_last is None or mask_last is None:
- bg_mask = self.background(input, only_mask=True)
+ bg_mask = self.background(input)
error_last = th.sqrt(reduce((input - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
position_last, gestalt_last, priority_last, error_cur = self.initial_states(
@@ -214,7 +212,7 @@ class Loci(nn.Module):
object_last_unprioritized = self.decoder(position_last, gestalt_last)[-1]
# background and bg_mask for the next time point
- bg_mask = self.background(input, error_last, mask_last[:,-1:], only_mask=True)
+ bg_mask = self.background(input)
# position and gestalt for the current time point
position_cur, gestalt_cur, priority_cur = self.encoder(input, error_last, mask_last, object_last_unprioritized, position_last, rawmask_last)
@@ -228,7 +226,7 @@ class Loci(nn.Module):
slots_bounded, slots_bounded_smooth, slots_occluded_cur, slots_partially_occluded_cur, slots_fully_visible_cur, slots_occlusionfactor_cur = self.occlusion_tracker(mask_cur, rawmask_cur, reset_mask)
# do not project into the future in the warmup phase
- slots_closed = th.ones_like(repeat(slots_bounded, 'b o -> b o c', c=2))
+ slots_closed = th.zeros_like(repeat(slots_bounded, 'b o -> b o c', c=2))
if warmup:
position_next = position_cur
gestalt_next = gestalt_cur
@@ -239,6 +237,7 @@ class Loci(nn.Module):
# update module
slots_closed = (1-self.percept_gate_controller(position_cur, gestalt_cur, priority_cur, slots_occlusionfactor_cur, position_last, gestalt_last, priority_last, slots_occlusionfactor_last, self.position_last2, evaluate=evaluate))
+ #slots_closed = repeat(rearrange(slots_occlusionfactor_cur, 'b (o c) -> b o c', c=1), 'b o 1 -> b o 2')
position_cur = self.linear_gate(position_cur, position_last, slots_closed[:, :, 1])
priority_cur = self.linear_gate(priority_cur, priority_last, slots_closed[:, :, 1])
@@ -313,6 +312,10 @@ class Loci(nn.Module):
position_next.detach(),
)
+ # add latent loss: current perception as target for last prediction
+ if self.cfg.inner_loop_enabled and self.latent_loss_enabled:
+ latent_loss = self.position_loss(position_cur.detach(), position_last, slots_bounded_smooth)
+
return (
output_next,
output_cur,
@@ -326,6 +329,8 @@ class Loci(nn.Module):
slots_occlusionfactor_next,
position_loss,
time_loss,
- slots_closed
+ latent_loss,
+ slots_closed,
+ slots_bounded
)
diff --git a/model/nn/background.py b/model/nn/background.py
index 38ec325..ff14e02 100644
--- a/model/nn/background.py
+++ b/model/nn/background.py
@@ -14,154 +14,21 @@ class BackgroundEnhancer(nn.Module):
def __init__(
self,
input_size: Tuple[int, int],
- img_channels: int,
- level1_channels,
- latent_channels,
- gestalt_size,
batch_size,
- depth
):
super(BackgroundEnhancer, self).__init__()
- latent_size = [input_size[0] // 16, input_size[1] // 16]
- self.input_size = input_size
-
- self.register_buffer('init', th.zeros(1).long())
- self.alpha = nn.Parameter(th.zeros(1)+1e-16)
-
- self.level = 1
- self.down_level2 = nn.Sequential(
- PatchDownConv(img_channels*2+2, level1_channels, alpha = 1e-16),
- *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for i in range(depth)]
- )
-
- self.down_level1 = nn.Sequential(
- PatchDownConv(level1_channels, latent_channels, alpha = 1),
- *[ResidualBlock(latent_channels, latent_channels, alpha_residual = False) for i in range(depth)]
- )
-
- self.down_level0 = nn.Sequential(
- *[ResidualBlock(latent_channels, latent_channels) for i in range(depth)],
- AggressiveConvToGestalt(latent_channels, gestalt_size, latent_size),
- LambdaModule(lambda x: rearrange(x, 'b c 1 1 -> b c')),
- Binarize(),
- )
-
- self.bias = nn.Parameter(th.zeros((1, gestalt_size, *latent_size)))
-
- self.to_grid = nn.Sequential(
- LinearResidual(gestalt_size, gestalt_size, input_relu = False),
- LambdaModule(lambda x: rearrange(x, 'b c -> b c 1 1')),
- LambdaModule(lambda x: x + self.bias),
- *[ResidualBlock(gestalt_size, gestalt_size) for i in range(depth)],
- )
-
-
- self.up_level0 = nn.Sequential(
- ResidualBlock(gestalt_size, latent_channels),
- *[ResidualBlock(latent_channels, latent_channels) for i in range(depth)],
- )
-
- self.up_level1 = nn.Sequential(
- *[ResidualBlock(latent_channels, latent_channels, alpha_residual = False) for i in range(depth)],
- PatchUpscale(latent_channels, level1_channels, alpha = 1),
- )
-
- self.up_level2 = nn.Sequential(
- *[ResidualBlock(level1_channels, level1_channels, alpha_residual = True) for i in range(depth)],
- PatchUpscale(level1_channels, img_channels, alpha = 1e-16),
- )
-
- self.to_channels = nn.ModuleList([
- SkipConnection(img_channels*2+2, latent_channels),
- SkipConnection(img_channels*2+2, level1_channels),
- SkipConnection(img_channels*2+2, img_channels*2+2),
- ])
-
- self.to_img = nn.ModuleList([
- SkipConnection(latent_channels, img_channels),
- SkipConnection(level1_channels, img_channels),
- SkipConnection(img_channels, img_channels),
- ])
-
+ self.batch_size = batch_size
+ self.height = input_size[0]
+ self.width = input_size[1]
self.mask = nn.Parameter(th.ones(1, 1, *input_size) * 10)
- self.object = nn.Parameter(th.ones(1, img_channels, *input_size))
-
- self.register_buffer('latent', th.zeros((batch_size, gestalt_size)), persistent=False)
+ self.register_buffer('init', th.zeros(1).long())
def get_init(self):
return self.init.item()
-
- def step_init(self):
- self.init = self.init + 1
-
- def detach(self):
- self.latent = self.latent.detach()
-
- def reset_state(self):
- self.latent = th.zeros_like(self.latent)
-
- def set_level(self, level):
- self.level = level
-
- def encoder(self, input):
- latent = self.to_channels[self.level](input)
-
- if self.level >= 2:
- latent = self.down_level2(latent)
-
- if self.level >= 1:
- latent = self.down_level1(latent)
-
- return self.down_level0(latent)
-
- def get_last_latent_gird(self):
- return self.to_grid(self.latent) * self.alpha
-
- def decoder(self, latent, input):
- grid = self.to_grid(latent)
- latent = self.up_level0(grid)
-
- if self.level >= 1:
- latent = self.up_level1(latent)
-
- if self.level >= 2:
- latent = self.up_level2(latent)
-
- object = reduce(self.object, '1 c (h h2) (w w2) -> 1 c h w', 'mean', h = input.shape[2], w = input.shape[3])
- object = repeat(object, '1 c h w -> b c h w', b = input.shape[0])
-
- return th.sigmoid(object + self.to_img[self.level](latent)), grid
-
- def forward(self, input: th.Tensor, error: th.Tensor = None, mask: th.Tensor = None, only_mask: bool = False):
+
+ def forward(self, input: th.Tensor):
- if only_mask:
- mask = reduce(self.mask, '1 1 (h h2) (w w2) -> 1 1 h w', 'mean', h = input.shape[2], w = input.shape[3])
- mask = repeat(mask, '1 1 h w -> b 1 h w', b = input.shape[0]) * 0.1
- return mask
-
- last_bg = self.decoder(self.latent, input)[0]
-
- bg_error = th.sqrt(reduce((input - last_bg)**2, 'b c h w -> b 1 h w', 'mean')).detach()
- bg_mask = (bg_error < th.mean(bg_error) + th.std(bg_error)).float().detach()
-
- if error is None or self.get_init() < 2:
- error = bg_error
-
- if mask is None or self.get_init() < 2:
- mask = bg_mask
-
- self.latent = self.encoder(th.cat((input, last_bg, error, mask), dim=1))
-
mask = reduce(self.mask, '1 1 (h h2) (w w2) -> 1 1 h w', 'mean', h = input.shape[2], w = input.shape[3])
- mask = repeat(mask, '1 1 h w -> b 1 h w', b = input.shape[0]) * 0.1
-
- background, grid = self.decoder(self.latent, input)
-
- if self.get_init() < 1:
- return mask, background
-
- if self.get_init() < 2:
- return mask, th.zeros_like(background), th.zeros_like(grid), background
-
- return mask, background, grid * self.alpha, background
+ mask = repeat(mask, '1 1 h w -> b 1 h w', b = self.batch_size) * 0.1
+ return mask
diff --git a/model/nn/eprop_gate_l0rd.py b/model/nn/eprop_gate_l0rd.py
index 82a9895..c2473d1 100644
--- a/model/nn/eprop_gate_l0rd.py
+++ b/model/nn/eprop_gate_l0rd.py
@@ -49,7 +49,7 @@ class EpropGateL0rdFunction(Function):
e_w_rx.clone(), e_w_rh.clone(), e_b_r.clone(),
)
- return h, th.mean(H_g)
+ return h, H_g
@staticmethod
def backward(ctx, dh, _):
@@ -169,6 +169,7 @@ class EpropGateL0rd(nn.Module):
self.register_buffer("h_last", th.zeros(batch_size, num_hidden), persistent=False)
self.register_buffer("openings", th.zeros(1), persistent=False)
+ self.register_buffer("openings_perslot", th.zeros(batch_size), persistent=False)
# initialize weights
stdv_ih = np.sqrt(6/(self.num_inputs + self.num_hidden))
@@ -294,7 +295,7 @@ class EpropGateL0rdShared(EpropGateL0rd):
return o * p, h_last
def eprop_forward(self, x: th.Tensor, h_last: th.Tensor):
- h, openings = self.fcn(
+ h, H_g = self.fcn(
x, h_last,
self.w_gx, self.w_gh, self.b_g,
self.w_rx, self.w_rh, self.b_r,
@@ -305,7 +306,8 @@ class EpropGateL0rdShared(EpropGateL0rd):
)
)
- self.openings = openings
+ self.openings = th.mean(H_g)
+ self.openings_perslot = th.mean(H_g, dim=1)
p = th.tanh(x.mm(self.w_px.t()) + h.mm(self.w_ph.t()) + self.b_p)
o = th.sigmoid(x.mm(self.w_ox.t()) + h.mm(self.w_oh.t()) + self.b_o)
diff --git a/model/nn/eprop_transformer.py b/model/nn/eprop_transformer.py
index 4d89ec4..d2341dd 100644
--- a/model/nn/eprop_transformer.py
+++ b/model/nn/eprop_transformer.py
@@ -35,13 +35,21 @@ class EpropGateL0rdTransformer(nn.Module):
_layers.append(OutputEmbeding(num_hidden, num_outputs))
self.layers = nn.Sequential(*_layers)
+ self.attention = []
+ self.l0rds = []
+ for l in self.layers:
+ if 'AlphaAttention' in type(l).__name__:
+ self.attention.append(l)
+ elif 'EpropAlphaGateL0rd' in type(l).__name__:
+ self.l0rds.append(l)
def get_openings(self):
- openings = 0
+ openings = []
for i in range(self.depth):
- openings += self.layers[2 * (i + 1)].l0rd.openings.item()
+ openings.append(self.l0rds[i].l0rd.openings_perslot)
- return openings / self.depth
+ openings = th.mean(th.stack(openings, dim=0), dim=0)
+ return openings
def get_hidden(self):
states = []
diff --git a/model/nn/eprop_transformer_shared.py b/model/nn/eprop_transformer_shared.py
index 23a1b0f..79223c1 100644
--- a/model/nn/eprop_transformer_shared.py
+++ b/model/nn/eprop_transformer_shared.py
@@ -39,11 +39,12 @@ class EpropGateL0rdTransformerShared(nn.Module):
self.output_embeding = OutputEmbeding(num_hidden, num_outputs)
def get_openings(self):
- openings = 0
+ openings = []
for i in range(self.depth):
- openings += self.l0rds[i].l0rd.openings.item()
+ openings.append(self.l0rds[i].l0rd.openings_perslot)
- return openings / self.depth
+ openings = th.mean(th.stack(openings, dim=0), dim=0)
+ return openings
def get_hidden(self):
return self.hidden
diff --git a/model/nn/eprop_transformer_utils.py b/model/nn/eprop_transformer_utils.py
index 9e5e874..3219cd0 100644
--- a/model/nn/eprop_transformer_utils.py
+++ b/model/nn/eprop_transformer_utils.py
@@ -9,7 +9,8 @@ class AlphaAttention(nn.Module):
num_hidden,
num_objects,
heads,
- dropout = 0.0
+ dropout = 0.0,
+ need_weights = False
):
super(AlphaAttention, self).__init__()
@@ -23,10 +24,13 @@ class AlphaAttention(nn.Module):
dropout = dropout,
batch_first = True
)
+ self.need_weights = need_weights
+ self.att_weights = None
def forward(self, x: th.Tensor):
x = self.to_sequence(x)
- x = x + self.alpha * self.attention(x, x, x, need_weights=False)[0]
+ att, self.att_weights = self.attention(x, x, x, need_weights=self.need_weights)
+ x = x + self.alpha * att
return self.to_batch(x)
class InputEmbeding(nn.Module):
diff --git a/model/nn/predictor.py b/model/nn/predictor.py
index 94f13b8..5f08de9 100644
--- a/model/nn/predictor.py
+++ b/model/nn/predictor.py
@@ -61,7 +61,9 @@ class LociPredictor(nn.Module):
self.to_shared = LambdaModule(lambda x: rearrange(x, '(b o) c -> b (o c)', o=num_objects))
def get_openings(self):
- return self.predictor.get_openings()
+ openings = self.predictor.get_openings().detach()
+ openings = self.to_shared(openings[:, None])
+ return openings
def get_hidden(self):
return self.predictor.get_hidden()
@@ -69,6 +71,24 @@ class LociPredictor(nn.Module):
def set_hidden(self, hidden):
self.predictor.set_hidden(hidden)
+ def get_att_weights(self):
+ att_weights = []
+ for layer in self.predictor.attention:
+ if layer.att_weights is None:
+ return []
+ else:
+ att_weights.append(layer.att_weights)
+ att_weights = th.stack(att_weights)
+ return reduce(att_weights, 'l b o1 o2-> b o1 o2', 'mean')
+
+ def enable_att_weights(self):
+ for layer in self.predictor.attention:
+ layer.need_weights = True
+
+ def disable_att_weights(self):
+ for layer in self.predictor.attention:
+ layer.need_weights = False
+
def forward(
self,
gestalt: th.Tensor,
diff --git a/scripts/evaluation_adept.py b/scripts/evaluation_adept.py
index eab3ad9..b2a18c0 100644
--- a/scripts/evaluation_adept.py
+++ b/scripts/evaluation_adept.py
@@ -2,6 +2,7 @@ import torch as th
from torch.utils.data import Dataset, DataLoader, Subset
from torch import nn
import os
+from scripts.utils.eval_adept import eval_adept
from scripts.utils.plot_utils import plot_timestep
from scripts.utils.configuration import Configuration
from scripts.utils.io import init_device
@@ -21,10 +22,12 @@ def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_frequency= 1, p
# Config
cfg_net = cfg.model
cfg_net.batch_size = 1
+ cfg_net.inner_loop_enabled = (cfg.max_updates > cfg.phases.start_inner_loop)
# Load model
net = load_model(cfg, cfg_net, file, device)
net.eval()
+ net.predictor.enable_att_weights()
# Plot config
object_view = True
@@ -216,7 +219,9 @@ def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_frequency= 1, p
# 4. Plot
if (t % plot_frequency == 0) and (i < plot_first_samples) and (t >= 0):
- plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i)
+ att = net.predictor.get_att_weights()
+ openings = net.get_openings()
+ plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i, att= att, openings=openings)
# fill jumping statistics
statistics_complete_slots['vanishing'].extend(np.tile(slots_vanishing_memory.astype(int), t+1))
@@ -235,8 +240,11 @@ def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_frequency= 1, p
pd.DataFrame(statistics_complete).to_csv(f'{root_path}/statistics/trialframe.csv')
pd.DataFrame(statistics_complete_slots).to_csv(f'{root_path}/statistics/slotframe.csv')
pd.DataFrame(acc_memory_complete).to_csv(f'{root_path}/statistics/accframe.csv')
+
if object_view and os.path.exists(f'{root_path}/tmp.jpg'):
os.remove(f'{root_path}/tmp.jpg')
+
+ eval_adept(f'{root_path}/statistics')
pass
@@ -309,6 +317,7 @@ def update_mota_acc(acc, gt_positions, estimated_positions, slots_bounded, cfg_n
def calculate_tracking_error(gt_positions_target, gt_visibility_target, position_cur, cfg_num_slots, slots_bounded, slots_occluded_cur, association_table, gt_occluder_mask):
# tracking utils
+ gt_positions_target = gt_positions_target.clone()
pdist = nn.PairwiseDistance(p=2).to(position_cur.device)
# 1. association of newly bounded slots to ground truth objects
diff --git a/scripts/evaluation_adept_savi.py b/scripts/evaluation_adept_baselines.py
index 6a2d5c7..e5a8e7d 100644
--- a/scripts/evaluation_adept_savi.py
+++ b/scripts/evaluation_adept_baselines.py
@@ -5,15 +5,16 @@ import cv2
import numpy as np
import pandas as pd
import os
-from data.datasets.ADEPT.dataset import AdeptDataset
import motmetrics as mm
from scripts.evaluation_adept import calculate_tracking_error, get_evaluation_sets, update_mota_acc
-from scripts.utils.eval_utils import setup_result_folders, store_statistics
+from scripts.utils.eval_utils import boxes_to_centroids, masks_to_boxes, setup_result_folders, store_statistics
from scripts.utils.plot_utils import write_image
FG_THRE = 0.95
-def evaluate(dataset: Dataset, file, n, plot_frequency= 1, plot_first_samples = 2):
+def evaluate(dataset: Dataset, file, n, model, plot_frequency= 1, plot_first_samples = 2):
+
+ assert model in ['savi', 'gswm']
# read pkl file
masks_complete = pd.read_pickle(file)
@@ -21,8 +22,12 @@ def evaluate(dataset: Dataset, file, n, plot_frequency= 1, plot_first_samples =
# plot config
color_list = [[255,0,0], [0,255,0], [0,0,255], [255,255,0], [0,255,255], [255,0,255], [255,255,255]]
dot_size = 2
- skip_frames = 2
- offset = 15
+ if model == 'savi':
+ skip_frames = 2
+ offset = 15
+ elif model == 'gswm':
+ skip_frames = 2
+ offset = 0
# memory
statistics_complete_slots = {'set': [], 'evalmode': [], 'scene': [], 'frame': [], 'slot':[], 'TE': [], 'visible': [], 'bound': [], 'occluder': [], 'inimage': [], 'slot_error': [], 'mask_size': [], 'rawmask_size': [], 'rawmask_size_hidden': [], 'alpha_pos': [], 'alpha_ges': [], 'object_id': []}
@@ -52,33 +57,53 @@ def evaluate(dataset: Dataset, file, n, plot_frequency= 1, plot_first_samples =
tensor = tensor[:,range(0, tensor.shape[1], skip_frames)]
sequence_len = tensor.shape[1]
- # load data
- masks = th.tensor(masks_complete['test'][f'control_{i}.mp4'])
- masks_before_softmax = th.tensor(masks_complete['test_raw'][f'control_{i}.mp4'])
-
- # calculate rawmasks
- bg_mask = masks_before_softmax.mean(dim=1)
- masks_raw = compute_maskraw(masks_before_softmax, bg_mask)
- slots_bound = compute_slots_bound(masks_raw)
+ if model == 'savi':
+ # load data
+ masks = th.tensor(masks_complete['test'][f'control_{i}.mp4']) # N, O, 1, H, W
+ masks_before_softmax = th.tensor(masks_complete['test_raw'][f'control_{i}.mp4'])
+ imgs_model = None
+ recons_model = None
+
+ # calculate rawmasks
+ bg_mask = masks_before_softmax.mean(dim=1)
+ masks_raw = compute_maskraw(masks_before_softmax, bg_mask, n_slots=7)
+ slots_bound = compute_slots_bound(masks_raw)
+
+ elif model == 'gswm':
+ # load data
+ masks = masks_complete[i]['visibility_mask'].squeeze(0)
+ masks_raw = masks_complete[i]['object_mask'].squeeze(0)
+ slots_bound = masks_complete[i]['z_pres'].squeeze(0)
+ slots_bound = (slots_bound > 0.9).float()
+
+ imgs_model = masks_complete[i]['imgs'].squeeze(0)
+ imgs_model[:] = imgs_model[:, [2,1,0]]
+ recons_model = masks_complete[i]['recon'].squeeze(0)
+ recons_model[:] = recons_model[:, [2,1,0]]
+
+ # consider only the first 7 slots
+ masks = masks[:,:7]
+ masks_raw = masks_raw[:,:7]
+ slots_bound = slots_bound[:,:7]
+
+ n_slots = masks.shape[1]
# threshold masks and calculate centroids
masks_binary = (masks_raw > FG_THRE).float()
masks2 = rearrange(masks_binary, 't o 1 h w -> (t o) h w')
boxes = masks_to_boxes(masks2.long())
- boxes = boxes.reshape(1, masks.shape[0], 7, 4)
+ boxes = boxes.reshape(1, masks.shape[0], n_slots, 4)
centroids = boxes_to_centroids(boxes)
# get rid of batch dimension
- association_table = th.ones(7) * -1
+ association_table = th.ones(n_slots) * -1
# iterate over frames
for t_index in range(offset,min(sequence_len,masks.shape[0])):
# move to next frame
input = tensor[:,t_index]
- target = th.clip(tensor[:,t_index+1], 0, 1)
gt_positions_target = gt_object_positions[:,t_index]
- gt_positions_target_next = gt_object_positions[:,t_index+1]
gt_visibility_target = gt_object_visibility[:,t_index]
position_cur = centroids[t_index]
@@ -87,32 +112,32 @@ def evaluate(dataset: Dataset, file, n, plot_frequency= 1, plot_first_samples =
slots_bound_cur = rearrange(slots_bound_cur, 'o c -> 1 (o c)')
# calculate tracking error
- tracking_error, tracking_error_perslot, association_table, slots_visible, slots_in_image, slots_occluder = calculate_tracking_error(gt_positions_target, gt_visibility_target, position_cur, 7, slots_bound_cur, None, association_table, gt_occluder_mask)
+ tracking_error, tracking_error_perslot, association_table, slots_visible, slots_in_image, slots_occluder = calculate_tracking_error(gt_positions_target, gt_visibility_target, position_cur, n_slots, slots_bound_cur, None, association_table, gt_occluder_mask)
rawmask_size = reduce(masks_raw[t_index], 'o 1 h w-> 1 o', 'sum')
mask_size = reduce(masks[t_index], 'o 1 h w-> 1 o', 'sum')
statistics_complete_slots = store_statistics(statistics_complete_slots,
- ['control'] * 7,
- ['control'] * 7,
- [control_samples[i]] * 7,
- [t_index] * 7,
- range(7),
+ ['control'] * n_slots,
+ ['control'] * n_slots,
+ [control_samples[i]] * n_slots,
+ [t_index] * n_slots,
+ range(n_slots),
tracking_error_perslot.cpu().numpy().flatten(),
slots_visible.cpu().numpy().flatten().astype(int),
slots_bound_cur.cpu().numpy().flatten().astype(int),
slots_occluder.cpu().numpy().flatten().astype(int),
slots_in_image.cpu().numpy().flatten().astype(int),
- [0] * 7,
+ [0] * n_slots,
mask_size.cpu().numpy().flatten(),
rawmask_size.cpu().numpy().flatten(),
- [0] * 7,
- [0] * 7,
- [0] * 7,
+ [0] * n_slots,
+ [0] * n_slots,
+ [0] * n_slots,
association_table[0].cpu().numpy().flatten().astype(int),
extend = True)
- acc = update_mota_acc(acc, gt_positions_target, position_cur, slots_bound_cur, 7, gt_occluder_mask, slots_occluder, None)
+ acc = update_mota_acc(acc, gt_positions_target, position_cur, slots_bound_cur, n_slots, gt_occluder_mask, slots_occluder, None)
# plot_option
if (t_index % plot_frequency == 0) and (i < plot_first_samples) and (t_index >= 0):
@@ -141,7 +166,12 @@ def evaluate(dataset: Dataset, file, n, plot_frequency= 1, plot_first_samples =
slot_frame_single = mask.transpose((1,2,0)).repeat(3, axis=2)
slot_frame = np.concatenate((slot_frame, slot_frame_single), axis=1)
- frame = np.concatenate((frame, slot_frame), axis=1)
+ if imgs_model is not None:
+ frame_model = imgs_model[t_index].numpy().transpose(1,2,0)
+ recon_model = recons_model[t_index].numpy().transpose(1,2,0)
+ frame = np.concatenate((frame, frame_model, recon_model, slot_frame), axis=1)
+ else:
+ frame = np.concatenate((frame, slot_frame), axis=1)
cv2.imwrite(f'{plot_path}object/objects-{i:04d}-{t_index:03d}.jpg', frame*255)
acc_memory_eval.append(acc)
@@ -153,57 +183,6 @@ def evaluate(dataset: Dataset, file, n, plot_frequency= 1, plot_first_samples =
pd.DataFrame(summary).to_csv(os.path.join(root_path, 'statistics' , 'accframe.csv'))
pd.DataFrame(statistics_complete_slots).to_csv(os.path.join(root_path, 'statistics' , 'slotframe.csv'))
-def masks_to_boxes(masks: th.Tensor) -> th.Tensor:
- """
- Compute the bounding boxes around the provided masks.
-
- Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with
- ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
-
- Args:
- masks (Tensor[N, H, W]): masks to transform where N is the number of masks
- and (H, W) are the spatial dimensions.
-
- Returns:
- Tensor[N, 4]: bounding boxes
- """
- if masks.numel() == 0:
- return th.zeros((0, 4), device=masks.device, dtype=th.float)
-
- n = masks.shape[0]
-
- bounding_boxes = th.zeros((n, 4), device=masks.device, dtype=th.float)
-
- for index, mask in enumerate(masks):
- if mask.sum() > 0:
- y, x = th.where(mask != 0)
-
- bounding_boxes[index, 0] = th.min(x)
- bounding_boxes[index, 1] = th.min(y)
- bounding_boxes[index, 2] = th.max(x)
- bounding_boxes[index, 3] = th.max(y)
-
- return bounding_boxes
-
-def boxes_to_centroids(boxes):
- """Post-process masks instead of directly taking argmax.
-
- Args:
- bboxes: [B, T, N, 4], 4: [x1, y1, x2, y2]
-
- Returns:
- centroids: [B, T, N, 2], 2: [x, y]
- """
-
- centroids = (boxes[:, :, :, :2] + boxes[:, :, :, 2:]) / 2
- centroids = centroids.squeeze(0)
-
- # scale to [-1, 1]
- centroids[:, :, 0] = centroids[:, :, 0] / 64 * 2 - 1
- centroids[:, :, 1] = centroids[:, :, 1] / 64 * 2 - 1
-
- return centroids
-
def compute_slots_bound(masks):
# take sum over axis 3,4 with th
@@ -211,7 +190,7 @@ def compute_slots_bound(masks):
slots_bound = (masks_sum > FG_THRE).float()
return slots_bound
-def compute_maskraw(mask, bg_mask):
+def compute_maskraw(mask, bg_mask, n_slots):
# d is a diagonal matrix which defines what to take the softmax over
d_mask = th.diag(th.ones(8))
@@ -224,7 +203,7 @@ def compute_maskraw(mask, bg_mask):
maskraw = th.cat((mask, bg_mask), dim=1)
maskraw = repeat(maskraw, 'b o h w -> b r o h w', r = 8)
maskraw = maskraw[:,d_mask.bool()]
- maskraw = rearrange(maskraw, 'b (o r) h w -> b o r h w', o = 7)
+ maskraw = rearrange(maskraw, 'b (o r) h w -> b o r h w', o = n_slots)
# take softmax between each object mask and the background mask
maskraw = th.squeeze(th.softmax(maskraw, dim=2)[:,:,0], dim=2)
diff --git a/scripts/evaluation_bb.py b/scripts/evaluation_bb.py
new file mode 100644
index 0000000..1d21cfe
--- /dev/null
+++ b/scripts/evaluation_bb.py
@@ -0,0 +1,385 @@
+import pickle
+import cv2
+import torch as th
+from torch.utils.data import Dataset, DataLoader, Subset
+from torch import nn
+import os
+from scripts.evaluation_adept import calculate_tracking_error
+from scripts.evaluation_clevrer import compute_statistics_summary
+from scripts.utils.plot_utils import plot_timestep
+from scripts.utils.eval_metrics import masks_to_boxes, pred_eval_step, postproc_mask
+from scripts.utils.eval_utils import append_statistics, compute_position_from_mask, load_model, setup_result_folders, store_statistics
+from scripts.utils.configuration import Configuration
+from scripts.utils.io import init_device
+import numpy as np
+from einops import rearrange, repeat, reduce
+from copy import deepcopy
+import lpips
+import torchvision.transforms as transforms
+import motmetrics as mm
+
+def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_first_samples = 0):
+
+ # Set up cpu or gpu training
+ device, verbose = init_device(cfg)
+
+ # Config
+ cfg_net = cfg.model
+ cfg_net.batch_size = 2 if verbose else 32
+ #cfg_net.num_objects = 3
+ cfg_net.inner_loop_enabled = True
+ if 'num_objects_test' in cfg_net:
+ cfg_net.num_objects = cfg_net.num_objects_test
+ dataset = Subset(dataset, range(4)) if verbose else dataset
+
+ # Load model
+ net = load_model(cfg, cfg_net, file, device)
+ net.eval()
+ net.predictor.enable_att_weights()
+
+ # config
+ object_view = True
+ individual_views = False
+ root_path = None
+ use_meds = True
+
+ # get evaluation sets
+ set_test_array, evaluation_modes = get_evaluation_sets(dataset)
+
+ # memory
+ statistics_template = {'set': [], 'evalmode': [], 'scene': [], 'frame': [], 'image_error_mse': []}
+ statistics_complete_slots = {'set': [], 'evalmode': [], 'scene': [], 'frame': [], 'slot':[], 'bound': [], 'slot_error': [], 'rawmask_size': [], 'alpha_pos': [], 'alpha_ges': []}
+ metric_complete = None
+
+ # Evaluation Specifics
+ burn_in_length = 10
+ rollout_length = 90
+ rollout_length_stats = 10 # only consider the first 10 frames for statistics
+ target_size = (64, 64)
+
+ # Losses
+ lpipsloss = lpips.LPIPS(net='vgg').to(device)
+ mseloss = nn.MSELoss()
+
+ for set_test in set_test_array:
+
+ for evaluation_mode in evaluation_modes:
+ print(f'Start evaluation loop: {evaluation_mode}')
+
+ # load data
+ dataloader = DataLoader(
+ dataset,
+ num_workers = 0,
+ pin_memory = False,
+ batch_size = cfg_net.batch_size,
+ shuffle = False,
+ drop_last = True,
+ )
+
+ # memory
+ root_path, plot_path = setup_result_folders(file, n, set_test, evaluation_mode, object_view, individual_views)
+ metric_complete = {'mse': [], 'ssim': [], 'psnr': [], 'percept_dist': [], 'ari': [], 'fari': [], 'miou': [], 'ap': [], 'ar': [], 'meds': [], 'ari_hidden': [], 'fari_hidden': [], 'miou_hidden': []}
+ video_list = []
+
+ # set seed: if there is a number in the evaluation mode, use it as seed
+ plot_mode = True
+ if evaluation_mode[-1].isdigit():
+ seed = int(evaluation_mode[-1])
+ th.manual_seed(seed)
+ np.random.seed(seed)
+ print(f'Set seed to {seed}')
+ if int(evaluation_mode[-1]) > 1:
+ plot_mode = False
+
+ with th.no_grad():
+ for i, input in enumerate(dataloader):
+ print(f'Processing sample {i+1}/{len(dataloader)}', flush=True)
+
+ # Load data
+ tensor = input[0].float().to(device)
+ background_fix = input[1].to(device)
+ gt_pos = input[2].to(device)
+ gt_mask = input[3].to(device)
+ gt_pres_mask = input[4].to(device)
+ gt_hidden_mask = input[5].to(device)
+ sequence_len = tensor.shape[1]
+
+ # Placehodlers
+ mask_cur = None
+ mask_last = None
+ rawmask_last = None
+ position_last = None
+ gestalt_last = None
+ priority_last = None
+ slots_occlusionfactor = None
+ error_last = None
+
+ # Memory
+ statistics_batch = deepcopy(statistics_template)
+ pred_pos_batch = th.zeros((cfg_net.batch_size, rollout_length, cfg_net.num_objects, 2)).to(device)
+ gt_pos_batch = th.zeros((cfg_net.batch_size, rollout_length, cfg_net.num_objects, 2)).to(device)
+ pred_img_batch = th.zeros((cfg_net.batch_size, rollout_length, 3, target_size[0], target_size[1])).to(device)
+ gt_img_batch = th.zeros((cfg_net.batch_size, rollout_length, 3, target_size[0], target_size[1])).to(device)
+ pred_mask_batch = th.zeros((cfg_net.batch_size, rollout_length, target_size[0], target_size[1])).to(device)
+ pred_hidden_mask_batch = th.zeros((cfg_net.batch_size, rollout_length, target_size[0], target_size[1])).to(device)
+
+ # Counters
+ num_rollout = 0
+ num_burnin = 0
+
+ # Loop through frames
+ for t_index,t in enumerate(range(-cfg.defaults.teacher_forcing, sequence_len-1)):
+
+ # Move to next frame
+ t_run = max(t, 0)
+ input = tensor[:,t_run]
+ target_cur = tensor[:,t_run]
+ target = th.clip(tensor[:,t_run+1], 0, 1)
+ gt_pos_t = gt_pos[:,t_run+1]/32-1
+ gt_pos_t = th.concat((gt_pos_t, th.ones_like(gt_pos_t[:,:,:1])), dim=2)
+
+ rollout_index = t_run - burn_in_length
+ rollout_active = False
+ if t>=0:
+ if rollout_index >= 0:
+ num_rollout += 1
+ if ('vidpred_black' in evaluation_mode):
+ input = output_next * 0
+ rollout_active = True
+ elif ('vidpred_auto' in evaluation_mode):
+ input = output_next
+ rollout_active = True
+ else:
+ num_burnin += 1
+
+ # obtain prediction
+ (
+ output_next,
+ position_next,
+ gestalt_next,
+ priority_next,
+ mask_next,
+ rawmask_next,
+ object_next,
+ background,
+ slots_occlusionfactor,
+ output_cur,
+ position_cur,
+ gestalt_cur,
+ priority_cur,
+ mask_cur,
+ rawmask_cur,
+ object_cur,
+ position_encoder_cur,
+ slots_bounded,
+ slots_partially_occluded_cur,
+ slots_occluded_cur,
+ slots_partially_occluded_next,
+ slots_occluded_next,
+ slots_closed,
+ output_hidden,
+ largest_object,
+ rawmask_hidden,
+ object_hidden
+ ) = net(
+ input,
+ error_last,
+ mask_last,
+ rawmask_last,
+ position_last,
+ gestalt_last,
+ priority_last,
+ background_fix,
+ slots_occlusionfactor,
+ reset = (t == -cfg.defaults.teacher_forcing),
+ evaluate=True,
+ warmup = (t < 0),
+ shuffleslots = True,
+ reset_mask = (t <= 0),
+ allow_spawn = True,
+ show_hidden = False,
+ clean_slots = False,
+ )
+
+ # 1. Track error for plots
+ if t >= 0:
+
+ if (rollout_index >= 0):
+ # store positions per batch
+ if use_meds:
+ if False:
+ pred_pos_batch[:,rollout_index] = rearrange(position_next, 'b (o c) -> b o c', o=cfg_net.num_objects)[:,:,:2]
+ else:
+ pred_pos_batch[:,rollout_index] = compute_position_from_mask(rawmask_next)
+
+ gt_pos_batch[:,rollout_index] = gt_pos_t[:,:,:2]
+
+ pred_img_batch[:,rollout_index] = output_next
+ gt_img_batch[:,rollout_index] = target
+
+ # Here we compute only the foreground segmentation mask
+ pred_mask_batch[:,rollout_index] = postproc_mask(mask_next[:,None,:,None])[:, 0]
+
+ # Here we compute the hidden segmentation
+ occluded_cur = th.clip(rawmask_next - mask_next, 0, 1)[:,:-1]
+ occluded_sum_cur = 1-(reduce(occluded_cur, 'b c h w -> b h w', 'max') > 0.5).float()
+ occluded_cur = th.cat((occluded_cur, occluded_sum_cur[:,None]), dim=1)
+ pred_hidden_mask_batch[:,rollout_index] = postproc_mask(occluded_cur[:,None,:,None])[:, 0]
+
+ # 2. Remember output
+ mask_last = mask_next.clone()
+ rawmask_last = rawmask_next.clone()
+ position_last = position_next.clone()
+ gestalt_last = gestalt_next.clone()
+ priority_last = priority_next.clone()
+
+ # 3. Error for next frame
+ bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+
+ # prediction error
+ error_next = th.sqrt(reduce((target - output_next)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+ error_next = th.sqrt(error_next) * bg_error_next
+ error_last = error_next.clone()
+
+ # PLotting
+ if i == 0 and plot_mode:
+ att = net.predictor.get_att_weights()
+ openings = net.get_openings()
+ img_tensor = plot_timestep(cfg, cfg_net, input, target_cur, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, None, None, error_next, None, True, individual_views, None, None, sequence_len, root_path, None, t_index, t, i, rollout_mode=rollout_active, num_vid=plot_first_samples, att= att, openings=None)
+ video_list.append(img_tensor)
+
+ # log video
+ if i == 0 and plot_mode:
+ video_tensor = rearrange(th.stack(video_list, dim=0), 't b c h w -> b t h w c')
+ save_videos(video_tensor, f'{plot_path}/object', verbose=verbose, trace_plot=True)
+
+ # Compute prediction accuracy based on Slotformer metrics (ARI, FARI, mIoU, AP, AR)
+ for b in range(cfg_net.batch_size):
+
+ # perceptual similarity from slotformer paper
+ metric_dict = pred_eval_step(
+ gt = gt_img_batch[b:b+1],
+ pred = pred_img_batch[b:b+1],
+ pred_mask = pred_mask_batch.long()[b:b+1],
+ pred_mask_hidden = pred_hidden_mask_batch.long()[b:b+1],
+ pred_bbox = None,
+ gt_mask = gt_mask.long()[b:b+1, burn_in_length+1:],
+ gt_mask_hidden = gt_hidden_mask.long()[b:b+1, burn_in_length+1:],
+ gt_pres_mask = gt_pres_mask[b:b+1, burn_in_length+1:],
+ gt_bbox = None,
+ lpips_fn = lpipsloss,
+ eval_traj = True,
+ )
+
+ metric_dict['meds'] = distance_eval_step(gt_pos_batch[b], pred_pos_batch[b])
+ metric_complete = append_statistics(metric_dict, metric_complete)
+
+ # sanity check
+ if (num_rollout != rollout_length) and (num_burnin != burn_in_length) and ('vidpred' in evaluation_mode):
+ raise ValueError('Number of rollout steps and burnin steps must be equal to the sequence length.')
+
+
+ average_dic = compute_statistics_summary(metric_complete, evaluation_mode, root_path=root_path, consider_first_n_frames=rollout_length_stats)
+
+ # Store statistics
+ with open(os.path.join(f'{root_path}/statistics', f'{evaluation_mode}_metric_complete.pkl'), 'wb') as f:
+ pickle.dump(metric_complete, f)
+ with open(os.path.join(f'{root_path}/statistics', f'{evaluation_mode}_metric_average.pkl'), 'wb') as f:
+ pickle.dump(average_dic, f)
+
+ print('-- Evaluation Done --')
+ if object_view and os.path.exists(f'{root_path}/tmp.jpg'):
+ os.remove(f'{root_path}/tmp.jpg')
+ pass
+
+# store videos as jpgs and then use ffmpeg to convert to video
+def save_videos(video_tensor, plot_path, verbose=False, fps=10, trace_plot=False):
+ video_tensor = video_tensor.cpu().numpy()
+ img_path = plot_path + '/img'
+ for b in range(video_tensor.shape[0]):
+ os.makedirs(img_path, exist_ok=True)
+ video = video_tensor[b]
+ video = (video).astype(np.uint8)
+ for t in range(video.shape[0]):
+ cv2.imwrite(f'{img_path}/{b}_{t:04d}.jpg', video[t])
+
+ if verbose:
+ os.system(f"ffmpeg -r {fps} -pattern_type glob -i '{img_path}/*.jpg' -c:v libx264 -y {plot_path}/{b}.mp4")
+ os.system(f'rm -rf {img_path}')
+
+ if trace_plot:
+ # trace plot
+ start = 15
+ length = 20
+ frame = np.zeros_like(video[0])
+ for i in range(start,start+length):
+ current = video[i] * (0.1 + (i-start)/length)
+ frame = np.max(np.stack((frame, current)), axis=0)
+ cv2.imwrite(f'{plot_path}/{b}_trace.jpg', frame)
+
+
+def distance_eval_step(gt_pos, pred_pos):
+ meds_per_timestep = []
+ gt_pred_pairings = None
+ for t in range(pred_pos.shape[0]):
+ frame_gt = gt_pos[t].cpu().numpy()
+ frame_pred = pred_pos[t].cpu().numpy()
+ frame_gt = (frame_gt + 1) * 0.5
+ frame_pred = (frame_pred + 1) * 0.5
+
+ distances = mm.distances.norm2squared_matrix(frame_gt, frame_pred, max_d2=1)
+ if gt_pred_pairings is None:
+ frame_gt_ids = list(range(frame_gt.shape[0]))
+ frame_pred_ids = list(range(frame_pred.shape[0]))
+ gt_pred_pairings = [(frame_gt_ids[g], frame_pred_ids[p]) for g, p in zip(*mm.lap.linear_sum_assignment(distances))]
+
+ med = 0
+ for gt_id, pred_id in gt_pred_pairings:
+ curr_med = np.sqrt(((frame_gt[gt_id] - frame_pred[pred_id])**2).sum())
+ med += curr_med
+ if len(gt_pred_pairings) > 0:
+ meds_per_timestep.append(med / len(gt_pred_pairings))
+ else:
+ meds_per_timestep.append(np.nan)
+ return meds_per_timestep
+
+
+def compute_plot_statistics(cfg_net, statistics_complete_slots, mseloss, set_test, evaluation_mode, i, statistics_batch, t, target, output_next, mask_next, slots_bounded, slots_closed, rawmask_hidden):
+ statistics_batch = store_statistics(statistics_batch,
+ set_test['type'],
+ evaluation_mode,
+ set_test['samples'][i],
+ t,
+ mseloss(output_next, target).item()
+ )
+
+ # compute slot-wise prediction error
+ output_slot = repeat(mask_next[:,:-1], 'b o h w -> b o 3 h w') * repeat(output_next, 'b c h w -> b o c h w', o=cfg_net.num_objects)
+ target_slot = repeat(mask_next[:,:-1], 'b o h w -> b o 3 h w') * repeat(target, 'b c h w -> b o c h w', o=cfg_net.num_objects)
+ slot_error = reduce((output_slot - target_slot)**2, 'b o c h w -> b o', 'mean')
+
+ # compute rawmask_size
+ rawmask_size = reduce(rawmask_hidden[:, :-1], 'b o h w-> b o', 'sum')
+
+ statistics_complete_slots = store_statistics(statistics_complete_slots,
+ [set_test['type']] * cfg_net.num_objects,
+ [evaluation_mode] * cfg_net.num_objects,
+ [set_test['samples'][i]] * cfg_net.num_objects,
+ [t] * cfg_net.num_objects,
+ range(cfg_net.num_objects),
+ slots_bounded.cpu().numpy().flatten().astype(int),
+ slot_error.cpu().numpy().flatten(),
+ rawmask_size.cpu().numpy().flatten(),
+ slots_closed[:, :, 1].cpu().numpy().flatten(),
+ slots_closed[:, :, 0].cpu().numpy().flatten(),
+ extend = True)
+
+ return statistics_complete_slots,statistics_batch
+
+def get_evaluation_sets(dataset):
+
+ set = {"samples": np.arange(len(dataset), dtype=int), "type": "test"}
+ evaluation_modes = ['open', 'vidpred_auto', 'vidpred_black_1', 'vidpred_black_2', 'vidpred_black_3', 'vidpred_black_4', 'vidpred_black_5'] # use 'open' for no blackouts
+ set_test_array = [set]
+
+ return set_test_array, evaluation_modes
diff --git a/scripts/evaluation_clevrer.py b/scripts/evaluation_clevrer.py
index a43d50c..168fba3 100644
--- a/scripts/evaluation_clevrer.py
+++ b/scripts/evaluation_clevrer.py
@@ -237,7 +237,7 @@ def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_frequency= 1, p
raise ValueError('Number of rollout steps and burnin steps must be equal to the sequence length.')
if not plotting_mode:
- average_dic = compute_statistics_summary(metric_complete, evaluation_mode)
+ average_dic = compute_statistics_summary(metric_complete, evaluation_mode, root_path=root_path)
# Store statistics
with open(os.path.join(f'{root_path}/statistics', f'{evaluation_mode}_metric_complete.pkl'), 'wb') as f:
@@ -250,24 +250,45 @@ def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_frequency= 1, p
os.remove(f'{root_path}/tmp.jpg')
pass
-def compute_statistics_summary(metric_complete, evaluation_mode):
+def compute_statistics_summary(metric_complete, evaluation_mode, root_path=None, consider_first_n_frames = None):
+ string = ''
+ def add_text(string, text, last=False):
+ string = string + ' \n ' + text
+ return string
+
average_dic = {}
+ if consider_first_n_frames is not None:
+ for key in metric_complete:
+ for sample in range(len(metric_complete[key])):
+ metric_complete[key][sample] = metric_complete[key][sample][:consider_first_n_frames]
+
for key in metric_complete:
# take average over all frames
- average_dic[key + 'complete_average'] = np.mean(metric_complete[key])
- average_dic[key + 'complete_std'] = np.std(metric_complete[key])
- print(f'{key} complete average: {average_dic[key + "complete_average"]:.4f} +/- {average_dic[key + "complete_std"]:.4f}')
+ average_dic[key + '_complete_average'] = np.mean(metric_complete[key])
+ average_dic[key + '_complete_std'] = np.std(metric_complete[key])
+ average_dic[key + '_complete_sum'] = np.sum(np.mean(metric_complete[key], axis=0)) # checked with GSWM code!
+ string = add_text(string, f'{key} complete average: {average_dic[key + "_complete_average"]:.4f} +/- {average_dic[key + "_complete_std"]:.4f} (sum: {average_dic[key + "_complete_sum"]:.4f})')
+ #print(f'{key} complete average: {average_dic[key + "complete_average"]:.4f} +/- {average_dic[key + "complete_std"]:.4f} (sum: {average_dic[key + "complete_sum"]:.4f})')
if evaluation_mode == 'blackout':
- # take average only for frames where blackout occurs
+ # take average only for frames where blackout occurs
blackout_mask = np.array(metric_complete['blackout']) > 0
- average_dic[key + 'blackout_average'] = np.mean(np.array(metric_complete[key])[blackout_mask])
- average_dic[key + 'blackout_std'] = np.std(np.array(metric_complete[key])[blackout_mask])
- average_dic[key + 'visible_average'] = np.mean(np.array(metric_complete[key])[blackout_mask == False])
- average_dic[key + 'visible_std'] = np.std(np.array(metric_complete[key])[blackout_mask == False])
+ average_dic[key + '_blackout_average'] = np.mean(np.array(metric_complete[key])[blackout_mask])
+ average_dic[key + '_blackout_std'] = np.std(np.array(metric_complete[key])[blackout_mask])
+ average_dic[key + '_visible_average'] = np.mean(np.array(metric_complete[key])[blackout_mask == False])
+ average_dic[key + '_visible_std'] = np.std(np.array(metric_complete[key])[blackout_mask == False])
- print(f'{key} blackout average: {average_dic[key + "blackout_average"]:.4f} +/- {average_dic[key + "blackout_std"]:.4f}')
- print(f'{key} visible average: {average_dic[key + "visible_average"]:.4f} +/- {average_dic[key + "visible_std"]:.4f}')
+ #print(f'{key} blackout average: {average_dic[key + "blackout_average"]:.4f} +/- {average_dic[key + "blackout_std"]:.4f}')
+ #print(f'{key} visible average: {average_dic[key + "visible_average"]:.4f} +/- {average_dic[key + "visible_std"]:.4f}')
+ string = add_text(string, f'{key} blackout average: {average_dic[key + "_blackout_average"]:.4f} +/- {average_dic[key + "_blackout_std"]:.4f}')
+ string = add_text(string, f'{key} visible average: {average_dic[key + "_visible_average"]:.4f} +/- {average_dic[key + "_visible_std"]:.4f}')
+
+ print(string)
+ if root_path is not None:
+ f'{root_path}/statistics', f'{evaluation_mode}_metric_complete.pkl'
+ with open(os.path.join(f'{root_path}/statistics', f'{evaluation_mode}_metric_average.txt'), 'w') as f:
+ f.write(string)
+
return average_dic
def compute_plot_statistics(cfg_net, statistics_complete_slots, mseloss, set_test, evaluation_mode, i, statistics_batch, t, target, output_next, mask_next, slots_bounded, slots_closed, rawmask_hidden):
diff --git a/scripts/exec/eval.py b/scripts/exec/eval.py
index 96ed32d..bc461ec 100644
--- a/scripts/exec/eval.py
+++ b/scripts/exec/eval.py
@@ -1,10 +1,29 @@
import argparse
import sys
from data.datasets.ADEPT.dataset import AdeptDataset
+from data.datasets.BOUNCINGBALLS.dataset import BouncingBallDataset
from data.datasets.CLEVRER.dataset import ClevrerDataset, ClevrerSample, RamImage
from scripts.utils.configuration import Configuration
-from scripts import evaluation_adept, evaluation_clevrer
+from scripts import evaluation_adept, evaluation_clevrer, evaluation_bb
+def main(load, n, cfg):
+
+ size = (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))
+
+ # Load dataset
+ if cfg.datatype == "clevrer":
+ testset = ClevrerDataset("./", cfg.dataset, 'val', size, use_slotformer=True, evaluation=True)
+ evaluation_clevrer.evaluate(cfg, testset, load, n, plot_frequency= 2, plot_first_samples = 3) # only plotting
+ evaluation_clevrer.evaluate(cfg, testset, load, n, plot_first_samples = 0) # evaluation
+ elif cfg.datatype == "adept":
+ testset = AdeptDataset("./", cfg.dataset, 'createdown', size)
+ evaluation_adept.evaluate(cfg, testset, load, n, plot_frequency= 1, plot_first_samples = 2)
+ elif cfg.datatype == "bouncingballs":
+ testset = BouncingBallDataset("./", cfg.dataset, "test", size, type_name = cfg.scenario)
+ evaluation_bb.evaluate(cfg, testset, load, n, plot_first_samples = 4)
+ else:
+ raise Exception("Dataset not supported")
+
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-cfg", default="", help='path to the configuration file')
@@ -16,13 +35,4 @@ if __name__ == "__main__":
cfg = Configuration(args.cfg)
print(f'Evaluating model {args.load}')
- # Load dataset
- if cfg.datatype == "clevrer":
- testset = ClevrerDataset("./", cfg.dataset, 'val', (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=True, evaluation=True)
- evaluation_clevrer.evaluate(cfg, testset, args.load, args.n, plot_frequency= 2, plot_first_samples = 3) # only plotting
- evaluation_clevrer.evaluate(cfg, testset, args.load, args.n, plot_first_samples = 0) # evaluation
- elif cfg.datatype == "adept":
- testset = AdeptDataset("./", cfg.dataset, 'createdown', (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
- evaluation_adept.evaluate(cfg, testset, args.load, args.n, plot_frequency= 1, plot_first_samples = 2)
- else:
- raise Exception("Dataset not supported") \ No newline at end of file
+ main(args.load, args.n, cfg) \ No newline at end of file
diff --git a/scripts/exec/eval_savi.py b/scripts/exec/eval_baselines.py
index 87d4e59..9e451ab 100644
--- a/scripts/exec/eval_savi.py
+++ b/scripts/exec/eval_baselines.py
@@ -1,12 +1,13 @@
import argparse
import sys
from data.datasets.ADEPT.dataset import AdeptDataset
-from scripts.evaluation_adept_savi import evaluate
+from scripts.evaluation_adept_baselines import evaluate
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-load", default="", type=str, help='path to savi slots')
parser.add_argument("-n", default="", type=str, help='results name')
+ parser.add_argument("-model", default="", type=str, help='model, either savi or gswm')
# Load configuration
args = parser.parse_args(sys.argv[1:])
@@ -14,4 +15,4 @@ if __name__ == "__main__":
# Load dataset
testset = AdeptDataset("./", 'adept', 'createdown', (30 * 2**(2*2), 20 * 2**(2*2)))
- evaluate(testset, args.load, args.n, plot_frequency= 1, plot_first_samples = 2) \ No newline at end of file
+ evaluate(testset, args.load, args.n, args.model, plot_frequency= 1, plot_first_samples = 2) \ No newline at end of file
diff --git a/scripts/exec/train.py b/scripts/exec/train.py
index f6e78fd..6fa6176 100644
--- a/scripts/exec/train.py
+++ b/scripts/exec/train.py
@@ -4,6 +4,8 @@ from scripts.utils.configuration import Configuration
from scripts import training
from data.datasets.CLEVRER.dataset import ClevrerDataset, ClevrerSample, RamImage
from data.datasets.ADEPT.dataset import AdeptDataset
+from data.datasets.BOUNCINGBALLS.dataset import BouncingBallDataset
+from scripts import evaluation_adept, evaluation_clevrer, evaluation_bb
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@@ -20,12 +22,17 @@ if __name__ == "__main__":
print(f'Training model {cfg.model_path}')
# Load dataset
+ size = (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))
if cfg.datatype == "clevrer":
- trainset = ClevrerDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=False)
- valset = ClevrerDataset("./", cfg.dataset, "val", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=True)
+ trainset = ClevrerDataset("./", cfg.dataset, "train", size, use_slotformer=False)
+ valset = ClevrerDataset("./", cfg.dataset, "val", size, use_slotformer=True)
elif cfg.datatype == "adept":
- trainset = AdeptDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
- valset = AdeptDataset("./", cfg.dataset, "test", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
+ trainset = AdeptDataset("./", cfg.dataset, "train", size)
+ valset = AdeptDataset("./", cfg.dataset, "test", size)
+ valset.train = True
+ elif cfg.datatype == "bouncingballs":
+ trainset = BouncingBallDataset("./", cfg.dataset, "train", size, type_name = cfg.scenario)
+ valset = BouncingBallDataset("./", cfg.dataset, "val", size, type_name = cfg.scenario)
else:
raise Exception("Dataset not supported")
diff --git a/scripts/training.py b/scripts/training.py
index b807e54..fd4a4bf 100644
--- a/scripts/training.py
+++ b/scripts/training.py
@@ -1,3 +1,4 @@
+from copy import deepcopy
import torch as th
from torch import nn
from torch.utils.data import Dataset, DataLoader, Subset
@@ -5,15 +6,18 @@ import cv2
import numpy as np
import os
from einops import rearrange, repeat, reduce
-from scripts.utils.configuration import Configuration
-from scripts.utils.io import init_device, model_path, LossLogger
+from scripts.utils.configuration import Configuration, Dict
+from scripts.utils.io import WriterWrapper, init_device, model_path, LossLogger
from scripts.utils.optimizers import RAdam
from model.loci import Loci
import random
from scripts.utils.io import Timer
-from scripts.utils.plot_utils import color_mask
-from scripts.validation import validation_clevrer, validation_adept
+from scripts.utils.plot_utils import color_mask, plot_timestep
+from scripts.validation import validation_bb, validation_clevrer, validation_adept
+from scripts.exec.eval import main as eval_main
+os.environ["WANDB__SERVICE_WAIT"] = "300"
+os.environ["WANDB_ DISABLE_SERVICE"] = "true"
def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
@@ -21,6 +25,12 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
device, verbose = init_device(cfg)
if verbose:
valset = Subset(valset, range(0, 8))
+ use_wandb = (verbose == False)
+
+ # generate random seed if not set
+ if not ('seed' in cfg.defaults):
+ cfg.defaults['seed'] = random.randint(0, 100)
+ th.manual_seed(cfg.defaults.seed)
# Define model path
path = model_path(cfg, overwrite=False)
@@ -34,12 +44,16 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
teacher_forcing = cfg.defaults.teacher_forcing
)
net = net.to(device=device)
+ #net = th.compile(net)
+ #net = th.jit.script(net)
# Log model size
log_modelsize(net)
+ writer = WriterWrapper(use_wandb, cfg)
# Init Optimizers
optimizer_init, optimizer_encoder, optimizer_decoder, optimizer_predictor, optimizer_background, optimizer_update = init_optimizer(cfg, net)
+ scheduler = init_lr_scheduler_StepLR(cfg, optimizer_init, optimizer_encoder, optimizer_decoder, optimizer_predictor, optimizer_background, optimizer_update)
# Option to load model
if file != "":
@@ -58,9 +72,8 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
print(f'loaded {file}', flush=True)
# Set up data loaders
- trainloader = get_loader(cfg, trainset, cfg_net, shuffle=True)
- valset.train = True #valset.dataset.train = True
- valloader = get_loader(cfg, valset, cfg_net, shuffle=False)
+ trainloader = get_loader(cfg, trainset, cfg_net, shuffle=True, verbose=verbose)
+ valloader = get_loader(cfg, valset, cfg_net, shuffle=False, verbose=verbose)
# initial save
save_model(
@@ -80,16 +93,24 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
print('!!! Net init status: ', net.get_init_status())
# Set up statistics
- loss_tracker = LossLogger()
+ loss_tracker = LossLogger(writer)
+
+ # Init loss function
+ imageloss = initialize_loss_function(cfg)
# Set up training variables
num_time_steps = 0
bptt_steps = cfg.bptt.bptt_steps
+ if not 'plot_interval' in cfg.defaults:
+ cfg.defaults.plot_interval = 20000
+ blackout_rate = cfg.blackout.blackout_rate if ('blackout' in cfg) else 0.0
+ rollout_length = cfg.vp.rollout_length if ('vp' in cfg) else 0
+ burnin_length = cfg.vp.burnin_length if ('vp' in cfg) else 0
increase_bptt_steps = False
background_blendin_factor = 0.0
th.backends.cudnn.benchmark = True
+ plot_next_sample = False
timer = Timer()
- bceloss = nn.BCELoss()
# Init net to current num_updates
if num_updates >= cfg.phases.background_pretraining_end and net.get_init_status() < 1:
@@ -101,7 +122,7 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
if num_updates >= cfg.phases.entity_pretraining_phase2_end and net.get_init_status() < 3:
net.inc_init_level()
for param in optimizer_init.param_groups:
- param['lr'] = cfg.learning_rate
+ param['lr'] = cfg.learning_rate.lr
if num_updates > cfg.phases.start_inner_loop:
net.cfg.inner_loop_enabled = True
@@ -117,16 +138,18 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
# Validation every epoch
if epoch >= 0:
if cfg.datatype == 'adept':
- validation_adept(valloader, net, cfg, device)
+ validation_adept(valloader, net, cfg, device, writer, epoch, path)
elif cfg.datatype == 'clevrer':
- validation_clevrer(valloader, net, cfg, device)
+ validation_clevrer(valloader, net, cfg, device, writer, epoch, path)
+ elif cfg.datatype == "bouncingballs" and (epoch % 2 == 0):
+ validation_bb(valloader, net, cfg, device, writer, epoch, path)
# Start epoch training
print('Start epoch:', epoch)
# Backprop through time steps
if increase_bptt_steps:
- bptt_steps = max(bptt_steps + 1, cfg.bptt.bptt_steps_max)
+ bptt_steps = min(bptt_steps + 1, cfg.bptt.bptt_steps_max)
print('Increase closed loop steps to', bptt_steps)
increase_bptt_steps = False
@@ -149,7 +172,8 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
slots_occlusionfactor = None
# Apply skip frames to sequence
- selec = range(random.randrange(cfg.defaults.skip_frames), tensor.shape[1], cfg.defaults.skip_frames)
+ start = random.randrange(cfg.defaults.skip_frames) if rollout_length == 0 else random.randrange(0, (tensor.shape[1]-(rollout_length+burnin_length)))
+ selec = range(start, tensor.shape[1], cfg.defaults.skip_frames)
tensor = tensor[:,selec]
sequence_len = tensor.shape[1]
@@ -159,6 +183,12 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
target = th.clip(input, 0, 1).detach()
error_last = None
+ # plotting mode
+ plot_this_sample = plot_next_sample
+ plot_next_sample = False
+ video_list = []
+ num_rollout = 0
+
# First apply teacher forcing for the first x frames
for t in range(-cfg.defaults.teacher_forcing, sequence_len-1):
@@ -185,13 +215,27 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
if net.get_init_status() > 2 and cfg.defaults.error_dropout > 0 and np.random.rand() < cfg.defaults.error_dropout:
error_last = th.zeros_like(error_last)
- # Apply sensation blackout when training clevrere
- if net.cfg.inner_loop_enabled and cfg.datatype == 'clevrer':
- if t >= 10:
- blackout = th.tensor((np.random.rand(cfg_net.batch_size) < 0.2)[:,None,None,None]).float().to(device)
+ # Apply sensation blackout when training clevrer
+ if net.cfg.inner_loop_enabled and blackout_rate > 0:
+ if t >= cfg.blackout.blackout_start_timestep:
+ blackout = th.tensor((np.random.rand(cfg_net.batch_size) < blackout_rate)[:,None,None,None]).float().to(device)
input = blackout * (input * 0) + (1-blackout) * input
error_last = blackout * (error_last * 0) + (1-blackout) * error_last
-
+
+ # Apply rollout training
+ if net.cfg.inner_loop_enabled and (rollout_length > 0) and (burnin_length-1) < t:
+ input = input * 0
+ error_last = error_last * 0
+ num_rollout += 1
+
+ if (burnin_length + rollout_length -1) == t:
+ run_optimizers = True
+ detach = True
+
+ if (burnin_length + rollout_length) == t:
+ break
+
+
# Forward Pass
(
output_next,
@@ -206,7 +250,9 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
slots_occlusionfactor,
position_loss,
time_loss,
- slots_closed
+ latent_loss,
+ slots_closed,
+ slots_bounded
) = net(
input, # current frame
error_last, # error of last frame --> missing object
@@ -228,6 +274,8 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
# Loss weighting
position_loss = position_loss * cfg_net.position_regularizer
time_loss = time_loss * cfg_net.time_regularizer
+ if latent_loss.item() > 0:
+ latent_loss = latent_loss * cfg_net.latent_regularizer
# Compute background error
bg_error_cur = th.sqrt(reduce((input - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
@@ -253,10 +301,14 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
background_blendin_factor = min(1, background_blendin_factor + 0.001)
# Final Loss computation
- encoder_loss = th.mean((output_cur - input)**2) * cfg_net.encoder_regularizer
- cliped_output_next = th.clip(output_next, 0, 1)
- loss = bceloss(cliped_output_next, target) + encoder_loss + position_loss + time_loss
+ encoding_loss = imageloss(output_cur, input) * cfg_net.encoder_regularizer
+ prediction_loss = imageloss(output_next, target)
+ loss = prediction_loss + encoding_loss + position_loss + time_loss + latent_loss
+ # apply loss decay according to num_rollout
+ if rollout_length > 0:
+ loss = loss * 0.75**(num_rollout-1)
+
# Accumulate loss over BPP steps
summed_loss = loss if summed_loss is None else summed_loss + loss
mask = mask.detach()
@@ -295,6 +347,7 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
# Update net status
update_net_status(num_updates, net, cfg, optimizer_init)
+ step_lr_scheduler(scheduler)
if num_updates == cfg.phases.start_inner_loop:
print('Start inner loop')
@@ -303,38 +356,39 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
if (cfg.bptt.increase_bptt_steps_every > 0) and ((num_updates-cfg.num_updates) % cfg.bptt.increase_bptt_steps_every == 0) and ((num_updates-cfg.num_updates) > 0):
increase_bptt_steps = True
+ if net.cfg.inner_loop_enabled and ('blackout' in cfg) and (cfg.blackout.blackout_increase_every > 0) and ((num_updates-cfg.num_updates) % cfg.blackout.blackout_increase_every == 0) and ((num_updates-cfg.num_updates) > 0):
+ blackout_rate = min(blackout_rate + cfg.blackout.blackout_increase_rate, cfg.blackout.blackout_rate_max)
+
# Plots for online evaluation
- if num_updates % 20000 == 0:
- plot_online(cfg, path, num_updates, input, background, mask, sequence_len, t, output_next, bg_error_next)
+ if num_updates % cfg.defaults.plot_interval == 0:
+ plot_next_sample = True
+ if plot_this_sample:
+ img_tensor = plot_online(cfg, path, f'{epoch}_{batch_index}', input, background, mask, sequence_len, t, output_next, bg_error_next)
+ video_list.append(img_tensor)
# Track statisitcs
if t >= cfg.defaults.statistics_offset:
- track_statistics(cfg_net, net, loss_tracker, input, gestalt, mask, target, output_next, output_cur, position_loss, time_loss, slots_closed, bg_error_cur, bg_error_next)
- loss_tracker.update_average_loss(loss.item())
+ track_statistics(cfg, net, loss_tracker, input, gestalt, mask, target, output_next, output_cur, encoding_loss, prediction_loss, position_loss, time_loss, latent_loss, slots_closed, slots_bounded, bg_error_cur, bg_error_next, scheduler, num_updates)
+ loss_tracker.update_average_loss(loss.item(), num_updates)
+ writer.add_scalar('Train/BPTT_steps', bptt_steps, num_updates)
+ writer.add_scalar('Train/Background_Blendin', background_blendin_factor, num_updates)
+ writer.add_scalar('Train/Blackout_Rate', blackout_rate, num_updates)
# Logging
if num_updates % 100 == 0 and run_optimizers:
print(f'Epoch[{num_updates}/{num_time_steps}/{sequence_len}]: {str(timer)}, {epoch + 1}, Blendin:{float(background_blendin_factor)}, i: {net.get_init_status() + net.initial_states.init.get():.2f},' + loss_tracker.get_log(), flush=True)
# Training finished
+ net_path = None
if num_updates > cfg.max_updates:
- save_model(
- os.path.join(path, 'nets', 'net_final.pt'),
- net,
- optimizer_init,
- optimizer_encoder,
- optimizer_decoder,
- optimizer_predictor,
- optimizer_background
- )
- print("Training finished")
- return
-
- # Checkpointing
+ net_path = os.path.join(path, 'nets', 'net_final.pt')
if num_updates % 50000 == 0 and run_optimizers:
+ net_path = os.path.join(path, 'nets', f'net_{num_updates}.pt')
+
+ if net_path is not None:
save_model(
- os.path.join(path, 'nets', f'net_{num_updates}.pt'),
+ net_path,
net,
optimizer_init,
optimizer_encoder,
@@ -342,33 +396,56 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
optimizer_predictor,
optimizer_background
)
- pass
+ if ('final' in net_path) or ('num_objects_test' in cfg.model):
+ eval_main(net_path, 1, deepcopy(cfg))
-def track_statistics(cfg_net, net, loss_tracker, input, gestalt, mask, target, output_next, output_cur, position_loss, time_loss, slots_closed, bg_error_cur, bg_error_next):
-
- # Prediction Loss
- mseloss = nn.MSELoss()
- loss_next = mseloss(output_next * bg_error_next, target * bg_error_next)
+ if 'final' in net_path:
+ print("Training finished")
+ writer.flush()
+ return
- # Encoder Loss (only for stats)
- loss_cur = mseloss(output_cur * bg_error_cur, input * bg_error_cur)
+ if plot_this_sample:
+ video_tensor = rearrange(th.stack(video_list, dim=0)[None], 'b t h w c -> b t c h w')
+ writer.add_video('Train/Video', video_tensor, num_updates)
+
+ pass
+def track_statistics(cfg, net, loss_tracker, input, gestalt, mask, target, output_next, output_cur, encoding_loss, prediction_loss, position_loss, time_loss, latent_loss, slots_closed, slots_bounded, bg_error_cur, bg_error_next, scheduler, num_updates):
+
# area of foreground mask
num_objects = th.mean(reduce((reduce(mask[:,:-1], 'b c h w -> b c', 'max') > 0.5).float(), 'b c -> b', 'sum')).item()
# difference in shape
- _gestalt = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt)), 'b (o c) -> (b o)', 'mean', o = cfg_net.num_objects)
- _gestalt2 = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt))**2, 'b (o c) -> (b o)', 'mean', o = cfg_net.num_objects)
+ _gestalt = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt)), 'b (o c) -> (b o)', 'mean', o = cfg.model.num_objects)
+ _gestalt2 = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt))**2, 'b (o c) -> (b o)', 'mean', o = cfg.model.num_objects)
max_mask = (reduce(mask[:,:-1], 'b c h w -> (b c)', 'max') > 0.5).float()
avg_gestalt = (th.sum(_gestalt * max_mask) / (1e-16 + th.sum(max_mask)))
avg_gestalt2 = (th.sum(_gestalt2 * max_mask) / (1e-16 + th.sum(max_mask)))
avg_gestalt_mean = th.mean(th.clip(gestalt, 0, 1))
# udpdate gates
- avg_update_gestalt = slots_closed[:,:,0].mean()
- avg_update_position = slots_closed[:,:,1].mean()
+ num_bounded = reduce(slots_bounded, 'b o -> b', 'sum').mean().item()
+ slots_closed = slots_closed * slots_bounded[:,:,None]
+ avg_update_gestalt = slots_closed[:,:,0].sum()/slots_bounded.sum() if slots_bounded.sum() > 0 else 0.0
+ avg_update_position = slots_closed[:,:,1].sum()/slots_bounded.sum() if slots_bounded.sum() > 0 else 0.0
+ avg_update_gestalt = float(avg_update_gestalt)
+ avg_update_position = float(avg_update_position)
+
+ # Prediction Loss + Encoder Loss as MSE + only foreground pixels
+ mseloss = nn.MSELoss()
+ loss_next = mseloss(output_next * bg_error_next, target * bg_error_next)
+ loss_cur = mseloss(output_cur * bg_error_cur, input * bg_error_cur)
+
+ # learning rate
+ if scheduler is not None:
+ lr = scheduler[0].get_last_lr()[0]
+ else:
+ lr = cfg.learning_rate.lr
- loss_tracker.update_complete(position_loss, time_loss, loss_cur, loss_next, loss_next, num_objects, net.get_openings(), avg_gestalt, avg_gestalt2, avg_gestalt_mean, avg_update_gestalt, avg_update_position)
+ # gatelORD openings
+ openings = th.mean(net.get_openings()).item()
+
+ loss_tracker.update_complete(position_loss, time_loss, latent_loss, loss_cur, loss_next, num_objects, openings, avg_gestalt, avg_gestalt2, avg_gestalt_mean, avg_update_gestalt, avg_update_position, num_bounded, lr, num_updates)
pass
def log_modelsize(net):
@@ -381,15 +458,49 @@ def log_modelsize(net):
print("\n")
pass
+def initialize_loss_function(cfg):
+ if not ('loss' in cfg.defaults):
+ cfg.defaults.loss = 'bce'
+
+ if cfg.defaults.loss == 'mse':
+ imageloss = nn.MSELoss()
+ elif cfg.defaults.loss == 'bce':
+ imageloss = nn.BCELoss()
+ else:
+ raise NotImplementedError
+
+ return imageloss
+
def init_optimizer(cfg, net):
- optimizer_init = RAdam(net.initial_states.parameters(), lr = cfg.learning_rate * 30)
- optimizer_encoder = RAdam(net.encoder.parameters(), lr = cfg.learning_rate)
- optimizer_decoder = RAdam(net.decoder.parameters(), lr = cfg.learning_rate)
- optimizer_predictor = RAdam(net.predictor.parameters(), lr = cfg.learning_rate)
- optimizer_background = RAdam([net.background.mask], lr = cfg.learning_rate)
- optimizer_update = RAdam(net.percept_gate_controller.parameters(), lr = cfg.learning_rate)
+ # backward compability:
+ if not isinstance(cfg.learning_rate, dict):
+ cfg.learning_rate = Dict({'lr': cfg.learning_rate})
+
+ lr = cfg.learning_rate.lr
+ optimizer_init = RAdam(net.initial_states.parameters(), lr = lr * 30)
+ optimizer_encoder = RAdam(net.encoder.parameters(), lr = lr)
+ optimizer_decoder = RAdam(net.decoder.parameters(), lr = lr)
+ optimizer_predictor = RAdam(net.predictor.parameters(), lr = lr)
+ optimizer_background = RAdam([net.background.mask], lr = lr)
+ optimizer_update = RAdam(net.percept_gate_controller.parameters(), lr = lr)
return optimizer_init,optimizer_encoder,optimizer_decoder,optimizer_predictor,optimizer_background,optimizer_update
+def init_lr_scheduler_StepLR(cfg, *optimizer_list):
+ scheduler_list = None
+ if 'deacrease_lr_every' in cfg.learning_rate:
+ print('Init lr scheduler')
+ scheduler_list = []
+ for optimizer in optimizer_list:
+ scheduler = th.optim.lr_scheduler.StepLR(optimizer, step_size=cfg.learning_rate.deacrease_lr_every, gamma=cfg.learning_rate.deacrease_lr_factor)
+ scheduler_list.append(scheduler)
+ return scheduler_list
+
+def step_lr_scheduler(scheduler_list):
+ if scheduler_list is not None:
+ for scheduler in scheduler_list:
+ scheduler.step()
+ pass
+
def save_model(
file,
net,
@@ -432,23 +543,23 @@ def load_model(
if load_optimizers:
optimizer_init.load_state_dict(state[f'optimizer_init'])
for n in range(len(optimizer_init.param_groups)):
- optimizer_init.param_groups[n]['lr'] = cfg.learning_rate
+ optimizer_init.param_groups[n]['lr'] = cfg.learning_rate.lr
optimizer_encoder.load_state_dict(state[f'optimizer_encoder'])
for n in range(len(optimizer_encoder.param_groups)):
- optimizer_encoder.param_groups[n]['lr'] = cfg.learning_rate
+ optimizer_encoder.param_groups[n]['lr'] = cfg.learning_rate.lr
optimizer_decoder.load_state_dict(state[f'optimizer_decoder'])
for n in range(len(optimizer_decoder.param_groups)):
- optimizer_decoder.param_groups[n]['lr'] = cfg.learning_rate
+ optimizer_decoder.param_groups[n]['lr'] = cfg.learning_rate.lr
optimizer_predictor.load_state_dict(state['optimizer_predictor'])
for n in range(len(optimizer_predictor.param_groups)):
- optimizer_predictor.param_groups[n]['lr'] = cfg.learning_rate
+ optimizer_predictor.param_groups[n]['lr'] = cfg.learning_rate.lr
optimizer_background.load_state_dict(state['optimizer_background'])
for n in range(len(optimizer_background.param_groups)):
- optimizer_background.param_groups[n]['lr'] = cfg.model.background.learning_rate
+ optimizer_background.param_groups[n]['lr'] = cfg.model.background.learning_rate.lr
# 1. Fill model with values of net
model = {}
@@ -484,40 +595,61 @@ def update_net_status(num_updates, net, cfg, optimizer_init):
if num_updates == cfg.phases.entity_pretraining_phase2_end and net.get_init_status() < 3:
net.inc_init_level()
for param in optimizer_init.param_groups:
- param['lr'] = cfg.learning_rate
+ param['lr'] = cfg.learning_rate.lr
pass
def plot_online(cfg, path, num_updates, input, background, mask, sequence_len, t, output_next, bg_error_next):
- plot_path = os.path.join(path, 'plots', f'net_{num_updates}')
- if not os.path.exists(plot_path):
+
+ # highlight error
+ grayscale = input[:,0:1] * 0.299 + input[:,1:2] * 0.587 + input[:,2:3] * 0.114
+ object_mask_cur = th.sum(mask[:,:-1], dim=1).unsqueeze(dim=1)
+ highlited_input = grayscale * (1 - object_mask_cur)
+ highlited_input += grayscale * object_mask_cur * 0.3333333
+ cmask = color_mask(mask[:,:-1])
+ highlited_input = highlited_input + cmask * 0.6666666
+
+ input_ = rearrange(input[0], 'c h w -> h w c').detach().cpu()
+ background_ = rearrange(background[0], 'c h w -> h w c').detach().cpu()
+ mask_ = rearrange(mask[0,-1:], 'c h w -> h w c').detach().cpu()
+ output_next_ = rearrange(output_next[0], 'c h w -> h w c').detach().cpu()
+ bg_error_next_ = rearrange(bg_error_next[0], 'c h w -> h w c').detach().cpu()
+ highlited_input_ = rearrange(highlited_input[0], 'c h w -> h w c').detach().cpu()
+
+ if False:
+ plot_path = os.path.join(path, 'plots', f'net_{num_updates}')
os.makedirs(plot_path, exist_ok=True)
-
- # highlight error
- grayscale = input[:,0:1] * 0.299 + input[:,1:2] * 0.587 + input[:,2:3] * 0.114
- object_mask_cur = th.sum(mask[:,:-1], dim=1).unsqueeze(dim=1)
- highlited_input = grayscale * (1 - object_mask_cur)
- highlited_input += grayscale * object_mask_cur * 0.3333333
- cmask = color_mask(mask[:,:-1])
- highlited_input = highlited_input + cmask * 0.6666666
-
- cv2.imwrite(os.path.join(plot_path, f'input-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(input[0], 'c h w -> h w c').detach().cpu().numpy() * 255)
- cv2.imwrite(os.path.join(plot_path, f'background-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(background[0], 'c h w -> h w c').detach().cpu().numpy() * 255)
- cv2.imwrite(os.path.join(plot_path, f'error_mask-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(bg_error_next[0], 'c h w -> h w c').detach().cpu().numpy() * 255)
- cv2.imwrite(os.path.join(plot_path, f'background_mask-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(mask[0,-1:], 'c h w -> h w c').detach().cpu().numpy() * 255)
- cv2.imwrite(os.path.join(plot_path, f'output_next-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(output_next[0], 'c h w -> h w c').detach().cpu().numpy() * 255)
- cv2.imwrite(os.path.join(plot_path, f'output_highlight-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(highlited_input[0], 'c h w -> h w c').detach().cpu().numpy() * 255)
- pass
-
-def get_loader(cfg, set, cfg_net, shuffle = True):
- loader = DataLoader(
- set,
- pin_memory = True,
- num_workers = cfg.defaults.num_workers,
- batch_size = cfg_net.batch_size,
- shuffle = shuffle,
- drop_last = True,
- prefetch_factor = cfg.defaults.prefetch_factor,
- persistent_workers = True
- )
+ cv2.imwrite(os.path.join(plot_path, f'input-{t+cfg.defaults.teacher_forcing:03d}.jpg'), input_.numpy() * 255)
+ cv2.imwrite(os.path.join(plot_path, f'background-{t+cfg.defaults.teacher_forcing:03d}.jpg'), background_.numpy() * 255)
+ cv2.imwrite(os.path.join(plot_path, f'error_mask-{t+cfg.defaults.teacher_forcing:03d}.jpg'), bg_error_next_.numpy() * 255)
+ cv2.imwrite(os.path.join(plot_path, f'background_mask-{t+cfg.defaults.teacher_forcing:03d}.jpg'), mask_.numpy() * 255)
+ cv2.imwrite(os.path.join(plot_path, f'output_next-{t+cfg.defaults.teacher_forcing:03d}.jpg'), output_next_.numpy() * 255)
+ cv2.imwrite(os.path.join(plot_path, f'output_highlight-{t+cfg.defaults.teacher_forcing:03d}.jpg'), highlited_input_.numpy() * 255)
+
+ # stack input, output, mask and highlight into one image horizontally
+ img_tensor = th.cat([input_, highlited_input_, output_next_], dim=1)
+ return img_tensor
+
+def get_loader(cfg, set, cfg_net, shuffle = True, verbose=False):
+ if ((cfg.datatype == 'bouncingballs') and verbose) or ((cfg.datatype == 'adept') and not shuffle):
+ loader = DataLoader(
+ set,
+ pin_memory = True,
+ num_workers = 0,
+ batch_size = cfg_net.batch_size,
+ shuffle = shuffle,
+ drop_last = True,
+ persistent_workers = False
+ )
+ else:
+ loader = DataLoader(
+ set,
+ pin_memory = True,
+ num_workers = cfg.defaults.num_workers,
+ batch_size = cfg_net.batch_size,
+ shuffle = shuffle,
+ drop_last = True,
+ prefetch_factor = cfg.defaults.prefetch_factor,
+ persistent_workers = True
+ )
return loader \ No newline at end of file
diff --git a/scripts/utils/eval_adept.py b/scripts/utils/eval_adept.py
new file mode 100644
index 0000000..7b8dfa8
--- /dev/null
+++ b/scripts/utils/eval_adept.py
@@ -0,0 +1,169 @@
+import pandas as pd
+import warnings
+import os
+import argparse
+
+warnings.simplefilter(action='ignore', category=FutureWarning)
+pd.options.mode.chained_assignment = None
+
+def eval_adept(path):
+ net = 'net1'
+
+ # read pickle file
+ tf = pd.DataFrame()
+ sf = pd.DataFrame()
+ af = pd.DataFrame()
+
+ with open(os.path.join(path, 'trialframe.csv'), 'rb') as f:
+ tf_temp = pd.read_csv(f, index_col=0)
+ tf_temp['net'] = net
+ tf = pd.concat([tf,tf_temp])
+
+ with open(os.path.join(path, 'slotframe.csv'), 'rb') as f:
+ sf_temp = pd.read_csv(f, index_col=0)
+ sf_temp['net'] = net
+ sf = pd.concat([sf,sf_temp])
+
+ with open(os.path.join(path, 'accframe.csv'), 'rb') as f:
+ af_temp = pd.read_csv(f, index_col=0)
+ af_temp['net'] = net
+ af = pd.concat([af,af_temp])
+
+ # cast variables
+ sf['visible'] = sf['visible'].astype(bool)
+ sf['bound'] = sf['bound'].astype(bool)
+ sf['occluder'] = sf['occluder'].astype(bool)
+ sf['inimage'] = sf['inimage'].astype(bool)
+ sf['vanishing'] = sf['vanishing'].astype(bool)
+ sf['alpha_pos'] = 1-sf['alpha_pos']
+ sf['alpha_ges'] = 1-sf['alpha_ges']
+
+ # scale to percentage
+ sf['TE'] = sf['TE'] * 100
+
+ # add surprise as dummy code
+ tf['control'] = [('control' in set) for set in tf['set']]
+ sf['control'] = [('control' in set) for set in sf['set']]
+
+
+
+ # STATS:
+ tracking_error_visible = 0
+ tracking_error_occluded = 0
+ num_positive_trackings = 0
+ mota = 0
+ gate_openings_visible = 0
+ gate_openings_occluded = 0
+
+
+ print('Tracking Error ------------------------------')
+ grouping = (sf.inimage & sf.bound & ~sf.occluder & sf.control)
+
+ def get_stats(col):
+ return f' M: {col.mean():.3} , STD: {col.std():.3}, Count: {col.count()}'
+
+ # When Visible
+ temp = sf[grouping & sf.visible]
+ print(f'Tracking Error when visible:' + get_stats(temp['TE']))
+ tracking_error_visible = temp['TE'].mean()
+
+ # When Occluded
+ temp = sf[grouping & ~sf.visible]
+ print(f'Tracking Error when occluded:' + get_stats(temp['TE']))
+ tracking_error_occluded = temp['TE'].mean()
+
+
+
+
+
+
+ print('Positive Trackings ------------------------------')
+ # succesfull trackings: In the last visible moment of the target, the slot was less than 10% away from the target
+ # determine last visible frame numeric
+ grouping_factors = ['net','set','evalmode','scene','slot']
+ ff = sf[sf.visible & sf.bound & sf.inimage].groupby(grouping_factors).max()
+ ff.rename(columns = {'frame':'last_visible'}, inplace = True)
+ sf = sf.merge(ff[['last_visible']], on=grouping_factors, how='left')
+
+ # same for first bound frame
+ ff = sf[sf.visible & sf.bound & sf.inimage].groupby(grouping_factors).min()
+ ff.rename(columns = {'frame':'first_visible'}, inplace = True)
+ sf = sf.merge(ff[['first_visible']], on=grouping_factors, how='left')
+
+ # add dummy variable to sf
+ sf['last_visible'] = (sf['last_visible'] == sf['frame'])
+
+ # extract the trials where the target was last visible and threshold the TE
+ ff = sf[sf['last_visible']]
+ ff['tracked_pos'] = (ff['TE'] < 10)
+ ff['tracked_neg'] = (ff['TE'] >= 10)
+
+ # fill NaN with 0
+ sf = sf.merge(ff[grouping_factors + ['tracked_pos', 'tracked_neg']], on=grouping_factors, how='left')
+ sf['tracked_pos'].fillna(False, inplace=True)
+ sf['tracked_neg'].fillna(False, inplace=True)
+
+ # Aggreagte over all scenes
+ temp = sf[(sf['frame']== 1) & ~sf.occluder & sf.control & (sf.first_visible < 20)]
+ temp = temp.groupby(['set', 'evalmode']).sum()
+ temp = temp[['tracked_pos', 'tracked_neg']]
+ temp = temp.reset_index()
+
+ temp['tracked_pos_pro'] = temp['tracked_pos'] / (temp['tracked_pos'] + temp['tracked_neg'])
+ temp['tracked_neg_pro'] = temp['tracked_neg'] / (temp['tracked_pos'] + temp['tracked_neg'])
+ print(temp)
+ num_positive_trackings = temp['tracked_pos_pro']
+
+
+
+
+
+ print('Mostly Trecked /MOTA ------------------------------')
+ temp = af[af.index == 'OVERALL']
+ temp['mostly_tracked'] = temp['mostly_tracked'] / temp['num_unique_objects']
+ temp['partially_tracked'] = temp['partially_tracked'] / temp['num_unique_objects']
+ temp['mostly_lost'] = temp['mostly_lost'] / temp['num_unique_objects']
+ print(temp)
+ mota = temp['mota']
+
+
+ print('Openings ------------------------------')
+ grouping = (sf.inimage & sf.bound & ~sf.occluder & sf.control)
+ temp = sf[grouping & sf.visible]
+ print(f'Percept gate openings when visible:' + get_stats(temp['alpha_pos'] + temp['alpha_ges']))
+ gate_openings_visible = temp['alpha_pos'].mean() + temp['alpha_ges'].mean()
+
+ temp = sf[grouping & ~sf.visible]
+ print(f'Percept gate openings when occluded:' + get_stats(temp['alpha_pos'] + temp['alpha_ges']))
+ gate_openings_occluded = temp['alpha_pos'].mean() + temp['alpha_ges'].mean()
+
+
+ print('------------------------------------------------')
+ print('------------------------------------------------')
+ str = ''
+ str += f'net: {net}\n'
+ str += f'Tracking Error when visible: {tracking_error_visible:.3}\n'
+ str += f'Tracking Error when occluded: {tracking_error_occluded:.3}\n'
+ str += 'Positive Trackings: ' + ', '.join(f'{val:.3}' for val in num_positive_trackings) + '\n'
+ str += 'MOTA: ' + ', '.join(f'{val:.3}' for val in mota) + '\n'
+ str += f'Percept gate openings when visible: {gate_openings_visible:.3}\n'
+ str += f'Percept gate openings when occluded: {gate_openings_occluded:.3}\n'
+
+ print(str)
+
+ # write tstring to file
+ with open(os.path.join(path, 'results.txt'), 'w') as f:
+ f.write(str)
+
+
+
+if __name__ == "__main__":
+
+ # use argparse to get the path to the results folder
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--path', type=str, default='')
+ args = parser.parse_args()
+
+ # setting path to results folder
+ path = args.path
+ eval_adept(path) \ No newline at end of file
diff --git a/scripts/utils/eval_metrics.py b/scripts/utils/eval_metrics.py
index 6f5106d..98af468 100644
--- a/scripts/utils/eval_metrics.py
+++ b/scripts/utils/eval_metrics.py
@@ -3,6 +3,8 @@ vp_utils.py
SCRIPT TAKEN FROM https://github.com/pairlab/SlotFormer
'''
+import cv2
+from einops import rearrange
import numpy as np
from scipy.optimize import linear_sum_assignment
from skimage.metrics import structural_similarity, peak_signal_noise_ratio
@@ -11,7 +13,7 @@ import torch
import torch.nn.functional as F
import torchvision.ops as vops
-FG_THRE = 0.5
+FG_THRE = 0.3 # 0.5 for the rest, 0.3 for bouncingballs
PALETTE = [(0, 255, 0), (0, 0, 255), (0, 255, 255), (255, 255, 0),
(255, 0, 255), (100, 100, 255), (200, 200, 100), (170, 120, 200),
(255, 0, 0), (200, 100, 100), (10, 200, 100), (200, 200, 200),
@@ -272,6 +274,8 @@ def pred_eval_step(
gt_bbox=None,
pred_bbox=None,
eval_traj=True,
+ gt_mask_hidden=None,
+ pred_mask_hidden=None,
):
"""Both of shape [B, T, C, H, W], torch.Tensor.
masks of shape [B, T, H, W].
@@ -288,13 +292,14 @@ def pred_eval_step(
if eval_traj:
assert len(gt_mask.shape) == len(pred_mask.shape) == 4
assert gt_mask.shape == pred_mask.shape
- if eval_traj:
+ if eval_traj and gt_bbox is not None:
assert len(gt_pres_mask.shape) == 3
assert len(gt_bbox.shape) == len(pred_bbox.shape) == 4
T = gt.shape[1]
# compute perceptual dist & mask metrics before converting to numpy
all_percept_dist, all_ari, all_fari, all_miou = [], [], [], []
+ all_ari_hidden, all_fari_hidden, all_miou_hidden = [], [], []
for t in range(T):
one_gt, one_pred = gt[:, t], pred[:, t]
percept_dist = perceptual_dist(one_gt, one_pred, lpips_fn).item()
@@ -307,6 +312,29 @@ def pred_eval_step(
all_ari.append(ari)
all_fari.append(fari)
all_miou.append(miou)
+
+ # hidden:
+ if gt_mask_hidden is not None:
+ one_gt_mask, one_pred_mask = gt_mask_hidden[:, t], pred_mask_hidden[:, t]
+ ari = ARI_metric(one_gt_mask, one_pred_mask)
+ fari = fARI_metric(one_gt_mask, one_pred_mask)
+ miou = miou_metric(one_gt_mask, one_pred_mask)
+ all_ari_hidden.append(ari)
+ all_fari_hidden.append(fari)
+ all_miou_hidden.append(miou)
+
+
+ # dispalay masks with cv2
+ if False:
+ one_gt_mask_, one_pred_mask_ = one_gt_mask.clone(), one_pred_mask.clone()
+
+ # concat one_gt_mask and one_pred_mask horizontally
+ frame = np.concatenate((one_gt_mask_, one_pred_mask_), axis=1)
+ frame = rearrange(frame, 'c h w -> h w c') * (255/6)
+ frame = frame.astype(np.uint8)
+ cv2.imshow('frame', frame)
+ cv2.waitKey(0)
+
else:
all_ari.append(0.)
all_fari.append(0.)
@@ -315,14 +343,14 @@ def pred_eval_step(
# compute bbox metrics
all_ap, all_ar = [], []
for t in range(T):
- if not eval_traj:
+ if not eval_traj or gt_bbox is None:
all_ap.append(0.)
all_ar.append(0.)
continue
one_gt_pres_mask, one_gt_bbox, one_pred_bbox = \
gt_pres_mask[:, t], gt_bbox[:, t], pred_bbox[:, t]
ap, ar = batch_bbox_precision_recall(one_gt_pres_mask, one_gt_bbox,
- one_pred_bbox)
+ one_pred_bbox)
all_ap.append(ap)
all_ar.append(ar)
@@ -333,11 +361,13 @@ def pred_eval_step(
one_gt, one_pred = gt[:, t], pred[:, t]
mse = mse_metric(one_gt, one_pred)
psnr = psnr_metric(one_gt, one_pred)
- ssim = ssim_metric(one_gt, one_pred)
+ #ssim = ssim_metric(one_gt, one_pred)
+ ssim = 0
all_mse.append(mse)
all_ssim.append(ssim)
all_psnr.append(psnr)
- return {
+
+ res = {
'mse': all_mse,
'ssim': all_ssim,
'psnr': all_psnr,
@@ -346,5 +376,12 @@ def pred_eval_step(
'fari': all_fari,
'miou': all_miou,
'ap': all_ap,
- 'ar': all_ar,
+ 'ar': all_ar
}
+
+ if gt_mask_hidden is not None:
+ res['ari_hidden'] = all_ari_hidden
+ res['fari_hidden'] = all_fari_hidden
+ res['miou_hidden'] = all_miou_hidden
+
+ return res
diff --git a/scripts/utils/eval_utils.py b/scripts/utils/eval_utils.py
index faab7ec..a01ffd0 100644
--- a/scripts/utils/eval_utils.py
+++ b/scripts/utils/eval_utils.py
@@ -1,18 +1,92 @@
import os
+import shutil
+from einops import rearrange
import torch as th
from model.loci import Loci
+def masks_to_boxes(masks: th.Tensor) -> th.Tensor:
+ """
+ Compute the bounding boxes around the provided masks.
+
+ Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with
+ ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
+
+ Args:
+ masks (Tensor[N, H, W]): masks to transform where N is the number of masks
+ and (H, W) are the spatial dimensions.
+
+ Returns:
+ Tensor[N, 4]: bounding boxes
+ """
+ if masks.numel() == 0:
+ return th.zeros((0, 4), device=masks.device, dtype=th.float)
+
+ n = masks.shape[0]
+
+ bounding_boxes = th.zeros((n, 4), device=masks.device, dtype=th.float)
+
+ for index, mask in enumerate(masks):
+ if mask.sum() > 0:
+ y, x = th.where(mask != 0)
+
+ bounding_boxes[index, 0] = th.min(x)
+ bounding_boxes[index, 1] = th.min(y)
+ bounding_boxes[index, 2] = th.max(x)
+ bounding_boxes[index, 3] = th.max(y)
+
+ return bounding_boxes
+
+def boxes_to_centroids(boxes):
+ """Post-process masks instead of directly taking argmax.
+
+ Args:
+ bboxes: [B, T, N, 4], 4: [x1, y1, x2, y2]
+
+ Returns:
+ centroids: [B, T, N, 2], 2: [x, y]
+ """
+
+ centroids = (boxes[:, :, :, :2] + boxes[:, :, :, 2:]) / 2
+ centroids = centroids.squeeze(0)
+
+ # scale to [-1, 1]
+ centroids[:, :, 0] = centroids[:, :, 0] / 64 * 2 - 1
+ centroids[:, :, 1] = centroids[:, :, 1] / 64 * 2 - 1
+
+ return centroids
+
+def compute_position_from_mask(mask):
+ """
+ Compute the position of the object from the mask.
+
+ Args:
+ mask (Tensor[B, N, H, W]): masks to transform where N is the number of masks
+ and (H, W) are the spatial dimensions.
+
+ Returns:
+ Tensor[B, N, 2]: position of the object
+
+ """
+ masks_binary = (mask > 0.8).float()[:, :-1]
+ b, o, h, w = masks_binary.shape
+ masks2 = rearrange(masks_binary, 'b o h w -> (b o) h w')
+ boxes = masks_to_boxes(masks2.long())
+ boxes = rearrange(boxes, '(b o) c -> b 1 o c', b=b, o=o)
+ centroids = boxes_to_centroids(boxes)
+ centroids = centroids[:, :, :, [1, 0]].squeeze(1)
+ return centroids
+
def setup_result_folders(file, name, set_test, evaluation_mode, object_view, individual_views):
net_name = file.split('/')[-1].split('.')[0]
#root_path = file.split('nets')[0]
- root_path = os.path.join(*file.split('/')[0:-1])
+ root_path = os.path.join(*file.split('/')[0:-2])
root_path = os.path.join(root_path, f'results{name}', net_name, set_test['type'])
plot_path = os.path.join(root_path, evaluation_mode)
# create directories
- #if os.path.exists(plot_path):
- # shutil.rmtree(plot_path)
+ if os.path.exists(plot_path):
+ shutil.rmtree(plot_path)
os.makedirs(plot_path, exist_ok = True)
if object_view:
os.makedirs(os.path.join(plot_path, 'object'), exist_ok = True)
@@ -59,13 +133,21 @@ def load_model(cfg, cfg_net, file, device):
print(f"load {file} to device {device}")
state = th.load(file, map_location=device)
- # backward compatibility
+ # 1. Get keys of current model while ensuring backward compatibility
model = {}
+ allowed_keys = []
+ rand_state = net.state_dict()
+ for key, value in rand_state.items():
+ allowed_keys.append(key)
+
+ # 2. Overwrite with values from file
for key, value in state["model"].items():
# replace update_module with percept_gate_controller in key string:
key = key.replace("update_module", "percept_gate_controller")
- model[key.replace(".module.", ".")] = value
+ if key in allowed_keys:
+ model[key.replace(".module.", ".")] = value
+
net.load_state_dict(model)
# ???
diff --git a/scripts/utils/io.py b/scripts/utils/io.py
index 9bd8158..787c575 100644
--- a/scripts/utils/io.py
+++ b/scripts/utils/io.py
@@ -65,7 +65,7 @@ def model_path(cfg: Configuration, overwrite=False, move_old=True):
:param move_old: Moves old folder with the same name to an old folder, if not overwrite
:return: Model path
"""
- _path = os.path.join('out')
+ _path = os.path.join('out', cfg.dataset)
path = os.path.join(_path, cfg.model_path)
if not os.path.exists(_path):
@@ -95,14 +95,15 @@ def model_path(cfg: Configuration, overwrite=False, move_old=True):
class LossLogger:
- def __init__(self):
+ def __init__(self, writer):
self.avgloss = UEMA()
self.avg_position_loss = UEMA()
self.avg_time_loss = UEMA()
- self.avg_encoder_loss = UEMA()
- self.avg_mse_object_loss = UEMA()
- self.avg_long_mse_object_loss = UEMA(33333)
+ self.avg_latent_loss = UEMA()
+ self.avg_encoding_loss = UEMA()
+ self.avg_prediction_loss = UEMA()
+ self.avg_prediction_loss_long = UEMA(33333)
self.avg_num_objects = UEMA()
self.avg_openings = UEMA()
self.avg_gestalt = UEMA()
@@ -110,28 +111,73 @@ class LossLogger:
self.avg_gestalt_mean = UEMA()
self.avg_update_gestalt = UEMA()
self.avg_update_position = UEMA()
+ self.avg_num_bounded = UEMA()
+
+ self.writer = writer
- def update_complete(self, avg_position_loss, avg_time_loss, avg_encoder_loss, avg_mse_object_loss, avg_long_mse_object_loss, avg_num_objects, avg_openings, avg_gestalt, avg_gestalt2, avg_gestalt_mean, avg_update_gestalt, avg_update_position):
+ def update_complete(self, avg_position_loss, avg_time_loss, avg_latent_loss, avg_encoding_loss, avg_prediction_loss, avg_num_objects, avg_openings, avg_gestalt, avg_gestalt2, avg_gestalt_mean, avg_update_gestalt, avg_update_position, avg_num_bounded, lr, num_updates):
self.avg_position_loss.update(avg_position_loss.item())
self.avg_time_loss.update(avg_time_loss.item())
- self.avg_encoder_loss.update(avg_encoder_loss.item())
- self.avg_mse_object_loss.update(avg_mse_object_loss.item())
- self.avg_long_mse_object_loss.update(avg_long_mse_object_loss.item())
+ self.avg_latent_loss.update(avg_latent_loss.item())
+ self.avg_encoding_loss.update(avg_encoding_loss.item())
+ self.avg_prediction_loss.update(avg_prediction_loss.item())
+ self.avg_prediction_loss_long.update(avg_prediction_loss.item())
self.avg_num_objects.update(avg_num_objects)
self.avg_openings.update(avg_openings)
self.avg_gestalt.update(avg_gestalt.item())
self.avg_gestalt2.update(avg_gestalt2.item())
self.avg_gestalt_mean.update(avg_gestalt_mean.item())
- self.avg_update_gestalt.update(avg_update_gestalt.item())
- self.avg_update_position.update(avg_update_position.item())
+ self.avg_update_gestalt.update(avg_update_gestalt)
+ self.avg_update_position.update(avg_update_position)
+ self.avg_num_bounded.update(avg_num_bounded)
+
+ self.writer.add_scalar("Train/Position Loss", avg_position_loss.item(), num_updates)
+ self.writer.add_scalar("Train/Time Loss", avg_time_loss.item(), num_updates)
+ self.writer.add_scalar("Train/Latent Loss", avg_latent_loss.item(), num_updates)
+ self.writer.add_scalar("Train/Encoder Loss", avg_encoding_loss.item(), num_updates)
+ self.writer.add_scalar("Train/Prediction Loss", avg_prediction_loss.item(), num_updates)
+ self.writer.add_scalar("Train/Number of Objects", avg_num_objects, num_updates)
+ self.writer.add_scalar("Train/Openings", avg_openings, num_updates)
+ self.writer.add_scalar("Train/Gestalt", avg_gestalt.item(), num_updates)
+ self.writer.add_scalar("Train/Gestalt2", avg_gestalt2.item(), num_updates)
+ self.writer.add_scalar("Train/Gestalt Mean", avg_gestalt_mean.item(), num_updates)
+ self.writer.add_scalar("Train/Update Gestalt", avg_update_gestalt, num_updates)
+ self.writer.add_scalar("Train/Update Position", avg_update_position, num_updates)
+ self.writer.add_scalar("Train/Number Bounded", avg_num_bounded, num_updates)
+ self.writer.add_scalar("Train/Learning Rate", lr, num_updates)
+
pass
- def update_average_loss(self, avgloss):
+ def update_average_loss(self, avgloss, num_updates):
self.avgloss.update(avgloss)
+ self.writer.add_scalar("Train/Loss", avgloss, num_updates)
pass
def get_log(self):
- info = f'Loss: {np.abs(float(self.avgloss)):.2e}|{float(self.avg_mse_object_loss):.2e}|{float(self.avg_long_mse_object_loss):.2e}, reg: {float(self.avg_encoder_loss):.2e}|{float(self.avg_time_loss):.2e}|{float(self.avg_position_loss):.2e}, obj: {float(self.avg_num_objects):.1f}, open: {float(self.avg_openings):.2e}|{float(self.avg_gestalt):.2f}, bin: {float(self.avg_gestalt_mean):.2e}|{np.sqrt(float(self.avg_gestalt2) - float(self.avg_gestalt)**2):.2e} closed: {float(self.avg_update_gestalt):.2e}|{float(self.avg_update_position):.2e}'
+ info = f'Loss: {np.abs(float(self.avgloss)):.2e}|{float(self.avg_prediction_loss):.2e}|{float(self.avg_prediction_loss_long):.2e}, reg: {float(self.avg_encoding_loss):.2e}|{float(self.avg_time_loss):.2e}|{float(self.avg_latent_loss):.2e}|{float(self.avg_position_loss):.2e}, obj: {float(self.avg_num_objects):.1f}, open: {float(self.avg_openings):.2e}|{float(self.avg_gestalt):.2f}, bin: {float(self.avg_gestalt_mean):.2e}|{np.sqrt(float(self.avg_gestalt2) - float(self.avg_gestalt)**2):.2e} closed: {float(self.avg_update_gestalt):.2e}|{float(self.avg_update_position):.2e}'
return info
+
+class WriterWrapper():
+
+ def __init__(self, use_wandb: bool, cfg: Configuration):
+ if use_wandb:
+ from torch.utils.tensorboard import SummaryWriter
+ import wandb
+ wandb.init(project=f'Loci_Looped_{cfg.dataset}', name= cfg.model_path, sync_tensorboard=True, config=cfg)
+ self.writer = SummaryWriter()
+ else:
+ self.writer = None
+
+ def add_scalar(self, name, value, step):
+ if self.writer is not None:
+ self.writer.add_scalar(name, value, step)
+
+ def add_video(self, name, value, step):
+ if self.writer is not None:
+ self.writer.add_video(name, value, step)
+
+ def flush(self):
+ if self.writer is not None:
+ self.writer.flush()
diff --git a/scripts/utils/plot_utils.py b/scripts/utils/plot_utils.py
index 5cb20de..1c83456 100644
--- a/scripts/utils/plot_utils.py
+++ b/scripts/utils/plot_utils.py
@@ -151,9 +151,10 @@ def to_rgb(tensor: th.Tensor):
tensor
), dim=1)
-def visualise_gate(gate, h, w):
+def visualise_gate(gate, h, w, invert = False):
bar = th.ones((1,h,w), device=gate.device) * 0.9
black = int(w*gate.item())
+ black = w-black if invert else black
if black > 0:
bar[:,:, -black:] = 0
return bar
@@ -266,28 +267,30 @@ def plot_online_error(error, error_name, target, t, i, sequence_len, root_path,
return error_plot
-def plot_object_view(error_plot, error_plot2, error_plot_slots, error_plot_slots2, highlighted_input, output_hidden, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object):
+def plot_object_view(error_plot, error_plot2, error_plot_slots, error_plot_slots2, highlighted_input, output_hidden, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object, rollout_mode=False, openings=None):
# add ground truth positions of objects to image
+ target = target.clone()
if gt_positions_target_next is not None:
for o in range(gt_positions_target_next.shape[1]):
position = gt_positions_target_next[0, o]
position = position/2 + 0.5
- if position[2] > 0.0 and position[0] > 0.0 and position[0] < 1.0 and position[1] > 0.0 and position[1] < 1.0:
- width = 5
- w = np.clip(int(position[0]*target.shape[2]), width, target.shape[2]-width).item()
+ if (len(position.shape) < 3 or position[2] > 0.0) and position[0] > 0.0 and position[0] < 1.0 and position[1] > 0.0 and position[1] < 1.0:
+ width = int(target.shape[2]*0.05)
+ w = np.clip(int(position[0]*target.shape[2]), width, target.shape[2]-width).item() # made for bouncing balls
h = np.clip(int(position[1]*target.shape[3]), width, target.shape[3]-width).item()
col = get_color(o).view(3,1,1)
target[0,:,(w-width):(w+width), (h-width):(h+width)] = col
# add these positions to the associated slots velocity_next2d ilustration
- slots = (association_table[0] == o).nonzero()
- for s in slots.flatten():
- velocity_next2d[s,:,(w-width):(w+width), (h-width):(h+width)] = col
+ if association_table is not None:
+ slots = (association_table[0] == o).nonzero()
+ for s in slots.flatten():
+ velocity_next2d[s,:,(w-width):(w+width), (h-width):(h+width)] = col
- if output_hidden is not None and s != largest_object:
- output_hidden[0,:,(w-width):(w+width), (h-width):(h+width)] = col
+ if output_hidden is not None and s != largest_object:
+ output_hidden[0,:,(w-width):(w+width), (h-width):(h+width)] = col
gateheight = 60
ch = 40
@@ -296,22 +299,28 @@ def plot_object_view(error_plot, error_plot2, error_plot_slots, error_plot_slots
gh_margin = int((gh-gh_bar)/2)
margin = 20
slots_margin = 10
- height = size[0] * 6 + 18*5
+ height = size[0] * 6 + 18*6
width = size[1] * 4 + 18*2 + size[1]*num_objects + 6*(num_objects+1) + slots_margin*(num_objects+1)
img = th.ones((3, height, width), device = object_next.device) * 0.4
row = (lambda row_index: [2*size[0]*row_index + (row_index+1)*margin, 2*size[0]*(row_index+1) + (row_index+1)*margin])
col1 = range(margin, margin + size[1]*2)
col2 = range(width-(margin+size[1]*2), width-margin)
+ # add frame around image
+ if rollout_mode:
+ img[0,margin-2:margin+size[0]*2+2, margin-2:margin+size[1]*2+2] = 1
+
img[:,row(0)[0]:row(0)[1], col1] = preprocess(highlighted_input.to(object_next.device), 2)[0]
img[:,row(1)[0]:row(1)[1], col1] = preprocess(output_hidden.to(object_next.device), 2)[0]
img[:,row(2)[0]:row(2)[1], col1] = preprocess(target.to(object_next.device), 2)[0]
+ # add large error plots to image
if error_plot is not None:
img[:,row(0)[0]+gh+ch+2*margin-gh_margin:row(0)[1]+gh+ch+2*margin-gh_margin, col2] = preprocess(error_plot.to(object_next.device), normalize= True)
if error_plot2 is not None:
img[:,row(2)[0]:row(2)[1], col2] = preprocess(error_plot2.to(object_next.device), normalize= True)
+ # fill colunmns with slots
for o in range(num_objects):
col = 18+size[1]*2+6+o*(6+size[1])+(o+1)*slots_margin
@@ -321,23 +330,35 @@ def plot_object_view(error_plot, error_plot2, error_plot_slots, error_plot_slots
if (error_plot_slots2 is not None) and len(error_plot_slots2) > o:
img[:,margin:margin+ch, col] = get_color(o).view(3,1,1).to(object_next.device)
+ # gestalt gate
img[:,margin+ch+2*margin:2*margin+gh_bar+ch+margin, col] = visualise_gate(slots_closed[:,o, 0].to(object_next.device), h=gh_bar, w=len(col))
offset = gh+margin-gh_margin+ch+2*margin
row = (lambda row_index: [offset+(size[0]+6)*row_index, offset+size[0]*(row_index+1)+6*row_index])
img[:,row(0)[0]:row(0)[1], col] = preprocess(rawmask_next[0,o].to(object_next.device))
img[:,row(1)[0]:row(1)[1], col] = preprocess(object_next[:,o].to(object_next.device))
+
+ # small error plots top row
if (error_plot_slots2 is not None) and len(error_plot_slots2) > o:
img[:,row(2)[0]:row(2)[1], col] = preprocess(error_plot_slots2[o].to(object_next.device), normalize=True)
+ # switch to bottom row
offset = margin*2-8
row = (lambda row_index: [offset+(size[0]+6)*row_index, offset+size[0]*(row_index+1)+6*row_index])
+
+ # position gate
img[:,row(4)[0]-gh+gh_margin:row(4)[0]-gh_margin, col] = visualise_gate(slots_closed[:,o, 1].to(object_next.device), h=gh_bar, w=len(col))
img[:,row(4)[0]:row(4)[1], col] = preprocess(velocity_next2d[o].to(object_next.device), normalize=True)[0]
+
+ # small error plots bottom row
if (error_plot_slots is not None) and len(error_plot_slots) > o:
img[:,row(5)[0]:row(5)[1], col] = preprocess(error_plot_slots[o].to(object_next.device), normalize=True)
- img = rearrange(img * 255, 'c h w -> h w c').cpu().numpy()
+ # add gatelord gate visualisation to image
+ if openings is not None:
+ img[:,row(5)[1]+gh_margin:row(5)[1]+gh-gh_margin, col] = visualise_gate(openings[:,o].to(object_next.device), h=gh_bar, w=len(col), invert = True)
+
+ img = rearrange(img * 255, 'c h w -> h w c').cpu()
return img
@@ -347,8 +368,57 @@ def write_image(file, img):
pass
-def plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i):
-
+def extract_element(tensor, index):
+ if tensor is None:
+ return None
+ return tensor[index:index+1]
+
+def plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, sample_i, rollout_mode=False, num_vid=2, att=None, openings=None):
+
+ if len(input) > 1:
+ img_list = None
+ for i in range(min(len(input), num_vid)):
+ _input = extract_element(input, i)
+ _target = extract_element(target, i)
+ _mask_cur = extract_element(mask_cur, i)
+ _mask_next = extract_element(mask_next, i)
+ _output_next = extract_element(output_next, i)
+ _position_encoder_cur = extract_element(position_encoder_cur, i)
+ _position_next = extract_element(position_next, i)
+ _rawmask_hidden = extract_element(rawmask_hidden, i)
+ _rawmask_cur = extract_element(rawmask_cur, i)
+ _rawmask_next = extract_element(rawmask_next, i)
+ _largest_object = extract_element(largest_object, i)
+ _object_cur = extract_element(object_cur, i)
+ _object_next = extract_element(object_next, i)
+ _object_hidden = extract_element(object_hidden, i)
+ _slots_bounded = extract_element(slots_bounded, i)
+ _slots_partially_occluded_cur = extract_element(slots_partially_occluded_cur, i)
+ _slots_occluded_cur = extract_element(slots_occluded_cur, i)
+ _slots_partially_occluded_next = extract_element(slots_partially_occluded_next, i)
+ _slots_occluded_next = extract_element(slots_occluded_next, i)
+ _slots_closed = extract_element(slots_closed, i)
+ _gt_positions_target_next = extract_element(gt_positions_target_next, i)
+ _association_table = extract_element(association_table, i)
+ _error_next = extract_element(error_next, i)
+ _output_hidden = extract_element(output_hidden, i)
+ _att = extract_element(att, i)
+ _openings = extract_element(openings, i)
+
+ img = plot_timestep_single(cfg, cfg_net, _input, _target, _mask_cur, _mask_next, _output_next, _position_encoder_cur, _position_next, _rawmask_hidden, _rawmask_cur, _rawmask_next, _largest_object, _object_cur, _object_next, _object_hidden, _slots_bounded, _slots_partially_occluded_cur, _slots_occluded_cur, _slots_partially_occluded_next, _slots_occluded_next, _slots_closed, _gt_positions_target_next, _association_table, _error_next, _output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i, rollout_mode, att=_att, openings=_openings)
+ if img_list is None:
+ img_list = img.unsqueeze(0)
+ else:
+ img_list = th.cat((img_list, img.unsqueeze(0)), dim=0)
+
+ return img_list.permute(0, 3, 1, 2)
+
+ else:
+ img = plot_timestep_single(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, sample_i, rollout_mode, att=att, openings=openings)
+ return img
+
+def plot_timestep_single(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i, rollout_mode=False, att=None, openings=None):
+
# Create eposition helpers
size, gaus2d, vector2d, scale = get_position_helper(cfg_net, mask_cur.device)
@@ -364,7 +434,7 @@ def plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next,
velocity_next2d = color_slots(velocity_next2d, slots_bounded, slots_partially_occluded_next, slots_occluded_next)
# compute occlusion
- if (cfg.datatype == "adept"):
+ if (cfg.datatype == "adept") and rawmask_hidden is not None:
rawmask_cur_l, rawmask_next_l = compute_occlusion_mask(rawmask_cur, rawmask_next, mask_cur, mask_next, scale)
rawmask_cur_h, rawmask_next_h = compute_occlusion_mask(rawmask_cur, rawmask_hidden, mask_cur, mask_next, scale)
rawmask_cur_h[:,largest_object] = rawmask_cur_l[:,largest_object]
@@ -385,30 +455,39 @@ def plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next,
mask_next = rearrange(mask_next, 'b (o 1) h w -> b o 1 h w')
if object_view:
- if (cfg.datatype == "adept"):
+ if (cfg.datatype == "adept") and statistics_complete_slots is not None:
num_objects = 4
error_plot_slots = plot_online_error_slots(statistics_complete_slots['TE'][-cfg_net.num_objects*(t+1):], 'Tracking error', target, sequence_len, root_path, statistics_complete_slots['visible'][-cfg_net.num_objects*(t+1):], slots_bounded)
- error_plot_slots2 = plot_online_error_slots(statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], 'Image error', target, sequence_len, root_path, statistics_complete_slots['visible'][-cfg_net.num_objects*(t+1):], slots_bounded, ylim=0.0001)
+ #error_plot_slots2 = plot_online_error_slots(statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], 'Image error', target, sequence_len, root_path, statistics_complete_slots['visible'][-cfg_net.num_objects*(t+1):], slots_bounded, ylim=0.0001)
error_plot = plot_online_error(statistics_batch['image_error'], 'Prediction error', target, t, i, sequence_len, root_path)
error_plot2 = plot_online_error(statistics_batch['TE'], 'Tracking error', target, t, i, sequence_len, root_path)
- img = plot_object_view(error_plot, error_plot2, error_plot_slots, error_plot_slots2, highlighted_input, output_hidden, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object)
- else:
+ att_histogram = plot_attention_histogram(att, target, root_path)
+ img = plot_object_view(error_plot, error_plot2, error_plot_slots, att_histogram, highlighted_input, output_hidden, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object, openings=openings)
+ elif (cfg.datatype == "clevrer") and statistics_complete_slots is not None:
num_objects = cfg_net.num_objects
error_plot_slots2 = plot_online_error_slots(statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], 'Image error', target, sequence_len, root_path, statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], slots_bounded, ylim=0.0001)
error_plot = plot_online_error(statistics_batch['image_error_mse'], 'Prediction error', target, t, i, sequence_len, root_path)
- img = plot_object_view(error_plot, None, None, error_plot_slots2, highlighted_input, output_next, object_next, rawmask_next, velocity_next2d, target, slots_closed, None, None, size, num_objects, largest_object)
-
- cv2.imwrite(f'{plot_path}object/gpnet-objects-{i:04d}-{t_index:03d}.jpg', img)
+ img = plot_object_view(error_plot, None, None, error_plot_slots2, highlighted_input, output_next, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object, openings=openings)
+ else:
+ num_objects = cfg_net.num_objects
+ att_histogram = plot_attention_histogram(att, target, root_path)
+ img = plot_object_view(None, None, att_histogram, None, input, output, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object, rollout_mode=rollout_mode, openings=openings)
+
+ if plot_path is not None:
+ cv2.imwrite(f'{plot_path}/object/{i:04d}-{t_index:03d}.jpg', img.numpy())
if individual_views:
# ['error', 'input', 'background', 'prediction', 'position', 'rawmask', 'mask', 'othermask']:
write_image(f'{plot_path}/individual/error/error-{i:04d}-{t_index:03d}.jpg', error_next[0])
- write_image(f'{plot_path}/individual/input/input-{i:04d}-{t_index:03d}.jpg', input[0])
+ write_image(f'{plot_path}/individual/input/input-{i:04d}-{t_index:03d}.jpg', target[0])
write_image(f'{plot_path}/individual/background/background-{i:04d}-{t_index:03d}.jpg', mask_next[0,-1])
- write_image(f'{plot_path}/individual/imagination/imagination-{i:04d}-{t_index:03d}.jpg', output_hidden[0])
+ #write_image(f'{plot_path}/individual/imagination/imagination-{i:04d}-{t_index:03d}.jpg', output_hidden[0])
write_image(f'{plot_path}/individual/prediction/prediction-{i:04d}-{t_index:03d}.jpg', output_next[0])
+ for o in range(len(rawmask_next[0])):
+ write_image(f'{plot_path}/individual/rgb/object-{i:04d}-{o}-{t_index:03d}.jpg', object_next[0][o])
+ write_image(f'{plot_path}/individual/rawmask/rawmask-{i:04d}-{o}-{t_index:03d}.jpg', rawmask_next[0][o])
- pass
+ return img
def get_position_helper(cfg_net, device):
size = cfg_net.input_size
@@ -425,4 +504,35 @@ def reshape_slots(slots_bounded, slots_partially_occluded_cur, slots_occluded_cu
slots_partially_occluded_next = th.squeeze(slots_partially_occluded_next)[..., None,None,None]
slots_occluded_next = th.squeeze(slots_occluded_next)[..., None,None,None]
- return slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next \ No newline at end of file
+ return slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next
+
+def plot_attention_histogram(att, target, root_path):
+ att_plots = []
+ if (att is not None) and (len(att) > 0):
+ att = att[0]
+ for object_attention in att:
+
+ fig, ax = plt.subplots(figsize=(round(target.shape[3]/100,2), round(target.shape[2]/100,2)))
+
+ # Plot a bar plot over the 6 objects
+ num_objects = len(object_attention)
+ ax.bar(range(num_objects), object_attention.cpu())
+ ax.set_ylim([0,1])
+ ax.set_xlim([-1,num_objects])
+ ax.set_xticks(range(num_objects))
+ #ax.set_xticklabels(['1','2','3','4','5','6'])
+ ax.set_ylabel('attention')
+ ax.set_xlabel('object')
+ ax.set_title('Attention histogram')
+
+ # fixed
+ fig.tight_layout()
+ plt.savefig(f'{root_path}/tmp.jpg')
+ plot = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb())
+ plot = th.from_numpy(np.array(plot).transpose(2,0,1))
+ plt.close(fig)
+ att_plots.append(plot)
+
+ return att_plots
+ else:
+ return None \ No newline at end of file
diff --git a/scripts/validation.py b/scripts/validation.py
index 5b1638d..b59cc39 100644
--- a/scripts/validation.py
+++ b/scripts/validation.py
@@ -1,8 +1,11 @@
+import os
import torch as th
from torch import nn
from torch.utils.data import DataLoader
import numpy as np
from einops import rearrange, repeat, reduce
+from scripts.evaluation_bb import distance_eval_step
+from scripts.evaluation_clevrer import compute_statistics_summary
from scripts.utils.configuration import Configuration
from model.loci import Loci
import time
@@ -10,12 +13,19 @@ import lpips
from skimage.metrics import structural_similarity as ssimloss
from skimage.metrics import peak_signal_noise_ratio as psnrloss
-def validation_adept(valloader: DataLoader, net: Loci, cfg: Configuration, device):
+from scripts.utils.eval_metrics import masks_to_boxes, postproc_mask, pred_eval_step
+from scripts.utils.eval_utils import append_statistics, compute_position_from_mask
+from scripts.utils.plot_utils import plot_timestep
+
+def validation_adept(valloader: DataLoader, net: Loci, cfg: Configuration, device, writer, epoch, root_path):
# memory
mseloss = nn.MSELoss()
- avgloss = 0
+ loss_next = 0
start_time = time.time()
+ cfg_net = cfg.model
+ num_steps = 0
+ plot_path = os.path.join(root_path, 'plots', f'epoch_{epoch}')
with th.no_grad():
for i, input in enumerate(valloader):
@@ -103,7 +113,8 @@ def validation_adept(valloader: DataLoader, net: Loci, cfg: Configuration, devic
# 1. Track error
if t >= 0:
loss = mseloss(output_next, target)
- avgloss += loss.item()
+ loss_next += loss.item()
+ num_steps += 1
# 2. Remember output
mask_last = mask_next.clone()
@@ -122,12 +133,19 @@ def validation_adept(valloader: DataLoader, net: Loci, cfg: Configuration, devic
error_next = th.sqrt(error_next) * bg_error_next
error_last = error_next.clone()
- print(f"Validation loss: {avgloss / len(valloader.dataset):.2e}, Time: {time.time() - start_time}")
-
+ # PLotting
+ if (i == 0) and (t_index % 2 == 0) and (epoch % 3 == 0):
+ os.makedirs(os.path.join(plot_path, 'object'), exist_ok=True)
+ openings = net.get_openings()
+ img_tensor = plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, None, None, error_next, None, True, False, None, None, sequence_len, root_path, plot_path, t_index, t, i, openings=openings)
+
+ print(f"Validation loss: {loss_next / num_steps:.2e}, Time: {time.time() - start_time}")
+ writer.add_scalar('Val/Prediction Loss', loss_next / num_steps, epoch)
+
pass
-def validation_clevrer(valloader: DataLoader, net: Loci, cfg: Configuration, device):
+def validation_clevrer(valloader: DataLoader, net: Loci, cfg: Configuration, device, writer, epoch, root_path):
# memory
mseloss = nn.MSELoss()
@@ -140,6 +158,9 @@ def validation_clevrer(valloader: DataLoader, net: Loci, cfg: Configuration, dev
burn_in_length = 6
rollout_length = 42
+
+ plot_path = os.path.join(root_path, 'plots', f'epoch_{epoch}')
+ os.makedirs(os.path.join(plot_path, 'object'), exist_ok=True)
with th.no_grad():
for i, input in enumerate(valloader):
@@ -256,6 +277,243 @@ def validation_clevrer(valloader: DataLoader, net: Loci, cfg: Configuration, dev
error_next = th.sqrt(error_next) * bg_error_next
error_last = error_next.clone()
+ # PLotting
+ if i == 0:
+ openings = net.get_openings()
+ img_tensor = plot_timestep(cfg, cfg.model, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, None, None, error_next, None, True, False, None, None, sequence_len, root_path, plot_path, t_index, t, i, openings=openings)
+
+
print(f"MSE loss: {avgloss_mse / len(valloader.dataset):.2e}, LPIPS loss: {avgloss_lpips / len(valloader.dataset):.2e}, PSNR loss: {avgloss_psnr / len(valloader.dataset):.2e}, SSIM loss: {avgloss_ssim / len(valloader.dataset):.2e}, Time: {time.time() - start_time}")
+ writer.add_scalar('Val/Prediction Loss', avgloss_mse / len(valloader.dataset), epoch)
+ writer.add_scalar('Val/LPIPS Loss', avgloss_lpips / len(valloader.dataset), epoch)
+ writer.add_scalar('Val/PSNR Loss', avgloss_psnr / len(valloader.dataset), epoch)
+ writer.add_scalar('Val/SSIM Loss', avgloss_ssim / len(valloader.dataset), epoch)
+
+ pass
+
+
+def validation_bb(valloader: DataLoader, net: Loci, cfg: Configuration, device, writer, epoch, root_path):
+
+ # memory
+ start_time = time.time()
+ net.eval()
+ evaluation_mode = 'vidpred_black'
+ use_meds = True
+
+ # Evaluation Specifics
+ burn_in_length = 10
+ rollout_length = 20
+ rollout_length_stats = 10 # only consider the first 10 frames for statistics
+ target_size = (64, 64)
+
+ # Losses
+ lpipsloss = lpips.LPIPS(net='vgg').to(device)
+ mseloss = nn.MSELoss()
+ metric_complete = {'mse': [], 'ssim': [], 'psnr': [], 'percept_dist': [], 'ari': [], 'fari': [], 'miou': [], 'ap': [], 'ar': [], 'meds': [], 'ari_hidden': [], 'fari_hidden': [], 'miou_hidden': []}
+ loss_next = 0.0
+ loss_cur = 0.0
+ num_steps = 0
+ plot_path = os.path.join(root_path, 'plots', f'epoch_{epoch}')
+ os.makedirs(os.path.join(plot_path, 'object'), exist_ok=True)
+
+ with th.no_grad():
+ for i, input in enumerate(valloader):
+
+ # Load data
+ tensor = input[0].float().to(device)
+ background_fix = input[1].to(device)
+ gt_pos = input[2].to(device)
+ gt_mask = input[3].to(device)
+ gt_pres_mask = input[4].to(device)
+ gt_hidden_mask = input[5].to(device)
+ sequence_len = tensor.shape[1]
+
+ # placehodlers
+ mask_cur = None
+ mask_last = None
+ rawmask_last = None
+ position_last = None
+ gestalt_last = None
+ priority_last = None
+ gt_positions_target = None
+ slots_occlusionfactor = None
+ error_last = None
+
+ # Memory
+ cfg_net = cfg.model
+ num_objects_bb = gt_pos.shape[2]
+ pred_pos_batch = th.zeros((cfg_net.batch_size, rollout_length, num_objects_bb, 2)).to(device)
+ gt_pos_batch = th.zeros((cfg_net.batch_size, rollout_length, num_objects_bb, 2)).to(device)
+ pred_img_batch = th.zeros((cfg_net.batch_size, rollout_length, 3, target_size[0], target_size[1])).to(device)
+ gt_img_batch = th.zeros((cfg_net.batch_size, rollout_length, 3, target_size[0], target_size[1])).to(device)
+ pred_mask_batch = th.zeros((cfg_net.batch_size, rollout_length, target_size[0], target_size[1])).to(device)
+ pred_hidden_mask_batch = th.zeros((cfg_net.batch_size, rollout_length, target_size[0], target_size[1])).to(device)
+ # Counters
+ num_rollout = 0
+ num_burnin = 0
+
+ # loop through frames
+ for t_index,t in enumerate(range(-cfg.defaults.teacher_forcing, burn_in_length+rollout_length)):
+
+ # Move to next frame
+ t_run = max(t, 0)
+ input = tensor[:,t_run]
+ target_cur = tensor[:,t_run]
+ target = th.clip(tensor[:,t_run+1], 0, 1)
+ gt_pos_t = gt_pos[:,t_run+1]/32-1
+ gt_pos_t = th.concat((gt_pos_t, th.ones_like(gt_pos_t[:,:,:1])), dim=2)
+
+ rollout_index = t_run - burn_in_length
+ rollout_active = False
+ if t>=0:
+ if rollout_index >= 0:
+ num_rollout += 1
+ if (evaluation_mode == 'vidpred_black'):
+ input = output_next * 0
+ error_last = error_last * 0
+ rollout_active = True
+ elif (evaluation_mode == 'vidpred_auto'):
+ input = output_next
+ error_last = error_last * 0
+ rollout_active = True
+ else:
+ num_burnin += 1
+
+ # obtain prediction
+ (
+ output_next,
+ position_next,
+ gestalt_next,
+ priority_next,
+ mask_next,
+ rawmask_next,
+ object_next,
+ background,
+ slots_occlusionfactor,
+ output_cur,
+ position_cur,
+ gestalt_cur,
+ priority_cur,
+ mask_cur,
+ rawmask_cur,
+ object_cur,
+ position_encoder_cur,
+ slots_bounded,
+ slots_partially_occluded_cur,
+ slots_occluded_cur,
+ slots_partially_occluded_next,
+ slots_occluded_next,
+ slots_closed,
+ output_hidden,
+ largest_object,
+ rawmask_hidden,
+ object_hidden
+ ) = net(
+ input,
+ error_last,
+ mask_last,
+ rawmask_last,
+ position_last,
+ gestalt_last,
+ priority_last,
+ background_fix,
+ slots_occlusionfactor,
+ reset = (t == -cfg.defaults.teacher_forcing),
+ evaluate=True,
+ warmup = (t < 0),
+ shuffleslots = True,
+ reset_mask = (t <= 0),
+ allow_spawn = True,
+ show_hidden = False,
+ clean_slots = False,
+ )
+
+ # 1. Track error
+ if t >= 0:
+
+ if (rollout_index >= 0):
+ # store positions per batch
+ if use_meds:
+ if False:
+ pred_pos_batch[:,rollout_index] = rearrange(position_next, 'b (o c) -> b o c', o=cfg_net.num_objects)[:,:,:2]
+ else:
+ pred_pos_batch[:,rollout_index] = compute_position_from_mask(rawmask_next)
+
+ gt_pos_batch[:,rollout_index] = gt_pos_t[:,:,:2]
+
+ pred_img_batch[:,rollout_index] = output_next
+ gt_img_batch[:,rollout_index] = target
+
+ # Here we compute only the foreground segmentation mask
+ pred_mask_batch[:,rollout_index] = postproc_mask(mask_next[:,None,:,None])[:, 0]
+
+ # Here we compute the hidden segmentation
+ occluded_cur = th.clip(rawmask_next - mask_next, 0, 1)[:,:-1]
+ occluded_sum_cur = 1-(reduce(occluded_cur, 'b c h w -> b h w', 'max') > 0.5).float()
+ occluded_cur = th.cat((occluded_cur, occluded_sum_cur[:,None]), dim=1)
+ pred_hidden_mask_batch[:,rollout_index] = postproc_mask(occluded_cur[:,None,:,None])[:, 0]
+
+ # 2. Remember output
+ mask_last = mask_next.clone()
+ rawmask_last = rawmask_next.clone()
+ position_last = position_next.clone()
+ gestalt_last = gestalt_next.clone()
+ priority_last = priority_next.clone()
+
+ # 3. Error for next frame
+ # background error
+ bg_error_cur = th.sqrt(reduce((input - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+ bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+
+ # prediction error
+ error_next = th.sqrt(reduce((target - output_next)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+ error_next = th.sqrt(error_next) * bg_error_next
+ error_last = error_next.clone()
+
+ # Prediction and encoder Loss
+ loss_next += mseloss(output_next * bg_error_next, target * bg_error_next)
+ loss_cur += mseloss(output_cur * bg_error_cur, input * bg_error_cur)
+ num_steps += 1
+
+ # PLotting
+ if i == 0:
+ openings = net.get_openings()
+ img_tensor = plot_timestep(cfg, cfg_net, input, target_cur, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_pos_t, None, error_next, None, True, False, None, None, sequence_len, root_path, plot_path, t_index, t, i, rollout_mode=rollout_active, openings=openings)
+
+ for b in range(cfg_net.batch_size):
+
+ # perceptual similarity from slotformer paper
+ metric_dict = pred_eval_step(
+ gt = gt_img_batch[b:b+1],
+ pred = pred_img_batch[b:b+1],
+ pred_mask = pred_mask_batch.long()[b:b+1],
+ pred_mask_hidden = pred_hidden_mask_batch.long()[b:b+1],
+ pred_bbox = None,
+ gt_mask = gt_mask.long()[b:b+1, burn_in_length+1:burn_in_length+rollout_length+1],
+ gt_mask_hidden = gt_hidden_mask.long()[b:b+1, burn_in_length+1:burn_in_length+rollout_length+1],
+ gt_pres_mask = gt_pres_mask[b:b+1, burn_in_length+1:burn_in_length+rollout_length+1],
+ gt_bbox = None,
+ lpips_fn = lpipsloss,
+ eval_traj = True,
+ )
+
+ metric_dict['meds'] = distance_eval_step(gt_pos_batch[b], pred_pos_batch[b])
+ metric_complete = append_statistics(metric_dict, metric_complete)
+
+ # sanity check
+ if (num_rollout != rollout_length) and (num_burnin != burn_in_length):
+ raise ValueError('Number of rollout steps and burnin steps must be equal to the sequence length.')
+
+ dic = compute_statistics_summary(metric_complete, evaluation_mode, consider_first_n_frames=rollout_length_stats)
+ writer.add_scalar('Val/Meds', dic['meds_complete_sum'], epoch)
+
+ writer.add_scalar('Val/ARI_hidden', dic['ari_hidden_complete_average'], epoch)
+ writer.add_scalar('Val/ARI', dic['ari_complete_average'], epoch)
+ writer.add_scalar('Val/LPIPS', dic['percept_dist_complete_average'], epoch)
+
+ writer.add_scalar('Val/Prediction Loss', loss_next / num_steps, epoch)
+ writer.add_scalar('Val/Encoding Loss', loss_cur / num_steps, epoch)
+
+ net.train()
pass \ No newline at end of file