aboutsummaryrefslogtreecommitdiff
path: root/scripts/utils/eval_metrics.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/utils/eval_metrics.py')
-rw-r--r--scripts/utils/eval_metrics.py350
1 files changed, 350 insertions, 0 deletions
diff --git a/scripts/utils/eval_metrics.py b/scripts/utils/eval_metrics.py
new file mode 100644
index 0000000..6f5106d
--- /dev/null
+++ b/scripts/utils/eval_metrics.py
@@ -0,0 +1,350 @@
+'''
+vp_utils.py
+SCRIPT TAKEN FROM https://github.com/pairlab/SlotFormer
+'''
+
+import numpy as np
+from scipy.optimize import linear_sum_assignment
+from skimage.metrics import structural_similarity, peak_signal_noise_ratio
+
+import torch
+import torch.nn.functional as F
+import torchvision.ops as vops
+
+FG_THRE = 0.5
+PALETTE = [(0, 255, 0), (0, 0, 255), (0, 255, 255), (255, 255, 0),
+ (255, 0, 255), (100, 100, 255), (200, 200, 100), (170, 120, 200),
+ (255, 0, 0), (200, 100, 100), (10, 200, 100), (200, 200, 200),
+ (50, 50, 50)]
+PALETTE_np = np.array(PALETTE, dtype=np.uint8)
+PALETTE_torch = torch.from_numpy(PALETTE_np).float() / 255. * 2. - 1.
+
+def to_rgb_from_tensor(x):
+ """Reverse the Normalize operation in torchvision."""
+ return (x * 0.5 + 0.5).clamp(0, 1)
+
+def postproc_mask(batch_masks):
+ """Post-process masks instead of directly taking argmax.
+
+ Args:
+ batch_masks: [B, T, N, 1, H, W]
+
+ Returns:
+ masks: [B, T, H, W]
+ """
+ batch_masks = batch_masks.clone()
+ B, T, N, _, H, W = batch_masks.shape
+ batch_masks = batch_masks.reshape(B * T, N, H * W)
+ slots_max = batch_masks.max(-1)[0] # [B*T, N]
+ bg_idx = slots_max.argmin(-1) # [B*T]
+ spatial_max = batch_masks.max(1)[0] # [B*T, H*W]
+ bg_mask = (spatial_max < FG_THRE) # [B*T, H*W]
+ idx_mask = torch.zeros((B * T, N)).type_as(bg_mask)
+ idx_mask[torch.arange(B * T), bg_idx] = True
+ # set the background mask score to 1
+ batch_masks[idx_mask.unsqueeze(-1) * bg_mask.unsqueeze(1)] = 1.
+ masks = batch_masks.argmax(1) # [B*T, H*W]
+ return masks.reshape(B, T, H, W)
+
+
+def masks_to_boxes_w_empty_mask(binary_masks):
+ """binary_masks: [B, H, W]."""
+ B = binary_masks.shape[0]
+ obj_mask = (binary_masks.sum([-1, -2]) > 0) # [B]
+ bboxes = torch.ones((B, 4)).float().to(binary_masks.device) * -1.
+ bboxes[obj_mask] = vops.masks_to_boxes(binary_masks[obj_mask])
+ return bboxes
+
+
+def masks_to_boxes(masks, num_boxes=7):
+ """Convert seg_masks to bboxes.
+
+ Args:
+ masks: [B, T, H, W], output after taking argmax
+ num_boxes: number of boxes to generate (num_slots)
+
+ Returns:
+ bboxes: [B, T, N, 4], 4: [x1, y1, x2, y2]
+ """
+ B, T, H, W = masks.shape
+ binary_masks = F.one_hot(masks, num_classes=num_boxes) # [B, T, H, W, N]
+ binary_masks = binary_masks.permute(0, 1, 4, 2, 3) # [B, T, N, H, W]
+ binary_masks = binary_masks.contiguous().flatten(0, 2)
+ bboxes = masks_to_boxes_w_empty_mask(binary_masks) # [B*T*N, 4]
+ bboxes = bboxes.reshape(B, T, num_boxes, 4) # [B, T, N, 4]
+ return bboxes
+
+
+def mse_metric(x, y):
+ """x/y: [B, C, H, W]"""
+ # people often sum over spatial dimensions in video prediction MSE
+ # see e.g. https://github.com/Yunbo426/predrnn-pp/blob/40764c52f433290aa02414b5f25c09c72b98e0af/train.py#L244
+ return ((x - y)**2).sum(-1).sum(-1).mean()
+
+
+def psnr_metric(x, y):
+ """x/y: [B, C, H, W]"""
+ psnrs = [
+ peak_signal_noise_ratio(
+ x[i],
+ y[i],
+ data_range=1.,
+ ) for i in range(x.shape[0])
+ ]
+ return np.mean(psnrs)
+
+
+def ssim_metric(x, y):
+ """x/y: [B, C, H, W]"""
+ x = x * 255.
+ y = y * 255.
+ ssims = [
+ structural_similarity(
+ x[i],
+ y[i],
+ channel_axis=0,
+ gaussian_weights=True,
+ sigma=1.5,
+ use_sample_covariance=False,
+ data_range=255,
+ ) for i in range(x.shape[0])
+ ]
+ return np.mean(ssims)
+
+
+def perceptual_dist(x, y, loss_fn):
+ """x/y: [B, C, H, W]"""
+ return loss_fn(x, y).mean()
+
+
+def adjusted_rand_index(true_ids, pred_ids, ignore_background=False):
+ """Computes the adjusted Rand index (ARI), a clustering similarity score.
+
+ Code borrowed from https://github.com/google-research/slot-attention-video/blob/e8ab54620d0f1934b332ddc09f1dba7bc07ff601/savi/lib/metrics.py#L111
+
+ Args:
+ true_ids: An integer-valued array of shape
+ [batch_size, seq_len, H, W]. The true cluster assignment encoded
+ as integer ids.
+ pred_ids: An integer-valued array of shape
+ [batch_size, seq_len, H, W]. The predicted cluster assignment
+ encoded as integer ids.
+ ignore_background: Boolean, if True, then ignore all pixels where
+ true_ids == 0 (default: False).
+
+ Returns:
+ ARI scores as a float32 array of shape [batch_size].
+ """
+ if len(true_ids.shape) == 3:
+ true_ids = true_ids.unsqueeze(1)
+ if len(pred_ids.shape) == 3:
+ pred_ids = pred_ids.unsqueeze(1)
+
+ true_oh = F.one_hot(true_ids).float()
+ pred_oh = F.one_hot(pred_ids).float()
+
+ if ignore_background:
+ true_oh = true_oh[..., 1:] # Remove the background row.
+
+ N = torch.einsum("bthwc,bthwk->bck", true_oh, pred_oh)
+ A = torch.sum(N, dim=-1) # row-sum (batch_size, c)
+ B = torch.sum(N, dim=-2) # col-sum (batch_size, k)
+ num_points = torch.sum(A, dim=1)
+
+ rindex = torch.sum(N * (N - 1), dim=[1, 2])
+ aindex = torch.sum(A * (A - 1), dim=1)
+ bindex = torch.sum(B * (B - 1), dim=1)
+ expected_rindex = aindex * bindex / torch.clamp(
+ num_points * (num_points - 1), min=1)
+ max_rindex = (aindex + bindex) / 2
+ denominator = max_rindex - expected_rindex
+ ari = (rindex - expected_rindex) / denominator
+
+ # There are two cases for which the denominator can be zero:
+ # 1. If both label_pred and label_true assign all pixels to a single cluster.
+ # (max_rindex == expected_rindex == rindex == num_points * (num_points-1))
+ # 2. If both label_pred and label_true assign max 1 point to each cluster.
+ # (max_rindex == expected_rindex == rindex == 0)
+ # In both cases, we want the ARI score to be 1.0:
+ return torch.where(denominator != 0, ari, torch.tensor(1.).type_as(ari))
+
+
+def ARI_metric(x, y):
+ """x/y: [B, H, W], both are seg_masks after argmax."""
+ assert 'int' in str(x.dtype)
+ assert 'int' in str(y.dtype)
+ return adjusted_rand_index(x, y).mean().item()
+
+
+def fARI_metric(x, y):
+ """x/y: [B, H, W], both are seg_masks after argmax."""
+ assert 'int' in str(x.dtype)
+ assert 'int' in str(y.dtype)
+ return adjusted_rand_index(x, y, ignore_background=True).mean().item()
+
+
+def bbox_precision_recall(gt_pres_mask, gt_bbox, pred_bbox, ovthresh=0.5):
+ """Compute the precision of predicted bounding boxes.
+
+ Args:
+ gt_pres_mask: A boolean tensor of shape [N]
+ gt_bbox: A tensor of shape [N, 4]
+ pred_bbox: A tensor of shape [M, 4]
+ """
+ gt_bbox, pred_bbox = gt_bbox.clone(), pred_bbox.clone()
+ gt_bbox = gt_bbox[gt_pres_mask.bool()]
+ pred_bbox = pred_bbox[pred_bbox[:, 0] >= 0.]
+ N, M = gt_bbox.shape[0], pred_bbox.shape[0]
+ assert gt_bbox.shape[1] == pred_bbox.shape[1] == 4
+ # assert M >= N
+ tp, fp = 0, 0
+ bbox_used = [False] * pred_bbox.shape[0]
+ bbox_ious = vops.box_iou(gt_bbox, pred_bbox) # [N, M]
+
+ # Find the best iou match for each ground truth bbox.
+ for i in range(N):
+ best_iou_idx = bbox_ious[i].argmax().item()
+ best_iou = bbox_ious[i, best_iou_idx].item()
+ if best_iou >= ovthresh and not bbox_used[best_iou_idx]:
+ tp += 1
+ bbox_used[best_iou_idx] = True
+ else:
+ fp += 1
+
+ # compute precision and recall
+ precision = tp / float(M)
+ recall = tp / float(N)
+ return precision, recall
+
+
+def batch_bbox_precision_recall(gt_pres_mask, gt_bbox, pred_bbox):
+ """Compute the precision of predicted bounding boxes over batch."""
+ aps, ars = [], []
+ for i in range(gt_pres_mask.shape[0]):
+ ap, ar = bbox_precision_recall(gt_pres_mask[i], gt_bbox[i],
+ pred_bbox[i])
+ aps.append(ap)
+ ars.append(ar)
+ return np.mean(aps), np.mean(ars)
+
+
+def hungarian_miou(gt_mask, pred_mask):
+ """both mask: [H*W] after argmax, 0 is gt background index."""
+ true_oh = F.one_hot(gt_mask).float()[..., 1:] # only foreground, [HW, N]
+ pred_oh = F.one_hot(pred_mask).float() # [HW, M]
+ N, M = true_oh.shape[-1], pred_oh.shape[-1]
+ # compute all pairwise IoU
+ intersect = (true_oh[:, :, None] * pred_oh[:, None, :]).sum(0) # [N, M]
+ union = true_oh.sum(0)[:, None] + pred_oh.sum(0)[None, :] # [N, M]
+ iou = intersect / (union + 1e-8) # [N, M]
+ iou = iou.detach().cpu().numpy()
+ # find the best match for each gt
+ row_ind, col_ind = linear_sum_assignment(iou, maximize=True)
+ # there are two possibilities here
+ # 1. M >= N, just take the best match mean
+ # 2. M < N, some objects are not detected, their iou is 0
+ if M >= N:
+ assert (row_ind == np.arange(N)).all()
+ return iou[row_ind, col_ind].mean()
+ return iou[row_ind, col_ind].sum() / float(N)
+
+
+def miou_metric(gt_mask, pred_mask):
+ """both mask: [B, H, W], both are seg_masks after argmax."""
+ assert 'int' in str(gt_mask.dtype)
+ assert 'int' in str(pred_mask.dtype)
+ gt_mask, pred_mask = gt_mask.flatten(1, 2), pred_mask.flatten(1, 2)
+ ious = [
+ hungarian_miou(gt_mask[i], pred_mask[i])
+ for i in range(gt_mask.shape[0])
+ ]
+ return np.mean(ious)
+
+
+@torch.no_grad()
+def pred_eval_step(
+ gt,
+ pred,
+ lpips_fn,
+ gt_mask=None,
+ pred_mask=None,
+ gt_pres_mask=None,
+ gt_bbox=None,
+ pred_bbox=None,
+ eval_traj=True,
+):
+ """Both of shape [B, T, C, H, W], torch.Tensor.
+ masks of shape [B, T, H, W].
+ pres_mask of shape [B, T, N].
+ bboxes of shape [B, T, N/M, 4].
+
+ eval_traj: whether to evaluate the trajectory (measured by bbox and mask).
+
+ Compute metrics for every timestep.
+ """
+ assert len(gt.shape) == len(pred.shape) == 5
+ assert gt.shape == pred.shape
+ assert gt.shape[2] == 3
+ if eval_traj:
+ assert len(gt_mask.shape) == len(pred_mask.shape) == 4
+ assert gt_mask.shape == pred_mask.shape
+ if eval_traj:
+ assert len(gt_pres_mask.shape) == 3
+ assert len(gt_bbox.shape) == len(pred_bbox.shape) == 4
+ T = gt.shape[1]
+
+ # compute perceptual dist & mask metrics before converting to numpy
+ all_percept_dist, all_ari, all_fari, all_miou = [], [], [], []
+ for t in range(T):
+ one_gt, one_pred = gt[:, t], pred[:, t]
+ percept_dist = perceptual_dist(one_gt, one_pred, lpips_fn).item()
+ all_percept_dist.append(percept_dist)
+ if eval_traj:
+ one_gt_mask, one_pred_mask = gt_mask[:, t], pred_mask[:, t]
+ ari = ARI_metric(one_gt_mask, one_pred_mask)
+ fari = fARI_metric(one_gt_mask, one_pred_mask)
+ miou = miou_metric(one_gt_mask, one_pred_mask)
+ all_ari.append(ari)
+ all_fari.append(fari)
+ all_miou.append(miou)
+ else:
+ all_ari.append(0.)
+ all_fari.append(0.)
+ all_miou.append(0.)
+
+ # compute bbox metrics
+ all_ap, all_ar = [], []
+ for t in range(T):
+ if not eval_traj:
+ all_ap.append(0.)
+ all_ar.append(0.)
+ continue
+ one_gt_pres_mask, one_gt_bbox, one_pred_bbox = \
+ gt_pres_mask[:, t], gt_bbox[:, t], pred_bbox[:, t]
+ ap, ar = batch_bbox_precision_recall(one_gt_pres_mask, one_gt_bbox,
+ one_pred_bbox)
+ all_ap.append(ap)
+ all_ar.append(ar)
+
+ gt = to_rgb_from_tensor(gt).cpu().numpy()
+ pred = to_rgb_from_tensor(pred).cpu().numpy()
+ all_mse, all_ssim, all_psnr = [], [], []
+ for t in range(T):
+ one_gt, one_pred = gt[:, t], pred[:, t]
+ mse = mse_metric(one_gt, one_pred)
+ psnr = psnr_metric(one_gt, one_pred)
+ ssim = ssim_metric(one_gt, one_pred)
+ all_mse.append(mse)
+ all_ssim.append(ssim)
+ all_psnr.append(psnr)
+ return {
+ 'mse': all_mse,
+ 'ssim': all_ssim,
+ 'psnr': all_psnr,
+ 'percept_dist': all_percept_dist,
+ 'ari': all_ari,
+ 'fari': all_fari,
+ 'miou': all_miou,
+ 'ap': all_ap,
+ 'ar': all_ar,
+ }