aboutsummaryrefslogtreecommitdiff
path: root/scripts/utils
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/utils')
-rw-r--r--scripts/utils/eval_adept.py169
-rw-r--r--scripts/utils/eval_metrics.py51
-rw-r--r--scripts/utils/eval_utils.py92
-rw-r--r--scripts/utils/io.py72
-rw-r--r--scripts/utils/plot_utils.py162
5 files changed, 495 insertions, 51 deletions
diff --git a/scripts/utils/eval_adept.py b/scripts/utils/eval_adept.py
new file mode 100644
index 0000000..7b8dfa8
--- /dev/null
+++ b/scripts/utils/eval_adept.py
@@ -0,0 +1,169 @@
+import pandas as pd
+import warnings
+import os
+import argparse
+
+warnings.simplefilter(action='ignore', category=FutureWarning)
+pd.options.mode.chained_assignment = None
+
+def eval_adept(path):
+ net = 'net1'
+
+ # read pickle file
+ tf = pd.DataFrame()
+ sf = pd.DataFrame()
+ af = pd.DataFrame()
+
+ with open(os.path.join(path, 'trialframe.csv'), 'rb') as f:
+ tf_temp = pd.read_csv(f, index_col=0)
+ tf_temp['net'] = net
+ tf = pd.concat([tf,tf_temp])
+
+ with open(os.path.join(path, 'slotframe.csv'), 'rb') as f:
+ sf_temp = pd.read_csv(f, index_col=0)
+ sf_temp['net'] = net
+ sf = pd.concat([sf,sf_temp])
+
+ with open(os.path.join(path, 'accframe.csv'), 'rb') as f:
+ af_temp = pd.read_csv(f, index_col=0)
+ af_temp['net'] = net
+ af = pd.concat([af,af_temp])
+
+ # cast variables
+ sf['visible'] = sf['visible'].astype(bool)
+ sf['bound'] = sf['bound'].astype(bool)
+ sf['occluder'] = sf['occluder'].astype(bool)
+ sf['inimage'] = sf['inimage'].astype(bool)
+ sf['vanishing'] = sf['vanishing'].astype(bool)
+ sf['alpha_pos'] = 1-sf['alpha_pos']
+ sf['alpha_ges'] = 1-sf['alpha_ges']
+
+ # scale to percentage
+ sf['TE'] = sf['TE'] * 100
+
+ # add surprise as dummy code
+ tf['control'] = [('control' in set) for set in tf['set']]
+ sf['control'] = [('control' in set) for set in sf['set']]
+
+
+
+ # STATS:
+ tracking_error_visible = 0
+ tracking_error_occluded = 0
+ num_positive_trackings = 0
+ mota = 0
+ gate_openings_visible = 0
+ gate_openings_occluded = 0
+
+
+ print('Tracking Error ------------------------------')
+ grouping = (sf.inimage & sf.bound & ~sf.occluder & sf.control)
+
+ def get_stats(col):
+ return f' M: {col.mean():.3} , STD: {col.std():.3}, Count: {col.count()}'
+
+ # When Visible
+ temp = sf[grouping & sf.visible]
+ print(f'Tracking Error when visible:' + get_stats(temp['TE']))
+ tracking_error_visible = temp['TE'].mean()
+
+ # When Occluded
+ temp = sf[grouping & ~sf.visible]
+ print(f'Tracking Error when occluded:' + get_stats(temp['TE']))
+ tracking_error_occluded = temp['TE'].mean()
+
+
+
+
+
+
+ print('Positive Trackings ------------------------------')
+ # succesfull trackings: In the last visible moment of the target, the slot was less than 10% away from the target
+ # determine last visible frame numeric
+ grouping_factors = ['net','set','evalmode','scene','slot']
+ ff = sf[sf.visible & sf.bound & sf.inimage].groupby(grouping_factors).max()
+ ff.rename(columns = {'frame':'last_visible'}, inplace = True)
+ sf = sf.merge(ff[['last_visible']], on=grouping_factors, how='left')
+
+ # same for first bound frame
+ ff = sf[sf.visible & sf.bound & sf.inimage].groupby(grouping_factors).min()
+ ff.rename(columns = {'frame':'first_visible'}, inplace = True)
+ sf = sf.merge(ff[['first_visible']], on=grouping_factors, how='left')
+
+ # add dummy variable to sf
+ sf['last_visible'] = (sf['last_visible'] == sf['frame'])
+
+ # extract the trials where the target was last visible and threshold the TE
+ ff = sf[sf['last_visible']]
+ ff['tracked_pos'] = (ff['TE'] < 10)
+ ff['tracked_neg'] = (ff['TE'] >= 10)
+
+ # fill NaN with 0
+ sf = sf.merge(ff[grouping_factors + ['tracked_pos', 'tracked_neg']], on=grouping_factors, how='left')
+ sf['tracked_pos'].fillna(False, inplace=True)
+ sf['tracked_neg'].fillna(False, inplace=True)
+
+ # Aggreagte over all scenes
+ temp = sf[(sf['frame']== 1) & ~sf.occluder & sf.control & (sf.first_visible < 20)]
+ temp = temp.groupby(['set', 'evalmode']).sum()
+ temp = temp[['tracked_pos', 'tracked_neg']]
+ temp = temp.reset_index()
+
+ temp['tracked_pos_pro'] = temp['tracked_pos'] / (temp['tracked_pos'] + temp['tracked_neg'])
+ temp['tracked_neg_pro'] = temp['tracked_neg'] / (temp['tracked_pos'] + temp['tracked_neg'])
+ print(temp)
+ num_positive_trackings = temp['tracked_pos_pro']
+
+
+
+
+
+ print('Mostly Trecked /MOTA ------------------------------')
+ temp = af[af.index == 'OVERALL']
+ temp['mostly_tracked'] = temp['mostly_tracked'] / temp['num_unique_objects']
+ temp['partially_tracked'] = temp['partially_tracked'] / temp['num_unique_objects']
+ temp['mostly_lost'] = temp['mostly_lost'] / temp['num_unique_objects']
+ print(temp)
+ mota = temp['mota']
+
+
+ print('Openings ------------------------------')
+ grouping = (sf.inimage & sf.bound & ~sf.occluder & sf.control)
+ temp = sf[grouping & sf.visible]
+ print(f'Percept gate openings when visible:' + get_stats(temp['alpha_pos'] + temp['alpha_ges']))
+ gate_openings_visible = temp['alpha_pos'].mean() + temp['alpha_ges'].mean()
+
+ temp = sf[grouping & ~sf.visible]
+ print(f'Percept gate openings when occluded:' + get_stats(temp['alpha_pos'] + temp['alpha_ges']))
+ gate_openings_occluded = temp['alpha_pos'].mean() + temp['alpha_ges'].mean()
+
+
+ print('------------------------------------------------')
+ print('------------------------------------------------')
+ str = ''
+ str += f'net: {net}\n'
+ str += f'Tracking Error when visible: {tracking_error_visible:.3}\n'
+ str += f'Tracking Error when occluded: {tracking_error_occluded:.3}\n'
+ str += 'Positive Trackings: ' + ', '.join(f'{val:.3}' for val in num_positive_trackings) + '\n'
+ str += 'MOTA: ' + ', '.join(f'{val:.3}' for val in mota) + '\n'
+ str += f'Percept gate openings when visible: {gate_openings_visible:.3}\n'
+ str += f'Percept gate openings when occluded: {gate_openings_occluded:.3}\n'
+
+ print(str)
+
+ # write tstring to file
+ with open(os.path.join(path, 'results.txt'), 'w') as f:
+ f.write(str)
+
+
+
+if __name__ == "__main__":
+
+ # use argparse to get the path to the results folder
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--path', type=str, default='')
+ args = parser.parse_args()
+
+ # setting path to results folder
+ path = args.path
+ eval_adept(path) \ No newline at end of file
diff --git a/scripts/utils/eval_metrics.py b/scripts/utils/eval_metrics.py
index 6f5106d..98af468 100644
--- a/scripts/utils/eval_metrics.py
+++ b/scripts/utils/eval_metrics.py
@@ -3,6 +3,8 @@ vp_utils.py
SCRIPT TAKEN FROM https://github.com/pairlab/SlotFormer
'''
+import cv2
+from einops import rearrange
import numpy as np
from scipy.optimize import linear_sum_assignment
from skimage.metrics import structural_similarity, peak_signal_noise_ratio
@@ -11,7 +13,7 @@ import torch
import torch.nn.functional as F
import torchvision.ops as vops
-FG_THRE = 0.5
+FG_THRE = 0.3 # 0.5 for the rest, 0.3 for bouncingballs
PALETTE = [(0, 255, 0), (0, 0, 255), (0, 255, 255), (255, 255, 0),
(255, 0, 255), (100, 100, 255), (200, 200, 100), (170, 120, 200),
(255, 0, 0), (200, 100, 100), (10, 200, 100), (200, 200, 200),
@@ -272,6 +274,8 @@ def pred_eval_step(
gt_bbox=None,
pred_bbox=None,
eval_traj=True,
+ gt_mask_hidden=None,
+ pred_mask_hidden=None,
):
"""Both of shape [B, T, C, H, W], torch.Tensor.
masks of shape [B, T, H, W].
@@ -288,13 +292,14 @@ def pred_eval_step(
if eval_traj:
assert len(gt_mask.shape) == len(pred_mask.shape) == 4
assert gt_mask.shape == pred_mask.shape
- if eval_traj:
+ if eval_traj and gt_bbox is not None:
assert len(gt_pres_mask.shape) == 3
assert len(gt_bbox.shape) == len(pred_bbox.shape) == 4
T = gt.shape[1]
# compute perceptual dist & mask metrics before converting to numpy
all_percept_dist, all_ari, all_fari, all_miou = [], [], [], []
+ all_ari_hidden, all_fari_hidden, all_miou_hidden = [], [], []
for t in range(T):
one_gt, one_pred = gt[:, t], pred[:, t]
percept_dist = perceptual_dist(one_gt, one_pred, lpips_fn).item()
@@ -307,6 +312,29 @@ def pred_eval_step(
all_ari.append(ari)
all_fari.append(fari)
all_miou.append(miou)
+
+ # hidden:
+ if gt_mask_hidden is not None:
+ one_gt_mask, one_pred_mask = gt_mask_hidden[:, t], pred_mask_hidden[:, t]
+ ari = ARI_metric(one_gt_mask, one_pred_mask)
+ fari = fARI_metric(one_gt_mask, one_pred_mask)
+ miou = miou_metric(one_gt_mask, one_pred_mask)
+ all_ari_hidden.append(ari)
+ all_fari_hidden.append(fari)
+ all_miou_hidden.append(miou)
+
+
+ # dispalay masks with cv2
+ if False:
+ one_gt_mask_, one_pred_mask_ = one_gt_mask.clone(), one_pred_mask.clone()
+
+ # concat one_gt_mask and one_pred_mask horizontally
+ frame = np.concatenate((one_gt_mask_, one_pred_mask_), axis=1)
+ frame = rearrange(frame, 'c h w -> h w c') * (255/6)
+ frame = frame.astype(np.uint8)
+ cv2.imshow('frame', frame)
+ cv2.waitKey(0)
+
else:
all_ari.append(0.)
all_fari.append(0.)
@@ -315,14 +343,14 @@ def pred_eval_step(
# compute bbox metrics
all_ap, all_ar = [], []
for t in range(T):
- if not eval_traj:
+ if not eval_traj or gt_bbox is None:
all_ap.append(0.)
all_ar.append(0.)
continue
one_gt_pres_mask, one_gt_bbox, one_pred_bbox = \
gt_pres_mask[:, t], gt_bbox[:, t], pred_bbox[:, t]
ap, ar = batch_bbox_precision_recall(one_gt_pres_mask, one_gt_bbox,
- one_pred_bbox)
+ one_pred_bbox)
all_ap.append(ap)
all_ar.append(ar)
@@ -333,11 +361,13 @@ def pred_eval_step(
one_gt, one_pred = gt[:, t], pred[:, t]
mse = mse_metric(one_gt, one_pred)
psnr = psnr_metric(one_gt, one_pred)
- ssim = ssim_metric(one_gt, one_pred)
+ #ssim = ssim_metric(one_gt, one_pred)
+ ssim = 0
all_mse.append(mse)
all_ssim.append(ssim)
all_psnr.append(psnr)
- return {
+
+ res = {
'mse': all_mse,
'ssim': all_ssim,
'psnr': all_psnr,
@@ -346,5 +376,12 @@ def pred_eval_step(
'fari': all_fari,
'miou': all_miou,
'ap': all_ap,
- 'ar': all_ar,
+ 'ar': all_ar
}
+
+ if gt_mask_hidden is not None:
+ res['ari_hidden'] = all_ari_hidden
+ res['fari_hidden'] = all_fari_hidden
+ res['miou_hidden'] = all_miou_hidden
+
+ return res
diff --git a/scripts/utils/eval_utils.py b/scripts/utils/eval_utils.py
index faab7ec..a01ffd0 100644
--- a/scripts/utils/eval_utils.py
+++ b/scripts/utils/eval_utils.py
@@ -1,18 +1,92 @@
import os
+import shutil
+from einops import rearrange
import torch as th
from model.loci import Loci
+def masks_to_boxes(masks: th.Tensor) -> th.Tensor:
+ """
+ Compute the bounding boxes around the provided masks.
+
+ Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with
+ ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
+
+ Args:
+ masks (Tensor[N, H, W]): masks to transform where N is the number of masks
+ and (H, W) are the spatial dimensions.
+
+ Returns:
+ Tensor[N, 4]: bounding boxes
+ """
+ if masks.numel() == 0:
+ return th.zeros((0, 4), device=masks.device, dtype=th.float)
+
+ n = masks.shape[0]
+
+ bounding_boxes = th.zeros((n, 4), device=masks.device, dtype=th.float)
+
+ for index, mask in enumerate(masks):
+ if mask.sum() > 0:
+ y, x = th.where(mask != 0)
+
+ bounding_boxes[index, 0] = th.min(x)
+ bounding_boxes[index, 1] = th.min(y)
+ bounding_boxes[index, 2] = th.max(x)
+ bounding_boxes[index, 3] = th.max(y)
+
+ return bounding_boxes
+
+def boxes_to_centroids(boxes):
+ """Post-process masks instead of directly taking argmax.
+
+ Args:
+ bboxes: [B, T, N, 4], 4: [x1, y1, x2, y2]
+
+ Returns:
+ centroids: [B, T, N, 2], 2: [x, y]
+ """
+
+ centroids = (boxes[:, :, :, :2] + boxes[:, :, :, 2:]) / 2
+ centroids = centroids.squeeze(0)
+
+ # scale to [-1, 1]
+ centroids[:, :, 0] = centroids[:, :, 0] / 64 * 2 - 1
+ centroids[:, :, 1] = centroids[:, :, 1] / 64 * 2 - 1
+
+ return centroids
+
+def compute_position_from_mask(mask):
+ """
+ Compute the position of the object from the mask.
+
+ Args:
+ mask (Tensor[B, N, H, W]): masks to transform where N is the number of masks
+ and (H, W) are the spatial dimensions.
+
+ Returns:
+ Tensor[B, N, 2]: position of the object
+
+ """
+ masks_binary = (mask > 0.8).float()[:, :-1]
+ b, o, h, w = masks_binary.shape
+ masks2 = rearrange(masks_binary, 'b o h w -> (b o) h w')
+ boxes = masks_to_boxes(masks2.long())
+ boxes = rearrange(boxes, '(b o) c -> b 1 o c', b=b, o=o)
+ centroids = boxes_to_centroids(boxes)
+ centroids = centroids[:, :, :, [1, 0]].squeeze(1)
+ return centroids
+
def setup_result_folders(file, name, set_test, evaluation_mode, object_view, individual_views):
net_name = file.split('/')[-1].split('.')[0]
#root_path = file.split('nets')[0]
- root_path = os.path.join(*file.split('/')[0:-1])
+ root_path = os.path.join(*file.split('/')[0:-2])
root_path = os.path.join(root_path, f'results{name}', net_name, set_test['type'])
plot_path = os.path.join(root_path, evaluation_mode)
# create directories
- #if os.path.exists(plot_path):
- # shutil.rmtree(plot_path)
+ if os.path.exists(plot_path):
+ shutil.rmtree(plot_path)
os.makedirs(plot_path, exist_ok = True)
if object_view:
os.makedirs(os.path.join(plot_path, 'object'), exist_ok = True)
@@ -59,13 +133,21 @@ def load_model(cfg, cfg_net, file, device):
print(f"load {file} to device {device}")
state = th.load(file, map_location=device)
- # backward compatibility
+ # 1. Get keys of current model while ensuring backward compatibility
model = {}
+ allowed_keys = []
+ rand_state = net.state_dict()
+ for key, value in rand_state.items():
+ allowed_keys.append(key)
+
+ # 2. Overwrite with values from file
for key, value in state["model"].items():
# replace update_module with percept_gate_controller in key string:
key = key.replace("update_module", "percept_gate_controller")
- model[key.replace(".module.", ".")] = value
+ if key in allowed_keys:
+ model[key.replace(".module.", ".")] = value
+
net.load_state_dict(model)
# ???
diff --git a/scripts/utils/io.py b/scripts/utils/io.py
index 9bd8158..787c575 100644
--- a/scripts/utils/io.py
+++ b/scripts/utils/io.py
@@ -65,7 +65,7 @@ def model_path(cfg: Configuration, overwrite=False, move_old=True):
:param move_old: Moves old folder with the same name to an old folder, if not overwrite
:return: Model path
"""
- _path = os.path.join('out')
+ _path = os.path.join('out', cfg.dataset)
path = os.path.join(_path, cfg.model_path)
if not os.path.exists(_path):
@@ -95,14 +95,15 @@ def model_path(cfg: Configuration, overwrite=False, move_old=True):
class LossLogger:
- def __init__(self):
+ def __init__(self, writer):
self.avgloss = UEMA()
self.avg_position_loss = UEMA()
self.avg_time_loss = UEMA()
- self.avg_encoder_loss = UEMA()
- self.avg_mse_object_loss = UEMA()
- self.avg_long_mse_object_loss = UEMA(33333)
+ self.avg_latent_loss = UEMA()
+ self.avg_encoding_loss = UEMA()
+ self.avg_prediction_loss = UEMA()
+ self.avg_prediction_loss_long = UEMA(33333)
self.avg_num_objects = UEMA()
self.avg_openings = UEMA()
self.avg_gestalt = UEMA()
@@ -110,28 +111,73 @@ class LossLogger:
self.avg_gestalt_mean = UEMA()
self.avg_update_gestalt = UEMA()
self.avg_update_position = UEMA()
+ self.avg_num_bounded = UEMA()
+
+ self.writer = writer
- def update_complete(self, avg_position_loss, avg_time_loss, avg_encoder_loss, avg_mse_object_loss, avg_long_mse_object_loss, avg_num_objects, avg_openings, avg_gestalt, avg_gestalt2, avg_gestalt_mean, avg_update_gestalt, avg_update_position):
+ def update_complete(self, avg_position_loss, avg_time_loss, avg_latent_loss, avg_encoding_loss, avg_prediction_loss, avg_num_objects, avg_openings, avg_gestalt, avg_gestalt2, avg_gestalt_mean, avg_update_gestalt, avg_update_position, avg_num_bounded, lr, num_updates):
self.avg_position_loss.update(avg_position_loss.item())
self.avg_time_loss.update(avg_time_loss.item())
- self.avg_encoder_loss.update(avg_encoder_loss.item())
- self.avg_mse_object_loss.update(avg_mse_object_loss.item())
- self.avg_long_mse_object_loss.update(avg_long_mse_object_loss.item())
+ self.avg_latent_loss.update(avg_latent_loss.item())
+ self.avg_encoding_loss.update(avg_encoding_loss.item())
+ self.avg_prediction_loss.update(avg_prediction_loss.item())
+ self.avg_prediction_loss_long.update(avg_prediction_loss.item())
self.avg_num_objects.update(avg_num_objects)
self.avg_openings.update(avg_openings)
self.avg_gestalt.update(avg_gestalt.item())
self.avg_gestalt2.update(avg_gestalt2.item())
self.avg_gestalt_mean.update(avg_gestalt_mean.item())
- self.avg_update_gestalt.update(avg_update_gestalt.item())
- self.avg_update_position.update(avg_update_position.item())
+ self.avg_update_gestalt.update(avg_update_gestalt)
+ self.avg_update_position.update(avg_update_position)
+ self.avg_num_bounded.update(avg_num_bounded)
+
+ self.writer.add_scalar("Train/Position Loss", avg_position_loss.item(), num_updates)
+ self.writer.add_scalar("Train/Time Loss", avg_time_loss.item(), num_updates)
+ self.writer.add_scalar("Train/Latent Loss", avg_latent_loss.item(), num_updates)
+ self.writer.add_scalar("Train/Encoder Loss", avg_encoding_loss.item(), num_updates)
+ self.writer.add_scalar("Train/Prediction Loss", avg_prediction_loss.item(), num_updates)
+ self.writer.add_scalar("Train/Number of Objects", avg_num_objects, num_updates)
+ self.writer.add_scalar("Train/Openings", avg_openings, num_updates)
+ self.writer.add_scalar("Train/Gestalt", avg_gestalt.item(), num_updates)
+ self.writer.add_scalar("Train/Gestalt2", avg_gestalt2.item(), num_updates)
+ self.writer.add_scalar("Train/Gestalt Mean", avg_gestalt_mean.item(), num_updates)
+ self.writer.add_scalar("Train/Update Gestalt", avg_update_gestalt, num_updates)
+ self.writer.add_scalar("Train/Update Position", avg_update_position, num_updates)
+ self.writer.add_scalar("Train/Number Bounded", avg_num_bounded, num_updates)
+ self.writer.add_scalar("Train/Learning Rate", lr, num_updates)
+
pass
- def update_average_loss(self, avgloss):
+ def update_average_loss(self, avgloss, num_updates):
self.avgloss.update(avgloss)
+ self.writer.add_scalar("Train/Loss", avgloss, num_updates)
pass
def get_log(self):
- info = f'Loss: {np.abs(float(self.avgloss)):.2e}|{float(self.avg_mse_object_loss):.2e}|{float(self.avg_long_mse_object_loss):.2e}, reg: {float(self.avg_encoder_loss):.2e}|{float(self.avg_time_loss):.2e}|{float(self.avg_position_loss):.2e}, obj: {float(self.avg_num_objects):.1f}, open: {float(self.avg_openings):.2e}|{float(self.avg_gestalt):.2f}, bin: {float(self.avg_gestalt_mean):.2e}|{np.sqrt(float(self.avg_gestalt2) - float(self.avg_gestalt)**2):.2e} closed: {float(self.avg_update_gestalt):.2e}|{float(self.avg_update_position):.2e}'
+ info = f'Loss: {np.abs(float(self.avgloss)):.2e}|{float(self.avg_prediction_loss):.2e}|{float(self.avg_prediction_loss_long):.2e}, reg: {float(self.avg_encoding_loss):.2e}|{float(self.avg_time_loss):.2e}|{float(self.avg_latent_loss):.2e}|{float(self.avg_position_loss):.2e}, obj: {float(self.avg_num_objects):.1f}, open: {float(self.avg_openings):.2e}|{float(self.avg_gestalt):.2f}, bin: {float(self.avg_gestalt_mean):.2e}|{np.sqrt(float(self.avg_gestalt2) - float(self.avg_gestalt)**2):.2e} closed: {float(self.avg_update_gestalt):.2e}|{float(self.avg_update_position):.2e}'
return info
+
+class WriterWrapper():
+
+ def __init__(self, use_wandb: bool, cfg: Configuration):
+ if use_wandb:
+ from torch.utils.tensorboard import SummaryWriter
+ import wandb
+ wandb.init(project=f'Loci_Looped_{cfg.dataset}', name= cfg.model_path, sync_tensorboard=True, config=cfg)
+ self.writer = SummaryWriter()
+ else:
+ self.writer = None
+
+ def add_scalar(self, name, value, step):
+ if self.writer is not None:
+ self.writer.add_scalar(name, value, step)
+
+ def add_video(self, name, value, step):
+ if self.writer is not None:
+ self.writer.add_video(name, value, step)
+
+ def flush(self):
+ if self.writer is not None:
+ self.writer.flush()
diff --git a/scripts/utils/plot_utils.py b/scripts/utils/plot_utils.py
index 5cb20de..1c83456 100644
--- a/scripts/utils/plot_utils.py
+++ b/scripts/utils/plot_utils.py
@@ -151,9 +151,10 @@ def to_rgb(tensor: th.Tensor):
tensor
), dim=1)
-def visualise_gate(gate, h, w):
+def visualise_gate(gate, h, w, invert = False):
bar = th.ones((1,h,w), device=gate.device) * 0.9
black = int(w*gate.item())
+ black = w-black if invert else black
if black > 0:
bar[:,:, -black:] = 0
return bar
@@ -266,28 +267,30 @@ def plot_online_error(error, error_name, target, t, i, sequence_len, root_path,
return error_plot
-def plot_object_view(error_plot, error_plot2, error_plot_slots, error_plot_slots2, highlighted_input, output_hidden, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object):
+def plot_object_view(error_plot, error_plot2, error_plot_slots, error_plot_slots2, highlighted_input, output_hidden, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object, rollout_mode=False, openings=None):
# add ground truth positions of objects to image
+ target = target.clone()
if gt_positions_target_next is not None:
for o in range(gt_positions_target_next.shape[1]):
position = gt_positions_target_next[0, o]
position = position/2 + 0.5
- if position[2] > 0.0 and position[0] > 0.0 and position[0] < 1.0 and position[1] > 0.0 and position[1] < 1.0:
- width = 5
- w = np.clip(int(position[0]*target.shape[2]), width, target.shape[2]-width).item()
+ if (len(position.shape) < 3 or position[2] > 0.0) and position[0] > 0.0 and position[0] < 1.0 and position[1] > 0.0 and position[1] < 1.0:
+ width = int(target.shape[2]*0.05)
+ w = np.clip(int(position[0]*target.shape[2]), width, target.shape[2]-width).item() # made for bouncing balls
h = np.clip(int(position[1]*target.shape[3]), width, target.shape[3]-width).item()
col = get_color(o).view(3,1,1)
target[0,:,(w-width):(w+width), (h-width):(h+width)] = col
# add these positions to the associated slots velocity_next2d ilustration
- slots = (association_table[0] == o).nonzero()
- for s in slots.flatten():
- velocity_next2d[s,:,(w-width):(w+width), (h-width):(h+width)] = col
+ if association_table is not None:
+ slots = (association_table[0] == o).nonzero()
+ for s in slots.flatten():
+ velocity_next2d[s,:,(w-width):(w+width), (h-width):(h+width)] = col
- if output_hidden is not None and s != largest_object:
- output_hidden[0,:,(w-width):(w+width), (h-width):(h+width)] = col
+ if output_hidden is not None and s != largest_object:
+ output_hidden[0,:,(w-width):(w+width), (h-width):(h+width)] = col
gateheight = 60
ch = 40
@@ -296,22 +299,28 @@ def plot_object_view(error_plot, error_plot2, error_plot_slots, error_plot_slots
gh_margin = int((gh-gh_bar)/2)
margin = 20
slots_margin = 10
- height = size[0] * 6 + 18*5
+ height = size[0] * 6 + 18*6
width = size[1] * 4 + 18*2 + size[1]*num_objects + 6*(num_objects+1) + slots_margin*(num_objects+1)
img = th.ones((3, height, width), device = object_next.device) * 0.4
row = (lambda row_index: [2*size[0]*row_index + (row_index+1)*margin, 2*size[0]*(row_index+1) + (row_index+1)*margin])
col1 = range(margin, margin + size[1]*2)
col2 = range(width-(margin+size[1]*2), width-margin)
+ # add frame around image
+ if rollout_mode:
+ img[0,margin-2:margin+size[0]*2+2, margin-2:margin+size[1]*2+2] = 1
+
img[:,row(0)[0]:row(0)[1], col1] = preprocess(highlighted_input.to(object_next.device), 2)[0]
img[:,row(1)[0]:row(1)[1], col1] = preprocess(output_hidden.to(object_next.device), 2)[0]
img[:,row(2)[0]:row(2)[1], col1] = preprocess(target.to(object_next.device), 2)[0]
+ # add large error plots to image
if error_plot is not None:
img[:,row(0)[0]+gh+ch+2*margin-gh_margin:row(0)[1]+gh+ch+2*margin-gh_margin, col2] = preprocess(error_plot.to(object_next.device), normalize= True)
if error_plot2 is not None:
img[:,row(2)[0]:row(2)[1], col2] = preprocess(error_plot2.to(object_next.device), normalize= True)
+ # fill colunmns with slots
for o in range(num_objects):
col = 18+size[1]*2+6+o*(6+size[1])+(o+1)*slots_margin
@@ -321,23 +330,35 @@ def plot_object_view(error_plot, error_plot2, error_plot_slots, error_plot_slots
if (error_plot_slots2 is not None) and len(error_plot_slots2) > o:
img[:,margin:margin+ch, col] = get_color(o).view(3,1,1).to(object_next.device)
+ # gestalt gate
img[:,margin+ch+2*margin:2*margin+gh_bar+ch+margin, col] = visualise_gate(slots_closed[:,o, 0].to(object_next.device), h=gh_bar, w=len(col))
offset = gh+margin-gh_margin+ch+2*margin
row = (lambda row_index: [offset+(size[0]+6)*row_index, offset+size[0]*(row_index+1)+6*row_index])
img[:,row(0)[0]:row(0)[1], col] = preprocess(rawmask_next[0,o].to(object_next.device))
img[:,row(1)[0]:row(1)[1], col] = preprocess(object_next[:,o].to(object_next.device))
+
+ # small error plots top row
if (error_plot_slots2 is not None) and len(error_plot_slots2) > o:
img[:,row(2)[0]:row(2)[1], col] = preprocess(error_plot_slots2[o].to(object_next.device), normalize=True)
+ # switch to bottom row
offset = margin*2-8
row = (lambda row_index: [offset+(size[0]+6)*row_index, offset+size[0]*(row_index+1)+6*row_index])
+
+ # position gate
img[:,row(4)[0]-gh+gh_margin:row(4)[0]-gh_margin, col] = visualise_gate(slots_closed[:,o, 1].to(object_next.device), h=gh_bar, w=len(col))
img[:,row(4)[0]:row(4)[1], col] = preprocess(velocity_next2d[o].to(object_next.device), normalize=True)[0]
+
+ # small error plots bottom row
if (error_plot_slots is not None) and len(error_plot_slots) > o:
img[:,row(5)[0]:row(5)[1], col] = preprocess(error_plot_slots[o].to(object_next.device), normalize=True)
- img = rearrange(img * 255, 'c h w -> h w c').cpu().numpy()
+ # add gatelord gate visualisation to image
+ if openings is not None:
+ img[:,row(5)[1]+gh_margin:row(5)[1]+gh-gh_margin, col] = visualise_gate(openings[:,o].to(object_next.device), h=gh_bar, w=len(col), invert = True)
+
+ img = rearrange(img * 255, 'c h w -> h w c').cpu()
return img
@@ -347,8 +368,57 @@ def write_image(file, img):
pass
-def plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i):
-
+def extract_element(tensor, index):
+ if tensor is None:
+ return None
+ return tensor[index:index+1]
+
+def plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, sample_i, rollout_mode=False, num_vid=2, att=None, openings=None):
+
+ if len(input) > 1:
+ img_list = None
+ for i in range(min(len(input), num_vid)):
+ _input = extract_element(input, i)
+ _target = extract_element(target, i)
+ _mask_cur = extract_element(mask_cur, i)
+ _mask_next = extract_element(mask_next, i)
+ _output_next = extract_element(output_next, i)
+ _position_encoder_cur = extract_element(position_encoder_cur, i)
+ _position_next = extract_element(position_next, i)
+ _rawmask_hidden = extract_element(rawmask_hidden, i)
+ _rawmask_cur = extract_element(rawmask_cur, i)
+ _rawmask_next = extract_element(rawmask_next, i)
+ _largest_object = extract_element(largest_object, i)
+ _object_cur = extract_element(object_cur, i)
+ _object_next = extract_element(object_next, i)
+ _object_hidden = extract_element(object_hidden, i)
+ _slots_bounded = extract_element(slots_bounded, i)
+ _slots_partially_occluded_cur = extract_element(slots_partially_occluded_cur, i)
+ _slots_occluded_cur = extract_element(slots_occluded_cur, i)
+ _slots_partially_occluded_next = extract_element(slots_partially_occluded_next, i)
+ _slots_occluded_next = extract_element(slots_occluded_next, i)
+ _slots_closed = extract_element(slots_closed, i)
+ _gt_positions_target_next = extract_element(gt_positions_target_next, i)
+ _association_table = extract_element(association_table, i)
+ _error_next = extract_element(error_next, i)
+ _output_hidden = extract_element(output_hidden, i)
+ _att = extract_element(att, i)
+ _openings = extract_element(openings, i)
+
+ img = plot_timestep_single(cfg, cfg_net, _input, _target, _mask_cur, _mask_next, _output_next, _position_encoder_cur, _position_next, _rawmask_hidden, _rawmask_cur, _rawmask_next, _largest_object, _object_cur, _object_next, _object_hidden, _slots_bounded, _slots_partially_occluded_cur, _slots_occluded_cur, _slots_partially_occluded_next, _slots_occluded_next, _slots_closed, _gt_positions_target_next, _association_table, _error_next, _output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i, rollout_mode, att=_att, openings=_openings)
+ if img_list is None:
+ img_list = img.unsqueeze(0)
+ else:
+ img_list = th.cat((img_list, img.unsqueeze(0)), dim=0)
+
+ return img_list.permute(0, 3, 1, 2)
+
+ else:
+ img = plot_timestep_single(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, sample_i, rollout_mode, att=att, openings=openings)
+ return img
+
+def plot_timestep_single(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i, rollout_mode=False, att=None, openings=None):
+
# Create eposition helpers
size, gaus2d, vector2d, scale = get_position_helper(cfg_net, mask_cur.device)
@@ -364,7 +434,7 @@ def plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next,
velocity_next2d = color_slots(velocity_next2d, slots_bounded, slots_partially_occluded_next, slots_occluded_next)
# compute occlusion
- if (cfg.datatype == "adept"):
+ if (cfg.datatype == "adept") and rawmask_hidden is not None:
rawmask_cur_l, rawmask_next_l = compute_occlusion_mask(rawmask_cur, rawmask_next, mask_cur, mask_next, scale)
rawmask_cur_h, rawmask_next_h = compute_occlusion_mask(rawmask_cur, rawmask_hidden, mask_cur, mask_next, scale)
rawmask_cur_h[:,largest_object] = rawmask_cur_l[:,largest_object]
@@ -385,30 +455,39 @@ def plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next,
mask_next = rearrange(mask_next, 'b (o 1) h w -> b o 1 h w')
if object_view:
- if (cfg.datatype == "adept"):
+ if (cfg.datatype == "adept") and statistics_complete_slots is not None:
num_objects = 4
error_plot_slots = plot_online_error_slots(statistics_complete_slots['TE'][-cfg_net.num_objects*(t+1):], 'Tracking error', target, sequence_len, root_path, statistics_complete_slots['visible'][-cfg_net.num_objects*(t+1):], slots_bounded)
- error_plot_slots2 = plot_online_error_slots(statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], 'Image error', target, sequence_len, root_path, statistics_complete_slots['visible'][-cfg_net.num_objects*(t+1):], slots_bounded, ylim=0.0001)
+ #error_plot_slots2 = plot_online_error_slots(statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], 'Image error', target, sequence_len, root_path, statistics_complete_slots['visible'][-cfg_net.num_objects*(t+1):], slots_bounded, ylim=0.0001)
error_plot = plot_online_error(statistics_batch['image_error'], 'Prediction error', target, t, i, sequence_len, root_path)
error_plot2 = plot_online_error(statistics_batch['TE'], 'Tracking error', target, t, i, sequence_len, root_path)
- img = plot_object_view(error_plot, error_plot2, error_plot_slots, error_plot_slots2, highlighted_input, output_hidden, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object)
- else:
+ att_histogram = plot_attention_histogram(att, target, root_path)
+ img = plot_object_view(error_plot, error_plot2, error_plot_slots, att_histogram, highlighted_input, output_hidden, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object, openings=openings)
+ elif (cfg.datatype == "clevrer") and statistics_complete_slots is not None:
num_objects = cfg_net.num_objects
error_plot_slots2 = plot_online_error_slots(statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], 'Image error', target, sequence_len, root_path, statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], slots_bounded, ylim=0.0001)
error_plot = plot_online_error(statistics_batch['image_error_mse'], 'Prediction error', target, t, i, sequence_len, root_path)
- img = plot_object_view(error_plot, None, None, error_plot_slots2, highlighted_input, output_next, object_next, rawmask_next, velocity_next2d, target, slots_closed, None, None, size, num_objects, largest_object)
-
- cv2.imwrite(f'{plot_path}object/gpnet-objects-{i:04d}-{t_index:03d}.jpg', img)
+ img = plot_object_view(error_plot, None, None, error_plot_slots2, highlighted_input, output_next, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object, openings=openings)
+ else:
+ num_objects = cfg_net.num_objects
+ att_histogram = plot_attention_histogram(att, target, root_path)
+ img = plot_object_view(None, None, att_histogram, None, input, output, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object, rollout_mode=rollout_mode, openings=openings)
+
+ if plot_path is not None:
+ cv2.imwrite(f'{plot_path}/object/{i:04d}-{t_index:03d}.jpg', img.numpy())
if individual_views:
# ['error', 'input', 'background', 'prediction', 'position', 'rawmask', 'mask', 'othermask']:
write_image(f'{plot_path}/individual/error/error-{i:04d}-{t_index:03d}.jpg', error_next[0])
- write_image(f'{plot_path}/individual/input/input-{i:04d}-{t_index:03d}.jpg', input[0])
+ write_image(f'{plot_path}/individual/input/input-{i:04d}-{t_index:03d}.jpg', target[0])
write_image(f'{plot_path}/individual/background/background-{i:04d}-{t_index:03d}.jpg', mask_next[0,-1])
- write_image(f'{plot_path}/individual/imagination/imagination-{i:04d}-{t_index:03d}.jpg', output_hidden[0])
+ #write_image(f'{plot_path}/individual/imagination/imagination-{i:04d}-{t_index:03d}.jpg', output_hidden[0])
write_image(f'{plot_path}/individual/prediction/prediction-{i:04d}-{t_index:03d}.jpg', output_next[0])
+ for o in range(len(rawmask_next[0])):
+ write_image(f'{plot_path}/individual/rgb/object-{i:04d}-{o}-{t_index:03d}.jpg', object_next[0][o])
+ write_image(f'{plot_path}/individual/rawmask/rawmask-{i:04d}-{o}-{t_index:03d}.jpg', rawmask_next[0][o])
- pass
+ return img
def get_position_helper(cfg_net, device):
size = cfg_net.input_size
@@ -425,4 +504,35 @@ def reshape_slots(slots_bounded, slots_partially_occluded_cur, slots_occluded_cu
slots_partially_occluded_next = th.squeeze(slots_partially_occluded_next)[..., None,None,None]
slots_occluded_next = th.squeeze(slots_occluded_next)[..., None,None,None]
- return slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next \ No newline at end of file
+ return slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next
+
+def plot_attention_histogram(att, target, root_path):
+ att_plots = []
+ if (att is not None) and (len(att) > 0):
+ att = att[0]
+ for object_attention in att:
+
+ fig, ax = plt.subplots(figsize=(round(target.shape[3]/100,2), round(target.shape[2]/100,2)))
+
+ # Plot a bar plot over the 6 objects
+ num_objects = len(object_attention)
+ ax.bar(range(num_objects), object_attention.cpu())
+ ax.set_ylim([0,1])
+ ax.set_xlim([-1,num_objects])
+ ax.set_xticks(range(num_objects))
+ #ax.set_xticklabels(['1','2','3','4','5','6'])
+ ax.set_ylabel('attention')
+ ax.set_xlabel('object')
+ ax.set_title('Attention histogram')
+
+ # fixed
+ fig.tight_layout()
+ plt.savefig(f'{root_path}/tmp.jpg')
+ plot = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb())
+ plot = th.from_numpy(np.array(plot).transpose(2,0,1))
+ plt.close(fig)
+ att_plots.append(plot)
+
+ return att_plots
+ else:
+ return None \ No newline at end of file