aboutsummaryrefslogtreecommitdiff
path: root/scripts/utils/eval_metrics.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/utils/eval_metrics.py')
-rw-r--r--scripts/utils/eval_metrics.py51
1 files changed, 44 insertions, 7 deletions
diff --git a/scripts/utils/eval_metrics.py b/scripts/utils/eval_metrics.py
index 6f5106d..98af468 100644
--- a/scripts/utils/eval_metrics.py
+++ b/scripts/utils/eval_metrics.py
@@ -3,6 +3,8 @@ vp_utils.py
SCRIPT TAKEN FROM https://github.com/pairlab/SlotFormer
'''
+import cv2
+from einops import rearrange
import numpy as np
from scipy.optimize import linear_sum_assignment
from skimage.metrics import structural_similarity, peak_signal_noise_ratio
@@ -11,7 +13,7 @@ import torch
import torch.nn.functional as F
import torchvision.ops as vops
-FG_THRE = 0.5
+FG_THRE = 0.3 # 0.5 for the rest, 0.3 for bouncingballs
PALETTE = [(0, 255, 0), (0, 0, 255), (0, 255, 255), (255, 255, 0),
(255, 0, 255), (100, 100, 255), (200, 200, 100), (170, 120, 200),
(255, 0, 0), (200, 100, 100), (10, 200, 100), (200, 200, 200),
@@ -272,6 +274,8 @@ def pred_eval_step(
gt_bbox=None,
pred_bbox=None,
eval_traj=True,
+ gt_mask_hidden=None,
+ pred_mask_hidden=None,
):
"""Both of shape [B, T, C, H, W], torch.Tensor.
masks of shape [B, T, H, W].
@@ -288,13 +292,14 @@ def pred_eval_step(
if eval_traj:
assert len(gt_mask.shape) == len(pred_mask.shape) == 4
assert gt_mask.shape == pred_mask.shape
- if eval_traj:
+ if eval_traj and gt_bbox is not None:
assert len(gt_pres_mask.shape) == 3
assert len(gt_bbox.shape) == len(pred_bbox.shape) == 4
T = gt.shape[1]
# compute perceptual dist & mask metrics before converting to numpy
all_percept_dist, all_ari, all_fari, all_miou = [], [], [], []
+ all_ari_hidden, all_fari_hidden, all_miou_hidden = [], [], []
for t in range(T):
one_gt, one_pred = gt[:, t], pred[:, t]
percept_dist = perceptual_dist(one_gt, one_pred, lpips_fn).item()
@@ -307,6 +312,29 @@ def pred_eval_step(
all_ari.append(ari)
all_fari.append(fari)
all_miou.append(miou)
+
+ # hidden:
+ if gt_mask_hidden is not None:
+ one_gt_mask, one_pred_mask = gt_mask_hidden[:, t], pred_mask_hidden[:, t]
+ ari = ARI_metric(one_gt_mask, one_pred_mask)
+ fari = fARI_metric(one_gt_mask, one_pred_mask)
+ miou = miou_metric(one_gt_mask, one_pred_mask)
+ all_ari_hidden.append(ari)
+ all_fari_hidden.append(fari)
+ all_miou_hidden.append(miou)
+
+
+ # dispalay masks with cv2
+ if False:
+ one_gt_mask_, one_pred_mask_ = one_gt_mask.clone(), one_pred_mask.clone()
+
+ # concat one_gt_mask and one_pred_mask horizontally
+ frame = np.concatenate((one_gt_mask_, one_pred_mask_), axis=1)
+ frame = rearrange(frame, 'c h w -> h w c') * (255/6)
+ frame = frame.astype(np.uint8)
+ cv2.imshow('frame', frame)
+ cv2.waitKey(0)
+
else:
all_ari.append(0.)
all_fari.append(0.)
@@ -315,14 +343,14 @@ def pred_eval_step(
# compute bbox metrics
all_ap, all_ar = [], []
for t in range(T):
- if not eval_traj:
+ if not eval_traj or gt_bbox is None:
all_ap.append(0.)
all_ar.append(0.)
continue
one_gt_pres_mask, one_gt_bbox, one_pred_bbox = \
gt_pres_mask[:, t], gt_bbox[:, t], pred_bbox[:, t]
ap, ar = batch_bbox_precision_recall(one_gt_pres_mask, one_gt_bbox,
- one_pred_bbox)
+ one_pred_bbox)
all_ap.append(ap)
all_ar.append(ar)
@@ -333,11 +361,13 @@ def pred_eval_step(
one_gt, one_pred = gt[:, t], pred[:, t]
mse = mse_metric(one_gt, one_pred)
psnr = psnr_metric(one_gt, one_pred)
- ssim = ssim_metric(one_gt, one_pred)
+ #ssim = ssim_metric(one_gt, one_pred)
+ ssim = 0
all_mse.append(mse)
all_ssim.append(ssim)
all_psnr.append(psnr)
- return {
+
+ res = {
'mse': all_mse,
'ssim': all_ssim,
'psnr': all_psnr,
@@ -346,5 +376,12 @@ def pred_eval_step(
'fari': all_fari,
'miou': all_miou,
'ap': all_ap,
- 'ar': all_ar,
+ 'ar': all_ar
}
+
+ if gt_mask_hidden is not None:
+ res['ari_hidden'] = all_ari_hidden
+ res['fari_hidden'] = all_fari_hidden
+ res['miou_hidden'] = all_miou_hidden
+
+ return res