aboutsummaryrefslogtreecommitdiff
path: root/scripts/utils/eval_metrics.py
blob: 6f5106d894fdeb62dbbe07bb8c3d1510d778b19f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
'''
vp_utils.py
SCRIPT TAKEN FROM https://github.com/pairlab/SlotFormer 
'''

import numpy as np
from scipy.optimize import linear_sum_assignment
from skimage.metrics import structural_similarity, peak_signal_noise_ratio

import torch
import torch.nn.functional as F
import torchvision.ops as vops

FG_THRE = 0.5
PALETTE = [(0, 255, 0), (0, 0, 255), (0, 255, 255), (255, 255, 0),
           (255, 0, 255), (100, 100, 255), (200, 200, 100), (170, 120, 200),
           (255, 0, 0), (200, 100, 100), (10, 200, 100), (200, 200, 200),
           (50, 50, 50)]
PALETTE_np = np.array(PALETTE, dtype=np.uint8)
PALETTE_torch = torch.from_numpy(PALETTE_np).float() / 255. * 2. - 1.

def to_rgb_from_tensor(x):
    """Reverse the Normalize operation in torchvision."""
    return (x * 0.5 + 0.5).clamp(0, 1)

def postproc_mask(batch_masks):
    """Post-process masks instead of directly taking argmax.

    Args:
        batch_masks: [B, T, N, 1, H, W]

    Returns:
        masks: [B, T, H, W]
    """
    batch_masks = batch_masks.clone()
    B, T, N, _, H, W = batch_masks.shape
    batch_masks = batch_masks.reshape(B * T, N, H * W)
    slots_max = batch_masks.max(-1)[0]  # [B*T, N]
    bg_idx = slots_max.argmin(-1)  # [B*T]
    spatial_max = batch_masks.max(1)[0]  # [B*T, H*W]
    bg_mask = (spatial_max < FG_THRE)  # [B*T, H*W]
    idx_mask = torch.zeros((B * T, N)).type_as(bg_mask)
    idx_mask[torch.arange(B * T), bg_idx] = True
    # set the background mask score to 1
    batch_masks[idx_mask.unsqueeze(-1) * bg_mask.unsqueeze(1)] = 1.
    masks = batch_masks.argmax(1)  # [B*T, H*W]
    return masks.reshape(B, T, H, W)


def masks_to_boxes_w_empty_mask(binary_masks):
    """binary_masks: [B, H, W]."""
    B = binary_masks.shape[0]
    obj_mask = (binary_masks.sum([-1, -2]) > 0)  # [B]
    bboxes = torch.ones((B, 4)).float().to(binary_masks.device) * -1.
    bboxes[obj_mask] = vops.masks_to_boxes(binary_masks[obj_mask])
    return bboxes


def masks_to_boxes(masks, num_boxes=7):
    """Convert seg_masks to bboxes.

    Args:
        masks: [B, T, H, W], output after taking argmax
        num_boxes: number of boxes to generate (num_slots)

    Returns:
        bboxes: [B, T, N, 4], 4: [x1, y1, x2, y2]
    """
    B, T, H, W = masks.shape
    binary_masks = F.one_hot(masks, num_classes=num_boxes)  # [B, T, H, W, N]
    binary_masks = binary_masks.permute(0, 1, 4, 2, 3)  # [B, T, N, H, W]
    binary_masks = binary_masks.contiguous().flatten(0, 2)
    bboxes = masks_to_boxes_w_empty_mask(binary_masks)  # [B*T*N, 4]
    bboxes = bboxes.reshape(B, T, num_boxes, 4)  # [B, T, N, 4]
    return bboxes


def mse_metric(x, y):
    """x/y: [B, C, H, W]"""
    # people often sum over spatial dimensions in video prediction MSE
    # see e.g. https://github.com/Yunbo426/predrnn-pp/blob/40764c52f433290aa02414b5f25c09c72b98e0af/train.py#L244
    return ((x - y)**2).sum(-1).sum(-1).mean()


def psnr_metric(x, y):
    """x/y: [B, C, H, W]"""
    psnrs = [
        peak_signal_noise_ratio(
            x[i],
            y[i],
            data_range=1.,
        ) for i in range(x.shape[0])
    ]
    return np.mean(psnrs)


def ssim_metric(x, y):
    """x/y: [B, C, H, W]"""
    x = x * 255.
    y = y * 255.
    ssims = [
        structural_similarity(
            x[i],
            y[i],
            channel_axis=0,
            gaussian_weights=True,
            sigma=1.5,
            use_sample_covariance=False,
            data_range=255,
        ) for i in range(x.shape[0])
    ]
    return np.mean(ssims)


def perceptual_dist(x, y, loss_fn):
    """x/y: [B, C, H, W]"""
    return loss_fn(x, y).mean()


def adjusted_rand_index(true_ids, pred_ids, ignore_background=False):
    """Computes the adjusted Rand index (ARI), a clustering similarity score.

    Code borrowed from https://github.com/google-research/slot-attention-video/blob/e8ab54620d0f1934b332ddc09f1dba7bc07ff601/savi/lib/metrics.py#L111

    Args:
        true_ids: An integer-valued array of shape
            [batch_size, seq_len, H, W]. The true cluster assignment encoded
            as integer ids.
        pred_ids: An integer-valued array of shape
            [batch_size, seq_len, H, W]. The predicted cluster assignment
            encoded as integer ids.
        ignore_background: Boolean, if True, then ignore all pixels where
            true_ids == 0 (default: False).

    Returns:
        ARI scores as a float32 array of shape [batch_size].
    """
    if len(true_ids.shape) == 3:
        true_ids = true_ids.unsqueeze(1)
    if len(pred_ids.shape) == 3:
        pred_ids = pred_ids.unsqueeze(1)

    true_oh = F.one_hot(true_ids).float()
    pred_oh = F.one_hot(pred_ids).float()

    if ignore_background:
        true_oh = true_oh[..., 1:]  # Remove the background row.

    N = torch.einsum("bthwc,bthwk->bck", true_oh, pred_oh)
    A = torch.sum(N, dim=-1)  # row-sum  (batch_size, c)
    B = torch.sum(N, dim=-2)  # col-sum  (batch_size, k)
    num_points = torch.sum(A, dim=1)

    rindex = torch.sum(N * (N - 1), dim=[1, 2])
    aindex = torch.sum(A * (A - 1), dim=1)
    bindex = torch.sum(B * (B - 1), dim=1)
    expected_rindex = aindex * bindex / torch.clamp(
        num_points * (num_points - 1), min=1)
    max_rindex = (aindex + bindex) / 2
    denominator = max_rindex - expected_rindex
    ari = (rindex - expected_rindex) / denominator

    # There are two cases for which the denominator can be zero:
    # 1. If both label_pred and label_true assign all pixels to a single cluster.
    #    (max_rindex == expected_rindex == rindex == num_points * (num_points-1))
    # 2. If both label_pred and label_true assign max 1 point to each cluster.
    #    (max_rindex == expected_rindex == rindex == 0)
    # In both cases, we want the ARI score to be 1.0:
    return torch.where(denominator != 0, ari, torch.tensor(1.).type_as(ari))


def ARI_metric(x, y):
    """x/y: [B, H, W], both are seg_masks after argmax."""
    assert 'int' in str(x.dtype)
    assert 'int' in str(y.dtype)
    return adjusted_rand_index(x, y).mean().item()


def fARI_metric(x, y):
    """x/y: [B, H, W], both are seg_masks after argmax."""
    assert 'int' in str(x.dtype)
    assert 'int' in str(y.dtype)
    return adjusted_rand_index(x, y, ignore_background=True).mean().item()


def bbox_precision_recall(gt_pres_mask, gt_bbox, pred_bbox, ovthresh=0.5):
    """Compute the precision of predicted bounding boxes.

    Args:
        gt_pres_mask: A boolean tensor of shape [N]
        gt_bbox: A tensor of shape [N, 4]
        pred_bbox: A tensor of shape [M, 4]
    """
    gt_bbox, pred_bbox = gt_bbox.clone(), pred_bbox.clone()
    gt_bbox = gt_bbox[gt_pres_mask.bool()]
    pred_bbox = pred_bbox[pred_bbox[:, 0] >= 0.]
    N, M = gt_bbox.shape[0], pred_bbox.shape[0]
    assert gt_bbox.shape[1] == pred_bbox.shape[1] == 4
    # assert M >= N
    tp, fp = 0, 0
    bbox_used = [False] * pred_bbox.shape[0]
    bbox_ious = vops.box_iou(gt_bbox, pred_bbox)  # [N, M]

    # Find the best iou match for each ground truth bbox.
    for i in range(N):
        best_iou_idx = bbox_ious[i].argmax().item()
        best_iou = bbox_ious[i, best_iou_idx].item()
        if best_iou >= ovthresh and not bbox_used[best_iou_idx]:
            tp += 1
            bbox_used[best_iou_idx] = True
        else:
            fp += 1

    # compute precision and recall
    precision = tp / float(M)
    recall = tp / float(N)
    return precision, recall


def batch_bbox_precision_recall(gt_pres_mask, gt_bbox, pred_bbox):
    """Compute the precision of predicted bounding boxes over batch."""
    aps, ars = [], []
    for i in range(gt_pres_mask.shape[0]):
        ap, ar = bbox_precision_recall(gt_pres_mask[i], gt_bbox[i],
                                       pred_bbox[i])
        aps.append(ap)
        ars.append(ar)
    return np.mean(aps), np.mean(ars)


def hungarian_miou(gt_mask, pred_mask):
    """both mask: [H*W] after argmax, 0 is gt background index."""
    true_oh = F.one_hot(gt_mask).float()[..., 1:]  # only foreground, [HW, N]
    pred_oh = F.one_hot(pred_mask).float()  # [HW, M]
    N, M = true_oh.shape[-1], pred_oh.shape[-1]
    # compute all pairwise IoU
    intersect = (true_oh[:, :, None] * pred_oh[:, None, :]).sum(0)  # [N, M]
    union = true_oh.sum(0)[:, None] + pred_oh.sum(0)[None, :]  # [N, M]
    iou = intersect / (union + 1e-8)  # [N, M]
    iou = iou.detach().cpu().numpy()
    # find the best match for each gt
    row_ind, col_ind = linear_sum_assignment(iou, maximize=True)
    # there are two possibilities here
    #   1. M >= N, just take the best match mean
    #   2. M < N, some objects are not detected, their iou is 0
    if M >= N:
        assert (row_ind == np.arange(N)).all()
        return iou[row_ind, col_ind].mean()
    return iou[row_ind, col_ind].sum() / float(N)


def miou_metric(gt_mask, pred_mask):
    """both mask: [B, H, W], both are seg_masks after argmax."""
    assert 'int' in str(gt_mask.dtype)
    assert 'int' in str(pred_mask.dtype)
    gt_mask, pred_mask = gt_mask.flatten(1, 2), pred_mask.flatten(1, 2)
    ious = [
        hungarian_miou(gt_mask[i], pred_mask[i])
        for i in range(gt_mask.shape[0])
    ]
    return np.mean(ious)


@torch.no_grad()
def pred_eval_step(
    gt,
    pred,
    lpips_fn,
    gt_mask=None,
    pred_mask=None,
    gt_pres_mask=None,
    gt_bbox=None,
    pred_bbox=None,
    eval_traj=True,
):
    """Both of shape [B, T, C, H, W], torch.Tensor.
    masks of shape [B, T, H, W].
    pres_mask of shape [B, T, N].
    bboxes of shape [B, T, N/M, 4].

    eval_traj: whether to evaluate the trajectory (measured by bbox and mask).

    Compute metrics for every timestep.
    """
    assert len(gt.shape) == len(pred.shape) == 5
    assert gt.shape == pred.shape
    assert gt.shape[2] == 3
    if eval_traj:
        assert len(gt_mask.shape) == len(pred_mask.shape) == 4
        assert gt_mask.shape == pred_mask.shape
    if eval_traj:
        assert len(gt_pres_mask.shape) == 3
        assert len(gt_bbox.shape) == len(pred_bbox.shape) == 4
    T = gt.shape[1]

    # compute perceptual dist & mask metrics before converting to numpy
    all_percept_dist, all_ari, all_fari, all_miou = [], [], [], []
    for t in range(T):
        one_gt, one_pred = gt[:, t], pred[:, t]
        percept_dist = perceptual_dist(one_gt, one_pred, lpips_fn).item()
        all_percept_dist.append(percept_dist)
        if eval_traj:
            one_gt_mask, one_pred_mask = gt_mask[:, t], pred_mask[:, t]
            ari = ARI_metric(one_gt_mask, one_pred_mask)
            fari = fARI_metric(one_gt_mask, one_pred_mask)
            miou = miou_metric(one_gt_mask, one_pred_mask)
            all_ari.append(ari)
            all_fari.append(fari)
            all_miou.append(miou)
        else:
            all_ari.append(0.)
            all_fari.append(0.)
            all_miou.append(0.)

    # compute bbox metrics
    all_ap, all_ar = [], []
    for t in range(T):
        if not eval_traj:
            all_ap.append(0.)
            all_ar.append(0.)
            continue
        one_gt_pres_mask, one_gt_bbox, one_pred_bbox = \
            gt_pres_mask[:, t], gt_bbox[:, t], pred_bbox[:, t]
        ap, ar = batch_bbox_precision_recall(one_gt_pres_mask, one_gt_bbox,
                                             one_pred_bbox)
        all_ap.append(ap)
        all_ar.append(ar)

    gt = to_rgb_from_tensor(gt).cpu().numpy()
    pred = to_rgb_from_tensor(pred).cpu().numpy()
    all_mse, all_ssim, all_psnr = [], [], []
    for t in range(T):
        one_gt, one_pred = gt[:, t], pred[:, t]
        mse = mse_metric(one_gt, one_pred)
        psnr = psnr_metric(one_gt, one_pred)
        ssim = ssim_metric(one_gt, one_pred)
        all_mse.append(mse)
        all_ssim.append(ssim)
        all_psnr.append(psnr)
    return {
        'mse': all_mse,
        'ssim': all_ssim,
        'psnr': all_psnr,
        'percept_dist': all_percept_dist,
        'ari': all_ari,
        'fari': all_fari,
        'miou': all_miou,
        'ap': all_ap,
        'ar': all_ar,
    }