aboutsummaryrefslogtreecommitdiff
path: root/scripts/evaluation_adept.py
diff options
context:
space:
mode:
authorfredeee2023-11-02 10:47:21 +0100
committerfredeee2023-11-02 10:47:21 +0100
commitf8302ee886ef9b631f11a52900dac964a61350e1 (patch)
tree87288be6f851ab69405e524b81940c501c52789a /scripts/evaluation_adept.py
parentf16fef1ab9371e1c81a2e0b2fbea59dee285a9f8 (diff)
initiaƶ commit
Diffstat (limited to 'scripts/evaluation_adept.py')
-rw-r--r--scripts/evaluation_adept.py398
1 files changed, 398 insertions, 0 deletions
diff --git a/scripts/evaluation_adept.py b/scripts/evaluation_adept.py
new file mode 100644
index 0000000..eab3ad9
--- /dev/null
+++ b/scripts/evaluation_adept.py
@@ -0,0 +1,398 @@
+import torch as th
+from torch.utils.data import Dataset, DataLoader, Subset
+from torch import nn
+import os
+from scripts.utils.plot_utils import plot_timestep
+from scripts.utils.configuration import Configuration
+from scripts.utils.io import init_device
+import numpy as np
+from einops import rearrange, repeat, reduce
+import motmetrics as mm
+from copy import deepcopy
+import pandas as pd
+from scripts.utils.eval_utils import append_statistics, load_model, setup_result_folders, store_statistics
+
+
+def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_frequency= 1, plot_first_samples = 2):
+
+ # Set up cpu or gpu training
+ device, verbose = init_device(cfg)
+
+ # Config
+ cfg_net = cfg.model
+ cfg_net.batch_size = 1
+
+ # Load model
+ net = load_model(cfg, cfg_net, file, device)
+ net.eval()
+
+ # Plot config
+ object_view = True
+ individual_views = False
+ root_path = None
+
+ # get evaluation sets for control and surprise condition
+ set_test_array, evaluation_modes = get_evaluation_sets(dataset)
+
+ # memory
+ statistics_template = {'set': [], 'evalmode': [], 'scene': [], 'frame': [], 'image_error': [], 'TE': []}
+ statistics_complete = deepcopy(statistics_template)
+ statistics_complete_slots = {'set': [], 'evalmode': [], 'scene': [], 'frame': [], 'slot':[], 'TE': [], 'visible': [], 'bound': [], 'occluder': [], 'inimage': [], 'slot_error': [], 'mask_size': [], 'rawmask_size': [], 'rawmask_size_hidden': [], 'alpha_pos': [], 'alpha_ges': [], 'object_id': [], 'vanishing': []}
+ acc_memory_complete = None
+
+ for set_test in set_test_array:
+
+ for evaluation_mode in evaluation_modes:
+ print(f'Start evaluation loop: {evaluation_mode} - {set_test["type"]}')
+
+ # Load data
+ dataloader = DataLoader(
+ Subset(dataset, set_test['samples']),
+ num_workers = 1,
+ pin_memory = False,
+ batch_size = 1,
+ shuffle = False
+ )
+
+ # memory
+ mseloss = nn.MSELoss()
+ root_path, plot_path = setup_result_folders(file, n, set_test, evaluation_mode, object_view, individual_views)
+ acc_memory_eval = []
+
+ with th.no_grad():
+ for i, input in enumerate(dataloader):
+ print(f'Processing sample {i+1}/{len(dataloader)}', flush=True)
+
+ # Data
+ tensor = input[0].float().to(device)
+ background_fix = input[1].to(device)
+ gt_object_positions = input[3].to(device)
+ gt_object_visibility = input[4].to(device)
+ gt_occluder_mask = input[5].to(device)
+
+ # Apply skip frames
+ gt_object_positions = gt_object_positions[:,range(0, tensor.shape[1], cfg.defaults.skip_frames)]
+ gt_object_visibility = gt_object_visibility[:,range(0, tensor.shape[1], cfg.defaults.skip_frames)]
+ tensor = tensor[:,range(0, tensor.shape[1], cfg.defaults.skip_frames)]
+ sequence_len = tensor.shape[1]
+
+ # Placehodlers
+ mask_cur = None
+ mask_last = None
+ rawmask_last = None
+ position_last = None
+ gestalt_last = None
+ priority_last = None
+ gt_positions_target = None
+ slots_occlusionfactor = None
+ error_last = None
+
+ # Memory
+ association_table = th.ones(cfg_net.num_objects).to(device) * -1
+ acc = mm.MOTAccumulator(auto_id=True)
+ statistics_batch = deepcopy(statistics_template)
+ slots_vanishing_memory = np.zeros(cfg_net.num_objects)
+
+ # loop through frames
+ for t_index,t in enumerate(range(-cfg.defaults.teacher_forcing, sequence_len-1)):
+
+ # Move to next frame
+ t_run = max(t, 0)
+ input = tensor[:,t_run]
+ target = th.clip(tensor[:,t_run+1], 0, 1)
+ gt_positions_target = gt_object_positions[:,t_run]
+ gt_positions_target_next = gt_object_positions[:,t_run+1]
+ gt_visibility_target = gt_object_visibility[:,t_run]
+
+ # Forward Pass
+ (
+ output_next,
+ position_next,
+ gestalt_next,
+ priority_next,
+ mask_next,
+ rawmask_next,
+ object_next,
+ background,
+ slots_occlusionfactor,
+ output_cur,
+ position_cur,
+ gestalt_cur,
+ priority_cur,
+ mask_cur,
+ rawmask_cur,
+ object_cur,
+ position_encoder_cur,
+ slots_bounded,
+ slots_partially_occluded_cur,
+ slots_occluded_cur,
+ slots_partially_occluded_next,
+ slots_occluded_next,
+ slots_closed,
+ output_hidden,
+ largest_object,
+ rawmask_hidden,
+ object_hidden
+ ) = net(
+ input,
+ error_last,
+ mask_last,
+ rawmask_last,
+ position_last,
+ gestalt_last,
+ priority_last,
+ background_fix,
+ slots_occlusionfactor,
+ reset = (t == -cfg.defaults.teacher_forcing),
+ evaluate=True,
+ warmup = (t < 0),
+ shuffleslots = False,
+ reset_mask = (t <= 0),
+ allow_spawn = True,
+ show_hidden = True,
+ clean_slots = (t <= 0),
+ )
+
+ # 1. Track error
+ if t >= 0:
+
+ # Position error: MSE between predicted position and target position
+ tracking_error, tracking_error_perslot, association_table, slots_visible, slots_in_image, slots_occluder = calculate_tracking_error(gt_positions_target, gt_visibility_target, position_cur, cfg_net.num_objects, slots_bounded, slots_occluded_cur, association_table, gt_occluder_mask)
+
+ statistics_batch = store_statistics(statistics_batch,
+ set_test['type'],
+ evaluation_mode,
+ set_test['samples'][i],
+ t,
+ mseloss(output_next, target).item(),
+ tracking_error)
+
+ # Compute rawmask_size
+ rawmask_size, rawmask_size_hidden, mask_size = compute_mask_sizes(mask_next, rawmask_next, rawmask_hidden)
+
+ # Compute slot-wise prediction error
+ slot_error = compute_slot_error(cfg_net, target, output_next, mask_next, mask_size)
+
+ # Check if objects vanishes: they leave the scene suprsingly in the suprrise condition
+ slots_vanishing = compute_vanishing_slots(gt_positions_target, association_table, gt_positions_target_next)
+ slots_vanishing_memory = slots_vanishing + slots_vanishing_memory
+
+ # Store slot statistics
+ statistics_complete_slots = store_statistics(statistics_complete_slots,
+ [set_test['type']] * cfg_net.num_objects,
+ [evaluation_mode] * cfg_net.num_objects,
+ [set_test['samples'][i]] * cfg_net.num_objects,
+ [t] * cfg_net.num_objects,
+ range(cfg_net.num_objects),
+ tracking_error_perslot.cpu().numpy().flatten(),
+ slots_visible.cpu().numpy().flatten().astype(int),
+ slots_bounded.cpu().numpy().flatten().astype(int),
+ slots_occluder.cpu().numpy().flatten().astype(int),
+ slots_in_image.cpu().numpy().flatten().astype(int),
+ slot_error.cpu().numpy().flatten(),
+ mask_size.cpu().numpy().flatten(),
+ rawmask_size.cpu().numpy().flatten(),
+ rawmask_size_hidden.cpu().numpy().flatten(),
+ slots_closed[:, :, 1].cpu().numpy().flatten(),
+ slots_closed[:, :, 0].cpu().numpy().flatten(),
+ association_table[0].cpu().numpy().flatten().astype(int),
+ extend = True)
+
+ # Compute MOTA
+ acc = update_mota_acc(acc, gt_positions_target, position_cur, slots_bounded, cfg_net.num_objects, gt_occluder_mask, slots_occluder, rawmask_next)
+
+ # 2. Remember output
+ mask_last = mask_next.clone()
+ rawmask_last = rawmask_next.clone()
+ position_last = position_next.clone()
+ gestalt_last = gestalt_next.clone()
+ priority_last = priority_next.clone()
+
+ # 3. Error for next frame
+ bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+ error_next = th.sqrt(reduce((target - output_next)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+ error_next = th.sqrt(error_next) * bg_error_next
+ error_last = error_next.clone()
+
+ # 4. Plot
+ if (t % plot_frequency == 0) and (i < plot_first_samples) and (t >= 0):
+ plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i)
+
+ # fill jumping statistics
+ statistics_complete_slots['vanishing'].extend(np.tile(slots_vanishing_memory.astype(int), t+1))
+
+ # store batch statistics in complete statistics
+ acc_memory_eval.append(acc)
+ statistics_complete = append_statistics(statistics_complete, statistics_batch, extend = True)
+
+ summary = mm.metrics.create().compute_many(acc_memory_eval, metrics=mm.metrics.motchallenge_metrics, generate_overall=True)
+ summary['set'] = set_test['type']
+ summary['evalmode'] = evaluation_mode
+ acc_memory_complete = summary.copy() if acc_memory_complete is None else pd.concat([acc_memory_complete, summary])
+
+
+ print('-- Evaluation Done --')
+ pd.DataFrame(statistics_complete).to_csv(f'{root_path}/statistics/trialframe.csv')
+ pd.DataFrame(statistics_complete_slots).to_csv(f'{root_path}/statistics/slotframe.csv')
+ pd.DataFrame(acc_memory_complete).to_csv(f'{root_path}/statistics/accframe.csv')
+ if object_view and os.path.exists(f'{root_path}/tmp.jpg'):
+ os.remove(f'{root_path}/tmp.jpg')
+ pass
+
+
+def compute_vanishing_slots(gt_positions_target, association_table, gt_positions_target_next):
+ objects_vanishing = th.abs(gt_positions_target[:,:,2] - gt_positions_target_next[:,:,2]) > 0.2
+ objects_vanishing = th.where(objects_vanishing.flatten())[0]
+ slots_vanishing = [(obj.item() in objects_vanishing) for obj in association_table[0]]
+ return slots_vanishing
+
+def compute_slot_error(cfg_net, target, output_next, mask_next, mask_size):
+ output_slot = repeat(mask_next[:,:-1], 'b o h w -> b o 3 h w') * repeat(output_next, 'b c h w -> b o c h w', o=cfg_net.num_objects)
+ target_slot = repeat(mask_next[:,:-1], 'b o h w -> b o 3 h w') * repeat(target, 'b c h w -> b o c h w', o=cfg_net.num_objects)
+ slot_error = reduce((output_slot - target_slot)**2, 'b o c h w -> b o', 'mean')/mask_size
+ return slot_error
+
+def compute_mask_sizes(mask_next, rawmask_next, rawmask_hidden):
+ rawmask_size = reduce(rawmask_next[:, :-1], 'b o h w-> b o', 'sum')
+ rawmask_size_hidden = reduce(rawmask_hidden[:, :-1], 'b o h w-> b o', 'sum')
+ mask_size = reduce(mask_next[:, :-1], 'b o h w-> b o', 'sum')
+ return rawmask_size,rawmask_size_hidden,mask_size
+
+def update_mota_acc(acc, gt_positions, estimated_positions, slots_bounded, cfg_num_objects, gt_occluder_mask, slots_occluder, rawmask, ignore_occluder = False):
+
+ # num objects
+ num_objects = len(gt_positions[0])
+
+ # get rid of batch dimension and priority dimension
+ pos = rearrange(estimated_positions.detach()[0], '(o c) -> o c', o=cfg_num_objects)[:, :2]
+ targets = gt_positions[0, :, :2]
+
+ # stretch positions to account for frame ratio, Specific for ADEPT!
+ pos = th.clip(pos, -1, 1)
+ pos[:, 0] = pos[:, 0] * 1.5
+ targets[:, 0] = targets[:, 0] * 1.5
+
+ # remove objects that are not in the image
+ edge = 1
+ in_image = th.cat([targets[:, 0] < (1.5 * edge), targets[:, 0] > (-1.5 * edge), targets[:, 1] < (1 * edge), targets[:, 1] > (-1 * edge)])
+ in_image = th.all(rearrange(in_image, '(c o) -> o c', o=num_objects), dim=1)
+
+ if ignore_occluder:
+ in_image = (gt_occluder_mask[0] == 0) * in_image
+ targets = targets[in_image]
+
+ # test if position estimates in image
+ in_image_pos = th.cat([pos[:, 0] < (1.5 * edge), pos[:, 0] > (-1.5 * edge), pos[:, 1] < (1 * edge), pos[:, 1] > (-1 * edge)])
+ in_image_pos = th.all(rearrange(in_image_pos, '(c o) -> c o', o=cfg_num_objects), dim=0, keepdim=True)
+
+ # only position estimates that are in image and bound
+ if rawmask is not None:
+ rawmask_size = reduce(rawmask[:, :-1], 'b o h w-> b o', 'sum')
+ m = (slots_bounded * in_image_pos * (rawmask_size > 100)).bool()
+ else:
+ m = (slots_bounded * in_image_pos).bool()
+ if ignore_occluder:
+ m = (m * (1 - slots_occluder)).bool()
+
+ pos = pos[repeat(m, '1 o -> o 2')]
+ pos = rearrange(pos, '(o c) -> o c', c = 2)
+
+ # compute pairwise distances
+ diagonal_length = th.sqrt(th.sum(th.tensor([2,3])**2)).item()
+ C = mm.distances.norm2squared_matrix(targets.cpu().numpy(), pos.cpu().numpy(), max_d2=diagonal_length*0.1)
+
+ # upadate accumulator
+ acc.update( (th.where(in_image)[0]).cpu(), (th.where(m)[1]).cpu(), C)
+
+ return acc
+
+def calculate_tracking_error(gt_positions_target, gt_visibility_target, position_cur, cfg_num_slots, slots_bounded, slots_occluded_cur, association_table, gt_occluder_mask):
+
+ # tracking utils
+ pdist = nn.PairwiseDistance(p=2).to(position_cur.device)
+
+ # 1. association of newly bounded slots to ground truth objects
+ # num objects
+ num_objects = len(gt_positions_target[0])
+
+ # get rid of batch dimension and priority dimension
+ pos = rearrange(position_cur.clone()[0], '(o c) -> o c', o=cfg_num_slots)[:, :2]
+ targets = gt_positions_target[0, :, :2]
+
+ # stretch positions to account for frame ratio, Specific for ADEPT!
+ pos = th.clip(pos, -1, 1)
+ pos[:, 0] = pos[:, 0] * 1.5
+ targets[:, 0] = targets[:, 0] * 1.5
+ diagonal_length = th.sqrt(th.sum(th.tensor([2,3])**2))
+
+ # reshape and repeat for comparison
+ pos = repeat(pos, 'o c -> (o r) c', r=num_objects)
+ targets = repeat(targets, 'o c -> (r o) c', r=cfg_num_slots)
+
+ # comparison
+ distance = pdist(pos, targets)
+ distance = rearrange(distance, '(o r) -> o r', r=num_objects)
+
+ # find closest target for each slot
+ distance = th.min(distance, dim=1, keepdim=True)
+
+ # update association table
+ slots_newly_bounded = slots_bounded * (association_table == -1)
+ if slots_occluded_cur is not None:
+ slots_newly_bounded = slots_newly_bounded * (1-slots_occluded_cur)
+ association_table = association_table * (1-slots_newly_bounded) + slots_newly_bounded * distance[1].T
+
+ # 2. position error
+ # get rid of batch dimension and priority dimension
+ pos = rearrange(position_cur.clone()[0], '(o c) -> o c', o=cfg_num_slots)[:, :2]
+ targets = gt_positions_target[0, :, :3]
+
+ # stretch positions to account for frame ratio, Specific for ADEPT!
+ pos[:, 0] = pos[:, 0] * 1.5
+ targets[:, 0] = targets[:, 0] * 1.5
+
+ # gather targets according to association table
+ targets = targets[association_table.long()][0]
+
+ # determine which slosts are within the image
+ slots_in_image = th.cat([targets[:, 0] < 1.5, targets[:, 0] > -1.5, targets[:, 1] < 1, targets[:, 1] > -1, targets[:, 2] > 0])
+ slots_in_image = rearrange(slots_in_image, '(c o) -> o c', o=cfg_num_slots)
+ slots_in_image = th.all(slots_in_image, dim=1)
+
+ # define which slots to consider for tracking error
+ slots_to_track = slots_bounded * slots_in_image
+
+ # compute position error
+ targets = targets[:, :2]
+ tracking_error_perslot = th.sqrt(th.sum((pos - targets)**2, dim=1))/diagonal_length
+ tracking_error_perslot = tracking_error_perslot[None, :] * slots_to_track
+ tracking_error = th.sum(tracking_error_perslot).item()/max(th.sum(slots_to_track).item(), 1)
+
+ # compute which slots are visible
+ visible_objects = th.where(gt_visibility_target[0] == 1)[0]
+ slots_visible = th.tensor([[int(obj.item()) in visible_objects for obj in association_table[0]]]).float().to(slots_to_track.device)
+ slots_visible = slots_visible * slots_to_track
+
+ # determine which objects are bound to the occluder
+ occluder_objects = th.where(gt_occluder_mask[0] == 1)[0]
+ slots_occluder = th.tensor([[int(obj.item()) in occluder_objects for obj in association_table[0]]]).float().to(slots_to_track.device)
+ slots_occluder = slots_occluder * slots_to_track
+
+ return tracking_error, tracking_error_perslot, association_table, slots_visible, slots_in_image, slots_occluder
+
+def get_evaluation_sets(dataset):
+
+ # Standad evaluation
+ evaluation_modes = ['open']
+
+ # Important!
+ # filter different scenarios: 1 as control and 0,3 as surprise (see Smith et al. 2020)
+ suprise_mask = [(sample.case in [1]) for i,sample in enumerate(dataset.samples)]
+ control_mask = [(sample.case in [0,3]) for i,sample in enumerate(dataset.samples)]
+
+ # Create test sets
+ set_surprise = {"samples": np.where(suprise_mask)[0].tolist(), "type": 'surprise'}
+ set_control = {"samples": np.where(control_mask)[0].tolist(), "type": 'control'}
+ set_test_array = [set_control, set_surprise]
+
+ return set_test_array, evaluation_modes \ No newline at end of file