aboutsummaryrefslogtreecommitdiff
path: root/scripts/utils
diff options
context:
space:
mode:
authorfredeee2023-11-02 10:47:21 +0100
committerfredeee2023-11-02 10:47:21 +0100
commitf8302ee886ef9b631f11a52900dac964a61350e1 (patch)
tree87288be6f851ab69405e524b81940c501c52789a /scripts/utils
parentf16fef1ab9371e1c81a2e0b2fbea59dee285a9f8 (diff)
initiaƶ commit
Diffstat (limited to 'scripts/utils')
-rw-r--r--scripts/utils/configuration.py83
-rw-r--r--scripts/utils/eval_metrics.py350
-rw-r--r--scripts/utils/eval_utils.py78
-rw-r--r--scripts/utils/io.py137
-rw-r--r--scripts/utils/optimizers.py100
-rw-r--r--scripts/utils/plot_utils.py428
6 files changed, 1176 insertions, 0 deletions
diff --git a/scripts/utils/configuration.py b/scripts/utils/configuration.py
new file mode 100644
index 0000000..7c38095
--- /dev/null
+++ b/scripts/utils/configuration.py
@@ -0,0 +1,83 @@
+import json
+import jsmin
+import os
+
+class Dict(dict):
+ """
+ Dictionary that allows to access per attributes and to except names from being loaded
+ """
+ def __init__(self, dictionary: dict = None):
+ super(Dict, self).__init__()
+
+ if dictionary is not None:
+ self.load(dictionary)
+
+ def __getattr__(self, item):
+ try:
+ return self[item] if item in self else getattr(super(Dict, self), item)
+ except AttributeError:
+ raise AttributeError(f'This dictionary has no attribute "{item}"')
+
+ def load(self, dictionary: dict, name_list: list = None):
+ """
+ Loads a dictionary
+ :param dictionary: Dictionary to be loaded
+ :param name_list: List of names to be updated
+ """
+ for name in dictionary:
+ data = dictionary[name]
+ if name_list is None or name in name_list:
+ if isinstance(data, dict):
+ if name in self:
+ self[name].load(data)
+ else:
+ self[name] = Dict(data)
+ elif isinstance(data, list):
+ self[name] = list()
+ for item in data:
+ if isinstance(item, dict):
+ self[name].append(Dict(item))
+ else:
+ self[name].append(item)
+ else:
+ self[name] = data
+
+ def save(self, path):
+ """
+ Saves the dictionary into a json file
+ :param path: Path of the json file
+ """
+ os.makedirs(path, exist_ok=True)
+
+ path = os.path.join(path, 'cfg.json')
+
+ with open(path, 'w') as file:
+ json.dump(self, file, indent=True)
+
+
+class Configuration(Dict):
+ """
+ Configuration loaded from a json file
+ """
+ def __init__(self, path: str, default_path=None):
+ super(Configuration, self).__init__()
+
+ if default_path is not None:
+ self.load(default_path)
+
+ self.load(path)
+
+ def load_model(self, path: str):
+ self.load(path, name_list=["model"])
+
+ def load(self, path: str, name_list: list = None):
+ """
+ Loads attributes from a json file
+ :param path: Path of the json file
+ :param name_list: List of names to be updated
+ :return:
+ """
+ with open(path) as file:
+ data = json.loads(jsmin.jsmin(file.read()))
+
+ super(Configuration, self).load(data, name_list)
diff --git a/scripts/utils/eval_metrics.py b/scripts/utils/eval_metrics.py
new file mode 100644
index 0000000..6f5106d
--- /dev/null
+++ b/scripts/utils/eval_metrics.py
@@ -0,0 +1,350 @@
+'''
+vp_utils.py
+SCRIPT TAKEN FROM https://github.com/pairlab/SlotFormer
+'''
+
+import numpy as np
+from scipy.optimize import linear_sum_assignment
+from skimage.metrics import structural_similarity, peak_signal_noise_ratio
+
+import torch
+import torch.nn.functional as F
+import torchvision.ops as vops
+
+FG_THRE = 0.5
+PALETTE = [(0, 255, 0), (0, 0, 255), (0, 255, 255), (255, 255, 0),
+ (255, 0, 255), (100, 100, 255), (200, 200, 100), (170, 120, 200),
+ (255, 0, 0), (200, 100, 100), (10, 200, 100), (200, 200, 200),
+ (50, 50, 50)]
+PALETTE_np = np.array(PALETTE, dtype=np.uint8)
+PALETTE_torch = torch.from_numpy(PALETTE_np).float() / 255. * 2. - 1.
+
+def to_rgb_from_tensor(x):
+ """Reverse the Normalize operation in torchvision."""
+ return (x * 0.5 + 0.5).clamp(0, 1)
+
+def postproc_mask(batch_masks):
+ """Post-process masks instead of directly taking argmax.
+
+ Args:
+ batch_masks: [B, T, N, 1, H, W]
+
+ Returns:
+ masks: [B, T, H, W]
+ """
+ batch_masks = batch_masks.clone()
+ B, T, N, _, H, W = batch_masks.shape
+ batch_masks = batch_masks.reshape(B * T, N, H * W)
+ slots_max = batch_masks.max(-1)[0] # [B*T, N]
+ bg_idx = slots_max.argmin(-1) # [B*T]
+ spatial_max = batch_masks.max(1)[0] # [B*T, H*W]
+ bg_mask = (spatial_max < FG_THRE) # [B*T, H*W]
+ idx_mask = torch.zeros((B * T, N)).type_as(bg_mask)
+ idx_mask[torch.arange(B * T), bg_idx] = True
+ # set the background mask score to 1
+ batch_masks[idx_mask.unsqueeze(-1) * bg_mask.unsqueeze(1)] = 1.
+ masks = batch_masks.argmax(1) # [B*T, H*W]
+ return masks.reshape(B, T, H, W)
+
+
+def masks_to_boxes_w_empty_mask(binary_masks):
+ """binary_masks: [B, H, W]."""
+ B = binary_masks.shape[0]
+ obj_mask = (binary_masks.sum([-1, -2]) > 0) # [B]
+ bboxes = torch.ones((B, 4)).float().to(binary_masks.device) * -1.
+ bboxes[obj_mask] = vops.masks_to_boxes(binary_masks[obj_mask])
+ return bboxes
+
+
+def masks_to_boxes(masks, num_boxes=7):
+ """Convert seg_masks to bboxes.
+
+ Args:
+ masks: [B, T, H, W], output after taking argmax
+ num_boxes: number of boxes to generate (num_slots)
+
+ Returns:
+ bboxes: [B, T, N, 4], 4: [x1, y1, x2, y2]
+ """
+ B, T, H, W = masks.shape
+ binary_masks = F.one_hot(masks, num_classes=num_boxes) # [B, T, H, W, N]
+ binary_masks = binary_masks.permute(0, 1, 4, 2, 3) # [B, T, N, H, W]
+ binary_masks = binary_masks.contiguous().flatten(0, 2)
+ bboxes = masks_to_boxes_w_empty_mask(binary_masks) # [B*T*N, 4]
+ bboxes = bboxes.reshape(B, T, num_boxes, 4) # [B, T, N, 4]
+ return bboxes
+
+
+def mse_metric(x, y):
+ """x/y: [B, C, H, W]"""
+ # people often sum over spatial dimensions in video prediction MSE
+ # see e.g. https://github.com/Yunbo426/predrnn-pp/blob/40764c52f433290aa02414b5f25c09c72b98e0af/train.py#L244
+ return ((x - y)**2).sum(-1).sum(-1).mean()
+
+
+def psnr_metric(x, y):
+ """x/y: [B, C, H, W]"""
+ psnrs = [
+ peak_signal_noise_ratio(
+ x[i],
+ y[i],
+ data_range=1.,
+ ) for i in range(x.shape[0])
+ ]
+ return np.mean(psnrs)
+
+
+def ssim_metric(x, y):
+ """x/y: [B, C, H, W]"""
+ x = x * 255.
+ y = y * 255.
+ ssims = [
+ structural_similarity(
+ x[i],
+ y[i],
+ channel_axis=0,
+ gaussian_weights=True,
+ sigma=1.5,
+ use_sample_covariance=False,
+ data_range=255,
+ ) for i in range(x.shape[0])
+ ]
+ return np.mean(ssims)
+
+
+def perceptual_dist(x, y, loss_fn):
+ """x/y: [B, C, H, W]"""
+ return loss_fn(x, y).mean()
+
+
+def adjusted_rand_index(true_ids, pred_ids, ignore_background=False):
+ """Computes the adjusted Rand index (ARI), a clustering similarity score.
+
+ Code borrowed from https://github.com/google-research/slot-attention-video/blob/e8ab54620d0f1934b332ddc09f1dba7bc07ff601/savi/lib/metrics.py#L111
+
+ Args:
+ true_ids: An integer-valued array of shape
+ [batch_size, seq_len, H, W]. The true cluster assignment encoded
+ as integer ids.
+ pred_ids: An integer-valued array of shape
+ [batch_size, seq_len, H, W]. The predicted cluster assignment
+ encoded as integer ids.
+ ignore_background: Boolean, if True, then ignore all pixels where
+ true_ids == 0 (default: False).
+
+ Returns:
+ ARI scores as a float32 array of shape [batch_size].
+ """
+ if len(true_ids.shape) == 3:
+ true_ids = true_ids.unsqueeze(1)
+ if len(pred_ids.shape) == 3:
+ pred_ids = pred_ids.unsqueeze(1)
+
+ true_oh = F.one_hot(true_ids).float()
+ pred_oh = F.one_hot(pred_ids).float()
+
+ if ignore_background:
+ true_oh = true_oh[..., 1:] # Remove the background row.
+
+ N = torch.einsum("bthwc,bthwk->bck", true_oh, pred_oh)
+ A = torch.sum(N, dim=-1) # row-sum (batch_size, c)
+ B = torch.sum(N, dim=-2) # col-sum (batch_size, k)
+ num_points = torch.sum(A, dim=1)
+
+ rindex = torch.sum(N * (N - 1), dim=[1, 2])
+ aindex = torch.sum(A * (A - 1), dim=1)
+ bindex = torch.sum(B * (B - 1), dim=1)
+ expected_rindex = aindex * bindex / torch.clamp(
+ num_points * (num_points - 1), min=1)
+ max_rindex = (aindex + bindex) / 2
+ denominator = max_rindex - expected_rindex
+ ari = (rindex - expected_rindex) / denominator
+
+ # There are two cases for which the denominator can be zero:
+ # 1. If both label_pred and label_true assign all pixels to a single cluster.
+ # (max_rindex == expected_rindex == rindex == num_points * (num_points-1))
+ # 2. If both label_pred and label_true assign max 1 point to each cluster.
+ # (max_rindex == expected_rindex == rindex == 0)
+ # In both cases, we want the ARI score to be 1.0:
+ return torch.where(denominator != 0, ari, torch.tensor(1.).type_as(ari))
+
+
+def ARI_metric(x, y):
+ """x/y: [B, H, W], both are seg_masks after argmax."""
+ assert 'int' in str(x.dtype)
+ assert 'int' in str(y.dtype)
+ return adjusted_rand_index(x, y).mean().item()
+
+
+def fARI_metric(x, y):
+ """x/y: [B, H, W], both are seg_masks after argmax."""
+ assert 'int' in str(x.dtype)
+ assert 'int' in str(y.dtype)
+ return adjusted_rand_index(x, y, ignore_background=True).mean().item()
+
+
+def bbox_precision_recall(gt_pres_mask, gt_bbox, pred_bbox, ovthresh=0.5):
+ """Compute the precision of predicted bounding boxes.
+
+ Args:
+ gt_pres_mask: A boolean tensor of shape [N]
+ gt_bbox: A tensor of shape [N, 4]
+ pred_bbox: A tensor of shape [M, 4]
+ """
+ gt_bbox, pred_bbox = gt_bbox.clone(), pred_bbox.clone()
+ gt_bbox = gt_bbox[gt_pres_mask.bool()]
+ pred_bbox = pred_bbox[pred_bbox[:, 0] >= 0.]
+ N, M = gt_bbox.shape[0], pred_bbox.shape[0]
+ assert gt_bbox.shape[1] == pred_bbox.shape[1] == 4
+ # assert M >= N
+ tp, fp = 0, 0
+ bbox_used = [False] * pred_bbox.shape[0]
+ bbox_ious = vops.box_iou(gt_bbox, pred_bbox) # [N, M]
+
+ # Find the best iou match for each ground truth bbox.
+ for i in range(N):
+ best_iou_idx = bbox_ious[i].argmax().item()
+ best_iou = bbox_ious[i, best_iou_idx].item()
+ if best_iou >= ovthresh and not bbox_used[best_iou_idx]:
+ tp += 1
+ bbox_used[best_iou_idx] = True
+ else:
+ fp += 1
+
+ # compute precision and recall
+ precision = tp / float(M)
+ recall = tp / float(N)
+ return precision, recall
+
+
+def batch_bbox_precision_recall(gt_pres_mask, gt_bbox, pred_bbox):
+ """Compute the precision of predicted bounding boxes over batch."""
+ aps, ars = [], []
+ for i in range(gt_pres_mask.shape[0]):
+ ap, ar = bbox_precision_recall(gt_pres_mask[i], gt_bbox[i],
+ pred_bbox[i])
+ aps.append(ap)
+ ars.append(ar)
+ return np.mean(aps), np.mean(ars)
+
+
+def hungarian_miou(gt_mask, pred_mask):
+ """both mask: [H*W] after argmax, 0 is gt background index."""
+ true_oh = F.one_hot(gt_mask).float()[..., 1:] # only foreground, [HW, N]
+ pred_oh = F.one_hot(pred_mask).float() # [HW, M]
+ N, M = true_oh.shape[-1], pred_oh.shape[-1]
+ # compute all pairwise IoU
+ intersect = (true_oh[:, :, None] * pred_oh[:, None, :]).sum(0) # [N, M]
+ union = true_oh.sum(0)[:, None] + pred_oh.sum(0)[None, :] # [N, M]
+ iou = intersect / (union + 1e-8) # [N, M]
+ iou = iou.detach().cpu().numpy()
+ # find the best match for each gt
+ row_ind, col_ind = linear_sum_assignment(iou, maximize=True)
+ # there are two possibilities here
+ # 1. M >= N, just take the best match mean
+ # 2. M < N, some objects are not detected, their iou is 0
+ if M >= N:
+ assert (row_ind == np.arange(N)).all()
+ return iou[row_ind, col_ind].mean()
+ return iou[row_ind, col_ind].sum() / float(N)
+
+
+def miou_metric(gt_mask, pred_mask):
+ """both mask: [B, H, W], both are seg_masks after argmax."""
+ assert 'int' in str(gt_mask.dtype)
+ assert 'int' in str(pred_mask.dtype)
+ gt_mask, pred_mask = gt_mask.flatten(1, 2), pred_mask.flatten(1, 2)
+ ious = [
+ hungarian_miou(gt_mask[i], pred_mask[i])
+ for i in range(gt_mask.shape[0])
+ ]
+ return np.mean(ious)
+
+
+@torch.no_grad()
+def pred_eval_step(
+ gt,
+ pred,
+ lpips_fn,
+ gt_mask=None,
+ pred_mask=None,
+ gt_pres_mask=None,
+ gt_bbox=None,
+ pred_bbox=None,
+ eval_traj=True,
+):
+ """Both of shape [B, T, C, H, W], torch.Tensor.
+ masks of shape [B, T, H, W].
+ pres_mask of shape [B, T, N].
+ bboxes of shape [B, T, N/M, 4].
+
+ eval_traj: whether to evaluate the trajectory (measured by bbox and mask).
+
+ Compute metrics for every timestep.
+ """
+ assert len(gt.shape) == len(pred.shape) == 5
+ assert gt.shape == pred.shape
+ assert gt.shape[2] == 3
+ if eval_traj:
+ assert len(gt_mask.shape) == len(pred_mask.shape) == 4
+ assert gt_mask.shape == pred_mask.shape
+ if eval_traj:
+ assert len(gt_pres_mask.shape) == 3
+ assert len(gt_bbox.shape) == len(pred_bbox.shape) == 4
+ T = gt.shape[1]
+
+ # compute perceptual dist & mask metrics before converting to numpy
+ all_percept_dist, all_ari, all_fari, all_miou = [], [], [], []
+ for t in range(T):
+ one_gt, one_pred = gt[:, t], pred[:, t]
+ percept_dist = perceptual_dist(one_gt, one_pred, lpips_fn).item()
+ all_percept_dist.append(percept_dist)
+ if eval_traj:
+ one_gt_mask, one_pred_mask = gt_mask[:, t], pred_mask[:, t]
+ ari = ARI_metric(one_gt_mask, one_pred_mask)
+ fari = fARI_metric(one_gt_mask, one_pred_mask)
+ miou = miou_metric(one_gt_mask, one_pred_mask)
+ all_ari.append(ari)
+ all_fari.append(fari)
+ all_miou.append(miou)
+ else:
+ all_ari.append(0.)
+ all_fari.append(0.)
+ all_miou.append(0.)
+
+ # compute bbox metrics
+ all_ap, all_ar = [], []
+ for t in range(T):
+ if not eval_traj:
+ all_ap.append(0.)
+ all_ar.append(0.)
+ continue
+ one_gt_pres_mask, one_gt_bbox, one_pred_bbox = \
+ gt_pres_mask[:, t], gt_bbox[:, t], pred_bbox[:, t]
+ ap, ar = batch_bbox_precision_recall(one_gt_pres_mask, one_gt_bbox,
+ one_pred_bbox)
+ all_ap.append(ap)
+ all_ar.append(ar)
+
+ gt = to_rgb_from_tensor(gt).cpu().numpy()
+ pred = to_rgb_from_tensor(pred).cpu().numpy()
+ all_mse, all_ssim, all_psnr = [], [], []
+ for t in range(T):
+ one_gt, one_pred = gt[:, t], pred[:, t]
+ mse = mse_metric(one_gt, one_pred)
+ psnr = psnr_metric(one_gt, one_pred)
+ ssim = ssim_metric(one_gt, one_pred)
+ all_mse.append(mse)
+ all_ssim.append(ssim)
+ all_psnr.append(psnr)
+ return {
+ 'mse': all_mse,
+ 'ssim': all_ssim,
+ 'psnr': all_psnr,
+ 'percept_dist': all_percept_dist,
+ 'ari': all_ari,
+ 'fari': all_fari,
+ 'miou': all_miou,
+ 'ap': all_ap,
+ 'ar': all_ar,
+ }
diff --git a/scripts/utils/eval_utils.py b/scripts/utils/eval_utils.py
new file mode 100644
index 0000000..faab7ec
--- /dev/null
+++ b/scripts/utils/eval_utils.py
@@ -0,0 +1,78 @@
+import os
+import torch as th
+from model.loci import Loci
+
+def setup_result_folders(file, name, set_test, evaluation_mode, object_view, individual_views):
+
+ net_name = file.split('/')[-1].split('.')[0]
+ #root_path = file.split('nets')[0]
+ root_path = os.path.join(*file.split('/')[0:-1])
+ root_path = os.path.join(root_path, f'results{name}', net_name, set_test['type'])
+ plot_path = os.path.join(root_path, evaluation_mode)
+
+ # create directories
+ #if os.path.exists(plot_path):
+ # shutil.rmtree(plot_path)
+ os.makedirs(plot_path, exist_ok = True)
+ if object_view:
+ os.makedirs(os.path.join(plot_path, 'object'), exist_ok = True)
+ if individual_views:
+ os.makedirs(os.path.join(plot_path, 'individual'), exist_ok = True)
+ for group in ['error', 'input', 'background', 'prediction', 'position', 'rawmask', 'mask', 'othermask', 'imagination']:
+ os.makedirs(os.path.join(plot_path, 'individual', group), exist_ok = True)
+ os.makedirs(os.path.join(root_path, 'statistics'), exist_ok = True)
+
+ # final directory
+ plot_path = plot_path + '/'
+ print(f"save plots to {plot_path}")
+
+ return root_path, plot_path
+
+def store_statistics(memory, *args, extend=False):
+ for i,key in enumerate(memory.keys()):
+ if i >= len(args):
+ break
+ if extend:
+ memory[key].extend(args[i])
+ else:
+ memory[key].append(args[i])
+ return memory
+
+def append_statistics(memory1, memory2, ignore=[], extend=False):
+ for key in memory1:
+ if key not in ignore:
+ if extend:
+ memory2[key] = memory2[key] + memory1[key]
+ else:
+ memory2[key].append(memory1[key])
+ return memory2
+
+def load_model(cfg, cfg_net, file, device):
+
+ net = Loci(
+ cfg_net,
+ teacher_forcing = cfg.defaults.teacher_forcing
+ )
+
+ # load model
+ if file != '':
+ print(f"load {file} to device {device}")
+ state = th.load(file, map_location=device)
+
+ # backward compatibility
+ model = {}
+ for key, value in state["model"].items():
+ # replace update_module with percept_gate_controller in key string:
+ key = key.replace("update_module", "percept_gate_controller")
+ model[key.replace(".module.", ".")] = value
+
+ net.load_state_dict(model)
+
+ # ???
+ if net.get_init_status() < 1:
+ net.inc_init_level()
+
+ # set network to evaluation mode
+ net = net.to(device=device)
+
+ return net \ No newline at end of file
diff --git a/scripts/utils/io.py b/scripts/utils/io.py
new file mode 100644
index 0000000..9bd8158
--- /dev/null
+++ b/scripts/utils/io.py
@@ -0,0 +1,137 @@
+import os
+from scripts.utils.configuration import Configuration
+import time
+import torch as th
+import numpy as np
+from einops import rearrange, repeat, reduce
+
+
+def init_device(cfg):
+ print(f'Cuda available: {th.cuda.is_available()} Cuda count: {th.cuda.device_count()}')
+ if th.cuda.is_available():
+ device = th.device("cuda:0")
+ verbose = False
+ cfg.device = "cuda:0"
+ cfg.model.device = "cuda:0"
+ print('!!! USING GPU !!!')
+ else:
+ device = th.device("cpu")
+ verbose = True
+ cfg.device = "cpu"
+ cfg.model.device = "cpu"
+ cfg.model.batch_size = 2
+ cfg.defaults.teacher_forcing = 4
+ print('!!! USING CPU !!!')
+ return device,verbose
+
+class Timer:
+
+ def __init__(self):
+ self.last = time.time()
+ self.passed = 0
+ self.sum = 0
+
+ def __str__(self):
+ self.passed = self.passed * 0.99 + time.time() - self.last
+ self.sum = self.sum * 0.99 + 1
+ passed = self.passed / self.sum
+ self.last = time.time()
+
+ if passed > 1:
+ return f"{passed:.2f}s/it"
+
+ return f"{1.0/passed:.2f}it/s"
+
+class UEMA:
+
+ def __init__(self, memory = 100):
+ self.value = 0
+ self.sum = 1e-30
+ self.decay = np.exp(-1 / memory)
+
+ def update(self, value):
+ self.value = self.value * self.decay + value
+ self.sum = self.sum * self.decay + 1
+
+ def __float__(self):
+ return self.value / self.sum
+
+
+def model_path(cfg: Configuration, overwrite=False, move_old=True):
+ """
+ Makes the model path, option to not overwrite
+ :param cfg: Configuration file with the model path
+ :param overwrite: Overwrites the files in the directory, else makes a new directory
+ :param move_old: Moves old folder with the same name to an old folder, if not overwrite
+ :return: Model path
+ """
+ _path = os.path.join('out')
+ path = os.path.join(_path, cfg.model_path)
+
+ if not os.path.exists(_path):
+ os.makedirs(_path)
+
+ if not overwrite:
+ if move_old:
+ # Moves existing directory to an old folder
+ if os.path.exists(path):
+ old_path = os.path.join(_path, f'{cfg.model_path}_old')
+ if not os.path.exists(old_path):
+ os.makedirs(old_path)
+ _old_path = os.path.join(old_path, cfg.model_path)
+ i = 0
+ while os.path.exists(_old_path):
+ i = i + 1
+ _old_path = os.path.join(old_path, f'{cfg.model_path}_{i}')
+ os.renames(path, _old_path)
+ else:
+ # Increases number after directory name for each new path
+ i = 0
+ while os.path.exists(path):
+ i = i + 1
+ path = os.path.join(_path, f'{cfg.model_path}_{i}')
+
+ return path
+
+class LossLogger:
+
+ def __init__(self):
+
+ self.avgloss = UEMA()
+ self.avg_position_loss = UEMA()
+ self.avg_time_loss = UEMA()
+ self.avg_encoder_loss = UEMA()
+ self.avg_mse_object_loss = UEMA()
+ self.avg_long_mse_object_loss = UEMA(33333)
+ self.avg_num_objects = UEMA()
+ self.avg_openings = UEMA()
+ self.avg_gestalt = UEMA()
+ self.avg_gestalt2 = UEMA()
+ self.avg_gestalt_mean = UEMA()
+ self.avg_update_gestalt = UEMA()
+ self.avg_update_position = UEMA()
+
+
+ def update_complete(self, avg_position_loss, avg_time_loss, avg_encoder_loss, avg_mse_object_loss, avg_long_mse_object_loss, avg_num_objects, avg_openings, avg_gestalt, avg_gestalt2, avg_gestalt_mean, avg_update_gestalt, avg_update_position):
+
+ self.avg_position_loss.update(avg_position_loss.item())
+ self.avg_time_loss.update(avg_time_loss.item())
+ self.avg_encoder_loss.update(avg_encoder_loss.item())
+ self.avg_mse_object_loss.update(avg_mse_object_loss.item())
+ self.avg_long_mse_object_loss.update(avg_long_mse_object_loss.item())
+ self.avg_num_objects.update(avg_num_objects)
+ self.avg_openings.update(avg_openings)
+ self.avg_gestalt.update(avg_gestalt.item())
+ self.avg_gestalt2.update(avg_gestalt2.item())
+ self.avg_gestalt_mean.update(avg_gestalt_mean.item())
+ self.avg_update_gestalt.update(avg_update_gestalt.item())
+ self.avg_update_position.update(avg_update_position.item())
+ pass
+
+ def update_average_loss(self, avgloss):
+ self.avgloss.update(avgloss)
+ pass
+
+ def get_log(self):
+ info = f'Loss: {np.abs(float(self.avgloss)):.2e}|{float(self.avg_mse_object_loss):.2e}|{float(self.avg_long_mse_object_loss):.2e}, reg: {float(self.avg_encoder_loss):.2e}|{float(self.avg_time_loss):.2e}|{float(self.avg_position_loss):.2e}, obj: {float(self.avg_num_objects):.1f}, open: {float(self.avg_openings):.2e}|{float(self.avg_gestalt):.2f}, bin: {float(self.avg_gestalt_mean):.2e}|{np.sqrt(float(self.avg_gestalt2) - float(self.avg_gestalt)**2):.2e} closed: {float(self.avg_update_gestalt):.2e}|{float(self.avg_update_position):.2e}'
+ return info
diff --git a/scripts/utils/optimizers.py b/scripts/utils/optimizers.py
new file mode 100644
index 0000000..49283b3
--- /dev/null
+++ b/scripts/utils/optimizers.py
@@ -0,0 +1,100 @@
+import math
+import torch as th
+import numpy as np
+from torch.optim.optimizer import Optimizer
+
+"""
+Liyuan Liu , Haoming Jiang, Pengcheng He, Weizhu Chen, Xiaodong Liu, Jianfeng Gao, and Jiawei Han (2020).
+On the Variance of the Adaptive Learning Rate and Beyond. the Eighth International Conference on Learning
+Representations.
+"""
+class RAdam(Optimizer):
+
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, weight_decay=0, degenerated_to_sgd=False):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+
+ self.degenerated_to_sgd = degenerated_to_sgd
+ if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
+ for param in params:
+ if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
+ param['buffer'] = [[None, None, None] for _ in range(10)]
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)])
+ super(RAdam, self).__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super(RAdam, self).__setstate__(state)
+
+ def step(self, closure=None):
+
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad.data.float()
+ if grad.is_sparse:
+ raise RuntimeError('RAdam does not support sparse gradients')
+
+ p_data_fp32 = p.data.float()
+
+ state = self.state[p]
+
+ if len(state) == 0:
+ state['step'] = 0
+ state['exp_avg'] = th.zeros_like(p_data_fp32)
+ state['exp_avg_sq'] = th.zeros_like(p_data_fp32)
+ else:
+ state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
+ state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
+
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+ beta1, beta2 = group['betas']
+
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value = 1 - beta2)
+ exp_avg.mul_(beta1).add_(grad, alpha = 1 - beta1)
+
+ state['step'] += 1
+ buffered = group['buffer'][int(state['step'] % 10)]
+ if state['step'] == buffered[0]:
+ N_sma, step_size = buffered[1], buffered[2]
+ else:
+ buffered[0] = state['step']
+ beta2_t = beta2 ** state['step']
+ N_sma_max = 2 / (1 - beta2) - 1
+ N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
+ buffered[1] = N_sma
+
+ # more conservative since it's an approximated value
+ if N_sma >= 5:
+ step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
+ elif self.degenerated_to_sgd:
+ step_size = 1.0 / (1 - beta1 ** state['step'])
+ else:
+ step_size = -1
+ buffered[2] = step_size
+
+ # more conservative since it's an approximated value
+ if N_sma >= 5:
+ if group['weight_decay'] != 0:
+ p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
+ p_data_fp32.addcdiv_(exp_avg, denom, value = -step_size * group['lr'])
+ p.data.copy_(p_data_fp32)
+ elif step_size > 0:
+ if group['weight_decay'] != 0:
+ p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
+ p_data_fp32.add_(-step_size * group['lr'], exp_avg)
+ p.data.copy_(p_data_fp32)
+
+ return loss \ No newline at end of file
diff --git a/scripts/utils/plot_utils.py b/scripts/utils/plot_utils.py
new file mode 100644
index 0000000..5cb20de
--- /dev/null
+++ b/scripts/utils/plot_utils.py
@@ -0,0 +1,428 @@
+import torch as th
+from torch import nn
+import numpy as np
+import cv2
+from einops import rearrange, repeat
+import matplotlib.pyplot as plt
+import PIL
+from model.utils.nn_utils import Gaus2D, Vector2D
+
+def preprocess(tensor, scale=1, normalize=False, mean_std_normalize=False):
+
+ if tensor is None:
+ return None
+
+ if normalize:
+ min_ = th.min(tensor)
+ max_ = th.max(tensor)
+ tensor = (tensor - min_) / (max_ - min_)
+
+ if mean_std_normalize:
+ mean = th.mean(tensor)
+ std = th.std(tensor)
+ tensor = th.clip((tensor - mean) / (2 * std), -1, 1) * 0.5 + 0.5
+
+ if scale > 1:
+ upsample = nn.Upsample(scale_factor=scale).to(tensor[0].device)
+ tensor = upsample(tensor)
+
+ return tensor
+
+def preprocess_multi(*args, scale):
+ return [preprocess(a, scale) for a in args]
+
+def color_mask(mask):
+
+ colors = th.tensor([
+ [ 255, 0, 0 ],
+ [ 0, 0, 255 ],
+ [ 255, 255, 0 ],
+ [ 255, 0, 255 ],
+ [ 0, 255, 255 ],
+ [ 0, 255, 0 ],
+ [ 255, 128, 0 ],
+ [ 128, 255, 0 ],
+ [ 128, 0, 255 ],
+ [ 255, 0, 128 ],
+ [ 0, 255, 128 ],
+ [ 0, 128, 255 ],
+ [ 255, 128, 128 ],
+ [ 128, 255, 128 ],
+ [ 128, 128, 255 ],
+ [ 255, 128, 128 ],
+ [ 128, 255, 128 ],
+ [ 128, 128, 255 ],
+ [ 255, 128, 255 ],
+ [ 128, 255, 255 ],
+ [ 128, 255, 255 ],
+ [ 255, 255, 128 ],
+ [ 255, 255, 128 ],
+ [ 255, 128, 255 ],
+ [ 128, 0, 0 ],
+ [ 0, 0, 128 ],
+ [ 128, 128, 0 ],
+ [ 128, 0, 128 ],
+ [ 0, 128, 128 ],
+ [ 0, 128, 0 ],
+ [ 128, 128, 0 ],
+ [ 128, 128, 0 ],
+ [ 128, 0, 128 ],
+ [ 128, 0, 128 ],
+ [ 0, 128, 128 ],
+ [ 0, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ ], device = mask.device) / 255.0
+
+ colors = colors.view(1, -1, 3, 1, 1)
+ mask = mask.unsqueeze(dim=2)
+
+ return th.sum(colors[:,:mask.shape[1]] * mask, dim=1)
+
+def get_color(o):
+ colors = th.tensor([
+ [ 255, 0, 0 ],
+ [ 0, 0, 255 ],
+ [ 255, 255, 0 ],
+ [ 255, 0, 255 ],
+ [ 0, 255, 255 ],
+ [ 0, 255, 0 ],
+ [ 255, 128, 0 ],
+ [ 128, 255, 0 ],
+ [ 128, 0, 255 ],
+ [ 255, 0, 128 ],
+ [ 0, 255, 128 ],
+ [ 0, 128, 255 ],
+ [ 255, 128, 128 ],
+ [ 128, 255, 128 ],
+ [ 128, 128, 255 ],
+ [ 255, 128, 128 ],
+ [ 128, 255, 128 ],
+ [ 128, 128, 255 ],
+ [ 255, 128, 255 ],
+ [ 128, 255, 255 ],
+ [ 128, 255, 255 ],
+ [ 255, 255, 128 ],
+ [ 255, 255, 128 ],
+ [ 255, 128, 255 ],
+ [ 128, 0, 0 ],
+ [ 0, 0, 128 ],
+ [ 128, 128, 0 ],
+ [ 128, 0, 128 ],
+ [ 0, 128, 128 ],
+ [ 0, 128, 0 ],
+ [ 128, 128, 0 ],
+ [ 128, 128, 0 ],
+ [ 128, 0, 128 ],
+ [ 128, 0, 128 ],
+ [ 0, 128, 128 ],
+ [ 0, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ ]) / 255.0
+
+ colors = colors.view(48,3)
+ return colors[o]
+
+def to_rgb(tensor: th.Tensor):
+ return th.cat((
+ tensor * 0.6 + 0.4,
+ tensor,
+ tensor
+ ), dim=1)
+
+def visualise_gate(gate, h, w):
+ bar = th.ones((1,h,w), device=gate.device) * 0.9
+ black = int(w*gate.item())
+ if black > 0:
+ bar[:,:, -black:] = 0
+ return bar
+
+def get_highlighted_input(input, mask_cur):
+
+ # highlight error
+ highlighted_input = input
+ if mask_cur is not None:
+ grayscale = input[:,0:1] * 0.299 + input[:,1:2] * 0.587 + input[:,2:3] * 0.114
+ object_mask_cur = th.sum(mask_cur[:,:-1], dim=1).unsqueeze(dim=1)
+ highlighted_input = grayscale * (1 - object_mask_cur)
+ highlighted_input += grayscale * object_mask_cur * 0.3333333
+ cmask = color_mask(mask_cur[:,:-1])
+ highlighted_input = highlighted_input + cmask * 0.6666666
+
+ return highlighted_input
+
+def color_slots(image, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur):
+
+ image = (1-image) * slots_bounded + image * (1-slots_bounded)
+ image = th.clip(image - 0.3, 0,1) * slots_partially_occluded_cur + image * (1-slots_partially_occluded_cur)
+ image = th.clip(image - 0.3, 0,1) * slots_occluded_cur + image * (1-slots_occluded_cur)
+
+ return image
+
+def compute_occlusion_mask(rawmask_cur, rawmask_next, mask_cur, mask_next, scale):
+
+ # compute occlusion mask
+ occluded_cur = th.clip(rawmask_cur - mask_cur, 0, 1)[:,:-1]
+ occluded_next = th.clip(rawmask_next - mask_next, 0, 1)[:,:-1]
+
+ # to rgb
+ rawmask_cur = repeat(rawmask_cur[:,:-1], 'b o h w -> b (o 3) h w')
+ rawmask_next = repeat(rawmask_next[:,:-1], 'b o h w -> b (o 3) h w')
+
+ # scale
+ occluded_next = preprocess(occluded_next, scale)
+ occluded_cur = preprocess(occluded_cur, scale)
+ rawmask_cur = preprocess(rawmask_cur, scale)
+ rawmask_next = preprocess(rawmask_next, scale)
+
+ # set occlusion to red
+ rawmask_cur = rearrange(rawmask_cur, 'b (o c) h w -> b o c h w', c = 3)
+ rawmask_cur[:,:,0] = rawmask_cur[:,:,0] * (1 - occluded_next)
+ rawmask_cur[:,:,1] = rawmask_cur[:,:,1] * (1 - occluded_next)
+
+ rawmask_next = rearrange(rawmask_next, 'b (o c) h w -> b o c h w', c = 3)
+ rawmask_next[:,:,0] = rawmask_next[:,:,0] * (1 - occluded_next)
+ rawmask_next[:,:,1] = rawmask_next[:,:,1] * (1 - occluded_next)
+
+ return rawmask_cur, rawmask_next
+
+def plot_online_error_slots(errors, error_name, target, sequence_len, root_path, visibilty_memory, slots_bounded, ylim=0.3):
+ error_plots = []
+ if len(errors) > 0:
+ num_slots = int(th.sum(slots_bounded).item())
+ errors = rearrange(np.array(errors), '(l o) -> o l', o=len(slots_bounded))[:num_slots]
+ visibilty_memory = rearrange(np.array(visibilty_memory), '(l o) -> o l', o=len(slots_bounded))[:num_slots]
+ for error,visibility in zip(errors, visibilty_memory):
+
+ if len(error) < sequence_len:
+ fig, ax = plt.subplots(figsize=(round(target.shape[3]/100,2), round(target.shape[2]/100,2)))
+ plt.plot(error, label=error_name)
+
+ visibility = np.concatenate((visibility, np.ones(sequence_len-len(error))))
+ ax.fill_between(range(sequence_len), 0, 1, where=visibility==0, color='orange', alpha=0.3, transform=ax.get_xaxis_transform())
+
+ plt.xlim((0,sequence_len))
+ plt.ylim((0,ylim))
+ fig.tight_layout()
+ plt.savefig(f'{root_path}/tmp.jpg')
+
+ error_plot = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb())
+ error_plot = th.from_numpy(np.array(error_plot).transpose(2,0,1))
+ plt.close(fig)
+
+ error_plots.append(error_plot)
+
+ return error_plots
+
+def plot_online_error(error, error_name, target, t, i, sequence_len, root_path, online_surprise = False):
+
+ fig = plt.figure(figsize=( round(target.shape[3]/50,2), round(target.shape[2]/50,2) ))
+ plt.plot(error, label=error_name)
+
+ if online_surprise:
+ # compute moving average of error
+ moving_average_length = 10
+ if t > moving_average_length:
+ moving_average_length += 1
+ average_error = np.mean(error[-moving_average_length:-1])
+ current_sd = np.std(error[-moving_average_length:-1])
+ current_error = error[-1]
+
+ if current_error > average_error + 2 * current_sd:
+ fig.set_facecolor('orange')
+
+ plt.xlim((0,sequence_len))
+ plt.legend()
+ # increase title size
+ plt.title(f'{error_name}', fontsize=20)
+ plt.xlabel('timestep')
+ plt.ylabel('error')
+ plt.savefig(f'{root_path}/tmp.jpg')
+
+ error_plot = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb())
+ error_plot = th.from_numpy(np.array(error_plot).transpose(2,0,1))
+ plt.close(fig)
+
+ return error_plot
+
+def plot_object_view(error_plot, error_plot2, error_plot_slots, error_plot_slots2, highlighted_input, output_hidden, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object):
+
+ # add ground truth positions of objects to image
+ if gt_positions_target_next is not None:
+ for o in range(gt_positions_target_next.shape[1]):
+ position = gt_positions_target_next[0, o]
+ position = position/2 + 0.5
+
+ if position[2] > 0.0 and position[0] > 0.0 and position[0] < 1.0 and position[1] > 0.0 and position[1] < 1.0:
+ width = 5
+ w = np.clip(int(position[0]*target.shape[2]), width, target.shape[2]-width).item()
+ h = np.clip(int(position[1]*target.shape[3]), width, target.shape[3]-width).item()
+ col = get_color(o).view(3,1,1)
+ target[0,:,(w-width):(w+width), (h-width):(h+width)] = col
+
+ # add these positions to the associated slots velocity_next2d ilustration
+ slots = (association_table[0] == o).nonzero()
+ for s in slots.flatten():
+ velocity_next2d[s,:,(w-width):(w+width), (h-width):(h+width)] = col
+
+ if output_hidden is not None and s != largest_object:
+ output_hidden[0,:,(w-width):(w+width), (h-width):(h+width)] = col
+
+ gateheight = 60
+ ch = 40
+ gh = 40
+ gh_bar = gh-20
+ gh_margin = int((gh-gh_bar)/2)
+ margin = 20
+ slots_margin = 10
+ height = size[0] * 6 + 18*5
+ width = size[1] * 4 + 18*2 + size[1]*num_objects + 6*(num_objects+1) + slots_margin*(num_objects+1)
+ img = th.ones((3, height, width), device = object_next.device) * 0.4
+ row = (lambda row_index: [2*size[0]*row_index + (row_index+1)*margin, 2*size[0]*(row_index+1) + (row_index+1)*margin])
+ col1 = range(margin, margin + size[1]*2)
+ col2 = range(width-(margin+size[1]*2), width-margin)
+
+ img[:,row(0)[0]:row(0)[1], col1] = preprocess(highlighted_input.to(object_next.device), 2)[0]
+ img[:,row(1)[0]:row(1)[1], col1] = preprocess(output_hidden.to(object_next.device), 2)[0]
+ img[:,row(2)[0]:row(2)[1], col1] = preprocess(target.to(object_next.device), 2)[0]
+
+ if error_plot is not None:
+ img[:,row(0)[0]+gh+ch+2*margin-gh_margin:row(0)[1]+gh+ch+2*margin-gh_margin, col2] = preprocess(error_plot.to(object_next.device), normalize= True)
+ if error_plot2 is not None:
+ img[:,row(2)[0]:row(2)[1], col2] = preprocess(error_plot2.to(object_next.device), normalize= True)
+
+ for o in range(num_objects):
+
+ col = 18+size[1]*2+6+o*(6+size[1])+(o+1)*slots_margin
+ col = range(col, col + size[1])
+
+ # color bar for the gate
+ if (error_plot_slots2 is not None) and len(error_plot_slots2) > o:
+ img[:,margin:margin+ch, col] = get_color(o).view(3,1,1).to(object_next.device)
+
+ img[:,margin+ch+2*margin:2*margin+gh_bar+ch+margin, col] = visualise_gate(slots_closed[:,o, 0].to(object_next.device), h=gh_bar, w=len(col))
+ offset = gh+margin-gh_margin+ch+2*margin
+ row = (lambda row_index: [offset+(size[0]+6)*row_index, offset+size[0]*(row_index+1)+6*row_index])
+
+ img[:,row(0)[0]:row(0)[1], col] = preprocess(rawmask_next[0,o].to(object_next.device))
+ img[:,row(1)[0]:row(1)[1], col] = preprocess(object_next[:,o].to(object_next.device))
+ if (error_plot_slots2 is not None) and len(error_plot_slots2) > o:
+ img[:,row(2)[0]:row(2)[1], col] = preprocess(error_plot_slots2[o].to(object_next.device), normalize=True)
+
+ offset = margin*2-8
+ row = (lambda row_index: [offset+(size[0]+6)*row_index, offset+size[0]*(row_index+1)+6*row_index])
+ img[:,row(4)[0]-gh+gh_margin:row(4)[0]-gh_margin, col] = visualise_gate(slots_closed[:,o, 1].to(object_next.device), h=gh_bar, w=len(col))
+ img[:,row(4)[0]:row(4)[1], col] = preprocess(velocity_next2d[o].to(object_next.device), normalize=True)[0]
+ if (error_plot_slots is not None) and len(error_plot_slots) > o:
+ img[:,row(5)[0]:row(5)[1], col] = preprocess(error_plot_slots[o].to(object_next.device), normalize=True)
+
+ img = rearrange(img * 255, 'c h w -> h w c').cpu().numpy()
+
+ return img
+
+def write_image(file, img):
+ img = rearrange(img * 255, 'c h w -> h w c').cpu().numpy()
+ cv2.imwrite(file, img)
+
+ pass
+
+def plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i):
+
+ # Create eposition helpers
+ size, gaus2d, vector2d, scale = get_position_helper(cfg_net, mask_cur.device)
+
+ # Compute plot content
+ highlighted_input = get_highlighted_input(input, mask_cur)
+ output = th.clip(output_next, 0, 1)
+ position_cur2d = gaus2d(rearrange(position_encoder_cur, 'b (o c) -> (b o) c', o=cfg_net.num_objects))
+ velocity_next2d = vector2d(rearrange(position_next, 'b (o c) -> (b o) c', o=cfg_net.num_objects))
+
+ # color slots
+ slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next = reshape_slots(slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next)
+ position_cur2d = color_slots(position_cur2d, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur)
+ velocity_next2d = color_slots(velocity_next2d, slots_bounded, slots_partially_occluded_next, slots_occluded_next)
+
+ # compute occlusion
+ if (cfg.datatype == "adept"):
+ rawmask_cur_l, rawmask_next_l = compute_occlusion_mask(rawmask_cur, rawmask_next, mask_cur, mask_next, scale)
+ rawmask_cur_h, rawmask_next_h = compute_occlusion_mask(rawmask_cur, rawmask_hidden, mask_cur, mask_next, scale)
+ rawmask_cur_h[:,largest_object] = rawmask_cur_l[:,largest_object]
+ rawmask_next_h[:,largest_object] = rawmask_next_l[:,largest_object]
+ rawmask_cur = rawmask_cur_h
+ rawmask_next = rawmask_next_h
+ object_hidden[:, largest_object] = object_next[:, largest_object]
+ object_next = object_hidden
+ else:
+ rawmask_cur, rawmask_next = compute_occlusion_mask(rawmask_cur, rawmask_next, mask_cur, mask_next, scale)
+
+ # scale plot content
+ input, target, output, highlighted_input, object_next, object_cur, mask_next, error_next, output_hidden, output_next = preprocess_multi(input, target, output, highlighted_input, object_next, object_cur, mask_next, error_next, output_hidden, output_next, scale=scale)
+
+ # reshape
+ object_next = rearrange(object_next, 'b (o c) h w -> b o c h w', c = cfg_net.img_channels)
+ object_cur = rearrange(object_cur, 'b (o c) h w -> b o c h w', c = cfg_net.img_channels)
+ mask_next = rearrange(mask_next, 'b (o 1) h w -> b o 1 h w')
+
+ if object_view:
+ if (cfg.datatype == "adept"):
+ num_objects = 4
+ error_plot_slots = plot_online_error_slots(statistics_complete_slots['TE'][-cfg_net.num_objects*(t+1):], 'Tracking error', target, sequence_len, root_path, statistics_complete_slots['visible'][-cfg_net.num_objects*(t+1):], slots_bounded)
+ error_plot_slots2 = plot_online_error_slots(statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], 'Image error', target, sequence_len, root_path, statistics_complete_slots['visible'][-cfg_net.num_objects*(t+1):], slots_bounded, ylim=0.0001)
+ error_plot = plot_online_error(statistics_batch['image_error'], 'Prediction error', target, t, i, sequence_len, root_path)
+ error_plot2 = plot_online_error(statistics_batch['TE'], 'Tracking error', target, t, i, sequence_len, root_path)
+ img = plot_object_view(error_plot, error_plot2, error_plot_slots, error_plot_slots2, highlighted_input, output_hidden, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object)
+ else:
+ num_objects = cfg_net.num_objects
+ error_plot_slots2 = plot_online_error_slots(statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], 'Image error', target, sequence_len, root_path, statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], slots_bounded, ylim=0.0001)
+ error_plot = plot_online_error(statistics_batch['image_error_mse'], 'Prediction error', target, t, i, sequence_len, root_path)
+ img = plot_object_view(error_plot, None, None, error_plot_slots2, highlighted_input, output_next, object_next, rawmask_next, velocity_next2d, target, slots_closed, None, None, size, num_objects, largest_object)
+
+ cv2.imwrite(f'{plot_path}object/gpnet-objects-{i:04d}-{t_index:03d}.jpg', img)
+
+ if individual_views:
+ # ['error', 'input', 'background', 'prediction', 'position', 'rawmask', 'mask', 'othermask']:
+ write_image(f'{plot_path}/individual/error/error-{i:04d}-{t_index:03d}.jpg', error_next[0])
+ write_image(f'{plot_path}/individual/input/input-{i:04d}-{t_index:03d}.jpg', input[0])
+ write_image(f'{plot_path}/individual/background/background-{i:04d}-{t_index:03d}.jpg', mask_next[0,-1])
+ write_image(f'{plot_path}/individual/imagination/imagination-{i:04d}-{t_index:03d}.jpg', output_hidden[0])
+ write_image(f'{plot_path}/individual/prediction/prediction-{i:04d}-{t_index:03d}.jpg', output_next[0])
+
+ pass
+
+def get_position_helper(cfg_net, device):
+ size = cfg_net.input_size
+ gaus2d = Gaus2D(size).to(device)
+ vector2d = Vector2D(size).to(device)
+ scale = size[0] // (cfg_net.latent_size[0] * 2**(cfg_net.level*2))
+ return size,gaus2d,vector2d,scale
+
+def reshape_slots(slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next):
+
+ slots_bounded = th.squeeze(slots_bounded)[..., None,None,None]
+ slots_partially_occluded_cur = th.squeeze(slots_partially_occluded_cur)[..., None,None,None]
+ slots_occluded_cur = th.squeeze(slots_occluded_cur)[..., None,None,None]
+ slots_partially_occluded_next = th.squeeze(slots_partially_occluded_next)[..., None,None,None]
+ slots_occluded_next = th.squeeze(slots_occluded_next)[..., None,None,None]
+
+ return slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next \ No newline at end of file