aboutsummaryrefslogtreecommitdiff
path: root/scripts/evaluation_adept_savi.py
diff options
context:
space:
mode:
authorfredeee2024-03-23 13:27:00 +0100
committerfredeee2024-03-23 13:27:00 +0100
commit6bcf6b8306ce4903734fb31824799a50281cea69 (patch)
tree0545ff1b8beb051993c2d75fd81306db1a22274d /scripts/evaluation_adept_savi.py
parentad0b64a7f0140406151d18b19ab2ed5d19b6c511 (diff)
add bouncingball experiment and ablation studies
Diffstat (limited to 'scripts/evaluation_adept_savi.py')
-rw-r--r--scripts/evaluation_adept_savi.py233
1 files changed, 0 insertions, 233 deletions
diff --git a/scripts/evaluation_adept_savi.py b/scripts/evaluation_adept_savi.py
deleted file mode 100644
index 6a2d5c7..0000000
--- a/scripts/evaluation_adept_savi.py
+++ /dev/null
@@ -1,233 +0,0 @@
-from einops import rearrange, reduce, repeat
-import torch as th
-from torch.utils.data import Dataset, DataLoader, Subset
-import cv2
-import numpy as np
-import pandas as pd
-import os
-from data.datasets.ADEPT.dataset import AdeptDataset
-import motmetrics as mm
-from scripts.evaluation_adept import calculate_tracking_error, get_evaluation_sets, update_mota_acc
-from scripts.utils.eval_utils import setup_result_folders, store_statistics
-from scripts.utils.plot_utils import write_image
-
-FG_THRE = 0.95
-
-def evaluate(dataset: Dataset, file, n, plot_frequency= 1, plot_first_samples = 2):
-
- # read pkl file
- masks_complete = pd.read_pickle(file)
-
- # plot config
- color_list = [[255,0,0], [0,255,0], [0,0,255], [255,255,0], [0,255,255], [255,0,255], [255,255,255]]
- dot_size = 2
- skip_frames = 2
- offset = 15
-
- # memory
- statistics_complete_slots = {'set': [], 'evalmode': [], 'scene': [], 'frame': [], 'slot':[], 'TE': [], 'visible': [], 'bound': [], 'occluder': [], 'inimage': [], 'slot_error': [], 'mask_size': [], 'rawmask_size': [], 'rawmask_size_hidden': [], 'alpha_pos': [], 'alpha_ges': [], 'object_id': []}
- acc_memory_eval = []
-
- # load adept dataset
- set_test_array, evaluation_modes = get_evaluation_sets(dataset)
- control_samples = set_test_array[0]['samples'] # only consider control set
- evalset = Subset(dataset, control_samples)
- root_path, plot_path = setup_result_folders(file, n, set_test_array[0], evaluation_modes[0], True, False)
-
- for i in range(len(evalset)):
- print(f'Processing sample {i+1}/{len(evalset)}', flush=True)
- input = evalset[i]
- acc = mm.MOTAccumulator(auto_id=True)
-
- # get input frame and target frame
- tensor = th.tensor(input[0]).float().unsqueeze(0)
- background_fix = th.tensor(input[1]).unsqueeze(0)
- gt_object_positions = th.tensor(input[3]).unsqueeze(0)
- gt_object_visibility = th.tensor(input[4]).unsqueeze(0)
- gt_occluder_mask = th.tensor(input[5]).unsqueeze(0)
-
- # apply skip frames
- gt_object_positions = gt_object_positions[:,range(0, tensor.shape[1], skip_frames)]
- gt_object_visibility = gt_object_visibility[:,range(0, tensor.shape[1], skip_frames)]
- tensor = tensor[:,range(0, tensor.shape[1], skip_frames)]
- sequence_len = tensor.shape[1]
-
- # load data
- masks = th.tensor(masks_complete['test'][f'control_{i}.mp4'])
- masks_before_softmax = th.tensor(masks_complete['test_raw'][f'control_{i}.mp4'])
-
- # calculate rawmasks
- bg_mask = masks_before_softmax.mean(dim=1)
- masks_raw = compute_maskraw(masks_before_softmax, bg_mask)
- slots_bound = compute_slots_bound(masks_raw)
-
- # threshold masks and calculate centroids
- masks_binary = (masks_raw > FG_THRE).float()
- masks2 = rearrange(masks_binary, 't o 1 h w -> (t o) h w')
- boxes = masks_to_boxes(masks2.long())
- boxes = boxes.reshape(1, masks.shape[0], 7, 4)
- centroids = boxes_to_centroids(boxes)
-
- # get rid of batch dimension
- association_table = th.ones(7) * -1
-
- # iterate over frames
- for t_index in range(offset,min(sequence_len,masks.shape[0])):
-
- # move to next frame
- input = tensor[:,t_index]
- target = th.clip(tensor[:,t_index+1], 0, 1)
- gt_positions_target = gt_object_positions[:,t_index]
- gt_positions_target_next = gt_object_positions[:,t_index+1]
- gt_visibility_target = gt_object_visibility[:,t_index]
-
- position_cur = centroids[t_index]
- position_cur = rearrange(position_cur, 'o c -> 1 (o c)')
- slots_bound_cur = slots_bound[t_index]
- slots_bound_cur = rearrange(slots_bound_cur, 'o c -> 1 (o c)')
-
- # calculate tracking error
- tracking_error, tracking_error_perslot, association_table, slots_visible, slots_in_image, slots_occluder = calculate_tracking_error(gt_positions_target, gt_visibility_target, position_cur, 7, slots_bound_cur, None, association_table, gt_occluder_mask)
-
- rawmask_size = reduce(masks_raw[t_index], 'o 1 h w-> 1 o', 'sum')
- mask_size = reduce(masks[t_index], 'o 1 h w-> 1 o', 'sum')
-
- statistics_complete_slots = store_statistics(statistics_complete_slots,
- ['control'] * 7,
- ['control'] * 7,
- [control_samples[i]] * 7,
- [t_index] * 7,
- range(7),
- tracking_error_perslot.cpu().numpy().flatten(),
- slots_visible.cpu().numpy().flatten().astype(int),
- slots_bound_cur.cpu().numpy().flatten().astype(int),
- slots_occluder.cpu().numpy().flatten().astype(int),
- slots_in_image.cpu().numpy().flatten().astype(int),
- [0] * 7,
- mask_size.cpu().numpy().flatten(),
- rawmask_size.cpu().numpy().flatten(),
- [0] * 7,
- [0] * 7,
- [0] * 7,
- association_table[0].cpu().numpy().flatten().astype(int),
- extend = True)
-
- acc = update_mota_acc(acc, gt_positions_target, position_cur, slots_bound_cur, 7, gt_occluder_mask, slots_occluder, None)
-
- # plot_option
- if (t_index % plot_frequency == 0) and (i < plot_first_samples) and (t_index >= 0):
- masks_to_display = masks_binary.numpy() # masks_binary.numpy()
-
- frame = tensor[0, t_index]
- frame = frame.numpy().transpose(1,2,0)
- frame = cv2.resize(frame, (64,64))
-
- centroids_frame = centroids[t_index]
- centroids_frame[:,0] = (centroids_frame[:,0] + 1) * 64 / 2
- centroids_frame[:,1] = (centroids_frame[:,1] + 1) * 64 / 2
-
- bound_frame = slots_bound[t_index]
- for c_index,centroid_slot in enumerate(centroids_frame):
- if bound_frame[c_index] == 1:
- frame[int(centroid_slot[1]-dot_size):int(centroid_slot[1]+dot_size), int(centroid_slot[0]-dot_size):int(centroid_slot[0]+dot_size)] = color_list[c_index]
-
- # slot images
- slot_frame = masks_to_display[t_index].max(axis=0)
- slot_frame = slot_frame.reshape((64,64,1)).repeat(3, axis=2)
-
- if True:
- for mask in masks_to_display[t_index]:
- #slot_frame_single = mask.reshape((64,64,1)).repeat(3, axis=2)
- slot_frame_single = mask.transpose((1,2,0)).repeat(3, axis=2)
- slot_frame = np.concatenate((slot_frame, slot_frame_single), axis=1)
-
- frame = np.concatenate((frame, slot_frame), axis=1)
- cv2.imwrite(f'{plot_path}object/objects-{i:04d}-{t_index:03d}.jpg', frame*255)
-
- acc_memory_eval.append(acc)
-
- mh = mm.metrics.create()
- summary = mh.compute_many(acc_memory_eval, metrics=mm.metrics.motchallenge_metrics, generate_overall=True)
- summary['set'] = 'control'
- summary['evalmode'] = 'control'
- pd.DataFrame(summary).to_csv(os.path.join(root_path, 'statistics' , 'accframe.csv'))
- pd.DataFrame(statistics_complete_slots).to_csv(os.path.join(root_path, 'statistics' , 'slotframe.csv'))
-
-def masks_to_boxes(masks: th.Tensor) -> th.Tensor:
- """
- Compute the bounding boxes around the provided masks.
-
- Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with
- ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
-
- Args:
- masks (Tensor[N, H, W]): masks to transform where N is the number of masks
- and (H, W) are the spatial dimensions.
-
- Returns:
- Tensor[N, 4]: bounding boxes
- """
- if masks.numel() == 0:
- return th.zeros((0, 4), device=masks.device, dtype=th.float)
-
- n = masks.shape[0]
-
- bounding_boxes = th.zeros((n, 4), device=masks.device, dtype=th.float)
-
- for index, mask in enumerate(masks):
- if mask.sum() > 0:
- y, x = th.where(mask != 0)
-
- bounding_boxes[index, 0] = th.min(x)
- bounding_boxes[index, 1] = th.min(y)
- bounding_boxes[index, 2] = th.max(x)
- bounding_boxes[index, 3] = th.max(y)
-
- return bounding_boxes
-
-def boxes_to_centroids(boxes):
- """Post-process masks instead of directly taking argmax.
-
- Args:
- bboxes: [B, T, N, 4], 4: [x1, y1, x2, y2]
-
- Returns:
- centroids: [B, T, N, 2], 2: [x, y]
- """
-
- centroids = (boxes[:, :, :, :2] + boxes[:, :, :, 2:]) / 2
- centroids = centroids.squeeze(0)
-
- # scale to [-1, 1]
- centroids[:, :, 0] = centroids[:, :, 0] / 64 * 2 - 1
- centroids[:, :, 1] = centroids[:, :, 1] / 64 * 2 - 1
-
- return centroids
-
-def compute_slots_bound(masks):
-
- # take sum over axis 3,4 with th
- masks_sum = masks.amax(dim=(3,4))
- slots_bound = (masks_sum > FG_THRE).float()
- return slots_bound
-
-def compute_maskraw(mask, bg_mask):
-
- # d is a diagonal matrix which defines what to take the softmax over
- d_mask = th.diag(th.ones(8))
- d_mask[:,-1] = 1
- d_mask[-1,-1] = 0
-
- mask = mask.squeeze(2)
-
- # take subset of maskraw with the diagonal matrix
- maskraw = th.cat((mask, bg_mask), dim=1)
- maskraw = repeat(maskraw, 'b o h w -> b r o h w', r = 8)
- maskraw = maskraw[:,d_mask.bool()]
- maskraw = rearrange(maskraw, 'b (o r) h w -> b o r h w', o = 7)
-
- # take softmax between each object mask and the background mask
- maskraw = th.squeeze(th.softmax(maskraw, dim=2)[:,:,0], dim=2)
- maskraw = maskraw.unsqueeze(2)
-
- return maskraw \ No newline at end of file