aboutsummaryrefslogtreecommitdiff
path: root/scripts/evaluation_clevrer.py
diff options
context:
space:
mode:
authorfredeee2023-11-02 10:47:21 +0100
committerfredeee2023-11-02 10:47:21 +0100
commitf8302ee886ef9b631f11a52900dac964a61350e1 (patch)
tree87288be6f851ab69405e524b81940c501c52789a /scripts/evaluation_clevrer.py
parentf16fef1ab9371e1c81a2e0b2fbea59dee285a9f8 (diff)
initiaƶ commit
Diffstat (limited to 'scripts/evaluation_clevrer.py')
-rw-r--r--scripts/evaluation_clevrer.py311
1 files changed, 311 insertions, 0 deletions
diff --git a/scripts/evaluation_clevrer.py b/scripts/evaluation_clevrer.py
new file mode 100644
index 0000000..a43d50c
--- /dev/null
+++ b/scripts/evaluation_clevrer.py
@@ -0,0 +1,311 @@
+import pickle
+import torch as th
+from torch.utils.data import Dataset, DataLoader, Subset
+from torch import nn
+import os
+from scripts.utils.plot_utils import plot_timestep
+from scripts.utils.eval_metrics import masks_to_boxes, pred_eval_step, postproc_mask
+from scripts.utils.eval_utils import append_statistics, load_model, setup_result_folders, store_statistics
+from scripts.utils.configuration import Configuration
+from scripts.utils.io import init_device
+import numpy as np
+from einops import rearrange, repeat, reduce
+from copy import deepcopy
+import lpips
+import torchvision.transforms as transforms
+
+def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_frequency= 1, plot_first_samples = 0):
+
+ # Set up cpu or gpu training
+ device, verbose = init_device(cfg)
+
+ # Config
+ cfg_net = cfg.model
+ cfg_net.batch_size = 2 if verbose else 32
+ cfg_net.batch_size = 1 if plot_first_samples > 0 else cfg_net.batch_size # only plotting first samples
+
+ # Load model
+ net = load_model(cfg, cfg_net, file, device)
+ net.eval()
+
+ # config
+ object_view = True
+ individual_views = False
+ root_path = None
+ plotting_mode = (cfg_net.batch_size == 1) and (plot_first_samples > 0)
+
+ # get evaluation sets
+ set_test_array, evaluation_modes = get_evaluation_sets(dataset)
+
+ # memory
+ statistics_template = {'set': [], 'evalmode': [], 'scene': [], 'frame': [], 'image_error_mse': []}
+ statistics_complete_slots = {'set': [], 'evalmode': [], 'scene': [], 'frame': [], 'slot':[], 'bound': [], 'slot_error': [], 'rawmask_size': [], 'alpha_pos': [], 'alpha_ges': []}
+ metric_complete = None
+
+ # Evaluation Specifics
+ burn_in_length = 6
+ rollout_length = 42
+ cfg.defaults.skip_frames = 2
+ blackout_p = 0.2
+ target_size = (64,64)
+ dataset.burn_in_length = burn_in_length
+ dataset.rollout_length = rollout_length
+ dataset.skip_length = cfg.defaults.skip_frames
+
+ # Transformation utils
+ to_small = transforms.Resize(target_size)
+ to_normalize = transforms.Normalize((0.5, ), (0.5, ))
+ to_smallnorm = transforms.Compose([to_small, to_normalize])
+
+ # Losses
+ lpipsloss = lpips.LPIPS(net='vgg').to(device)
+ mseloss = nn.MSELoss()
+
+ for set_test in set_test_array:
+
+ for evaluation_mode in evaluation_modes:
+ print(f'Start evaluation loop: {evaluation_mode}')
+
+ # load data
+ dataloader = DataLoader(
+ Subset(dataset, range(plot_first_samples)) if plotting_mode else dataset,
+ num_workers = 1,
+ pin_memory = False,
+ batch_size = cfg_net.batch_size,
+ shuffle = False,
+ drop_last = True,
+ )
+
+ # memory
+ root_path, plot_path = setup_result_folders(file, n, set_test, evaluation_mode, object_view, individual_views)
+ metric_complete = {'mse': [], 'ssim': [], 'psnr': [], 'percept_dist': [], 'ari': [], 'fari': [], 'miou': [], 'ap': [], 'ar': [], 'blackout': []}
+
+ with th.no_grad():
+ for i, input in enumerate(dataloader):
+ print(f'Processing sample {i+1}/{len(dataloader)}', flush=True)
+
+ # Load data
+ tensor = input[0].float().to(device)
+ background_fix = input[1].to(device)
+ gt_mask = input[2].to(device)
+ gt_bbox = input[3].to(device)
+ gt_pres_mask = input[4].to(device)
+ sequence_len = tensor.shape[1]
+
+ # Placehodlers
+ mask_cur = None
+ mask_last = None
+ rawmask_last = None
+ position_last = None
+ gestalt_last = None
+ priority_last = None
+ slots_occlusionfactor = None
+ error_last = None
+
+ # Memory
+ pred = th.zeros((cfg_net.batch_size, rollout_length, 3, target_size[0], target_size[1])).to(device)
+ gt = th.zeros((cfg_net.batch_size, rollout_length, 3, target_size[0], target_size[1])).to(device)
+ pred_mask = th.zeros((cfg_net.batch_size, rollout_length, target_size[0], target_size[1])).to(device)
+ statistics_batch = deepcopy(statistics_template)
+
+ # Counters
+ num_rollout = 0
+ num_burnin = 0
+ blackout_mem = [0]
+
+ # Loop through frames
+ for t_index,t in enumerate(range(-cfg.defaults.teacher_forcing, sequence_len-1)):
+
+ # Move to next frame
+ t_run = max(t, 0)
+ input = tensor[:,t_run]
+ target = th.clip(tensor[:,t_run+1], 0, 1)
+
+ rollout_index = t_run - burn_in_length
+ if (rollout_index >= 0) and (evaluation_mode == 'blackout'):
+ blackout = (th.rand(1) < blackout_p).float().to(device)
+ input = blackout * (input * 0) + (1-blackout) * input
+ error_last = blackout * (error_last * 0) + (1-blackout) * error_last
+ blackout_mem.append(blackout.int().cpu().item())
+
+ elif t>=0:
+ num_burnin += 1
+
+ if (rollout_index >= 0) and (evaluation_mode != 'blackout'):
+ blackout_mem.append(0)
+
+ # obtain prediction
+ (
+ output_next,
+ position_next,
+ gestalt_next,
+ priority_next,
+ mask_next,
+ rawmask_next,
+ object_next,
+ background,
+ slots_occlusionfactor,
+ output_cur,
+ position_cur,
+ gestalt_cur,
+ priority_cur,
+ mask_cur,
+ rawmask_cur,
+ object_cur,
+ position_encoder_cur,
+ slots_bounded,
+ slots_partially_occluded_cur,
+ slots_occluded_cur,
+ slots_partially_occluded_next,
+ slots_occluded_next,
+ slots_closed,
+ output_hidden,
+ largest_object,
+ rawmask_hidden,
+ object_hidden
+ ) = net(
+ input,
+ error_last,
+ mask_last,
+ rawmask_last,
+ position_last,
+ gestalt_last,
+ priority_last,
+ background_fix,
+ slots_occlusionfactor,
+ reset = (t == -cfg.defaults.teacher_forcing),
+ evaluate=True,
+ warmup = (t < 0),
+ shuffleslots = False,
+ reset_mask = (t <= 0),
+ allow_spawn = True,
+ show_hidden = plotting_mode,
+ clean_slots = (t <= 0),
+ )
+
+ # 1. Track error for plots
+ if t >= 0:
+
+ # save prediction
+ if rollout_index >= -1:
+ pred[:,rollout_index+1] = to_smallnorm(output_next)
+ gt[:,rollout_index+1] = to_smallnorm(target)
+ pred_mask[:,rollout_index+1] = postproc_mask(to_small(mask_next)[:,None,:,None])[:, 0]
+ num_rollout += 1
+
+ if plotting_mode and object_view:
+ statistics_complete_slots, statistics_batch = compute_plot_statistics(cfg_net, statistics_complete_slots, mseloss, set_test, evaluation_mode, i, statistics_batch, t, target, output_next, mask_next, slots_bounded, slots_closed, rawmask_hidden)
+
+ # 2. Remember output
+ mask_last = mask_next.clone()
+ rawmask_last = rawmask_next.clone()
+ position_last = position_next.clone()
+ gestalt_last = gestalt_next.clone()
+ priority_last = priority_next.clone()
+
+ # 3. Error for next frame
+ bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+
+ # prediction error
+ error_next = th.sqrt(reduce((target - output_next)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+ error_next = th.sqrt(error_next) * bg_error_next
+ error_last = error_next.clone()
+
+ # 4. plot preparation
+ if plotting_mode and (t % plot_frequency == 0):
+ plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, None, None, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i)
+
+ # Compute prediction accuracy based on Slotformer metrics (ARI, FARI, mIoU, AP, AR)
+ if not plotting_mode:
+ for b in range(cfg_net.batch_size):
+ metric_dict = pred_eval_step(
+ gt=gt[b:b+1],
+ pred=pred[b:b+1],
+ pred_mask=pred_mask.long()[b:b+1],
+ pred_bbox=masks_to_boxes(pred_mask.long()[b:b+1], cfg_net.num_objects+1),
+ gt_mask=gt_mask.long()[b:b+1],
+ gt_pres_mask=gt_pres_mask[b:b+1],
+ gt_bbox=gt_bbox[b:b+1],
+ lpips_fn=lpipsloss,
+ eval_traj=True,
+ )
+ metric_dict['blackout'] = blackout_mem
+ metric_complete = append_statistics(metric_dict, metric_complete)
+
+ # sanity check
+ if (num_rollout != rollout_length) and (num_burnin != burn_in_length) and (evaluation_mode == 'rollout'):
+ raise ValueError('Number of rollout steps and burnin steps must be equal to the sequence length.')
+
+ if not plotting_mode:
+ average_dic = compute_statistics_summary(metric_complete, evaluation_mode)
+
+ # Store statistics
+ with open(os.path.join(f'{root_path}/statistics', f'{evaluation_mode}_metric_complete.pkl'), 'wb') as f:
+ pickle.dump(metric_complete, f)
+ with open(os.path.join(f'{root_path}/statistics', f'{evaluation_mode}_metric_average.pkl'), 'wb') as f:
+ pickle.dump(average_dic, f)
+
+ print('-- Evaluation Done --')
+ if object_view and os.path.exists(f'{root_path}/tmp.jpg'):
+ os.remove(f'{root_path}/tmp.jpg')
+ pass
+
+def compute_statistics_summary(metric_complete, evaluation_mode):
+ average_dic = {}
+ for key in metric_complete:
+ # take average over all frames
+ average_dic[key + 'complete_average'] = np.mean(metric_complete[key])
+ average_dic[key + 'complete_std'] = np.std(metric_complete[key])
+ print(f'{key} complete average: {average_dic[key + "complete_average"]:.4f} +/- {average_dic[key + "complete_std"]:.4f}')
+
+ if evaluation_mode == 'blackout':
+ # take average only for frames where blackout occurs
+ blackout_mask = np.array(metric_complete['blackout']) > 0
+ average_dic[key + 'blackout_average'] = np.mean(np.array(metric_complete[key])[blackout_mask])
+ average_dic[key + 'blackout_std'] = np.std(np.array(metric_complete[key])[blackout_mask])
+ average_dic[key + 'visible_average'] = np.mean(np.array(metric_complete[key])[blackout_mask == False])
+ average_dic[key + 'visible_std'] = np.std(np.array(metric_complete[key])[blackout_mask == False])
+
+ print(f'{key} blackout average: {average_dic[key + "blackout_average"]:.4f} +/- {average_dic[key + "blackout_std"]:.4f}')
+ print(f'{key} visible average: {average_dic[key + "visible_average"]:.4f} +/- {average_dic[key + "visible_std"]:.4f}')
+ return average_dic
+
+def compute_plot_statistics(cfg_net, statistics_complete_slots, mseloss, set_test, evaluation_mode, i, statistics_batch, t, target, output_next, mask_next, slots_bounded, slots_closed, rawmask_hidden):
+ statistics_batch = store_statistics(statistics_batch,
+ set_test['type'],
+ evaluation_mode,
+ set_test['samples'][i],
+ t,
+ mseloss(output_next, target).item()
+ )
+
+ # compute slot-wise prediction error
+ output_slot = repeat(mask_next[:,:-1], 'b o h w -> b o 3 h w') * repeat(output_next, 'b c h w -> b o c h w', o=cfg_net.num_objects)
+ target_slot = repeat(mask_next[:,:-1], 'b o h w -> b o 3 h w') * repeat(target, 'b c h w -> b o c h w', o=cfg_net.num_objects)
+ slot_error = reduce((output_slot - target_slot)**2, 'b o c h w -> b o', 'mean')
+
+ # compute rawmask_size
+ rawmask_size = reduce(rawmask_hidden[:, :-1], 'b o h w-> b o', 'sum')
+
+ statistics_complete_slots = store_statistics(statistics_complete_slots,
+ [set_test['type']] * cfg_net.num_objects,
+ [evaluation_mode] * cfg_net.num_objects,
+ [set_test['samples'][i]] * cfg_net.num_objects,
+ [t] * cfg_net.num_objects,
+ range(cfg_net.num_objects),
+ slots_bounded.cpu().numpy().flatten().astype(int),
+ slot_error.cpu().numpy().flatten(),
+ rawmask_size.cpu().numpy().flatten(),
+ slots_closed[:, :, 1].cpu().numpy().flatten(),
+ slots_closed[:, :, 0].cpu().numpy().flatten(),
+ extend = True)
+
+ return statistics_complete_slots,statistics_batch
+
+def get_evaluation_sets(dataset):
+
+ set = {"samples": np.arange(len(dataset), dtype=int), "type": "test"}
+ evaluation_modes = ['blackout', 'open'] # use 'open' for no blackouts
+ set_test_array = [set]
+
+ return set_test_array, evaluation_modes