diff options
author | fredeee | 2023-11-02 10:47:21 +0100 |
---|---|---|
committer | fredeee | 2023-11-02 10:47:21 +0100 |
commit | f8302ee886ef9b631f11a52900dac964a61350e1 (patch) | |
tree | 87288be6f851ab69405e524b81940c501c52789a /scripts/utils | |
parent | f16fef1ab9371e1c81a2e0b2fbea59dee285a9f8 (diff) |
initiaƶ commit
Diffstat (limited to 'scripts/utils')
-rw-r--r-- | scripts/utils/configuration.py | 83 | ||||
-rw-r--r-- | scripts/utils/eval_metrics.py | 350 | ||||
-rw-r--r-- | scripts/utils/eval_utils.py | 78 | ||||
-rw-r--r-- | scripts/utils/io.py | 137 | ||||
-rw-r--r-- | scripts/utils/optimizers.py | 100 | ||||
-rw-r--r-- | scripts/utils/plot_utils.py | 428 |
6 files changed, 1176 insertions, 0 deletions
diff --git a/scripts/utils/configuration.py b/scripts/utils/configuration.py new file mode 100644 index 0000000..7c38095 --- /dev/null +++ b/scripts/utils/configuration.py @@ -0,0 +1,83 @@ +import json +import jsmin +import os + +class Dict(dict): + """ + Dictionary that allows to access per attributes and to except names from being loaded + """ + def __init__(self, dictionary: dict = None): + super(Dict, self).__init__() + + if dictionary is not None: + self.load(dictionary) + + def __getattr__(self, item): + try: + return self[item] if item in self else getattr(super(Dict, self), item) + except AttributeError: + raise AttributeError(f'This dictionary has no attribute "{item}"') + + def load(self, dictionary: dict, name_list: list = None): + """ + Loads a dictionary + :param dictionary: Dictionary to be loaded + :param name_list: List of names to be updated + """ + for name in dictionary: + data = dictionary[name] + if name_list is None or name in name_list: + if isinstance(data, dict): + if name in self: + self[name].load(data) + else: + self[name] = Dict(data) + elif isinstance(data, list): + self[name] = list() + for item in data: + if isinstance(item, dict): + self[name].append(Dict(item)) + else: + self[name].append(item) + else: + self[name] = data + + def save(self, path): + """ + Saves the dictionary into a json file + :param path: Path of the json file + """ + os.makedirs(path, exist_ok=True) + + path = os.path.join(path, 'cfg.json') + + with open(path, 'w') as file: + json.dump(self, file, indent=True) + + +class Configuration(Dict): + """ + Configuration loaded from a json file + """ + def __init__(self, path: str, default_path=None): + super(Configuration, self).__init__() + + if default_path is not None: + self.load(default_path) + + self.load(path) + + def load_model(self, path: str): + self.load(path, name_list=["model"]) + + def load(self, path: str, name_list: list = None): + """ + Loads attributes from a json file + :param path: Path of the json file + :param name_list: List of names to be updated + :return: + """ + with open(path) as file: + data = json.loads(jsmin.jsmin(file.read())) + + super(Configuration, self).load(data, name_list) diff --git a/scripts/utils/eval_metrics.py b/scripts/utils/eval_metrics.py new file mode 100644 index 0000000..6f5106d --- /dev/null +++ b/scripts/utils/eval_metrics.py @@ -0,0 +1,350 @@ +''' +vp_utils.py +SCRIPT TAKEN FROM https://github.com/pairlab/SlotFormer +''' + +import numpy as np +from scipy.optimize import linear_sum_assignment +from skimage.metrics import structural_similarity, peak_signal_noise_ratio + +import torch +import torch.nn.functional as F +import torchvision.ops as vops + +FG_THRE = 0.5 +PALETTE = [(0, 255, 0), (0, 0, 255), (0, 255, 255), (255, 255, 0), + (255, 0, 255), (100, 100, 255), (200, 200, 100), (170, 120, 200), + (255, 0, 0), (200, 100, 100), (10, 200, 100), (200, 200, 200), + (50, 50, 50)] +PALETTE_np = np.array(PALETTE, dtype=np.uint8) +PALETTE_torch = torch.from_numpy(PALETTE_np).float() / 255. * 2. - 1. + +def to_rgb_from_tensor(x): + """Reverse the Normalize operation in torchvision.""" + return (x * 0.5 + 0.5).clamp(0, 1) + +def postproc_mask(batch_masks): + """Post-process masks instead of directly taking argmax. + + Args: + batch_masks: [B, T, N, 1, H, W] + + Returns: + masks: [B, T, H, W] + """ + batch_masks = batch_masks.clone() + B, T, N, _, H, W = batch_masks.shape + batch_masks = batch_masks.reshape(B * T, N, H * W) + slots_max = batch_masks.max(-1)[0] # [B*T, N] + bg_idx = slots_max.argmin(-1) # [B*T] + spatial_max = batch_masks.max(1)[0] # [B*T, H*W] + bg_mask = (spatial_max < FG_THRE) # [B*T, H*W] + idx_mask = torch.zeros((B * T, N)).type_as(bg_mask) + idx_mask[torch.arange(B * T), bg_idx] = True + # set the background mask score to 1 + batch_masks[idx_mask.unsqueeze(-1) * bg_mask.unsqueeze(1)] = 1. + masks = batch_masks.argmax(1) # [B*T, H*W] + return masks.reshape(B, T, H, W) + + +def masks_to_boxes_w_empty_mask(binary_masks): + """binary_masks: [B, H, W].""" + B = binary_masks.shape[0] + obj_mask = (binary_masks.sum([-1, -2]) > 0) # [B] + bboxes = torch.ones((B, 4)).float().to(binary_masks.device) * -1. + bboxes[obj_mask] = vops.masks_to_boxes(binary_masks[obj_mask]) + return bboxes + + +def masks_to_boxes(masks, num_boxes=7): + """Convert seg_masks to bboxes. + + Args: + masks: [B, T, H, W], output after taking argmax + num_boxes: number of boxes to generate (num_slots) + + Returns: + bboxes: [B, T, N, 4], 4: [x1, y1, x2, y2] + """ + B, T, H, W = masks.shape + binary_masks = F.one_hot(masks, num_classes=num_boxes) # [B, T, H, W, N] + binary_masks = binary_masks.permute(0, 1, 4, 2, 3) # [B, T, N, H, W] + binary_masks = binary_masks.contiguous().flatten(0, 2) + bboxes = masks_to_boxes_w_empty_mask(binary_masks) # [B*T*N, 4] + bboxes = bboxes.reshape(B, T, num_boxes, 4) # [B, T, N, 4] + return bboxes + + +def mse_metric(x, y): + """x/y: [B, C, H, W]""" + # people often sum over spatial dimensions in video prediction MSE + # see e.g. https://github.com/Yunbo426/predrnn-pp/blob/40764c52f433290aa02414b5f25c09c72b98e0af/train.py#L244 + return ((x - y)**2).sum(-1).sum(-1).mean() + + +def psnr_metric(x, y): + """x/y: [B, C, H, W]""" + psnrs = [ + peak_signal_noise_ratio( + x[i], + y[i], + data_range=1., + ) for i in range(x.shape[0]) + ] + return np.mean(psnrs) + + +def ssim_metric(x, y): + """x/y: [B, C, H, W]""" + x = x * 255. + y = y * 255. + ssims = [ + structural_similarity( + x[i], + y[i], + channel_axis=0, + gaussian_weights=True, + sigma=1.5, + use_sample_covariance=False, + data_range=255, + ) for i in range(x.shape[0]) + ] + return np.mean(ssims) + + +def perceptual_dist(x, y, loss_fn): + """x/y: [B, C, H, W]""" + return loss_fn(x, y).mean() + + +def adjusted_rand_index(true_ids, pred_ids, ignore_background=False): + """Computes the adjusted Rand index (ARI), a clustering similarity score. + + Code borrowed from https://github.com/google-research/slot-attention-video/blob/e8ab54620d0f1934b332ddc09f1dba7bc07ff601/savi/lib/metrics.py#L111 + + Args: + true_ids: An integer-valued array of shape + [batch_size, seq_len, H, W]. The true cluster assignment encoded + as integer ids. + pred_ids: An integer-valued array of shape + [batch_size, seq_len, H, W]. The predicted cluster assignment + encoded as integer ids. + ignore_background: Boolean, if True, then ignore all pixels where + true_ids == 0 (default: False). + + Returns: + ARI scores as a float32 array of shape [batch_size]. + """ + if len(true_ids.shape) == 3: + true_ids = true_ids.unsqueeze(1) + if len(pred_ids.shape) == 3: + pred_ids = pred_ids.unsqueeze(1) + + true_oh = F.one_hot(true_ids).float() + pred_oh = F.one_hot(pred_ids).float() + + if ignore_background: + true_oh = true_oh[..., 1:] # Remove the background row. + + N = torch.einsum("bthwc,bthwk->bck", true_oh, pred_oh) + A = torch.sum(N, dim=-1) # row-sum (batch_size, c) + B = torch.sum(N, dim=-2) # col-sum (batch_size, k) + num_points = torch.sum(A, dim=1) + + rindex = torch.sum(N * (N - 1), dim=[1, 2]) + aindex = torch.sum(A * (A - 1), dim=1) + bindex = torch.sum(B * (B - 1), dim=1) + expected_rindex = aindex * bindex / torch.clamp( + num_points * (num_points - 1), min=1) + max_rindex = (aindex + bindex) / 2 + denominator = max_rindex - expected_rindex + ari = (rindex - expected_rindex) / denominator + + # There are two cases for which the denominator can be zero: + # 1. If both label_pred and label_true assign all pixels to a single cluster. + # (max_rindex == expected_rindex == rindex == num_points * (num_points-1)) + # 2. If both label_pred and label_true assign max 1 point to each cluster. + # (max_rindex == expected_rindex == rindex == 0) + # In both cases, we want the ARI score to be 1.0: + return torch.where(denominator != 0, ari, torch.tensor(1.).type_as(ari)) + + +def ARI_metric(x, y): + """x/y: [B, H, W], both are seg_masks after argmax.""" + assert 'int' in str(x.dtype) + assert 'int' in str(y.dtype) + return adjusted_rand_index(x, y).mean().item() + + +def fARI_metric(x, y): + """x/y: [B, H, W], both are seg_masks after argmax.""" + assert 'int' in str(x.dtype) + assert 'int' in str(y.dtype) + return adjusted_rand_index(x, y, ignore_background=True).mean().item() + + +def bbox_precision_recall(gt_pres_mask, gt_bbox, pred_bbox, ovthresh=0.5): + """Compute the precision of predicted bounding boxes. + + Args: + gt_pres_mask: A boolean tensor of shape [N] + gt_bbox: A tensor of shape [N, 4] + pred_bbox: A tensor of shape [M, 4] + """ + gt_bbox, pred_bbox = gt_bbox.clone(), pred_bbox.clone() + gt_bbox = gt_bbox[gt_pres_mask.bool()] + pred_bbox = pred_bbox[pred_bbox[:, 0] >= 0.] + N, M = gt_bbox.shape[0], pred_bbox.shape[0] + assert gt_bbox.shape[1] == pred_bbox.shape[1] == 4 + # assert M >= N + tp, fp = 0, 0 + bbox_used = [False] * pred_bbox.shape[0] + bbox_ious = vops.box_iou(gt_bbox, pred_bbox) # [N, M] + + # Find the best iou match for each ground truth bbox. + for i in range(N): + best_iou_idx = bbox_ious[i].argmax().item() + best_iou = bbox_ious[i, best_iou_idx].item() + if best_iou >= ovthresh and not bbox_used[best_iou_idx]: + tp += 1 + bbox_used[best_iou_idx] = True + else: + fp += 1 + + # compute precision and recall + precision = tp / float(M) + recall = tp / float(N) + return precision, recall + + +def batch_bbox_precision_recall(gt_pres_mask, gt_bbox, pred_bbox): + """Compute the precision of predicted bounding boxes over batch.""" + aps, ars = [], [] + for i in range(gt_pres_mask.shape[0]): + ap, ar = bbox_precision_recall(gt_pres_mask[i], gt_bbox[i], + pred_bbox[i]) + aps.append(ap) + ars.append(ar) + return np.mean(aps), np.mean(ars) + + +def hungarian_miou(gt_mask, pred_mask): + """both mask: [H*W] after argmax, 0 is gt background index.""" + true_oh = F.one_hot(gt_mask).float()[..., 1:] # only foreground, [HW, N] + pred_oh = F.one_hot(pred_mask).float() # [HW, M] + N, M = true_oh.shape[-1], pred_oh.shape[-1] + # compute all pairwise IoU + intersect = (true_oh[:, :, None] * pred_oh[:, None, :]).sum(0) # [N, M] + union = true_oh.sum(0)[:, None] + pred_oh.sum(0)[None, :] # [N, M] + iou = intersect / (union + 1e-8) # [N, M] + iou = iou.detach().cpu().numpy() + # find the best match for each gt + row_ind, col_ind = linear_sum_assignment(iou, maximize=True) + # there are two possibilities here + # 1. M >= N, just take the best match mean + # 2. M < N, some objects are not detected, their iou is 0 + if M >= N: + assert (row_ind == np.arange(N)).all() + return iou[row_ind, col_ind].mean() + return iou[row_ind, col_ind].sum() / float(N) + + +def miou_metric(gt_mask, pred_mask): + """both mask: [B, H, W], both are seg_masks after argmax.""" + assert 'int' in str(gt_mask.dtype) + assert 'int' in str(pred_mask.dtype) + gt_mask, pred_mask = gt_mask.flatten(1, 2), pred_mask.flatten(1, 2) + ious = [ + hungarian_miou(gt_mask[i], pred_mask[i]) + for i in range(gt_mask.shape[0]) + ] + return np.mean(ious) + + +@torch.no_grad() +def pred_eval_step( + gt, + pred, + lpips_fn, + gt_mask=None, + pred_mask=None, + gt_pres_mask=None, + gt_bbox=None, + pred_bbox=None, + eval_traj=True, +): + """Both of shape [B, T, C, H, W], torch.Tensor. + masks of shape [B, T, H, W]. + pres_mask of shape [B, T, N]. + bboxes of shape [B, T, N/M, 4]. + + eval_traj: whether to evaluate the trajectory (measured by bbox and mask). + + Compute metrics for every timestep. + """ + assert len(gt.shape) == len(pred.shape) == 5 + assert gt.shape == pred.shape + assert gt.shape[2] == 3 + if eval_traj: + assert len(gt_mask.shape) == len(pred_mask.shape) == 4 + assert gt_mask.shape == pred_mask.shape + if eval_traj: + assert len(gt_pres_mask.shape) == 3 + assert len(gt_bbox.shape) == len(pred_bbox.shape) == 4 + T = gt.shape[1] + + # compute perceptual dist & mask metrics before converting to numpy + all_percept_dist, all_ari, all_fari, all_miou = [], [], [], [] + for t in range(T): + one_gt, one_pred = gt[:, t], pred[:, t] + percept_dist = perceptual_dist(one_gt, one_pred, lpips_fn).item() + all_percept_dist.append(percept_dist) + if eval_traj: + one_gt_mask, one_pred_mask = gt_mask[:, t], pred_mask[:, t] + ari = ARI_metric(one_gt_mask, one_pred_mask) + fari = fARI_metric(one_gt_mask, one_pred_mask) + miou = miou_metric(one_gt_mask, one_pred_mask) + all_ari.append(ari) + all_fari.append(fari) + all_miou.append(miou) + else: + all_ari.append(0.) + all_fari.append(0.) + all_miou.append(0.) + + # compute bbox metrics + all_ap, all_ar = [], [] + for t in range(T): + if not eval_traj: + all_ap.append(0.) + all_ar.append(0.) + continue + one_gt_pres_mask, one_gt_bbox, one_pred_bbox = \ + gt_pres_mask[:, t], gt_bbox[:, t], pred_bbox[:, t] + ap, ar = batch_bbox_precision_recall(one_gt_pres_mask, one_gt_bbox, + one_pred_bbox) + all_ap.append(ap) + all_ar.append(ar) + + gt = to_rgb_from_tensor(gt).cpu().numpy() + pred = to_rgb_from_tensor(pred).cpu().numpy() + all_mse, all_ssim, all_psnr = [], [], [] + for t in range(T): + one_gt, one_pred = gt[:, t], pred[:, t] + mse = mse_metric(one_gt, one_pred) + psnr = psnr_metric(one_gt, one_pred) + ssim = ssim_metric(one_gt, one_pred) + all_mse.append(mse) + all_ssim.append(ssim) + all_psnr.append(psnr) + return { + 'mse': all_mse, + 'ssim': all_ssim, + 'psnr': all_psnr, + 'percept_dist': all_percept_dist, + 'ari': all_ari, + 'fari': all_fari, + 'miou': all_miou, + 'ap': all_ap, + 'ar': all_ar, + } diff --git a/scripts/utils/eval_utils.py b/scripts/utils/eval_utils.py new file mode 100644 index 0000000..faab7ec --- /dev/null +++ b/scripts/utils/eval_utils.py @@ -0,0 +1,78 @@ +import os +import torch as th +from model.loci import Loci + +def setup_result_folders(file, name, set_test, evaluation_mode, object_view, individual_views): + + net_name = file.split('/')[-1].split('.')[0] + #root_path = file.split('nets')[0] + root_path = os.path.join(*file.split('/')[0:-1]) + root_path = os.path.join(root_path, f'results{name}', net_name, set_test['type']) + plot_path = os.path.join(root_path, evaluation_mode) + + # create directories + #if os.path.exists(plot_path): + # shutil.rmtree(plot_path) + os.makedirs(plot_path, exist_ok = True) + if object_view: + os.makedirs(os.path.join(plot_path, 'object'), exist_ok = True) + if individual_views: + os.makedirs(os.path.join(plot_path, 'individual'), exist_ok = True) + for group in ['error', 'input', 'background', 'prediction', 'position', 'rawmask', 'mask', 'othermask', 'imagination']: + os.makedirs(os.path.join(plot_path, 'individual', group), exist_ok = True) + os.makedirs(os.path.join(root_path, 'statistics'), exist_ok = True) + + # final directory + plot_path = plot_path + '/' + print(f"save plots to {plot_path}") + + return root_path, plot_path + +def store_statistics(memory, *args, extend=False): + for i,key in enumerate(memory.keys()): + if i >= len(args): + break + if extend: + memory[key].extend(args[i]) + else: + memory[key].append(args[i]) + return memory + +def append_statistics(memory1, memory2, ignore=[], extend=False): + for key in memory1: + if key not in ignore: + if extend: + memory2[key] = memory2[key] + memory1[key] + else: + memory2[key].append(memory1[key]) + return memory2 + +def load_model(cfg, cfg_net, file, device): + + net = Loci( + cfg_net, + teacher_forcing = cfg.defaults.teacher_forcing + ) + + # load model + if file != '': + print(f"load {file} to device {device}") + state = th.load(file, map_location=device) + + # backward compatibility + model = {} + for key, value in state["model"].items(): + # replace update_module with percept_gate_controller in key string: + key = key.replace("update_module", "percept_gate_controller") + model[key.replace(".module.", ".")] = value + + net.load_state_dict(model) + + # ??? + if net.get_init_status() < 1: + net.inc_init_level() + + # set network to evaluation mode + net = net.to(device=device) + + return net
\ No newline at end of file diff --git a/scripts/utils/io.py b/scripts/utils/io.py new file mode 100644 index 0000000..9bd8158 --- /dev/null +++ b/scripts/utils/io.py @@ -0,0 +1,137 @@ +import os +from scripts.utils.configuration import Configuration +import time +import torch as th +import numpy as np +from einops import rearrange, repeat, reduce + + +def init_device(cfg): + print(f'Cuda available: {th.cuda.is_available()} Cuda count: {th.cuda.device_count()}') + if th.cuda.is_available(): + device = th.device("cuda:0") + verbose = False + cfg.device = "cuda:0" + cfg.model.device = "cuda:0" + print('!!! USING GPU !!!') + else: + device = th.device("cpu") + verbose = True + cfg.device = "cpu" + cfg.model.device = "cpu" + cfg.model.batch_size = 2 + cfg.defaults.teacher_forcing = 4 + print('!!! USING CPU !!!') + return device,verbose + +class Timer: + + def __init__(self): + self.last = time.time() + self.passed = 0 + self.sum = 0 + + def __str__(self): + self.passed = self.passed * 0.99 + time.time() - self.last + self.sum = self.sum * 0.99 + 1 + passed = self.passed / self.sum + self.last = time.time() + + if passed > 1: + return f"{passed:.2f}s/it" + + return f"{1.0/passed:.2f}it/s" + +class UEMA: + + def __init__(self, memory = 100): + self.value = 0 + self.sum = 1e-30 + self.decay = np.exp(-1 / memory) + + def update(self, value): + self.value = self.value * self.decay + value + self.sum = self.sum * self.decay + 1 + + def __float__(self): + return self.value / self.sum + + +def model_path(cfg: Configuration, overwrite=False, move_old=True): + """ + Makes the model path, option to not overwrite + :param cfg: Configuration file with the model path + :param overwrite: Overwrites the files in the directory, else makes a new directory + :param move_old: Moves old folder with the same name to an old folder, if not overwrite + :return: Model path + """ + _path = os.path.join('out') + path = os.path.join(_path, cfg.model_path) + + if not os.path.exists(_path): + os.makedirs(_path) + + if not overwrite: + if move_old: + # Moves existing directory to an old folder + if os.path.exists(path): + old_path = os.path.join(_path, f'{cfg.model_path}_old') + if not os.path.exists(old_path): + os.makedirs(old_path) + _old_path = os.path.join(old_path, cfg.model_path) + i = 0 + while os.path.exists(_old_path): + i = i + 1 + _old_path = os.path.join(old_path, f'{cfg.model_path}_{i}') + os.renames(path, _old_path) + else: + # Increases number after directory name for each new path + i = 0 + while os.path.exists(path): + i = i + 1 + path = os.path.join(_path, f'{cfg.model_path}_{i}') + + return path + +class LossLogger: + + def __init__(self): + + self.avgloss = UEMA() + self.avg_position_loss = UEMA() + self.avg_time_loss = UEMA() + self.avg_encoder_loss = UEMA() + self.avg_mse_object_loss = UEMA() + self.avg_long_mse_object_loss = UEMA(33333) + self.avg_num_objects = UEMA() + self.avg_openings = UEMA() + self.avg_gestalt = UEMA() + self.avg_gestalt2 = UEMA() + self.avg_gestalt_mean = UEMA() + self.avg_update_gestalt = UEMA() + self.avg_update_position = UEMA() + + + def update_complete(self, avg_position_loss, avg_time_loss, avg_encoder_loss, avg_mse_object_loss, avg_long_mse_object_loss, avg_num_objects, avg_openings, avg_gestalt, avg_gestalt2, avg_gestalt_mean, avg_update_gestalt, avg_update_position): + + self.avg_position_loss.update(avg_position_loss.item()) + self.avg_time_loss.update(avg_time_loss.item()) + self.avg_encoder_loss.update(avg_encoder_loss.item()) + self.avg_mse_object_loss.update(avg_mse_object_loss.item()) + self.avg_long_mse_object_loss.update(avg_long_mse_object_loss.item()) + self.avg_num_objects.update(avg_num_objects) + self.avg_openings.update(avg_openings) + self.avg_gestalt.update(avg_gestalt.item()) + self.avg_gestalt2.update(avg_gestalt2.item()) + self.avg_gestalt_mean.update(avg_gestalt_mean.item()) + self.avg_update_gestalt.update(avg_update_gestalt.item()) + self.avg_update_position.update(avg_update_position.item()) + pass + + def update_average_loss(self, avgloss): + self.avgloss.update(avgloss) + pass + + def get_log(self): + info = f'Loss: {np.abs(float(self.avgloss)):.2e}|{float(self.avg_mse_object_loss):.2e}|{float(self.avg_long_mse_object_loss):.2e}, reg: {float(self.avg_encoder_loss):.2e}|{float(self.avg_time_loss):.2e}|{float(self.avg_position_loss):.2e}, obj: {float(self.avg_num_objects):.1f}, open: {float(self.avg_openings):.2e}|{float(self.avg_gestalt):.2f}, bin: {float(self.avg_gestalt_mean):.2e}|{np.sqrt(float(self.avg_gestalt2) - float(self.avg_gestalt)**2):.2e} closed: {float(self.avg_update_gestalt):.2e}|{float(self.avg_update_position):.2e}' + return info diff --git a/scripts/utils/optimizers.py b/scripts/utils/optimizers.py new file mode 100644 index 0000000..49283b3 --- /dev/null +++ b/scripts/utils/optimizers.py @@ -0,0 +1,100 @@ +import math +import torch as th +import numpy as np +from torch.optim.optimizer import Optimizer + +""" +Liyuan Liu , Haoming Jiang, Pengcheng He, Weizhu Chen, Xiaodong Liu, Jianfeng Gao, and Jiawei Han (2020). +On the Variance of the Adaptive Learning Rate and Beyond. the Eighth International Conference on Learning +Representations. +""" +class RAdam(Optimizer): + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, weight_decay=0, degenerated_to_sgd=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + self.degenerated_to_sgd = degenerated_to_sgd + if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): + for param in params: + if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): + param['buffer'] = [[None, None, None] for _ in range(10)] + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]) + super(RAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(RAdam, self).__setstate__(state) + + def step(self, closure=None): + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError('RAdam does not support sparse gradients') + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = th.zeros_like(p_data_fp32) + state['exp_avg_sq'] = th.zeros_like(p_data_fp32) + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value = 1 - beta2) + exp_avg.mul_(beta1).add_(grad, alpha = 1 - beta1) + + state['step'] += 1 + buffered = group['buffer'][int(state['step'] % 10)] + if state['step'] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) + elif self.degenerated_to_sgd: + step_size = 1.0 / (1 - beta1 ** state['step']) + else: + step_size = -1 + buffered[2] = step_size + + # more conservative since it's an approximated value + if N_sma >= 5: + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_data_fp32.addcdiv_(exp_avg, denom, value = -step_size * group['lr']) + p.data.copy_(p_data_fp32) + elif step_size > 0: + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + p_data_fp32.add_(-step_size * group['lr'], exp_avg) + p.data.copy_(p_data_fp32) + + return loss
\ No newline at end of file diff --git a/scripts/utils/plot_utils.py b/scripts/utils/plot_utils.py new file mode 100644 index 0000000..5cb20de --- /dev/null +++ b/scripts/utils/plot_utils.py @@ -0,0 +1,428 @@ +import torch as th +from torch import nn +import numpy as np +import cv2 +from einops import rearrange, repeat +import matplotlib.pyplot as plt +import PIL +from model.utils.nn_utils import Gaus2D, Vector2D + +def preprocess(tensor, scale=1, normalize=False, mean_std_normalize=False): + + if tensor is None: + return None + + if normalize: + min_ = th.min(tensor) + max_ = th.max(tensor) + tensor = (tensor - min_) / (max_ - min_) + + if mean_std_normalize: + mean = th.mean(tensor) + std = th.std(tensor) + tensor = th.clip((tensor - mean) / (2 * std), -1, 1) * 0.5 + 0.5 + + if scale > 1: + upsample = nn.Upsample(scale_factor=scale).to(tensor[0].device) + tensor = upsample(tensor) + + return tensor + +def preprocess_multi(*args, scale): + return [preprocess(a, scale) for a in args] + +def color_mask(mask): + + colors = th.tensor([ + [ 255, 0, 0 ], + [ 0, 0, 255 ], + [ 255, 255, 0 ], + [ 255, 0, 255 ], + [ 0, 255, 255 ], + [ 0, 255, 0 ], + [ 255, 128, 0 ], + [ 128, 255, 0 ], + [ 128, 0, 255 ], + [ 255, 0, 128 ], + [ 0, 255, 128 ], + [ 0, 128, 255 ], + [ 255, 128, 128 ], + [ 128, 255, 128 ], + [ 128, 128, 255 ], + [ 255, 128, 128 ], + [ 128, 255, 128 ], + [ 128, 128, 255 ], + [ 255, 128, 255 ], + [ 128, 255, 255 ], + [ 128, 255, 255 ], + [ 255, 255, 128 ], + [ 255, 255, 128 ], + [ 255, 128, 255 ], + [ 128, 0, 0 ], + [ 0, 0, 128 ], + [ 128, 128, 0 ], + [ 128, 0, 128 ], + [ 0, 128, 128 ], + [ 0, 128, 0 ], + [ 128, 128, 0 ], + [ 128, 128, 0 ], + [ 128, 0, 128 ], + [ 128, 0, 128 ], + [ 0, 128, 128 ], + [ 0, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + ], device = mask.device) / 255.0 + + colors = colors.view(1, -1, 3, 1, 1) + mask = mask.unsqueeze(dim=2) + + return th.sum(colors[:,:mask.shape[1]] * mask, dim=1) + +def get_color(o): + colors = th.tensor([ + [ 255, 0, 0 ], + [ 0, 0, 255 ], + [ 255, 255, 0 ], + [ 255, 0, 255 ], + [ 0, 255, 255 ], + [ 0, 255, 0 ], + [ 255, 128, 0 ], + [ 128, 255, 0 ], + [ 128, 0, 255 ], + [ 255, 0, 128 ], + [ 0, 255, 128 ], + [ 0, 128, 255 ], + [ 255, 128, 128 ], + [ 128, 255, 128 ], + [ 128, 128, 255 ], + [ 255, 128, 128 ], + [ 128, 255, 128 ], + [ 128, 128, 255 ], + [ 255, 128, 255 ], + [ 128, 255, 255 ], + [ 128, 255, 255 ], + [ 255, 255, 128 ], + [ 255, 255, 128 ], + [ 255, 128, 255 ], + [ 128, 0, 0 ], + [ 0, 0, 128 ], + [ 128, 128, 0 ], + [ 128, 0, 128 ], + [ 0, 128, 128 ], + [ 0, 128, 0 ], + [ 128, 128, 0 ], + [ 128, 128, 0 ], + [ 128, 0, 128 ], + [ 128, 0, 128 ], + [ 0, 128, 128 ], + [ 0, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + ]) / 255.0 + + colors = colors.view(48,3) + return colors[o] + +def to_rgb(tensor: th.Tensor): + return th.cat(( + tensor * 0.6 + 0.4, + tensor, + tensor + ), dim=1) + +def visualise_gate(gate, h, w): + bar = th.ones((1,h,w), device=gate.device) * 0.9 + black = int(w*gate.item()) + if black > 0: + bar[:,:, -black:] = 0 + return bar + +def get_highlighted_input(input, mask_cur): + + # highlight error + highlighted_input = input + if mask_cur is not None: + grayscale = input[:,0:1] * 0.299 + input[:,1:2] * 0.587 + input[:,2:3] * 0.114 + object_mask_cur = th.sum(mask_cur[:,:-1], dim=1).unsqueeze(dim=1) + highlighted_input = grayscale * (1 - object_mask_cur) + highlighted_input += grayscale * object_mask_cur * 0.3333333 + cmask = color_mask(mask_cur[:,:-1]) + highlighted_input = highlighted_input + cmask * 0.6666666 + + return highlighted_input + +def color_slots(image, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur): + + image = (1-image) * slots_bounded + image * (1-slots_bounded) + image = th.clip(image - 0.3, 0,1) * slots_partially_occluded_cur + image * (1-slots_partially_occluded_cur) + image = th.clip(image - 0.3, 0,1) * slots_occluded_cur + image * (1-slots_occluded_cur) + + return image + +def compute_occlusion_mask(rawmask_cur, rawmask_next, mask_cur, mask_next, scale): + + # compute occlusion mask + occluded_cur = th.clip(rawmask_cur - mask_cur, 0, 1)[:,:-1] + occluded_next = th.clip(rawmask_next - mask_next, 0, 1)[:,:-1] + + # to rgb + rawmask_cur = repeat(rawmask_cur[:,:-1], 'b o h w -> b (o 3) h w') + rawmask_next = repeat(rawmask_next[:,:-1], 'b o h w -> b (o 3) h w') + + # scale + occluded_next = preprocess(occluded_next, scale) + occluded_cur = preprocess(occluded_cur, scale) + rawmask_cur = preprocess(rawmask_cur, scale) + rawmask_next = preprocess(rawmask_next, scale) + + # set occlusion to red + rawmask_cur = rearrange(rawmask_cur, 'b (o c) h w -> b o c h w', c = 3) + rawmask_cur[:,:,0] = rawmask_cur[:,:,0] * (1 - occluded_next) + rawmask_cur[:,:,1] = rawmask_cur[:,:,1] * (1 - occluded_next) + + rawmask_next = rearrange(rawmask_next, 'b (o c) h w -> b o c h w', c = 3) + rawmask_next[:,:,0] = rawmask_next[:,:,0] * (1 - occluded_next) + rawmask_next[:,:,1] = rawmask_next[:,:,1] * (1 - occluded_next) + + return rawmask_cur, rawmask_next + +def plot_online_error_slots(errors, error_name, target, sequence_len, root_path, visibilty_memory, slots_bounded, ylim=0.3): + error_plots = [] + if len(errors) > 0: + num_slots = int(th.sum(slots_bounded).item()) + errors = rearrange(np.array(errors), '(l o) -> o l', o=len(slots_bounded))[:num_slots] + visibilty_memory = rearrange(np.array(visibilty_memory), '(l o) -> o l', o=len(slots_bounded))[:num_slots] + for error,visibility in zip(errors, visibilty_memory): + + if len(error) < sequence_len: + fig, ax = plt.subplots(figsize=(round(target.shape[3]/100,2), round(target.shape[2]/100,2))) + plt.plot(error, label=error_name) + + visibility = np.concatenate((visibility, np.ones(sequence_len-len(error)))) + ax.fill_between(range(sequence_len), 0, 1, where=visibility==0, color='orange', alpha=0.3, transform=ax.get_xaxis_transform()) + + plt.xlim((0,sequence_len)) + plt.ylim((0,ylim)) + fig.tight_layout() + plt.savefig(f'{root_path}/tmp.jpg') + + error_plot = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb()) + error_plot = th.from_numpy(np.array(error_plot).transpose(2,0,1)) + plt.close(fig) + + error_plots.append(error_plot) + + return error_plots + +def plot_online_error(error, error_name, target, t, i, sequence_len, root_path, online_surprise = False): + + fig = plt.figure(figsize=( round(target.shape[3]/50,2), round(target.shape[2]/50,2) )) + plt.plot(error, label=error_name) + + if online_surprise: + # compute moving average of error + moving_average_length = 10 + if t > moving_average_length: + moving_average_length += 1 + average_error = np.mean(error[-moving_average_length:-1]) + current_sd = np.std(error[-moving_average_length:-1]) + current_error = error[-1] + + if current_error > average_error + 2 * current_sd: + fig.set_facecolor('orange') + + plt.xlim((0,sequence_len)) + plt.legend() + # increase title size + plt.title(f'{error_name}', fontsize=20) + plt.xlabel('timestep') + plt.ylabel('error') + plt.savefig(f'{root_path}/tmp.jpg') + + error_plot = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb()) + error_plot = th.from_numpy(np.array(error_plot).transpose(2,0,1)) + plt.close(fig) + + return error_plot + +def plot_object_view(error_plot, error_plot2, error_plot_slots, error_plot_slots2, highlighted_input, output_hidden, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object): + + # add ground truth positions of objects to image + if gt_positions_target_next is not None: + for o in range(gt_positions_target_next.shape[1]): + position = gt_positions_target_next[0, o] + position = position/2 + 0.5 + + if position[2] > 0.0 and position[0] > 0.0 and position[0] < 1.0 and position[1] > 0.0 and position[1] < 1.0: + width = 5 + w = np.clip(int(position[0]*target.shape[2]), width, target.shape[2]-width).item() + h = np.clip(int(position[1]*target.shape[3]), width, target.shape[3]-width).item() + col = get_color(o).view(3,1,1) + target[0,:,(w-width):(w+width), (h-width):(h+width)] = col + + # add these positions to the associated slots velocity_next2d ilustration + slots = (association_table[0] == o).nonzero() + for s in slots.flatten(): + velocity_next2d[s,:,(w-width):(w+width), (h-width):(h+width)] = col + + if output_hidden is not None and s != largest_object: + output_hidden[0,:,(w-width):(w+width), (h-width):(h+width)] = col + + gateheight = 60 + ch = 40 + gh = 40 + gh_bar = gh-20 + gh_margin = int((gh-gh_bar)/2) + margin = 20 + slots_margin = 10 + height = size[0] * 6 + 18*5 + width = size[1] * 4 + 18*2 + size[1]*num_objects + 6*(num_objects+1) + slots_margin*(num_objects+1) + img = th.ones((3, height, width), device = object_next.device) * 0.4 + row = (lambda row_index: [2*size[0]*row_index + (row_index+1)*margin, 2*size[0]*(row_index+1) + (row_index+1)*margin]) + col1 = range(margin, margin + size[1]*2) + col2 = range(width-(margin+size[1]*2), width-margin) + + img[:,row(0)[0]:row(0)[1], col1] = preprocess(highlighted_input.to(object_next.device), 2)[0] + img[:,row(1)[0]:row(1)[1], col1] = preprocess(output_hidden.to(object_next.device), 2)[0] + img[:,row(2)[0]:row(2)[1], col1] = preprocess(target.to(object_next.device), 2)[0] + + if error_plot is not None: + img[:,row(0)[0]+gh+ch+2*margin-gh_margin:row(0)[1]+gh+ch+2*margin-gh_margin, col2] = preprocess(error_plot.to(object_next.device), normalize= True) + if error_plot2 is not None: + img[:,row(2)[0]:row(2)[1], col2] = preprocess(error_plot2.to(object_next.device), normalize= True) + + for o in range(num_objects): + + col = 18+size[1]*2+6+o*(6+size[1])+(o+1)*slots_margin + col = range(col, col + size[1]) + + # color bar for the gate + if (error_plot_slots2 is not None) and len(error_plot_slots2) > o: + img[:,margin:margin+ch, col] = get_color(o).view(3,1,1).to(object_next.device) + + img[:,margin+ch+2*margin:2*margin+gh_bar+ch+margin, col] = visualise_gate(slots_closed[:,o, 0].to(object_next.device), h=gh_bar, w=len(col)) + offset = gh+margin-gh_margin+ch+2*margin + row = (lambda row_index: [offset+(size[0]+6)*row_index, offset+size[0]*(row_index+1)+6*row_index]) + + img[:,row(0)[0]:row(0)[1], col] = preprocess(rawmask_next[0,o].to(object_next.device)) + img[:,row(1)[0]:row(1)[1], col] = preprocess(object_next[:,o].to(object_next.device)) + if (error_plot_slots2 is not None) and len(error_plot_slots2) > o: + img[:,row(2)[0]:row(2)[1], col] = preprocess(error_plot_slots2[o].to(object_next.device), normalize=True) + + offset = margin*2-8 + row = (lambda row_index: [offset+(size[0]+6)*row_index, offset+size[0]*(row_index+1)+6*row_index]) + img[:,row(4)[0]-gh+gh_margin:row(4)[0]-gh_margin, col] = visualise_gate(slots_closed[:,o, 1].to(object_next.device), h=gh_bar, w=len(col)) + img[:,row(4)[0]:row(4)[1], col] = preprocess(velocity_next2d[o].to(object_next.device), normalize=True)[0] + if (error_plot_slots is not None) and len(error_plot_slots) > o: + img[:,row(5)[0]:row(5)[1], col] = preprocess(error_plot_slots[o].to(object_next.device), normalize=True) + + img = rearrange(img * 255, 'c h w -> h w c').cpu().numpy() + + return img + +def write_image(file, img): + img = rearrange(img * 255, 'c h w -> h w c').cpu().numpy() + cv2.imwrite(file, img) + + pass + +def plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i): + + # Create eposition helpers + size, gaus2d, vector2d, scale = get_position_helper(cfg_net, mask_cur.device) + + # Compute plot content + highlighted_input = get_highlighted_input(input, mask_cur) + output = th.clip(output_next, 0, 1) + position_cur2d = gaus2d(rearrange(position_encoder_cur, 'b (o c) -> (b o) c', o=cfg_net.num_objects)) + velocity_next2d = vector2d(rearrange(position_next, 'b (o c) -> (b o) c', o=cfg_net.num_objects)) + + # color slots + slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next = reshape_slots(slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next) + position_cur2d = color_slots(position_cur2d, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur) + velocity_next2d = color_slots(velocity_next2d, slots_bounded, slots_partially_occluded_next, slots_occluded_next) + + # compute occlusion + if (cfg.datatype == "adept"): + rawmask_cur_l, rawmask_next_l = compute_occlusion_mask(rawmask_cur, rawmask_next, mask_cur, mask_next, scale) + rawmask_cur_h, rawmask_next_h = compute_occlusion_mask(rawmask_cur, rawmask_hidden, mask_cur, mask_next, scale) + rawmask_cur_h[:,largest_object] = rawmask_cur_l[:,largest_object] + rawmask_next_h[:,largest_object] = rawmask_next_l[:,largest_object] + rawmask_cur = rawmask_cur_h + rawmask_next = rawmask_next_h + object_hidden[:, largest_object] = object_next[:, largest_object] + object_next = object_hidden + else: + rawmask_cur, rawmask_next = compute_occlusion_mask(rawmask_cur, rawmask_next, mask_cur, mask_next, scale) + + # scale plot content + input, target, output, highlighted_input, object_next, object_cur, mask_next, error_next, output_hidden, output_next = preprocess_multi(input, target, output, highlighted_input, object_next, object_cur, mask_next, error_next, output_hidden, output_next, scale=scale) + + # reshape + object_next = rearrange(object_next, 'b (o c) h w -> b o c h w', c = cfg_net.img_channels) + object_cur = rearrange(object_cur, 'b (o c) h w -> b o c h w', c = cfg_net.img_channels) + mask_next = rearrange(mask_next, 'b (o 1) h w -> b o 1 h w') + + if object_view: + if (cfg.datatype == "adept"): + num_objects = 4 + error_plot_slots = plot_online_error_slots(statistics_complete_slots['TE'][-cfg_net.num_objects*(t+1):], 'Tracking error', target, sequence_len, root_path, statistics_complete_slots['visible'][-cfg_net.num_objects*(t+1):], slots_bounded) + error_plot_slots2 = plot_online_error_slots(statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], 'Image error', target, sequence_len, root_path, statistics_complete_slots['visible'][-cfg_net.num_objects*(t+1):], slots_bounded, ylim=0.0001) + error_plot = plot_online_error(statistics_batch['image_error'], 'Prediction error', target, t, i, sequence_len, root_path) + error_plot2 = plot_online_error(statistics_batch['TE'], 'Tracking error', target, t, i, sequence_len, root_path) + img = plot_object_view(error_plot, error_plot2, error_plot_slots, error_plot_slots2, highlighted_input, output_hidden, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object) + else: + num_objects = cfg_net.num_objects + error_plot_slots2 = plot_online_error_slots(statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], 'Image error', target, sequence_len, root_path, statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], slots_bounded, ylim=0.0001) + error_plot = plot_online_error(statistics_batch['image_error_mse'], 'Prediction error', target, t, i, sequence_len, root_path) + img = plot_object_view(error_plot, None, None, error_plot_slots2, highlighted_input, output_next, object_next, rawmask_next, velocity_next2d, target, slots_closed, None, None, size, num_objects, largest_object) + + cv2.imwrite(f'{plot_path}object/gpnet-objects-{i:04d}-{t_index:03d}.jpg', img) + + if individual_views: + # ['error', 'input', 'background', 'prediction', 'position', 'rawmask', 'mask', 'othermask']: + write_image(f'{plot_path}/individual/error/error-{i:04d}-{t_index:03d}.jpg', error_next[0]) + write_image(f'{plot_path}/individual/input/input-{i:04d}-{t_index:03d}.jpg', input[0]) + write_image(f'{plot_path}/individual/background/background-{i:04d}-{t_index:03d}.jpg', mask_next[0,-1]) + write_image(f'{plot_path}/individual/imagination/imagination-{i:04d}-{t_index:03d}.jpg', output_hidden[0]) + write_image(f'{plot_path}/individual/prediction/prediction-{i:04d}-{t_index:03d}.jpg', output_next[0]) + + pass + +def get_position_helper(cfg_net, device): + size = cfg_net.input_size + gaus2d = Gaus2D(size).to(device) + vector2d = Vector2D(size).to(device) + scale = size[0] // (cfg_net.latent_size[0] * 2**(cfg_net.level*2)) + return size,gaus2d,vector2d,scale + +def reshape_slots(slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next): + + slots_bounded = th.squeeze(slots_bounded)[..., None,None,None] + slots_partially_occluded_cur = th.squeeze(slots_partially_occluded_cur)[..., None,None,None] + slots_occluded_cur = th.squeeze(slots_occluded_cur)[..., None,None,None] + slots_partially_occluded_next = th.squeeze(slots_partially_occluded_next)[..., None,None,None] + slots_occluded_next = th.squeeze(slots_occluded_next)[..., None,None,None] + + return slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next
\ No newline at end of file |