aboutsummaryrefslogtreecommitdiff
path: root/scripts/validation.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/validation.py')
-rw-r--r--scripts/validation.py270
1 files changed, 264 insertions, 6 deletions
diff --git a/scripts/validation.py b/scripts/validation.py
index 5b1638d..b59cc39 100644
--- a/scripts/validation.py
+++ b/scripts/validation.py
@@ -1,8 +1,11 @@
+import os
import torch as th
from torch import nn
from torch.utils.data import DataLoader
import numpy as np
from einops import rearrange, repeat, reduce
+from scripts.evaluation_bb import distance_eval_step
+from scripts.evaluation_clevrer import compute_statistics_summary
from scripts.utils.configuration import Configuration
from model.loci import Loci
import time
@@ -10,12 +13,19 @@ import lpips
from skimage.metrics import structural_similarity as ssimloss
from skimage.metrics import peak_signal_noise_ratio as psnrloss
-def validation_adept(valloader: DataLoader, net: Loci, cfg: Configuration, device):
+from scripts.utils.eval_metrics import masks_to_boxes, postproc_mask, pred_eval_step
+from scripts.utils.eval_utils import append_statistics, compute_position_from_mask
+from scripts.utils.plot_utils import plot_timestep
+
+def validation_adept(valloader: DataLoader, net: Loci, cfg: Configuration, device, writer, epoch, root_path):
# memory
mseloss = nn.MSELoss()
- avgloss = 0
+ loss_next = 0
start_time = time.time()
+ cfg_net = cfg.model
+ num_steps = 0
+ plot_path = os.path.join(root_path, 'plots', f'epoch_{epoch}')
with th.no_grad():
for i, input in enumerate(valloader):
@@ -103,7 +113,8 @@ def validation_adept(valloader: DataLoader, net: Loci, cfg: Configuration, devic
# 1. Track error
if t >= 0:
loss = mseloss(output_next, target)
- avgloss += loss.item()
+ loss_next += loss.item()
+ num_steps += 1
# 2. Remember output
mask_last = mask_next.clone()
@@ -122,12 +133,19 @@ def validation_adept(valloader: DataLoader, net: Loci, cfg: Configuration, devic
error_next = th.sqrt(error_next) * bg_error_next
error_last = error_next.clone()
- print(f"Validation loss: {avgloss / len(valloader.dataset):.2e}, Time: {time.time() - start_time}")
-
+ # PLotting
+ if (i == 0) and (t_index % 2 == 0) and (epoch % 3 == 0):
+ os.makedirs(os.path.join(plot_path, 'object'), exist_ok=True)
+ openings = net.get_openings()
+ img_tensor = plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, None, None, error_next, None, True, False, None, None, sequence_len, root_path, plot_path, t_index, t, i, openings=openings)
+
+ print(f"Validation loss: {loss_next / num_steps:.2e}, Time: {time.time() - start_time}")
+ writer.add_scalar('Val/Prediction Loss', loss_next / num_steps, epoch)
+
pass
-def validation_clevrer(valloader: DataLoader, net: Loci, cfg: Configuration, device):
+def validation_clevrer(valloader: DataLoader, net: Loci, cfg: Configuration, device, writer, epoch, root_path):
# memory
mseloss = nn.MSELoss()
@@ -140,6 +158,9 @@ def validation_clevrer(valloader: DataLoader, net: Loci, cfg: Configuration, dev
burn_in_length = 6
rollout_length = 42
+
+ plot_path = os.path.join(root_path, 'plots', f'epoch_{epoch}')
+ os.makedirs(os.path.join(plot_path, 'object'), exist_ok=True)
with th.no_grad():
for i, input in enumerate(valloader):
@@ -256,6 +277,243 @@ def validation_clevrer(valloader: DataLoader, net: Loci, cfg: Configuration, dev
error_next = th.sqrt(error_next) * bg_error_next
error_last = error_next.clone()
+ # PLotting
+ if i == 0:
+ openings = net.get_openings()
+ img_tensor = plot_timestep(cfg, cfg.model, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, None, None, error_next, None, True, False, None, None, sequence_len, root_path, plot_path, t_index, t, i, openings=openings)
+
+
print(f"MSE loss: {avgloss_mse / len(valloader.dataset):.2e}, LPIPS loss: {avgloss_lpips / len(valloader.dataset):.2e}, PSNR loss: {avgloss_psnr / len(valloader.dataset):.2e}, SSIM loss: {avgloss_ssim / len(valloader.dataset):.2e}, Time: {time.time() - start_time}")
+ writer.add_scalar('Val/Prediction Loss', avgloss_mse / len(valloader.dataset), epoch)
+ writer.add_scalar('Val/LPIPS Loss', avgloss_lpips / len(valloader.dataset), epoch)
+ writer.add_scalar('Val/PSNR Loss', avgloss_psnr / len(valloader.dataset), epoch)
+ writer.add_scalar('Val/SSIM Loss', avgloss_ssim / len(valloader.dataset), epoch)
+
+ pass
+
+
+def validation_bb(valloader: DataLoader, net: Loci, cfg: Configuration, device, writer, epoch, root_path):
+
+ # memory
+ start_time = time.time()
+ net.eval()
+ evaluation_mode = 'vidpred_black'
+ use_meds = True
+
+ # Evaluation Specifics
+ burn_in_length = 10
+ rollout_length = 20
+ rollout_length_stats = 10 # only consider the first 10 frames for statistics
+ target_size = (64, 64)
+
+ # Losses
+ lpipsloss = lpips.LPIPS(net='vgg').to(device)
+ mseloss = nn.MSELoss()
+ metric_complete = {'mse': [], 'ssim': [], 'psnr': [], 'percept_dist': [], 'ari': [], 'fari': [], 'miou': [], 'ap': [], 'ar': [], 'meds': [], 'ari_hidden': [], 'fari_hidden': [], 'miou_hidden': []}
+ loss_next = 0.0
+ loss_cur = 0.0
+ num_steps = 0
+ plot_path = os.path.join(root_path, 'plots', f'epoch_{epoch}')
+ os.makedirs(os.path.join(plot_path, 'object'), exist_ok=True)
+
+ with th.no_grad():
+ for i, input in enumerate(valloader):
+
+ # Load data
+ tensor = input[0].float().to(device)
+ background_fix = input[1].to(device)
+ gt_pos = input[2].to(device)
+ gt_mask = input[3].to(device)
+ gt_pres_mask = input[4].to(device)
+ gt_hidden_mask = input[5].to(device)
+ sequence_len = tensor.shape[1]
+
+ # placehodlers
+ mask_cur = None
+ mask_last = None
+ rawmask_last = None
+ position_last = None
+ gestalt_last = None
+ priority_last = None
+ gt_positions_target = None
+ slots_occlusionfactor = None
+ error_last = None
+
+ # Memory
+ cfg_net = cfg.model
+ num_objects_bb = gt_pos.shape[2]
+ pred_pos_batch = th.zeros((cfg_net.batch_size, rollout_length, num_objects_bb, 2)).to(device)
+ gt_pos_batch = th.zeros((cfg_net.batch_size, rollout_length, num_objects_bb, 2)).to(device)
+ pred_img_batch = th.zeros((cfg_net.batch_size, rollout_length, 3, target_size[0], target_size[1])).to(device)
+ gt_img_batch = th.zeros((cfg_net.batch_size, rollout_length, 3, target_size[0], target_size[1])).to(device)
+ pred_mask_batch = th.zeros((cfg_net.batch_size, rollout_length, target_size[0], target_size[1])).to(device)
+ pred_hidden_mask_batch = th.zeros((cfg_net.batch_size, rollout_length, target_size[0], target_size[1])).to(device)
+ # Counters
+ num_rollout = 0
+ num_burnin = 0
+
+ # loop through frames
+ for t_index,t in enumerate(range(-cfg.defaults.teacher_forcing, burn_in_length+rollout_length)):
+
+ # Move to next frame
+ t_run = max(t, 0)
+ input = tensor[:,t_run]
+ target_cur = tensor[:,t_run]
+ target = th.clip(tensor[:,t_run+1], 0, 1)
+ gt_pos_t = gt_pos[:,t_run+1]/32-1
+ gt_pos_t = th.concat((gt_pos_t, th.ones_like(gt_pos_t[:,:,:1])), dim=2)
+
+ rollout_index = t_run - burn_in_length
+ rollout_active = False
+ if t>=0:
+ if rollout_index >= 0:
+ num_rollout += 1
+ if (evaluation_mode == 'vidpred_black'):
+ input = output_next * 0
+ error_last = error_last * 0
+ rollout_active = True
+ elif (evaluation_mode == 'vidpred_auto'):
+ input = output_next
+ error_last = error_last * 0
+ rollout_active = True
+ else:
+ num_burnin += 1
+
+ # obtain prediction
+ (
+ output_next,
+ position_next,
+ gestalt_next,
+ priority_next,
+ mask_next,
+ rawmask_next,
+ object_next,
+ background,
+ slots_occlusionfactor,
+ output_cur,
+ position_cur,
+ gestalt_cur,
+ priority_cur,
+ mask_cur,
+ rawmask_cur,
+ object_cur,
+ position_encoder_cur,
+ slots_bounded,
+ slots_partially_occluded_cur,
+ slots_occluded_cur,
+ slots_partially_occluded_next,
+ slots_occluded_next,
+ slots_closed,
+ output_hidden,
+ largest_object,
+ rawmask_hidden,
+ object_hidden
+ ) = net(
+ input,
+ error_last,
+ mask_last,
+ rawmask_last,
+ position_last,
+ gestalt_last,
+ priority_last,
+ background_fix,
+ slots_occlusionfactor,
+ reset = (t == -cfg.defaults.teacher_forcing),
+ evaluate=True,
+ warmup = (t < 0),
+ shuffleslots = True,
+ reset_mask = (t <= 0),
+ allow_spawn = True,
+ show_hidden = False,
+ clean_slots = False,
+ )
+
+ # 1. Track error
+ if t >= 0:
+
+ if (rollout_index >= 0):
+ # store positions per batch
+ if use_meds:
+ if False:
+ pred_pos_batch[:,rollout_index] = rearrange(position_next, 'b (o c) -> b o c', o=cfg_net.num_objects)[:,:,:2]
+ else:
+ pred_pos_batch[:,rollout_index] = compute_position_from_mask(rawmask_next)
+
+ gt_pos_batch[:,rollout_index] = gt_pos_t[:,:,:2]
+
+ pred_img_batch[:,rollout_index] = output_next
+ gt_img_batch[:,rollout_index] = target
+
+ # Here we compute only the foreground segmentation mask
+ pred_mask_batch[:,rollout_index] = postproc_mask(mask_next[:,None,:,None])[:, 0]
+
+ # Here we compute the hidden segmentation
+ occluded_cur = th.clip(rawmask_next - mask_next, 0, 1)[:,:-1]
+ occluded_sum_cur = 1-(reduce(occluded_cur, 'b c h w -> b h w', 'max') > 0.5).float()
+ occluded_cur = th.cat((occluded_cur, occluded_sum_cur[:,None]), dim=1)
+ pred_hidden_mask_batch[:,rollout_index] = postproc_mask(occluded_cur[:,None,:,None])[:, 0]
+
+ # 2. Remember output
+ mask_last = mask_next.clone()
+ rawmask_last = rawmask_next.clone()
+ position_last = position_next.clone()
+ gestalt_last = gestalt_next.clone()
+ priority_last = priority_next.clone()
+
+ # 3. Error for next frame
+ # background error
+ bg_error_cur = th.sqrt(reduce((input - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+ bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+
+ # prediction error
+ error_next = th.sqrt(reduce((target - output_next)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+ error_next = th.sqrt(error_next) * bg_error_next
+ error_last = error_next.clone()
+
+ # Prediction and encoder Loss
+ loss_next += mseloss(output_next * bg_error_next, target * bg_error_next)
+ loss_cur += mseloss(output_cur * bg_error_cur, input * bg_error_cur)
+ num_steps += 1
+
+ # PLotting
+ if i == 0:
+ openings = net.get_openings()
+ img_tensor = plot_timestep(cfg, cfg_net, input, target_cur, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_pos_t, None, error_next, None, True, False, None, None, sequence_len, root_path, plot_path, t_index, t, i, rollout_mode=rollout_active, openings=openings)
+
+ for b in range(cfg_net.batch_size):
+
+ # perceptual similarity from slotformer paper
+ metric_dict = pred_eval_step(
+ gt = gt_img_batch[b:b+1],
+ pred = pred_img_batch[b:b+1],
+ pred_mask = pred_mask_batch.long()[b:b+1],
+ pred_mask_hidden = pred_hidden_mask_batch.long()[b:b+1],
+ pred_bbox = None,
+ gt_mask = gt_mask.long()[b:b+1, burn_in_length+1:burn_in_length+rollout_length+1],
+ gt_mask_hidden = gt_hidden_mask.long()[b:b+1, burn_in_length+1:burn_in_length+rollout_length+1],
+ gt_pres_mask = gt_pres_mask[b:b+1, burn_in_length+1:burn_in_length+rollout_length+1],
+ gt_bbox = None,
+ lpips_fn = lpipsloss,
+ eval_traj = True,
+ )
+
+ metric_dict['meds'] = distance_eval_step(gt_pos_batch[b], pred_pos_batch[b])
+ metric_complete = append_statistics(metric_dict, metric_complete)
+
+ # sanity check
+ if (num_rollout != rollout_length) and (num_burnin != burn_in_length):
+ raise ValueError('Number of rollout steps and burnin steps must be equal to the sequence length.')
+
+ dic = compute_statistics_summary(metric_complete, evaluation_mode, consider_first_n_frames=rollout_length_stats)
+ writer.add_scalar('Val/Meds', dic['meds_complete_sum'], epoch)
+
+ writer.add_scalar('Val/ARI_hidden', dic['ari_hidden_complete_average'], epoch)
+ writer.add_scalar('Val/ARI', dic['ari_complete_average'], epoch)
+ writer.add_scalar('Val/LPIPS', dic['percept_dist_complete_average'], epoch)
+
+ writer.add_scalar('Val/Prediction Loss', loss_next / num_steps, epoch)
+ writer.add_scalar('Val/Encoding Loss', loss_cur / num_steps, epoch)
+
+ net.train()
pass \ No newline at end of file