diff options
Diffstat (limited to 'scripts/validation.py')
-rw-r--r-- | scripts/validation.py | 270 |
1 files changed, 264 insertions, 6 deletions
diff --git a/scripts/validation.py b/scripts/validation.py index 5b1638d..b59cc39 100644 --- a/scripts/validation.py +++ b/scripts/validation.py @@ -1,8 +1,11 @@ +import os import torch as th from torch import nn from torch.utils.data import DataLoader import numpy as np from einops import rearrange, repeat, reduce +from scripts.evaluation_bb import distance_eval_step +from scripts.evaluation_clevrer import compute_statistics_summary from scripts.utils.configuration import Configuration from model.loci import Loci import time @@ -10,12 +13,19 @@ import lpips from skimage.metrics import structural_similarity as ssimloss from skimage.metrics import peak_signal_noise_ratio as psnrloss -def validation_adept(valloader: DataLoader, net: Loci, cfg: Configuration, device): +from scripts.utils.eval_metrics import masks_to_boxes, postproc_mask, pred_eval_step +from scripts.utils.eval_utils import append_statistics, compute_position_from_mask +from scripts.utils.plot_utils import plot_timestep + +def validation_adept(valloader: DataLoader, net: Loci, cfg: Configuration, device, writer, epoch, root_path): # memory mseloss = nn.MSELoss() - avgloss = 0 + loss_next = 0 start_time = time.time() + cfg_net = cfg.model + num_steps = 0 + plot_path = os.path.join(root_path, 'plots', f'epoch_{epoch}') with th.no_grad(): for i, input in enumerate(valloader): @@ -103,7 +113,8 @@ def validation_adept(valloader: DataLoader, net: Loci, cfg: Configuration, devic # 1. Track error if t >= 0: loss = mseloss(output_next, target) - avgloss += loss.item() + loss_next += loss.item() + num_steps += 1 # 2. Remember output mask_last = mask_next.clone() @@ -122,12 +133,19 @@ def validation_adept(valloader: DataLoader, net: Loci, cfg: Configuration, devic error_next = th.sqrt(error_next) * bg_error_next error_last = error_next.clone() - print(f"Validation loss: {avgloss / len(valloader.dataset):.2e}, Time: {time.time() - start_time}") - + # PLotting + if (i == 0) and (t_index % 2 == 0) and (epoch % 3 == 0): + os.makedirs(os.path.join(plot_path, 'object'), exist_ok=True) + openings = net.get_openings() + img_tensor = plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, None, None, error_next, None, True, False, None, None, sequence_len, root_path, plot_path, t_index, t, i, openings=openings) + + print(f"Validation loss: {loss_next / num_steps:.2e}, Time: {time.time() - start_time}") + writer.add_scalar('Val/Prediction Loss', loss_next / num_steps, epoch) + pass -def validation_clevrer(valloader: DataLoader, net: Loci, cfg: Configuration, device): +def validation_clevrer(valloader: DataLoader, net: Loci, cfg: Configuration, device, writer, epoch, root_path): # memory mseloss = nn.MSELoss() @@ -140,6 +158,9 @@ def validation_clevrer(valloader: DataLoader, net: Loci, cfg: Configuration, dev burn_in_length = 6 rollout_length = 42 + + plot_path = os.path.join(root_path, 'plots', f'epoch_{epoch}') + os.makedirs(os.path.join(plot_path, 'object'), exist_ok=True) with th.no_grad(): for i, input in enumerate(valloader): @@ -256,6 +277,243 @@ def validation_clevrer(valloader: DataLoader, net: Loci, cfg: Configuration, dev error_next = th.sqrt(error_next) * bg_error_next error_last = error_next.clone() + # PLotting + if i == 0: + openings = net.get_openings() + img_tensor = plot_timestep(cfg, cfg.model, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, None, None, error_next, None, True, False, None, None, sequence_len, root_path, plot_path, t_index, t, i, openings=openings) + + print(f"MSE loss: {avgloss_mse / len(valloader.dataset):.2e}, LPIPS loss: {avgloss_lpips / len(valloader.dataset):.2e}, PSNR loss: {avgloss_psnr / len(valloader.dataset):.2e}, SSIM loss: {avgloss_ssim / len(valloader.dataset):.2e}, Time: {time.time() - start_time}") + writer.add_scalar('Val/Prediction Loss', avgloss_mse / len(valloader.dataset), epoch) + writer.add_scalar('Val/LPIPS Loss', avgloss_lpips / len(valloader.dataset), epoch) + writer.add_scalar('Val/PSNR Loss', avgloss_psnr / len(valloader.dataset), epoch) + writer.add_scalar('Val/SSIM Loss', avgloss_ssim / len(valloader.dataset), epoch) + + pass + + +def validation_bb(valloader: DataLoader, net: Loci, cfg: Configuration, device, writer, epoch, root_path): + + # memory + start_time = time.time() + net.eval() + evaluation_mode = 'vidpred_black' + use_meds = True + + # Evaluation Specifics + burn_in_length = 10 + rollout_length = 20 + rollout_length_stats = 10 # only consider the first 10 frames for statistics + target_size = (64, 64) + + # Losses + lpipsloss = lpips.LPIPS(net='vgg').to(device) + mseloss = nn.MSELoss() + metric_complete = {'mse': [], 'ssim': [], 'psnr': [], 'percept_dist': [], 'ari': [], 'fari': [], 'miou': [], 'ap': [], 'ar': [], 'meds': [], 'ari_hidden': [], 'fari_hidden': [], 'miou_hidden': []} + loss_next = 0.0 + loss_cur = 0.0 + num_steps = 0 + plot_path = os.path.join(root_path, 'plots', f'epoch_{epoch}') + os.makedirs(os.path.join(plot_path, 'object'), exist_ok=True) + + with th.no_grad(): + for i, input in enumerate(valloader): + + # Load data + tensor = input[0].float().to(device) + background_fix = input[1].to(device) + gt_pos = input[2].to(device) + gt_mask = input[3].to(device) + gt_pres_mask = input[4].to(device) + gt_hidden_mask = input[5].to(device) + sequence_len = tensor.shape[1] + + # placehodlers + mask_cur = None + mask_last = None + rawmask_last = None + position_last = None + gestalt_last = None + priority_last = None + gt_positions_target = None + slots_occlusionfactor = None + error_last = None + + # Memory + cfg_net = cfg.model + num_objects_bb = gt_pos.shape[2] + pred_pos_batch = th.zeros((cfg_net.batch_size, rollout_length, num_objects_bb, 2)).to(device) + gt_pos_batch = th.zeros((cfg_net.batch_size, rollout_length, num_objects_bb, 2)).to(device) + pred_img_batch = th.zeros((cfg_net.batch_size, rollout_length, 3, target_size[0], target_size[1])).to(device) + gt_img_batch = th.zeros((cfg_net.batch_size, rollout_length, 3, target_size[0], target_size[1])).to(device) + pred_mask_batch = th.zeros((cfg_net.batch_size, rollout_length, target_size[0], target_size[1])).to(device) + pred_hidden_mask_batch = th.zeros((cfg_net.batch_size, rollout_length, target_size[0], target_size[1])).to(device) + # Counters + num_rollout = 0 + num_burnin = 0 + + # loop through frames + for t_index,t in enumerate(range(-cfg.defaults.teacher_forcing, burn_in_length+rollout_length)): + + # Move to next frame + t_run = max(t, 0) + input = tensor[:,t_run] + target_cur = tensor[:,t_run] + target = th.clip(tensor[:,t_run+1], 0, 1) + gt_pos_t = gt_pos[:,t_run+1]/32-1 + gt_pos_t = th.concat((gt_pos_t, th.ones_like(gt_pos_t[:,:,:1])), dim=2) + + rollout_index = t_run - burn_in_length + rollout_active = False + if t>=0: + if rollout_index >= 0: + num_rollout += 1 + if (evaluation_mode == 'vidpred_black'): + input = output_next * 0 + error_last = error_last * 0 + rollout_active = True + elif (evaluation_mode == 'vidpred_auto'): + input = output_next + error_last = error_last * 0 + rollout_active = True + else: + num_burnin += 1 + + # obtain prediction + ( + output_next, + position_next, + gestalt_next, + priority_next, + mask_next, + rawmask_next, + object_next, + background, + slots_occlusionfactor, + output_cur, + position_cur, + gestalt_cur, + priority_cur, + mask_cur, + rawmask_cur, + object_cur, + position_encoder_cur, + slots_bounded, + slots_partially_occluded_cur, + slots_occluded_cur, + slots_partially_occluded_next, + slots_occluded_next, + slots_closed, + output_hidden, + largest_object, + rawmask_hidden, + object_hidden + ) = net( + input, + error_last, + mask_last, + rawmask_last, + position_last, + gestalt_last, + priority_last, + background_fix, + slots_occlusionfactor, + reset = (t == -cfg.defaults.teacher_forcing), + evaluate=True, + warmup = (t < 0), + shuffleslots = True, + reset_mask = (t <= 0), + allow_spawn = True, + show_hidden = False, + clean_slots = False, + ) + + # 1. Track error + if t >= 0: + + if (rollout_index >= 0): + # store positions per batch + if use_meds: + if False: + pred_pos_batch[:,rollout_index] = rearrange(position_next, 'b (o c) -> b o c', o=cfg_net.num_objects)[:,:,:2] + else: + pred_pos_batch[:,rollout_index] = compute_position_from_mask(rawmask_next) + + gt_pos_batch[:,rollout_index] = gt_pos_t[:,:,:2] + + pred_img_batch[:,rollout_index] = output_next + gt_img_batch[:,rollout_index] = target + + # Here we compute only the foreground segmentation mask + pred_mask_batch[:,rollout_index] = postproc_mask(mask_next[:,None,:,None])[:, 0] + + # Here we compute the hidden segmentation + occluded_cur = th.clip(rawmask_next - mask_next, 0, 1)[:,:-1] + occluded_sum_cur = 1-(reduce(occluded_cur, 'b c h w -> b h w', 'max') > 0.5).float() + occluded_cur = th.cat((occluded_cur, occluded_sum_cur[:,None]), dim=1) + pred_hidden_mask_batch[:,rollout_index] = postproc_mask(occluded_cur[:,None,:,None])[:, 0] + + # 2. Remember output + mask_last = mask_next.clone() + rawmask_last = rawmask_next.clone() + position_last = position_next.clone() + gestalt_last = gestalt_next.clone() + priority_last = priority_next.clone() + + # 3. Error for next frame + # background error + bg_error_cur = th.sqrt(reduce((input - background)**2, 'b c h w -> b 1 h w', 'mean')).detach() + bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach() + + # prediction error + error_next = th.sqrt(reduce((target - output_next)**2, 'b c h w -> b 1 h w', 'mean')).detach() + error_next = th.sqrt(error_next) * bg_error_next + error_last = error_next.clone() + + # Prediction and encoder Loss + loss_next += mseloss(output_next * bg_error_next, target * bg_error_next) + loss_cur += mseloss(output_cur * bg_error_cur, input * bg_error_cur) + num_steps += 1 + + # PLotting + if i == 0: + openings = net.get_openings() + img_tensor = plot_timestep(cfg, cfg_net, input, target_cur, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_pos_t, None, error_next, None, True, False, None, None, sequence_len, root_path, plot_path, t_index, t, i, rollout_mode=rollout_active, openings=openings) + + for b in range(cfg_net.batch_size): + + # perceptual similarity from slotformer paper + metric_dict = pred_eval_step( + gt = gt_img_batch[b:b+1], + pred = pred_img_batch[b:b+1], + pred_mask = pred_mask_batch.long()[b:b+1], + pred_mask_hidden = pred_hidden_mask_batch.long()[b:b+1], + pred_bbox = None, + gt_mask = gt_mask.long()[b:b+1, burn_in_length+1:burn_in_length+rollout_length+1], + gt_mask_hidden = gt_hidden_mask.long()[b:b+1, burn_in_length+1:burn_in_length+rollout_length+1], + gt_pres_mask = gt_pres_mask[b:b+1, burn_in_length+1:burn_in_length+rollout_length+1], + gt_bbox = None, + lpips_fn = lpipsloss, + eval_traj = True, + ) + + metric_dict['meds'] = distance_eval_step(gt_pos_batch[b], pred_pos_batch[b]) + metric_complete = append_statistics(metric_dict, metric_complete) + + # sanity check + if (num_rollout != rollout_length) and (num_burnin != burn_in_length): + raise ValueError('Number of rollout steps and burnin steps must be equal to the sequence length.') + + dic = compute_statistics_summary(metric_complete, evaluation_mode, consider_first_n_frames=rollout_length_stats) + writer.add_scalar('Val/Meds', dic['meds_complete_sum'], epoch) + + writer.add_scalar('Val/ARI_hidden', dic['ari_hidden_complete_average'], epoch) + writer.add_scalar('Val/ARI', dic['ari_complete_average'], epoch) + writer.add_scalar('Val/LPIPS', dic['percept_dist_complete_average'], epoch) + + writer.add_scalar('Val/Prediction Loss', loss_next / num_steps, epoch) + writer.add_scalar('Val/Encoding Loss', loss_cur / num_steps, epoch) + + net.train() pass
\ No newline at end of file |