diff options
Diffstat (limited to 'scripts/utils/eval_metrics.py')
-rw-r--r-- | scripts/utils/eval_metrics.py | 51 |
1 files changed, 44 insertions, 7 deletions
diff --git a/scripts/utils/eval_metrics.py b/scripts/utils/eval_metrics.py index 6f5106d..98af468 100644 --- a/scripts/utils/eval_metrics.py +++ b/scripts/utils/eval_metrics.py @@ -3,6 +3,8 @@ vp_utils.py SCRIPT TAKEN FROM https://github.com/pairlab/SlotFormer ''' +import cv2 +from einops import rearrange import numpy as np from scipy.optimize import linear_sum_assignment from skimage.metrics import structural_similarity, peak_signal_noise_ratio @@ -11,7 +13,7 @@ import torch import torch.nn.functional as F import torchvision.ops as vops -FG_THRE = 0.5 +FG_THRE = 0.3 # 0.5 for the rest, 0.3 for bouncingballs PALETTE = [(0, 255, 0), (0, 0, 255), (0, 255, 255), (255, 255, 0), (255, 0, 255), (100, 100, 255), (200, 200, 100), (170, 120, 200), (255, 0, 0), (200, 100, 100), (10, 200, 100), (200, 200, 200), @@ -272,6 +274,8 @@ def pred_eval_step( gt_bbox=None, pred_bbox=None, eval_traj=True, + gt_mask_hidden=None, + pred_mask_hidden=None, ): """Both of shape [B, T, C, H, W], torch.Tensor. masks of shape [B, T, H, W]. @@ -288,13 +292,14 @@ def pred_eval_step( if eval_traj: assert len(gt_mask.shape) == len(pred_mask.shape) == 4 assert gt_mask.shape == pred_mask.shape - if eval_traj: + if eval_traj and gt_bbox is not None: assert len(gt_pres_mask.shape) == 3 assert len(gt_bbox.shape) == len(pred_bbox.shape) == 4 T = gt.shape[1] # compute perceptual dist & mask metrics before converting to numpy all_percept_dist, all_ari, all_fari, all_miou = [], [], [], [] + all_ari_hidden, all_fari_hidden, all_miou_hidden = [], [], [] for t in range(T): one_gt, one_pred = gt[:, t], pred[:, t] percept_dist = perceptual_dist(one_gt, one_pred, lpips_fn).item() @@ -307,6 +312,29 @@ def pred_eval_step( all_ari.append(ari) all_fari.append(fari) all_miou.append(miou) + + # hidden: + if gt_mask_hidden is not None: + one_gt_mask, one_pred_mask = gt_mask_hidden[:, t], pred_mask_hidden[:, t] + ari = ARI_metric(one_gt_mask, one_pred_mask) + fari = fARI_metric(one_gt_mask, one_pred_mask) + miou = miou_metric(one_gt_mask, one_pred_mask) + all_ari_hidden.append(ari) + all_fari_hidden.append(fari) + all_miou_hidden.append(miou) + + + # dispalay masks with cv2 + if False: + one_gt_mask_, one_pred_mask_ = one_gt_mask.clone(), one_pred_mask.clone() + + # concat one_gt_mask and one_pred_mask horizontally + frame = np.concatenate((one_gt_mask_, one_pred_mask_), axis=1) + frame = rearrange(frame, 'c h w -> h w c') * (255/6) + frame = frame.astype(np.uint8) + cv2.imshow('frame', frame) + cv2.waitKey(0) + else: all_ari.append(0.) all_fari.append(0.) @@ -315,14 +343,14 @@ def pred_eval_step( # compute bbox metrics all_ap, all_ar = [], [] for t in range(T): - if not eval_traj: + if not eval_traj or gt_bbox is None: all_ap.append(0.) all_ar.append(0.) continue one_gt_pres_mask, one_gt_bbox, one_pred_bbox = \ gt_pres_mask[:, t], gt_bbox[:, t], pred_bbox[:, t] ap, ar = batch_bbox_precision_recall(one_gt_pres_mask, one_gt_bbox, - one_pred_bbox) + one_pred_bbox) all_ap.append(ap) all_ar.append(ar) @@ -333,11 +361,13 @@ def pred_eval_step( one_gt, one_pred = gt[:, t], pred[:, t] mse = mse_metric(one_gt, one_pred) psnr = psnr_metric(one_gt, one_pred) - ssim = ssim_metric(one_gt, one_pred) + #ssim = ssim_metric(one_gt, one_pred) + ssim = 0 all_mse.append(mse) all_ssim.append(ssim) all_psnr.append(psnr) - return { + + res = { 'mse': all_mse, 'ssim': all_ssim, 'psnr': all_psnr, @@ -346,5 +376,12 @@ def pred_eval_step( 'fari': all_fari, 'miou': all_miou, 'ap': all_ap, - 'ar': all_ar, + 'ar': all_ar } + + if gt_mask_hidden is not None: + res['ari_hidden'] = all_ari_hidden + res['fari_hidden'] = all_fari_hidden + res['miou_hidden'] = all_miou_hidden + + return res |