diff options
author | fredeee | 2023-11-02 10:47:21 +0100 |
---|---|---|
committer | fredeee | 2023-11-02 10:47:21 +0100 |
commit | f8302ee886ef9b631f11a52900dac964a61350e1 (patch) | |
tree | 87288be6f851ab69405e524b81940c501c52789a /scripts/evaluation_clevrer.py | |
parent | f16fef1ab9371e1c81a2e0b2fbea59dee285a9f8 (diff) |
initiaƶ commit
Diffstat (limited to 'scripts/evaluation_clevrer.py')
-rw-r--r-- | scripts/evaluation_clevrer.py | 311 |
1 files changed, 311 insertions, 0 deletions
diff --git a/scripts/evaluation_clevrer.py b/scripts/evaluation_clevrer.py new file mode 100644 index 0000000..a43d50c --- /dev/null +++ b/scripts/evaluation_clevrer.py @@ -0,0 +1,311 @@ +import pickle +import torch as th +from torch.utils.data import Dataset, DataLoader, Subset +from torch import nn +import os +from scripts.utils.plot_utils import plot_timestep +from scripts.utils.eval_metrics import masks_to_boxes, pred_eval_step, postproc_mask +from scripts.utils.eval_utils import append_statistics, load_model, setup_result_folders, store_statistics +from scripts.utils.configuration import Configuration +from scripts.utils.io import init_device +import numpy as np +from einops import rearrange, repeat, reduce +from copy import deepcopy +import lpips +import torchvision.transforms as transforms + +def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_frequency= 1, plot_first_samples = 0): + + # Set up cpu or gpu training + device, verbose = init_device(cfg) + + # Config + cfg_net = cfg.model + cfg_net.batch_size = 2 if verbose else 32 + cfg_net.batch_size = 1 if plot_first_samples > 0 else cfg_net.batch_size # only plotting first samples + + # Load model + net = load_model(cfg, cfg_net, file, device) + net.eval() + + # config + object_view = True + individual_views = False + root_path = None + plotting_mode = (cfg_net.batch_size == 1) and (plot_first_samples > 0) + + # get evaluation sets + set_test_array, evaluation_modes = get_evaluation_sets(dataset) + + # memory + statistics_template = {'set': [], 'evalmode': [], 'scene': [], 'frame': [], 'image_error_mse': []} + statistics_complete_slots = {'set': [], 'evalmode': [], 'scene': [], 'frame': [], 'slot':[], 'bound': [], 'slot_error': [], 'rawmask_size': [], 'alpha_pos': [], 'alpha_ges': []} + metric_complete = None + + # Evaluation Specifics + burn_in_length = 6 + rollout_length = 42 + cfg.defaults.skip_frames = 2 + blackout_p = 0.2 + target_size = (64,64) + dataset.burn_in_length = burn_in_length + dataset.rollout_length = rollout_length + dataset.skip_length = cfg.defaults.skip_frames + + # Transformation utils + to_small = transforms.Resize(target_size) + to_normalize = transforms.Normalize((0.5, ), (0.5, )) + to_smallnorm = transforms.Compose([to_small, to_normalize]) + + # Losses + lpipsloss = lpips.LPIPS(net='vgg').to(device) + mseloss = nn.MSELoss() + + for set_test in set_test_array: + + for evaluation_mode in evaluation_modes: + print(f'Start evaluation loop: {evaluation_mode}') + + # load data + dataloader = DataLoader( + Subset(dataset, range(plot_first_samples)) if plotting_mode else dataset, + num_workers = 1, + pin_memory = False, + batch_size = cfg_net.batch_size, + shuffle = False, + drop_last = True, + ) + + # memory + root_path, plot_path = setup_result_folders(file, n, set_test, evaluation_mode, object_view, individual_views) + metric_complete = {'mse': [], 'ssim': [], 'psnr': [], 'percept_dist': [], 'ari': [], 'fari': [], 'miou': [], 'ap': [], 'ar': [], 'blackout': []} + + with th.no_grad(): + for i, input in enumerate(dataloader): + print(f'Processing sample {i+1}/{len(dataloader)}', flush=True) + + # Load data + tensor = input[0].float().to(device) + background_fix = input[1].to(device) + gt_mask = input[2].to(device) + gt_bbox = input[3].to(device) + gt_pres_mask = input[4].to(device) + sequence_len = tensor.shape[1] + + # Placehodlers + mask_cur = None + mask_last = None + rawmask_last = None + position_last = None + gestalt_last = None + priority_last = None + slots_occlusionfactor = None + error_last = None + + # Memory + pred = th.zeros((cfg_net.batch_size, rollout_length, 3, target_size[0], target_size[1])).to(device) + gt = th.zeros((cfg_net.batch_size, rollout_length, 3, target_size[0], target_size[1])).to(device) + pred_mask = th.zeros((cfg_net.batch_size, rollout_length, target_size[0], target_size[1])).to(device) + statistics_batch = deepcopy(statistics_template) + + # Counters + num_rollout = 0 + num_burnin = 0 + blackout_mem = [0] + + # Loop through frames + for t_index,t in enumerate(range(-cfg.defaults.teacher_forcing, sequence_len-1)): + + # Move to next frame + t_run = max(t, 0) + input = tensor[:,t_run] + target = th.clip(tensor[:,t_run+1], 0, 1) + + rollout_index = t_run - burn_in_length + if (rollout_index >= 0) and (evaluation_mode == 'blackout'): + blackout = (th.rand(1) < blackout_p).float().to(device) + input = blackout * (input * 0) + (1-blackout) * input + error_last = blackout * (error_last * 0) + (1-blackout) * error_last + blackout_mem.append(blackout.int().cpu().item()) + + elif t>=0: + num_burnin += 1 + + if (rollout_index >= 0) and (evaluation_mode != 'blackout'): + blackout_mem.append(0) + + # obtain prediction + ( + output_next, + position_next, + gestalt_next, + priority_next, + mask_next, + rawmask_next, + object_next, + background, + slots_occlusionfactor, + output_cur, + position_cur, + gestalt_cur, + priority_cur, + mask_cur, + rawmask_cur, + object_cur, + position_encoder_cur, + slots_bounded, + slots_partially_occluded_cur, + slots_occluded_cur, + slots_partially_occluded_next, + slots_occluded_next, + slots_closed, + output_hidden, + largest_object, + rawmask_hidden, + object_hidden + ) = net( + input, + error_last, + mask_last, + rawmask_last, + position_last, + gestalt_last, + priority_last, + background_fix, + slots_occlusionfactor, + reset = (t == -cfg.defaults.teacher_forcing), + evaluate=True, + warmup = (t < 0), + shuffleslots = False, + reset_mask = (t <= 0), + allow_spawn = True, + show_hidden = plotting_mode, + clean_slots = (t <= 0), + ) + + # 1. Track error for plots + if t >= 0: + + # save prediction + if rollout_index >= -1: + pred[:,rollout_index+1] = to_smallnorm(output_next) + gt[:,rollout_index+1] = to_smallnorm(target) + pred_mask[:,rollout_index+1] = postproc_mask(to_small(mask_next)[:,None,:,None])[:, 0] + num_rollout += 1 + + if plotting_mode and object_view: + statistics_complete_slots, statistics_batch = compute_plot_statistics(cfg_net, statistics_complete_slots, mseloss, set_test, evaluation_mode, i, statistics_batch, t, target, output_next, mask_next, slots_bounded, slots_closed, rawmask_hidden) + + # 2. Remember output + mask_last = mask_next.clone() + rawmask_last = rawmask_next.clone() + position_last = position_next.clone() + gestalt_last = gestalt_next.clone() + priority_last = priority_next.clone() + + # 3. Error for next frame + bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach() + + # prediction error + error_next = th.sqrt(reduce((target - output_next)**2, 'b c h w -> b 1 h w', 'mean')).detach() + error_next = th.sqrt(error_next) * bg_error_next + error_last = error_next.clone() + + # 4. plot preparation + if plotting_mode and (t % plot_frequency == 0): + plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, None, None, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i) + + # Compute prediction accuracy based on Slotformer metrics (ARI, FARI, mIoU, AP, AR) + if not plotting_mode: + for b in range(cfg_net.batch_size): + metric_dict = pred_eval_step( + gt=gt[b:b+1], + pred=pred[b:b+1], + pred_mask=pred_mask.long()[b:b+1], + pred_bbox=masks_to_boxes(pred_mask.long()[b:b+1], cfg_net.num_objects+1), + gt_mask=gt_mask.long()[b:b+1], + gt_pres_mask=gt_pres_mask[b:b+1], + gt_bbox=gt_bbox[b:b+1], + lpips_fn=lpipsloss, + eval_traj=True, + ) + metric_dict['blackout'] = blackout_mem + metric_complete = append_statistics(metric_dict, metric_complete) + + # sanity check + if (num_rollout != rollout_length) and (num_burnin != burn_in_length) and (evaluation_mode == 'rollout'): + raise ValueError('Number of rollout steps and burnin steps must be equal to the sequence length.') + + if not plotting_mode: + average_dic = compute_statistics_summary(metric_complete, evaluation_mode) + + # Store statistics + with open(os.path.join(f'{root_path}/statistics', f'{evaluation_mode}_metric_complete.pkl'), 'wb') as f: + pickle.dump(metric_complete, f) + with open(os.path.join(f'{root_path}/statistics', f'{evaluation_mode}_metric_average.pkl'), 'wb') as f: + pickle.dump(average_dic, f) + + print('-- Evaluation Done --') + if object_view and os.path.exists(f'{root_path}/tmp.jpg'): + os.remove(f'{root_path}/tmp.jpg') + pass + +def compute_statistics_summary(metric_complete, evaluation_mode): + average_dic = {} + for key in metric_complete: + # take average over all frames + average_dic[key + 'complete_average'] = np.mean(metric_complete[key]) + average_dic[key + 'complete_std'] = np.std(metric_complete[key]) + print(f'{key} complete average: {average_dic[key + "complete_average"]:.4f} +/- {average_dic[key + "complete_std"]:.4f}') + + if evaluation_mode == 'blackout': + # take average only for frames where blackout occurs + blackout_mask = np.array(metric_complete['blackout']) > 0 + average_dic[key + 'blackout_average'] = np.mean(np.array(metric_complete[key])[blackout_mask]) + average_dic[key + 'blackout_std'] = np.std(np.array(metric_complete[key])[blackout_mask]) + average_dic[key + 'visible_average'] = np.mean(np.array(metric_complete[key])[blackout_mask == False]) + average_dic[key + 'visible_std'] = np.std(np.array(metric_complete[key])[blackout_mask == False]) + + print(f'{key} blackout average: {average_dic[key + "blackout_average"]:.4f} +/- {average_dic[key + "blackout_std"]:.4f}') + print(f'{key} visible average: {average_dic[key + "visible_average"]:.4f} +/- {average_dic[key + "visible_std"]:.4f}') + return average_dic + +def compute_plot_statistics(cfg_net, statistics_complete_slots, mseloss, set_test, evaluation_mode, i, statistics_batch, t, target, output_next, mask_next, slots_bounded, slots_closed, rawmask_hidden): + statistics_batch = store_statistics(statistics_batch, + set_test['type'], + evaluation_mode, + set_test['samples'][i], + t, + mseloss(output_next, target).item() + ) + + # compute slot-wise prediction error + output_slot = repeat(mask_next[:,:-1], 'b o h w -> b o 3 h w') * repeat(output_next, 'b c h w -> b o c h w', o=cfg_net.num_objects) + target_slot = repeat(mask_next[:,:-1], 'b o h w -> b o 3 h w') * repeat(target, 'b c h w -> b o c h w', o=cfg_net.num_objects) + slot_error = reduce((output_slot - target_slot)**2, 'b o c h w -> b o', 'mean') + + # compute rawmask_size + rawmask_size = reduce(rawmask_hidden[:, :-1], 'b o h w-> b o', 'sum') + + statistics_complete_slots = store_statistics(statistics_complete_slots, + [set_test['type']] * cfg_net.num_objects, + [evaluation_mode] * cfg_net.num_objects, + [set_test['samples'][i]] * cfg_net.num_objects, + [t] * cfg_net.num_objects, + range(cfg_net.num_objects), + slots_bounded.cpu().numpy().flatten().astype(int), + slot_error.cpu().numpy().flatten(), + rawmask_size.cpu().numpy().flatten(), + slots_closed[:, :, 1].cpu().numpy().flatten(), + slots_closed[:, :, 0].cpu().numpy().flatten(), + extend = True) + + return statistics_complete_slots,statistics_batch + +def get_evaluation_sets(dataset): + + set = {"samples": np.arange(len(dataset), dtype=int), "type": "test"} + evaluation_modes = ['blackout', 'open'] # use 'open' for no blackouts + set_test_array = [set] + + return set_test_array, evaluation_modes |