diff options
Diffstat (limited to 'scripts')
-rw-r--r-- | scripts/evaluation_adept.py | 11 | ||||
-rw-r--r-- | scripts/evaluation_adept_baselines.py (renamed from scripts/evaluation_adept_savi.py) | 143 | ||||
-rw-r--r-- | scripts/evaluation_bb.py | 385 | ||||
-rw-r--r-- | scripts/evaluation_clevrer.py | 45 | ||||
-rw-r--r-- | scripts/exec/eval.py | 32 | ||||
-rw-r--r-- | scripts/exec/eval_baselines.py (renamed from scripts/exec/eval_savi.py) | 5 | ||||
-rw-r--r-- | scripts/exec/train.py | 15 | ||||
-rw-r--r-- | scripts/training.py | 324 | ||||
-rw-r--r-- | scripts/utils/eval_adept.py | 169 | ||||
-rw-r--r-- | scripts/utils/eval_metrics.py | 51 | ||||
-rw-r--r-- | scripts/utils/eval_utils.py | 92 | ||||
-rw-r--r-- | scripts/utils/io.py | 72 | ||||
-rw-r--r-- | scripts/utils/plot_utils.py | 162 | ||||
-rw-r--r-- | scripts/validation.py | 270 |
14 files changed, 1511 insertions, 265 deletions
diff --git a/scripts/evaluation_adept.py b/scripts/evaluation_adept.py index eab3ad9..b2a18c0 100644 --- a/scripts/evaluation_adept.py +++ b/scripts/evaluation_adept.py @@ -2,6 +2,7 @@ import torch as th from torch.utils.data import Dataset, DataLoader, Subset from torch import nn import os +from scripts.utils.eval_adept import eval_adept from scripts.utils.plot_utils import plot_timestep from scripts.utils.configuration import Configuration from scripts.utils.io import init_device @@ -21,10 +22,12 @@ def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_frequency= 1, p # Config cfg_net = cfg.model cfg_net.batch_size = 1 + cfg_net.inner_loop_enabled = (cfg.max_updates > cfg.phases.start_inner_loop) # Load model net = load_model(cfg, cfg_net, file, device) net.eval() + net.predictor.enable_att_weights() # Plot config object_view = True @@ -216,7 +219,9 @@ def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_frequency= 1, p # 4. Plot if (t % plot_frequency == 0) and (i < plot_first_samples) and (t >= 0): - plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i) + att = net.predictor.get_att_weights() + openings = net.get_openings() + plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i, att= att, openings=openings) # fill jumping statistics statistics_complete_slots['vanishing'].extend(np.tile(slots_vanishing_memory.astype(int), t+1)) @@ -235,8 +240,11 @@ def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_frequency= 1, p pd.DataFrame(statistics_complete).to_csv(f'{root_path}/statistics/trialframe.csv') pd.DataFrame(statistics_complete_slots).to_csv(f'{root_path}/statistics/slotframe.csv') pd.DataFrame(acc_memory_complete).to_csv(f'{root_path}/statistics/accframe.csv') + if object_view and os.path.exists(f'{root_path}/tmp.jpg'): os.remove(f'{root_path}/tmp.jpg') + + eval_adept(f'{root_path}/statistics') pass @@ -309,6 +317,7 @@ def update_mota_acc(acc, gt_positions, estimated_positions, slots_bounded, cfg_n def calculate_tracking_error(gt_positions_target, gt_visibility_target, position_cur, cfg_num_slots, slots_bounded, slots_occluded_cur, association_table, gt_occluder_mask): # tracking utils + gt_positions_target = gt_positions_target.clone() pdist = nn.PairwiseDistance(p=2).to(position_cur.device) # 1. association of newly bounded slots to ground truth objects diff --git a/scripts/evaluation_adept_savi.py b/scripts/evaluation_adept_baselines.py index 6a2d5c7..e5a8e7d 100644 --- a/scripts/evaluation_adept_savi.py +++ b/scripts/evaluation_adept_baselines.py @@ -5,15 +5,16 @@ import cv2 import numpy as np import pandas as pd import os -from data.datasets.ADEPT.dataset import AdeptDataset import motmetrics as mm from scripts.evaluation_adept import calculate_tracking_error, get_evaluation_sets, update_mota_acc -from scripts.utils.eval_utils import setup_result_folders, store_statistics +from scripts.utils.eval_utils import boxes_to_centroids, masks_to_boxes, setup_result_folders, store_statistics from scripts.utils.plot_utils import write_image FG_THRE = 0.95 -def evaluate(dataset: Dataset, file, n, plot_frequency= 1, plot_first_samples = 2): +def evaluate(dataset: Dataset, file, n, model, plot_frequency= 1, plot_first_samples = 2): + + assert model in ['savi', 'gswm'] # read pkl file masks_complete = pd.read_pickle(file) @@ -21,8 +22,12 @@ def evaluate(dataset: Dataset, file, n, plot_frequency= 1, plot_first_samples = # plot config color_list = [[255,0,0], [0,255,0], [0,0,255], [255,255,0], [0,255,255], [255,0,255], [255,255,255]] dot_size = 2 - skip_frames = 2 - offset = 15 + if model == 'savi': + skip_frames = 2 + offset = 15 + elif model == 'gswm': + skip_frames = 2 + offset = 0 # memory statistics_complete_slots = {'set': [], 'evalmode': [], 'scene': [], 'frame': [], 'slot':[], 'TE': [], 'visible': [], 'bound': [], 'occluder': [], 'inimage': [], 'slot_error': [], 'mask_size': [], 'rawmask_size': [], 'rawmask_size_hidden': [], 'alpha_pos': [], 'alpha_ges': [], 'object_id': []} @@ -52,33 +57,53 @@ def evaluate(dataset: Dataset, file, n, plot_frequency= 1, plot_first_samples = tensor = tensor[:,range(0, tensor.shape[1], skip_frames)] sequence_len = tensor.shape[1] - # load data - masks = th.tensor(masks_complete['test'][f'control_{i}.mp4']) - masks_before_softmax = th.tensor(masks_complete['test_raw'][f'control_{i}.mp4']) - - # calculate rawmasks - bg_mask = masks_before_softmax.mean(dim=1) - masks_raw = compute_maskraw(masks_before_softmax, bg_mask) - slots_bound = compute_slots_bound(masks_raw) + if model == 'savi': + # load data + masks = th.tensor(masks_complete['test'][f'control_{i}.mp4']) # N, O, 1, H, W + masks_before_softmax = th.tensor(masks_complete['test_raw'][f'control_{i}.mp4']) + imgs_model = None + recons_model = None + + # calculate rawmasks + bg_mask = masks_before_softmax.mean(dim=1) + masks_raw = compute_maskraw(masks_before_softmax, bg_mask, n_slots=7) + slots_bound = compute_slots_bound(masks_raw) + + elif model == 'gswm': + # load data + masks = masks_complete[i]['visibility_mask'].squeeze(0) + masks_raw = masks_complete[i]['object_mask'].squeeze(0) + slots_bound = masks_complete[i]['z_pres'].squeeze(0) + slots_bound = (slots_bound > 0.9).float() + + imgs_model = masks_complete[i]['imgs'].squeeze(0) + imgs_model[:] = imgs_model[:, [2,1,0]] + recons_model = masks_complete[i]['recon'].squeeze(0) + recons_model[:] = recons_model[:, [2,1,0]] + + # consider only the first 7 slots + masks = masks[:,:7] + masks_raw = masks_raw[:,:7] + slots_bound = slots_bound[:,:7] + + n_slots = masks.shape[1] # threshold masks and calculate centroids masks_binary = (masks_raw > FG_THRE).float() masks2 = rearrange(masks_binary, 't o 1 h w -> (t o) h w') boxes = masks_to_boxes(masks2.long()) - boxes = boxes.reshape(1, masks.shape[0], 7, 4) + boxes = boxes.reshape(1, masks.shape[0], n_slots, 4) centroids = boxes_to_centroids(boxes) # get rid of batch dimension - association_table = th.ones(7) * -1 + association_table = th.ones(n_slots) * -1 # iterate over frames for t_index in range(offset,min(sequence_len,masks.shape[0])): # move to next frame input = tensor[:,t_index] - target = th.clip(tensor[:,t_index+1], 0, 1) gt_positions_target = gt_object_positions[:,t_index] - gt_positions_target_next = gt_object_positions[:,t_index+1] gt_visibility_target = gt_object_visibility[:,t_index] position_cur = centroids[t_index] @@ -87,32 +112,32 @@ def evaluate(dataset: Dataset, file, n, plot_frequency= 1, plot_first_samples = slots_bound_cur = rearrange(slots_bound_cur, 'o c -> 1 (o c)') # calculate tracking error - tracking_error, tracking_error_perslot, association_table, slots_visible, slots_in_image, slots_occluder = calculate_tracking_error(gt_positions_target, gt_visibility_target, position_cur, 7, slots_bound_cur, None, association_table, gt_occluder_mask) + tracking_error, tracking_error_perslot, association_table, slots_visible, slots_in_image, slots_occluder = calculate_tracking_error(gt_positions_target, gt_visibility_target, position_cur, n_slots, slots_bound_cur, None, association_table, gt_occluder_mask) rawmask_size = reduce(masks_raw[t_index], 'o 1 h w-> 1 o', 'sum') mask_size = reduce(masks[t_index], 'o 1 h w-> 1 o', 'sum') statistics_complete_slots = store_statistics(statistics_complete_slots, - ['control'] * 7, - ['control'] * 7, - [control_samples[i]] * 7, - [t_index] * 7, - range(7), + ['control'] * n_slots, + ['control'] * n_slots, + [control_samples[i]] * n_slots, + [t_index] * n_slots, + range(n_slots), tracking_error_perslot.cpu().numpy().flatten(), slots_visible.cpu().numpy().flatten().astype(int), slots_bound_cur.cpu().numpy().flatten().astype(int), slots_occluder.cpu().numpy().flatten().astype(int), slots_in_image.cpu().numpy().flatten().astype(int), - [0] * 7, + [0] * n_slots, mask_size.cpu().numpy().flatten(), rawmask_size.cpu().numpy().flatten(), - [0] * 7, - [0] * 7, - [0] * 7, + [0] * n_slots, + [0] * n_slots, + [0] * n_slots, association_table[0].cpu().numpy().flatten().astype(int), extend = True) - acc = update_mota_acc(acc, gt_positions_target, position_cur, slots_bound_cur, 7, gt_occluder_mask, slots_occluder, None) + acc = update_mota_acc(acc, gt_positions_target, position_cur, slots_bound_cur, n_slots, gt_occluder_mask, slots_occluder, None) # plot_option if (t_index % plot_frequency == 0) and (i < plot_first_samples) and (t_index >= 0): @@ -141,7 +166,12 @@ def evaluate(dataset: Dataset, file, n, plot_frequency= 1, plot_first_samples = slot_frame_single = mask.transpose((1,2,0)).repeat(3, axis=2) slot_frame = np.concatenate((slot_frame, slot_frame_single), axis=1) - frame = np.concatenate((frame, slot_frame), axis=1) + if imgs_model is not None: + frame_model = imgs_model[t_index].numpy().transpose(1,2,0) + recon_model = recons_model[t_index].numpy().transpose(1,2,0) + frame = np.concatenate((frame, frame_model, recon_model, slot_frame), axis=1) + else: + frame = np.concatenate((frame, slot_frame), axis=1) cv2.imwrite(f'{plot_path}object/objects-{i:04d}-{t_index:03d}.jpg', frame*255) acc_memory_eval.append(acc) @@ -153,57 +183,6 @@ def evaluate(dataset: Dataset, file, n, plot_frequency= 1, plot_first_samples = pd.DataFrame(summary).to_csv(os.path.join(root_path, 'statistics' , 'accframe.csv')) pd.DataFrame(statistics_complete_slots).to_csv(os.path.join(root_path, 'statistics' , 'slotframe.csv')) -def masks_to_boxes(masks: th.Tensor) -> th.Tensor: - """ - Compute the bounding boxes around the provided masks. - - Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with - ``0 <= x1 < x2`` and ``0 <= y1 < y2``. - - Args: - masks (Tensor[N, H, W]): masks to transform where N is the number of masks - and (H, W) are the spatial dimensions. - - Returns: - Tensor[N, 4]: bounding boxes - """ - if masks.numel() == 0: - return th.zeros((0, 4), device=masks.device, dtype=th.float) - - n = masks.shape[0] - - bounding_boxes = th.zeros((n, 4), device=masks.device, dtype=th.float) - - for index, mask in enumerate(masks): - if mask.sum() > 0: - y, x = th.where(mask != 0) - - bounding_boxes[index, 0] = th.min(x) - bounding_boxes[index, 1] = th.min(y) - bounding_boxes[index, 2] = th.max(x) - bounding_boxes[index, 3] = th.max(y) - - return bounding_boxes - -def boxes_to_centroids(boxes): - """Post-process masks instead of directly taking argmax. - - Args: - bboxes: [B, T, N, 4], 4: [x1, y1, x2, y2] - - Returns: - centroids: [B, T, N, 2], 2: [x, y] - """ - - centroids = (boxes[:, :, :, :2] + boxes[:, :, :, 2:]) / 2 - centroids = centroids.squeeze(0) - - # scale to [-1, 1] - centroids[:, :, 0] = centroids[:, :, 0] / 64 * 2 - 1 - centroids[:, :, 1] = centroids[:, :, 1] / 64 * 2 - 1 - - return centroids - def compute_slots_bound(masks): # take sum over axis 3,4 with th @@ -211,7 +190,7 @@ def compute_slots_bound(masks): slots_bound = (masks_sum > FG_THRE).float() return slots_bound -def compute_maskraw(mask, bg_mask): +def compute_maskraw(mask, bg_mask, n_slots): # d is a diagonal matrix which defines what to take the softmax over d_mask = th.diag(th.ones(8)) @@ -224,7 +203,7 @@ def compute_maskraw(mask, bg_mask): maskraw = th.cat((mask, bg_mask), dim=1) maskraw = repeat(maskraw, 'b o h w -> b r o h w', r = 8) maskraw = maskraw[:,d_mask.bool()] - maskraw = rearrange(maskraw, 'b (o r) h w -> b o r h w', o = 7) + maskraw = rearrange(maskraw, 'b (o r) h w -> b o r h w', o = n_slots) # take softmax between each object mask and the background mask maskraw = th.squeeze(th.softmax(maskraw, dim=2)[:,:,0], dim=2) diff --git a/scripts/evaluation_bb.py b/scripts/evaluation_bb.py new file mode 100644 index 0000000..1d21cfe --- /dev/null +++ b/scripts/evaluation_bb.py @@ -0,0 +1,385 @@ +import pickle +import cv2 +import torch as th +from torch.utils.data import Dataset, DataLoader, Subset +from torch import nn +import os +from scripts.evaluation_adept import calculate_tracking_error +from scripts.evaluation_clevrer import compute_statistics_summary +from scripts.utils.plot_utils import plot_timestep +from scripts.utils.eval_metrics import masks_to_boxes, pred_eval_step, postproc_mask +from scripts.utils.eval_utils import append_statistics, compute_position_from_mask, load_model, setup_result_folders, store_statistics +from scripts.utils.configuration import Configuration +from scripts.utils.io import init_device +import numpy as np +from einops import rearrange, repeat, reduce +from copy import deepcopy +import lpips +import torchvision.transforms as transforms +import motmetrics as mm + +def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_first_samples = 0): + + # Set up cpu or gpu training + device, verbose = init_device(cfg) + + # Config + cfg_net = cfg.model + cfg_net.batch_size = 2 if verbose else 32 + #cfg_net.num_objects = 3 + cfg_net.inner_loop_enabled = True + if 'num_objects_test' in cfg_net: + cfg_net.num_objects = cfg_net.num_objects_test + dataset = Subset(dataset, range(4)) if verbose else dataset + + # Load model + net = load_model(cfg, cfg_net, file, device) + net.eval() + net.predictor.enable_att_weights() + + # config + object_view = True + individual_views = False + root_path = None + use_meds = True + + # get evaluation sets + set_test_array, evaluation_modes = get_evaluation_sets(dataset) + + # memory + statistics_template = {'set': [], 'evalmode': [], 'scene': [], 'frame': [], 'image_error_mse': []} + statistics_complete_slots = {'set': [], 'evalmode': [], 'scene': [], 'frame': [], 'slot':[], 'bound': [], 'slot_error': [], 'rawmask_size': [], 'alpha_pos': [], 'alpha_ges': []} + metric_complete = None + + # Evaluation Specifics + burn_in_length = 10 + rollout_length = 90 + rollout_length_stats = 10 # only consider the first 10 frames for statistics + target_size = (64, 64) + + # Losses + lpipsloss = lpips.LPIPS(net='vgg').to(device) + mseloss = nn.MSELoss() + + for set_test in set_test_array: + + for evaluation_mode in evaluation_modes: + print(f'Start evaluation loop: {evaluation_mode}') + + # load data + dataloader = DataLoader( + dataset, + num_workers = 0, + pin_memory = False, + batch_size = cfg_net.batch_size, + shuffle = False, + drop_last = True, + ) + + # memory + root_path, plot_path = setup_result_folders(file, n, set_test, evaluation_mode, object_view, individual_views) + metric_complete = {'mse': [], 'ssim': [], 'psnr': [], 'percept_dist': [], 'ari': [], 'fari': [], 'miou': [], 'ap': [], 'ar': [], 'meds': [], 'ari_hidden': [], 'fari_hidden': [], 'miou_hidden': []} + video_list = [] + + # set seed: if there is a number in the evaluation mode, use it as seed + plot_mode = True + if evaluation_mode[-1].isdigit(): + seed = int(evaluation_mode[-1]) + th.manual_seed(seed) + np.random.seed(seed) + print(f'Set seed to {seed}') + if int(evaluation_mode[-1]) > 1: + plot_mode = False + + with th.no_grad(): + for i, input in enumerate(dataloader): + print(f'Processing sample {i+1}/{len(dataloader)}', flush=True) + + # Load data + tensor = input[0].float().to(device) + background_fix = input[1].to(device) + gt_pos = input[2].to(device) + gt_mask = input[3].to(device) + gt_pres_mask = input[4].to(device) + gt_hidden_mask = input[5].to(device) + sequence_len = tensor.shape[1] + + # Placehodlers + mask_cur = None + mask_last = None + rawmask_last = None + position_last = None + gestalt_last = None + priority_last = None + slots_occlusionfactor = None + error_last = None + + # Memory + statistics_batch = deepcopy(statistics_template) + pred_pos_batch = th.zeros((cfg_net.batch_size, rollout_length, cfg_net.num_objects, 2)).to(device) + gt_pos_batch = th.zeros((cfg_net.batch_size, rollout_length, cfg_net.num_objects, 2)).to(device) + pred_img_batch = th.zeros((cfg_net.batch_size, rollout_length, 3, target_size[0], target_size[1])).to(device) + gt_img_batch = th.zeros((cfg_net.batch_size, rollout_length, 3, target_size[0], target_size[1])).to(device) + pred_mask_batch = th.zeros((cfg_net.batch_size, rollout_length, target_size[0], target_size[1])).to(device) + pred_hidden_mask_batch = th.zeros((cfg_net.batch_size, rollout_length, target_size[0], target_size[1])).to(device) + + # Counters + num_rollout = 0 + num_burnin = 0 + + # Loop through frames + for t_index,t in enumerate(range(-cfg.defaults.teacher_forcing, sequence_len-1)): + + # Move to next frame + t_run = max(t, 0) + input = tensor[:,t_run] + target_cur = tensor[:,t_run] + target = th.clip(tensor[:,t_run+1], 0, 1) + gt_pos_t = gt_pos[:,t_run+1]/32-1 + gt_pos_t = th.concat((gt_pos_t, th.ones_like(gt_pos_t[:,:,:1])), dim=2) + + rollout_index = t_run - burn_in_length + rollout_active = False + if t>=0: + if rollout_index >= 0: + num_rollout += 1 + if ('vidpred_black' in evaluation_mode): + input = output_next * 0 + rollout_active = True + elif ('vidpred_auto' in evaluation_mode): + input = output_next + rollout_active = True + else: + num_burnin += 1 + + # obtain prediction + ( + output_next, + position_next, + gestalt_next, + priority_next, + mask_next, + rawmask_next, + object_next, + background, + slots_occlusionfactor, + output_cur, + position_cur, + gestalt_cur, + priority_cur, + mask_cur, + rawmask_cur, + object_cur, + position_encoder_cur, + slots_bounded, + slots_partially_occluded_cur, + slots_occluded_cur, + slots_partially_occluded_next, + slots_occluded_next, + slots_closed, + output_hidden, + largest_object, + rawmask_hidden, + object_hidden + ) = net( + input, + error_last, + mask_last, + rawmask_last, + position_last, + gestalt_last, + priority_last, + background_fix, + slots_occlusionfactor, + reset = (t == -cfg.defaults.teacher_forcing), + evaluate=True, + warmup = (t < 0), + shuffleslots = True, + reset_mask = (t <= 0), + allow_spawn = True, + show_hidden = False, + clean_slots = False, + ) + + # 1. Track error for plots + if t >= 0: + + if (rollout_index >= 0): + # store positions per batch + if use_meds: + if False: + pred_pos_batch[:,rollout_index] = rearrange(position_next, 'b (o c) -> b o c', o=cfg_net.num_objects)[:,:,:2] + else: + pred_pos_batch[:,rollout_index] = compute_position_from_mask(rawmask_next) + + gt_pos_batch[:,rollout_index] = gt_pos_t[:,:,:2] + + pred_img_batch[:,rollout_index] = output_next + gt_img_batch[:,rollout_index] = target + + # Here we compute only the foreground segmentation mask + pred_mask_batch[:,rollout_index] = postproc_mask(mask_next[:,None,:,None])[:, 0] + + # Here we compute the hidden segmentation + occluded_cur = th.clip(rawmask_next - mask_next, 0, 1)[:,:-1] + occluded_sum_cur = 1-(reduce(occluded_cur, 'b c h w -> b h w', 'max') > 0.5).float() + occluded_cur = th.cat((occluded_cur, occluded_sum_cur[:,None]), dim=1) + pred_hidden_mask_batch[:,rollout_index] = postproc_mask(occluded_cur[:,None,:,None])[:, 0] + + # 2. Remember output + mask_last = mask_next.clone() + rawmask_last = rawmask_next.clone() + position_last = position_next.clone() + gestalt_last = gestalt_next.clone() + priority_last = priority_next.clone() + + # 3. Error for next frame + bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach() + + # prediction error + error_next = th.sqrt(reduce((target - output_next)**2, 'b c h w -> b 1 h w', 'mean')).detach() + error_next = th.sqrt(error_next) * bg_error_next + error_last = error_next.clone() + + # PLotting + if i == 0 and plot_mode: + att = net.predictor.get_att_weights() + openings = net.get_openings() + img_tensor = plot_timestep(cfg, cfg_net, input, target_cur, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, None, None, error_next, None, True, individual_views, None, None, sequence_len, root_path, None, t_index, t, i, rollout_mode=rollout_active, num_vid=plot_first_samples, att= att, openings=None) + video_list.append(img_tensor) + + # log video + if i == 0 and plot_mode: + video_tensor = rearrange(th.stack(video_list, dim=0), 't b c h w -> b t h w c') + save_videos(video_tensor, f'{plot_path}/object', verbose=verbose, trace_plot=True) + + # Compute prediction accuracy based on Slotformer metrics (ARI, FARI, mIoU, AP, AR) + for b in range(cfg_net.batch_size): + + # perceptual similarity from slotformer paper + metric_dict = pred_eval_step( + gt = gt_img_batch[b:b+1], + pred = pred_img_batch[b:b+1], + pred_mask = pred_mask_batch.long()[b:b+1], + pred_mask_hidden = pred_hidden_mask_batch.long()[b:b+1], + pred_bbox = None, + gt_mask = gt_mask.long()[b:b+1, burn_in_length+1:], + gt_mask_hidden = gt_hidden_mask.long()[b:b+1, burn_in_length+1:], + gt_pres_mask = gt_pres_mask[b:b+1, burn_in_length+1:], + gt_bbox = None, + lpips_fn = lpipsloss, + eval_traj = True, + ) + + metric_dict['meds'] = distance_eval_step(gt_pos_batch[b], pred_pos_batch[b]) + metric_complete = append_statistics(metric_dict, metric_complete) + + # sanity check + if (num_rollout != rollout_length) and (num_burnin != burn_in_length) and ('vidpred' in evaluation_mode): + raise ValueError('Number of rollout steps and burnin steps must be equal to the sequence length.') + + + average_dic = compute_statistics_summary(metric_complete, evaluation_mode, root_path=root_path, consider_first_n_frames=rollout_length_stats) + + # Store statistics + with open(os.path.join(f'{root_path}/statistics', f'{evaluation_mode}_metric_complete.pkl'), 'wb') as f: + pickle.dump(metric_complete, f) + with open(os.path.join(f'{root_path}/statistics', f'{evaluation_mode}_metric_average.pkl'), 'wb') as f: + pickle.dump(average_dic, f) + + print('-- Evaluation Done --') + if object_view and os.path.exists(f'{root_path}/tmp.jpg'): + os.remove(f'{root_path}/tmp.jpg') + pass + +# store videos as jpgs and then use ffmpeg to convert to video +def save_videos(video_tensor, plot_path, verbose=False, fps=10, trace_plot=False): + video_tensor = video_tensor.cpu().numpy() + img_path = plot_path + '/img' + for b in range(video_tensor.shape[0]): + os.makedirs(img_path, exist_ok=True) + video = video_tensor[b] + video = (video).astype(np.uint8) + for t in range(video.shape[0]): + cv2.imwrite(f'{img_path}/{b}_{t:04d}.jpg', video[t]) + + if verbose: + os.system(f"ffmpeg -r {fps} -pattern_type glob -i '{img_path}/*.jpg' -c:v libx264 -y {plot_path}/{b}.mp4") + os.system(f'rm -rf {img_path}') + + if trace_plot: + # trace plot + start = 15 + length = 20 + frame = np.zeros_like(video[0]) + for i in range(start,start+length): + current = video[i] * (0.1 + (i-start)/length) + frame = np.max(np.stack((frame, current)), axis=0) + cv2.imwrite(f'{plot_path}/{b}_trace.jpg', frame) + + +def distance_eval_step(gt_pos, pred_pos): + meds_per_timestep = [] + gt_pred_pairings = None + for t in range(pred_pos.shape[0]): + frame_gt = gt_pos[t].cpu().numpy() + frame_pred = pred_pos[t].cpu().numpy() + frame_gt = (frame_gt + 1) * 0.5 + frame_pred = (frame_pred + 1) * 0.5 + + distances = mm.distances.norm2squared_matrix(frame_gt, frame_pred, max_d2=1) + if gt_pred_pairings is None: + frame_gt_ids = list(range(frame_gt.shape[0])) + frame_pred_ids = list(range(frame_pred.shape[0])) + gt_pred_pairings = [(frame_gt_ids[g], frame_pred_ids[p]) for g, p in zip(*mm.lap.linear_sum_assignment(distances))] + + med = 0 + for gt_id, pred_id in gt_pred_pairings: + curr_med = np.sqrt(((frame_gt[gt_id] - frame_pred[pred_id])**2).sum()) + med += curr_med + if len(gt_pred_pairings) > 0: + meds_per_timestep.append(med / len(gt_pred_pairings)) + else: + meds_per_timestep.append(np.nan) + return meds_per_timestep + + +def compute_plot_statistics(cfg_net, statistics_complete_slots, mseloss, set_test, evaluation_mode, i, statistics_batch, t, target, output_next, mask_next, slots_bounded, slots_closed, rawmask_hidden): + statistics_batch = store_statistics(statistics_batch, + set_test['type'], + evaluation_mode, + set_test['samples'][i], + t, + mseloss(output_next, target).item() + ) + + # compute slot-wise prediction error + output_slot = repeat(mask_next[:,:-1], 'b o h w -> b o 3 h w') * repeat(output_next, 'b c h w -> b o c h w', o=cfg_net.num_objects) + target_slot = repeat(mask_next[:,:-1], 'b o h w -> b o 3 h w') * repeat(target, 'b c h w -> b o c h w', o=cfg_net.num_objects) + slot_error = reduce((output_slot - target_slot)**2, 'b o c h w -> b o', 'mean') + + # compute rawmask_size + rawmask_size = reduce(rawmask_hidden[:, :-1], 'b o h w-> b o', 'sum') + + statistics_complete_slots = store_statistics(statistics_complete_slots, + [set_test['type']] * cfg_net.num_objects, + [evaluation_mode] * cfg_net.num_objects, + [set_test['samples'][i]] * cfg_net.num_objects, + [t] * cfg_net.num_objects, + range(cfg_net.num_objects), + slots_bounded.cpu().numpy().flatten().astype(int), + slot_error.cpu().numpy().flatten(), + rawmask_size.cpu().numpy().flatten(), + slots_closed[:, :, 1].cpu().numpy().flatten(), + slots_closed[:, :, 0].cpu().numpy().flatten(), + extend = True) + + return statistics_complete_slots,statistics_batch + +def get_evaluation_sets(dataset): + + set = {"samples": np.arange(len(dataset), dtype=int), "type": "test"} + evaluation_modes = ['open', 'vidpred_auto', 'vidpred_black_1', 'vidpred_black_2', 'vidpred_black_3', 'vidpred_black_4', 'vidpred_black_5'] # use 'open' for no blackouts + set_test_array = [set] + + return set_test_array, evaluation_modes diff --git a/scripts/evaluation_clevrer.py b/scripts/evaluation_clevrer.py index a43d50c..168fba3 100644 --- a/scripts/evaluation_clevrer.py +++ b/scripts/evaluation_clevrer.py @@ -237,7 +237,7 @@ def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_frequency= 1, p raise ValueError('Number of rollout steps and burnin steps must be equal to the sequence length.') if not plotting_mode: - average_dic = compute_statistics_summary(metric_complete, evaluation_mode) + average_dic = compute_statistics_summary(metric_complete, evaluation_mode, root_path=root_path) # Store statistics with open(os.path.join(f'{root_path}/statistics', f'{evaluation_mode}_metric_complete.pkl'), 'wb') as f: @@ -250,24 +250,45 @@ def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_frequency= 1, p os.remove(f'{root_path}/tmp.jpg') pass -def compute_statistics_summary(metric_complete, evaluation_mode): +def compute_statistics_summary(metric_complete, evaluation_mode, root_path=None, consider_first_n_frames = None): + string = '' + def add_text(string, text, last=False): + string = string + ' \n ' + text + return string + average_dic = {} + if consider_first_n_frames is not None: + for key in metric_complete: + for sample in range(len(metric_complete[key])): + metric_complete[key][sample] = metric_complete[key][sample][:consider_first_n_frames] + for key in metric_complete: # take average over all frames - average_dic[key + 'complete_average'] = np.mean(metric_complete[key]) - average_dic[key + 'complete_std'] = np.std(metric_complete[key]) - print(f'{key} complete average: {average_dic[key + "complete_average"]:.4f} +/- {average_dic[key + "complete_std"]:.4f}') + average_dic[key + '_complete_average'] = np.mean(metric_complete[key]) + average_dic[key + '_complete_std'] = np.std(metric_complete[key]) + average_dic[key + '_complete_sum'] = np.sum(np.mean(metric_complete[key], axis=0)) # checked with GSWM code! + string = add_text(string, f'{key} complete average: {average_dic[key + "_complete_average"]:.4f} +/- {average_dic[key + "_complete_std"]:.4f} (sum: {average_dic[key + "_complete_sum"]:.4f})') + #print(f'{key} complete average: {average_dic[key + "complete_average"]:.4f} +/- {average_dic[key + "complete_std"]:.4f} (sum: {average_dic[key + "complete_sum"]:.4f})') if evaluation_mode == 'blackout': - # take average only for frames where blackout occurs + # take average only for frames where blackout occurs blackout_mask = np.array(metric_complete['blackout']) > 0 - average_dic[key + 'blackout_average'] = np.mean(np.array(metric_complete[key])[blackout_mask]) - average_dic[key + 'blackout_std'] = np.std(np.array(metric_complete[key])[blackout_mask]) - average_dic[key + 'visible_average'] = np.mean(np.array(metric_complete[key])[blackout_mask == False]) - average_dic[key + 'visible_std'] = np.std(np.array(metric_complete[key])[blackout_mask == False]) + average_dic[key + '_blackout_average'] = np.mean(np.array(metric_complete[key])[blackout_mask]) + average_dic[key + '_blackout_std'] = np.std(np.array(metric_complete[key])[blackout_mask]) + average_dic[key + '_visible_average'] = np.mean(np.array(metric_complete[key])[blackout_mask == False]) + average_dic[key + '_visible_std'] = np.std(np.array(metric_complete[key])[blackout_mask == False]) - print(f'{key} blackout average: {average_dic[key + "blackout_average"]:.4f} +/- {average_dic[key + "blackout_std"]:.4f}') - print(f'{key} visible average: {average_dic[key + "visible_average"]:.4f} +/- {average_dic[key + "visible_std"]:.4f}') + #print(f'{key} blackout average: {average_dic[key + "blackout_average"]:.4f} +/- {average_dic[key + "blackout_std"]:.4f}') + #print(f'{key} visible average: {average_dic[key + "visible_average"]:.4f} +/- {average_dic[key + "visible_std"]:.4f}') + string = add_text(string, f'{key} blackout average: {average_dic[key + "_blackout_average"]:.4f} +/- {average_dic[key + "_blackout_std"]:.4f}') + string = add_text(string, f'{key} visible average: {average_dic[key + "_visible_average"]:.4f} +/- {average_dic[key + "_visible_std"]:.4f}') + + print(string) + if root_path is not None: + f'{root_path}/statistics', f'{evaluation_mode}_metric_complete.pkl' + with open(os.path.join(f'{root_path}/statistics', f'{evaluation_mode}_metric_average.txt'), 'w') as f: + f.write(string) + return average_dic def compute_plot_statistics(cfg_net, statistics_complete_slots, mseloss, set_test, evaluation_mode, i, statistics_batch, t, target, output_next, mask_next, slots_bounded, slots_closed, rawmask_hidden): diff --git a/scripts/exec/eval.py b/scripts/exec/eval.py index 96ed32d..bc461ec 100644 --- a/scripts/exec/eval.py +++ b/scripts/exec/eval.py @@ -1,10 +1,29 @@ import argparse import sys from data.datasets.ADEPT.dataset import AdeptDataset +from data.datasets.BOUNCINGBALLS.dataset import BouncingBallDataset from data.datasets.CLEVRER.dataset import ClevrerDataset, ClevrerSample, RamImage from scripts.utils.configuration import Configuration -from scripts import evaluation_adept, evaluation_clevrer +from scripts import evaluation_adept, evaluation_clevrer, evaluation_bb +def main(load, n, cfg): + + size = (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)) + + # Load dataset + if cfg.datatype == "clevrer": + testset = ClevrerDataset("./", cfg.dataset, 'val', size, use_slotformer=True, evaluation=True) + evaluation_clevrer.evaluate(cfg, testset, load, n, plot_frequency= 2, plot_first_samples = 3) # only plotting + evaluation_clevrer.evaluate(cfg, testset, load, n, plot_first_samples = 0) # evaluation + elif cfg.datatype == "adept": + testset = AdeptDataset("./", cfg.dataset, 'createdown', size) + evaluation_adept.evaluate(cfg, testset, load, n, plot_frequency= 1, plot_first_samples = 2) + elif cfg.datatype == "bouncingballs": + testset = BouncingBallDataset("./", cfg.dataset, "test", size, type_name = cfg.scenario) + evaluation_bb.evaluate(cfg, testset, load, n, plot_first_samples = 4) + else: + raise Exception("Dataset not supported") + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-cfg", default="", help='path to the configuration file') @@ -16,13 +35,4 @@ if __name__ == "__main__": cfg = Configuration(args.cfg) print(f'Evaluating model {args.load}') - # Load dataset - if cfg.datatype == "clevrer": - testset = ClevrerDataset("./", cfg.dataset, 'val', (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=True, evaluation=True) - evaluation_clevrer.evaluate(cfg, testset, args.load, args.n, plot_frequency= 2, plot_first_samples = 3) # only plotting - evaluation_clevrer.evaluate(cfg, testset, args.load, args.n, plot_first_samples = 0) # evaluation - elif cfg.datatype == "adept": - testset = AdeptDataset("./", cfg.dataset, 'createdown', (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))) - evaluation_adept.evaluate(cfg, testset, args.load, args.n, plot_frequency= 1, plot_first_samples = 2) - else: - raise Exception("Dataset not supported")
\ No newline at end of file + main(args.load, args.n, cfg)
\ No newline at end of file diff --git a/scripts/exec/eval_savi.py b/scripts/exec/eval_baselines.py index 87d4e59..9e451ab 100644 --- a/scripts/exec/eval_savi.py +++ b/scripts/exec/eval_baselines.py @@ -1,12 +1,13 @@ import argparse import sys from data.datasets.ADEPT.dataset import AdeptDataset -from scripts.evaluation_adept_savi import evaluate +from scripts.evaluation_adept_baselines import evaluate if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-load", default="", type=str, help='path to savi slots') parser.add_argument("-n", default="", type=str, help='results name') + parser.add_argument("-model", default="", type=str, help='model, either savi or gswm') # Load configuration args = parser.parse_args(sys.argv[1:]) @@ -14,4 +15,4 @@ if __name__ == "__main__": # Load dataset testset = AdeptDataset("./", 'adept', 'createdown', (30 * 2**(2*2), 20 * 2**(2*2))) - evaluate(testset, args.load, args.n, plot_frequency= 1, plot_first_samples = 2)
\ No newline at end of file + evaluate(testset, args.load, args.n, args.model, plot_frequency= 1, plot_first_samples = 2)
\ No newline at end of file diff --git a/scripts/exec/train.py b/scripts/exec/train.py index f6e78fd..6fa6176 100644 --- a/scripts/exec/train.py +++ b/scripts/exec/train.py @@ -4,6 +4,8 @@ from scripts.utils.configuration import Configuration from scripts import training from data.datasets.CLEVRER.dataset import ClevrerDataset, ClevrerSample, RamImage from data.datasets.ADEPT.dataset import AdeptDataset +from data.datasets.BOUNCINGBALLS.dataset import BouncingBallDataset +from scripts import evaluation_adept, evaluation_clevrer, evaluation_bb if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -20,12 +22,17 @@ if __name__ == "__main__": print(f'Training model {cfg.model_path}') # Load dataset + size = (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)) if cfg.datatype == "clevrer": - trainset = ClevrerDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=False) - valset = ClevrerDataset("./", cfg.dataset, "val", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=True) + trainset = ClevrerDataset("./", cfg.dataset, "train", size, use_slotformer=False) + valset = ClevrerDataset("./", cfg.dataset, "val", size, use_slotformer=True) elif cfg.datatype == "adept": - trainset = AdeptDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))) - valset = AdeptDataset("./", cfg.dataset, "test", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))) + trainset = AdeptDataset("./", cfg.dataset, "train", size) + valset = AdeptDataset("./", cfg.dataset, "test", size) + valset.train = True + elif cfg.datatype == "bouncingballs": + trainset = BouncingBallDataset("./", cfg.dataset, "train", size, type_name = cfg.scenario) + valset = BouncingBallDataset("./", cfg.dataset, "val", size, type_name = cfg.scenario) else: raise Exception("Dataset not supported") diff --git a/scripts/training.py b/scripts/training.py index b807e54..fd4a4bf 100644 --- a/scripts/training.py +++ b/scripts/training.py @@ -1,3 +1,4 @@ +from copy import deepcopy import torch as th from torch import nn from torch.utils.data import Dataset, DataLoader, Subset @@ -5,15 +6,18 @@ import cv2 import numpy as np import os from einops import rearrange, repeat, reduce -from scripts.utils.configuration import Configuration -from scripts.utils.io import init_device, model_path, LossLogger +from scripts.utils.configuration import Configuration, Dict +from scripts.utils.io import WriterWrapper, init_device, model_path, LossLogger from scripts.utils.optimizers import RAdam from model.loci import Loci import random from scripts.utils.io import Timer -from scripts.utils.plot_utils import color_mask -from scripts.validation import validation_clevrer, validation_adept +from scripts.utils.plot_utils import color_mask, plot_timestep +from scripts.validation import validation_bb, validation_clevrer, validation_adept +from scripts.exec.eval import main as eval_main +os.environ["WANDB__SERVICE_WAIT"] = "300" +os.environ["WANDB_ DISABLE_SERVICE"] = "true" def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): @@ -21,6 +25,12 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): device, verbose = init_device(cfg) if verbose: valset = Subset(valset, range(0, 8)) + use_wandb = (verbose == False) + + # generate random seed if not set + if not ('seed' in cfg.defaults): + cfg.defaults['seed'] = random.randint(0, 100) + th.manual_seed(cfg.defaults.seed) # Define model path path = model_path(cfg, overwrite=False) @@ -34,12 +44,16 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): teacher_forcing = cfg.defaults.teacher_forcing ) net = net.to(device=device) + #net = th.compile(net) + #net = th.jit.script(net) # Log model size log_modelsize(net) + writer = WriterWrapper(use_wandb, cfg) # Init Optimizers optimizer_init, optimizer_encoder, optimizer_decoder, optimizer_predictor, optimizer_background, optimizer_update = init_optimizer(cfg, net) + scheduler = init_lr_scheduler_StepLR(cfg, optimizer_init, optimizer_encoder, optimizer_decoder, optimizer_predictor, optimizer_background, optimizer_update) # Option to load model if file != "": @@ -58,9 +72,8 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): print(f'loaded {file}', flush=True) # Set up data loaders - trainloader = get_loader(cfg, trainset, cfg_net, shuffle=True) - valset.train = True #valset.dataset.train = True - valloader = get_loader(cfg, valset, cfg_net, shuffle=False) + trainloader = get_loader(cfg, trainset, cfg_net, shuffle=True, verbose=verbose) + valloader = get_loader(cfg, valset, cfg_net, shuffle=False, verbose=verbose) # initial save save_model( @@ -80,16 +93,24 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): print('!!! Net init status: ', net.get_init_status()) # Set up statistics - loss_tracker = LossLogger() + loss_tracker = LossLogger(writer) + + # Init loss function + imageloss = initialize_loss_function(cfg) # Set up training variables num_time_steps = 0 bptt_steps = cfg.bptt.bptt_steps + if not 'plot_interval' in cfg.defaults: + cfg.defaults.plot_interval = 20000 + blackout_rate = cfg.blackout.blackout_rate if ('blackout' in cfg) else 0.0 + rollout_length = cfg.vp.rollout_length if ('vp' in cfg) else 0 + burnin_length = cfg.vp.burnin_length if ('vp' in cfg) else 0 increase_bptt_steps = False background_blendin_factor = 0.0 th.backends.cudnn.benchmark = True + plot_next_sample = False timer = Timer() - bceloss = nn.BCELoss() # Init net to current num_updates if num_updates >= cfg.phases.background_pretraining_end and net.get_init_status() < 1: @@ -101,7 +122,7 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): if num_updates >= cfg.phases.entity_pretraining_phase2_end and net.get_init_status() < 3: net.inc_init_level() for param in optimizer_init.param_groups: - param['lr'] = cfg.learning_rate + param['lr'] = cfg.learning_rate.lr if num_updates > cfg.phases.start_inner_loop: net.cfg.inner_loop_enabled = True @@ -117,16 +138,18 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): # Validation every epoch if epoch >= 0: if cfg.datatype == 'adept': - validation_adept(valloader, net, cfg, device) + validation_adept(valloader, net, cfg, device, writer, epoch, path) elif cfg.datatype == 'clevrer': - validation_clevrer(valloader, net, cfg, device) + validation_clevrer(valloader, net, cfg, device, writer, epoch, path) + elif cfg.datatype == "bouncingballs" and (epoch % 2 == 0): + validation_bb(valloader, net, cfg, device, writer, epoch, path) # Start epoch training print('Start epoch:', epoch) # Backprop through time steps if increase_bptt_steps: - bptt_steps = max(bptt_steps + 1, cfg.bptt.bptt_steps_max) + bptt_steps = min(bptt_steps + 1, cfg.bptt.bptt_steps_max) print('Increase closed loop steps to', bptt_steps) increase_bptt_steps = False @@ -149,7 +172,8 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): slots_occlusionfactor = None # Apply skip frames to sequence - selec = range(random.randrange(cfg.defaults.skip_frames), tensor.shape[1], cfg.defaults.skip_frames) + start = random.randrange(cfg.defaults.skip_frames) if rollout_length == 0 else random.randrange(0, (tensor.shape[1]-(rollout_length+burnin_length))) + selec = range(start, tensor.shape[1], cfg.defaults.skip_frames) tensor = tensor[:,selec] sequence_len = tensor.shape[1] @@ -159,6 +183,12 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): target = th.clip(input, 0, 1).detach() error_last = None + # plotting mode + plot_this_sample = plot_next_sample + plot_next_sample = False + video_list = [] + num_rollout = 0 + # First apply teacher forcing for the first x frames for t in range(-cfg.defaults.teacher_forcing, sequence_len-1): @@ -185,13 +215,27 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): if net.get_init_status() > 2 and cfg.defaults.error_dropout > 0 and np.random.rand() < cfg.defaults.error_dropout: error_last = th.zeros_like(error_last) - # Apply sensation blackout when training clevrere - if net.cfg.inner_loop_enabled and cfg.datatype == 'clevrer': - if t >= 10: - blackout = th.tensor((np.random.rand(cfg_net.batch_size) < 0.2)[:,None,None,None]).float().to(device) + # Apply sensation blackout when training clevrer + if net.cfg.inner_loop_enabled and blackout_rate > 0: + if t >= cfg.blackout.blackout_start_timestep: + blackout = th.tensor((np.random.rand(cfg_net.batch_size) < blackout_rate)[:,None,None,None]).float().to(device) input = blackout * (input * 0) + (1-blackout) * input error_last = blackout * (error_last * 0) + (1-blackout) * error_last - + + # Apply rollout training + if net.cfg.inner_loop_enabled and (rollout_length > 0) and (burnin_length-1) < t: + input = input * 0 + error_last = error_last * 0 + num_rollout += 1 + + if (burnin_length + rollout_length -1) == t: + run_optimizers = True + detach = True + + if (burnin_length + rollout_length) == t: + break + + # Forward Pass ( output_next, @@ -206,7 +250,9 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): slots_occlusionfactor, position_loss, time_loss, - slots_closed + latent_loss, + slots_closed, + slots_bounded ) = net( input, # current frame error_last, # error of last frame --> missing object @@ -228,6 +274,8 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): # Loss weighting position_loss = position_loss * cfg_net.position_regularizer time_loss = time_loss * cfg_net.time_regularizer + if latent_loss.item() > 0: + latent_loss = latent_loss * cfg_net.latent_regularizer # Compute background error bg_error_cur = th.sqrt(reduce((input - background)**2, 'b c h w -> b 1 h w', 'mean')).detach() @@ -253,10 +301,14 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): background_blendin_factor = min(1, background_blendin_factor + 0.001) # Final Loss computation - encoder_loss = th.mean((output_cur - input)**2) * cfg_net.encoder_regularizer - cliped_output_next = th.clip(output_next, 0, 1) - loss = bceloss(cliped_output_next, target) + encoder_loss + position_loss + time_loss + encoding_loss = imageloss(output_cur, input) * cfg_net.encoder_regularizer + prediction_loss = imageloss(output_next, target) + loss = prediction_loss + encoding_loss + position_loss + time_loss + latent_loss + # apply loss decay according to num_rollout + if rollout_length > 0: + loss = loss * 0.75**(num_rollout-1) + # Accumulate loss over BPP steps summed_loss = loss if summed_loss is None else summed_loss + loss mask = mask.detach() @@ -295,6 +347,7 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): # Update net status update_net_status(num_updates, net, cfg, optimizer_init) + step_lr_scheduler(scheduler) if num_updates == cfg.phases.start_inner_loop: print('Start inner loop') @@ -303,38 +356,39 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): if (cfg.bptt.increase_bptt_steps_every > 0) and ((num_updates-cfg.num_updates) % cfg.bptt.increase_bptt_steps_every == 0) and ((num_updates-cfg.num_updates) > 0): increase_bptt_steps = True + if net.cfg.inner_loop_enabled and ('blackout' in cfg) and (cfg.blackout.blackout_increase_every > 0) and ((num_updates-cfg.num_updates) % cfg.blackout.blackout_increase_every == 0) and ((num_updates-cfg.num_updates) > 0): + blackout_rate = min(blackout_rate + cfg.blackout.blackout_increase_rate, cfg.blackout.blackout_rate_max) + # Plots for online evaluation - if num_updates % 20000 == 0: - plot_online(cfg, path, num_updates, input, background, mask, sequence_len, t, output_next, bg_error_next) + if num_updates % cfg.defaults.plot_interval == 0: + plot_next_sample = True + if plot_this_sample: + img_tensor = plot_online(cfg, path, f'{epoch}_{batch_index}', input, background, mask, sequence_len, t, output_next, bg_error_next) + video_list.append(img_tensor) # Track statisitcs if t >= cfg.defaults.statistics_offset: - track_statistics(cfg_net, net, loss_tracker, input, gestalt, mask, target, output_next, output_cur, position_loss, time_loss, slots_closed, bg_error_cur, bg_error_next) - loss_tracker.update_average_loss(loss.item()) + track_statistics(cfg, net, loss_tracker, input, gestalt, mask, target, output_next, output_cur, encoding_loss, prediction_loss, position_loss, time_loss, latent_loss, slots_closed, slots_bounded, bg_error_cur, bg_error_next, scheduler, num_updates) + loss_tracker.update_average_loss(loss.item(), num_updates) + writer.add_scalar('Train/BPTT_steps', bptt_steps, num_updates) + writer.add_scalar('Train/Background_Blendin', background_blendin_factor, num_updates) + writer.add_scalar('Train/Blackout_Rate', blackout_rate, num_updates) # Logging if num_updates % 100 == 0 and run_optimizers: print(f'Epoch[{num_updates}/{num_time_steps}/{sequence_len}]: {str(timer)}, {epoch + 1}, Blendin:{float(background_blendin_factor)}, i: {net.get_init_status() + net.initial_states.init.get():.2f},' + loss_tracker.get_log(), flush=True) # Training finished + net_path = None if num_updates > cfg.max_updates: - save_model( - os.path.join(path, 'nets', 'net_final.pt'), - net, - optimizer_init, - optimizer_encoder, - optimizer_decoder, - optimizer_predictor, - optimizer_background - ) - print("Training finished") - return - - # Checkpointing + net_path = os.path.join(path, 'nets', 'net_final.pt') if num_updates % 50000 == 0 and run_optimizers: + net_path = os.path.join(path, 'nets', f'net_{num_updates}.pt') + + if net_path is not None: save_model( - os.path.join(path, 'nets', f'net_{num_updates}.pt'), + net_path, net, optimizer_init, optimizer_encoder, @@ -342,33 +396,56 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): optimizer_predictor, optimizer_background ) - pass + if ('final' in net_path) or ('num_objects_test' in cfg.model): + eval_main(net_path, 1, deepcopy(cfg)) -def track_statistics(cfg_net, net, loss_tracker, input, gestalt, mask, target, output_next, output_cur, position_loss, time_loss, slots_closed, bg_error_cur, bg_error_next): - - # Prediction Loss - mseloss = nn.MSELoss() - loss_next = mseloss(output_next * bg_error_next, target * bg_error_next) + if 'final' in net_path: + print("Training finished") + writer.flush() + return - # Encoder Loss (only for stats) - loss_cur = mseloss(output_cur * bg_error_cur, input * bg_error_cur) + if plot_this_sample: + video_tensor = rearrange(th.stack(video_list, dim=0)[None], 'b t h w c -> b t c h w') + writer.add_video('Train/Video', video_tensor, num_updates) + + pass +def track_statistics(cfg, net, loss_tracker, input, gestalt, mask, target, output_next, output_cur, encoding_loss, prediction_loss, position_loss, time_loss, latent_loss, slots_closed, slots_bounded, bg_error_cur, bg_error_next, scheduler, num_updates): + # area of foreground mask num_objects = th.mean(reduce((reduce(mask[:,:-1], 'b c h w -> b c', 'max') > 0.5).float(), 'b c -> b', 'sum')).item() # difference in shape - _gestalt = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt)), 'b (o c) -> (b o)', 'mean', o = cfg_net.num_objects) - _gestalt2 = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt))**2, 'b (o c) -> (b o)', 'mean', o = cfg_net.num_objects) + _gestalt = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt)), 'b (o c) -> (b o)', 'mean', o = cfg.model.num_objects) + _gestalt2 = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt))**2, 'b (o c) -> (b o)', 'mean', o = cfg.model.num_objects) max_mask = (reduce(mask[:,:-1], 'b c h w -> (b c)', 'max') > 0.5).float() avg_gestalt = (th.sum(_gestalt * max_mask) / (1e-16 + th.sum(max_mask))) avg_gestalt2 = (th.sum(_gestalt2 * max_mask) / (1e-16 + th.sum(max_mask))) avg_gestalt_mean = th.mean(th.clip(gestalt, 0, 1)) # udpdate gates - avg_update_gestalt = slots_closed[:,:,0].mean() - avg_update_position = slots_closed[:,:,1].mean() + num_bounded = reduce(slots_bounded, 'b o -> b', 'sum').mean().item() + slots_closed = slots_closed * slots_bounded[:,:,None] + avg_update_gestalt = slots_closed[:,:,0].sum()/slots_bounded.sum() if slots_bounded.sum() > 0 else 0.0 + avg_update_position = slots_closed[:,:,1].sum()/slots_bounded.sum() if slots_bounded.sum() > 0 else 0.0 + avg_update_gestalt = float(avg_update_gestalt) + avg_update_position = float(avg_update_position) + + # Prediction Loss + Encoder Loss as MSE + only foreground pixels + mseloss = nn.MSELoss() + loss_next = mseloss(output_next * bg_error_next, target * bg_error_next) + loss_cur = mseloss(output_cur * bg_error_cur, input * bg_error_cur) + + # learning rate + if scheduler is not None: + lr = scheduler[0].get_last_lr()[0] + else: + lr = cfg.learning_rate.lr - loss_tracker.update_complete(position_loss, time_loss, loss_cur, loss_next, loss_next, num_objects, net.get_openings(), avg_gestalt, avg_gestalt2, avg_gestalt_mean, avg_update_gestalt, avg_update_position) + # gatelORD openings + openings = th.mean(net.get_openings()).item() + + loss_tracker.update_complete(position_loss, time_loss, latent_loss, loss_cur, loss_next, num_objects, openings, avg_gestalt, avg_gestalt2, avg_gestalt_mean, avg_update_gestalt, avg_update_position, num_bounded, lr, num_updates) pass def log_modelsize(net): @@ -381,15 +458,49 @@ def log_modelsize(net): print("\n") pass +def initialize_loss_function(cfg): + if not ('loss' in cfg.defaults): + cfg.defaults.loss = 'bce' + + if cfg.defaults.loss == 'mse': + imageloss = nn.MSELoss() + elif cfg.defaults.loss == 'bce': + imageloss = nn.BCELoss() + else: + raise NotImplementedError + + return imageloss + def init_optimizer(cfg, net): - optimizer_init = RAdam(net.initial_states.parameters(), lr = cfg.learning_rate * 30) - optimizer_encoder = RAdam(net.encoder.parameters(), lr = cfg.learning_rate) - optimizer_decoder = RAdam(net.decoder.parameters(), lr = cfg.learning_rate) - optimizer_predictor = RAdam(net.predictor.parameters(), lr = cfg.learning_rate) - optimizer_background = RAdam([net.background.mask], lr = cfg.learning_rate) - optimizer_update = RAdam(net.percept_gate_controller.parameters(), lr = cfg.learning_rate) + # backward compability: + if not isinstance(cfg.learning_rate, dict): + cfg.learning_rate = Dict({'lr': cfg.learning_rate}) + + lr = cfg.learning_rate.lr + optimizer_init = RAdam(net.initial_states.parameters(), lr = lr * 30) + optimizer_encoder = RAdam(net.encoder.parameters(), lr = lr) + optimizer_decoder = RAdam(net.decoder.parameters(), lr = lr) + optimizer_predictor = RAdam(net.predictor.parameters(), lr = lr) + optimizer_background = RAdam([net.background.mask], lr = lr) + optimizer_update = RAdam(net.percept_gate_controller.parameters(), lr = lr) return optimizer_init,optimizer_encoder,optimizer_decoder,optimizer_predictor,optimizer_background,optimizer_update +def init_lr_scheduler_StepLR(cfg, *optimizer_list): + scheduler_list = None + if 'deacrease_lr_every' in cfg.learning_rate: + print('Init lr scheduler') + scheduler_list = [] + for optimizer in optimizer_list: + scheduler = th.optim.lr_scheduler.StepLR(optimizer, step_size=cfg.learning_rate.deacrease_lr_every, gamma=cfg.learning_rate.deacrease_lr_factor) + scheduler_list.append(scheduler) + return scheduler_list + +def step_lr_scheduler(scheduler_list): + if scheduler_list is not None: + for scheduler in scheduler_list: + scheduler.step() + pass + def save_model( file, net, @@ -432,23 +543,23 @@ def load_model( if load_optimizers: optimizer_init.load_state_dict(state[f'optimizer_init']) for n in range(len(optimizer_init.param_groups)): - optimizer_init.param_groups[n]['lr'] = cfg.learning_rate + optimizer_init.param_groups[n]['lr'] = cfg.learning_rate.lr optimizer_encoder.load_state_dict(state[f'optimizer_encoder']) for n in range(len(optimizer_encoder.param_groups)): - optimizer_encoder.param_groups[n]['lr'] = cfg.learning_rate + optimizer_encoder.param_groups[n]['lr'] = cfg.learning_rate.lr optimizer_decoder.load_state_dict(state[f'optimizer_decoder']) for n in range(len(optimizer_decoder.param_groups)): - optimizer_decoder.param_groups[n]['lr'] = cfg.learning_rate + optimizer_decoder.param_groups[n]['lr'] = cfg.learning_rate.lr optimizer_predictor.load_state_dict(state['optimizer_predictor']) for n in range(len(optimizer_predictor.param_groups)): - optimizer_predictor.param_groups[n]['lr'] = cfg.learning_rate + optimizer_predictor.param_groups[n]['lr'] = cfg.learning_rate.lr optimizer_background.load_state_dict(state['optimizer_background']) for n in range(len(optimizer_background.param_groups)): - optimizer_background.param_groups[n]['lr'] = cfg.model.background.learning_rate + optimizer_background.param_groups[n]['lr'] = cfg.model.background.learning_rate.lr # 1. Fill model with values of net model = {} @@ -484,40 +595,61 @@ def update_net_status(num_updates, net, cfg, optimizer_init): if num_updates == cfg.phases.entity_pretraining_phase2_end and net.get_init_status() < 3: net.inc_init_level() for param in optimizer_init.param_groups: - param['lr'] = cfg.learning_rate + param['lr'] = cfg.learning_rate.lr pass def plot_online(cfg, path, num_updates, input, background, mask, sequence_len, t, output_next, bg_error_next): - plot_path = os.path.join(path, 'plots', f'net_{num_updates}') - if not os.path.exists(plot_path): + + # highlight error + grayscale = input[:,0:1] * 0.299 + input[:,1:2] * 0.587 + input[:,2:3] * 0.114 + object_mask_cur = th.sum(mask[:,:-1], dim=1).unsqueeze(dim=1) + highlited_input = grayscale * (1 - object_mask_cur) + highlited_input += grayscale * object_mask_cur * 0.3333333 + cmask = color_mask(mask[:,:-1]) + highlited_input = highlited_input + cmask * 0.6666666 + + input_ = rearrange(input[0], 'c h w -> h w c').detach().cpu() + background_ = rearrange(background[0], 'c h w -> h w c').detach().cpu() + mask_ = rearrange(mask[0,-1:], 'c h w -> h w c').detach().cpu() + output_next_ = rearrange(output_next[0], 'c h w -> h w c').detach().cpu() + bg_error_next_ = rearrange(bg_error_next[0], 'c h w -> h w c').detach().cpu() + highlited_input_ = rearrange(highlited_input[0], 'c h w -> h w c').detach().cpu() + + if False: + plot_path = os.path.join(path, 'plots', f'net_{num_updates}') os.makedirs(plot_path, exist_ok=True) - - # highlight error - grayscale = input[:,0:1] * 0.299 + input[:,1:2] * 0.587 + input[:,2:3] * 0.114 - object_mask_cur = th.sum(mask[:,:-1], dim=1).unsqueeze(dim=1) - highlited_input = grayscale * (1 - object_mask_cur) - highlited_input += grayscale * object_mask_cur * 0.3333333 - cmask = color_mask(mask[:,:-1]) - highlited_input = highlited_input + cmask * 0.6666666 - - cv2.imwrite(os.path.join(plot_path, f'input-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(input[0], 'c h w -> h w c').detach().cpu().numpy() * 255) - cv2.imwrite(os.path.join(plot_path, f'background-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(background[0], 'c h w -> h w c').detach().cpu().numpy() * 255) - cv2.imwrite(os.path.join(plot_path, f'error_mask-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(bg_error_next[0], 'c h w -> h w c').detach().cpu().numpy() * 255) - cv2.imwrite(os.path.join(plot_path, f'background_mask-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(mask[0,-1:], 'c h w -> h w c').detach().cpu().numpy() * 255) - cv2.imwrite(os.path.join(plot_path, f'output_next-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(output_next[0], 'c h w -> h w c').detach().cpu().numpy() * 255) - cv2.imwrite(os.path.join(plot_path, f'output_highlight-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(highlited_input[0], 'c h w -> h w c').detach().cpu().numpy() * 255) - pass - -def get_loader(cfg, set, cfg_net, shuffle = True): - loader = DataLoader( - set, - pin_memory = True, - num_workers = cfg.defaults.num_workers, - batch_size = cfg_net.batch_size, - shuffle = shuffle, - drop_last = True, - prefetch_factor = cfg.defaults.prefetch_factor, - persistent_workers = True - ) + cv2.imwrite(os.path.join(plot_path, f'input-{t+cfg.defaults.teacher_forcing:03d}.jpg'), input_.numpy() * 255) + cv2.imwrite(os.path.join(plot_path, f'background-{t+cfg.defaults.teacher_forcing:03d}.jpg'), background_.numpy() * 255) + cv2.imwrite(os.path.join(plot_path, f'error_mask-{t+cfg.defaults.teacher_forcing:03d}.jpg'), bg_error_next_.numpy() * 255) + cv2.imwrite(os.path.join(plot_path, f'background_mask-{t+cfg.defaults.teacher_forcing:03d}.jpg'), mask_.numpy() * 255) + cv2.imwrite(os.path.join(plot_path, f'output_next-{t+cfg.defaults.teacher_forcing:03d}.jpg'), output_next_.numpy() * 255) + cv2.imwrite(os.path.join(plot_path, f'output_highlight-{t+cfg.defaults.teacher_forcing:03d}.jpg'), highlited_input_.numpy() * 255) + + # stack input, output, mask and highlight into one image horizontally + img_tensor = th.cat([input_, highlited_input_, output_next_], dim=1) + return img_tensor + +def get_loader(cfg, set, cfg_net, shuffle = True, verbose=False): + if ((cfg.datatype == 'bouncingballs') and verbose) or ((cfg.datatype == 'adept') and not shuffle): + loader = DataLoader( + set, + pin_memory = True, + num_workers = 0, + batch_size = cfg_net.batch_size, + shuffle = shuffle, + drop_last = True, + persistent_workers = False + ) + else: + loader = DataLoader( + set, + pin_memory = True, + num_workers = cfg.defaults.num_workers, + batch_size = cfg_net.batch_size, + shuffle = shuffle, + drop_last = True, + prefetch_factor = cfg.defaults.prefetch_factor, + persistent_workers = True + ) return loader
\ No newline at end of file diff --git a/scripts/utils/eval_adept.py b/scripts/utils/eval_adept.py new file mode 100644 index 0000000..7b8dfa8 --- /dev/null +++ b/scripts/utils/eval_adept.py @@ -0,0 +1,169 @@ +import pandas as pd +import warnings +import os +import argparse + +warnings.simplefilter(action='ignore', category=FutureWarning) +pd.options.mode.chained_assignment = None + +def eval_adept(path): + net = 'net1' + + # read pickle file + tf = pd.DataFrame() + sf = pd.DataFrame() + af = pd.DataFrame() + + with open(os.path.join(path, 'trialframe.csv'), 'rb') as f: + tf_temp = pd.read_csv(f, index_col=0) + tf_temp['net'] = net + tf = pd.concat([tf,tf_temp]) + + with open(os.path.join(path, 'slotframe.csv'), 'rb') as f: + sf_temp = pd.read_csv(f, index_col=0) + sf_temp['net'] = net + sf = pd.concat([sf,sf_temp]) + + with open(os.path.join(path, 'accframe.csv'), 'rb') as f: + af_temp = pd.read_csv(f, index_col=0) + af_temp['net'] = net + af = pd.concat([af,af_temp]) + + # cast variables + sf['visible'] = sf['visible'].astype(bool) + sf['bound'] = sf['bound'].astype(bool) + sf['occluder'] = sf['occluder'].astype(bool) + sf['inimage'] = sf['inimage'].astype(bool) + sf['vanishing'] = sf['vanishing'].astype(bool) + sf['alpha_pos'] = 1-sf['alpha_pos'] + sf['alpha_ges'] = 1-sf['alpha_ges'] + + # scale to percentage + sf['TE'] = sf['TE'] * 100 + + # add surprise as dummy code + tf['control'] = [('control' in set) for set in tf['set']] + sf['control'] = [('control' in set) for set in sf['set']] + + + + # STATS: + tracking_error_visible = 0 + tracking_error_occluded = 0 + num_positive_trackings = 0 + mota = 0 + gate_openings_visible = 0 + gate_openings_occluded = 0 + + + print('Tracking Error ------------------------------') + grouping = (sf.inimage & sf.bound & ~sf.occluder & sf.control) + + def get_stats(col): + return f' M: {col.mean():.3} , STD: {col.std():.3}, Count: {col.count()}' + + # When Visible + temp = sf[grouping & sf.visible] + print(f'Tracking Error when visible:' + get_stats(temp['TE'])) + tracking_error_visible = temp['TE'].mean() + + # When Occluded + temp = sf[grouping & ~sf.visible] + print(f'Tracking Error when occluded:' + get_stats(temp['TE'])) + tracking_error_occluded = temp['TE'].mean() + + + + + + + print('Positive Trackings ------------------------------') + # succesfull trackings: In the last visible moment of the target, the slot was less than 10% away from the target + # determine last visible frame numeric + grouping_factors = ['net','set','evalmode','scene','slot'] + ff = sf[sf.visible & sf.bound & sf.inimage].groupby(grouping_factors).max() + ff.rename(columns = {'frame':'last_visible'}, inplace = True) + sf = sf.merge(ff[['last_visible']], on=grouping_factors, how='left') + + # same for first bound frame + ff = sf[sf.visible & sf.bound & sf.inimage].groupby(grouping_factors).min() + ff.rename(columns = {'frame':'first_visible'}, inplace = True) + sf = sf.merge(ff[['first_visible']], on=grouping_factors, how='left') + + # add dummy variable to sf + sf['last_visible'] = (sf['last_visible'] == sf['frame']) + + # extract the trials where the target was last visible and threshold the TE + ff = sf[sf['last_visible']] + ff['tracked_pos'] = (ff['TE'] < 10) + ff['tracked_neg'] = (ff['TE'] >= 10) + + # fill NaN with 0 + sf = sf.merge(ff[grouping_factors + ['tracked_pos', 'tracked_neg']], on=grouping_factors, how='left') + sf['tracked_pos'].fillna(False, inplace=True) + sf['tracked_neg'].fillna(False, inplace=True) + + # Aggreagte over all scenes + temp = sf[(sf['frame']== 1) & ~sf.occluder & sf.control & (sf.first_visible < 20)] + temp = temp.groupby(['set', 'evalmode']).sum() + temp = temp[['tracked_pos', 'tracked_neg']] + temp = temp.reset_index() + + temp['tracked_pos_pro'] = temp['tracked_pos'] / (temp['tracked_pos'] + temp['tracked_neg']) + temp['tracked_neg_pro'] = temp['tracked_neg'] / (temp['tracked_pos'] + temp['tracked_neg']) + print(temp) + num_positive_trackings = temp['tracked_pos_pro'] + + + + + + print('Mostly Trecked /MOTA ------------------------------') + temp = af[af.index == 'OVERALL'] + temp['mostly_tracked'] = temp['mostly_tracked'] / temp['num_unique_objects'] + temp['partially_tracked'] = temp['partially_tracked'] / temp['num_unique_objects'] + temp['mostly_lost'] = temp['mostly_lost'] / temp['num_unique_objects'] + print(temp) + mota = temp['mota'] + + + print('Openings ------------------------------') + grouping = (sf.inimage & sf.bound & ~sf.occluder & sf.control) + temp = sf[grouping & sf.visible] + print(f'Percept gate openings when visible:' + get_stats(temp['alpha_pos'] + temp['alpha_ges'])) + gate_openings_visible = temp['alpha_pos'].mean() + temp['alpha_ges'].mean() + + temp = sf[grouping & ~sf.visible] + print(f'Percept gate openings when occluded:' + get_stats(temp['alpha_pos'] + temp['alpha_ges'])) + gate_openings_occluded = temp['alpha_pos'].mean() + temp['alpha_ges'].mean() + + + print('------------------------------------------------') + print('------------------------------------------------') + str = '' + str += f'net: {net}\n' + str += f'Tracking Error when visible: {tracking_error_visible:.3}\n' + str += f'Tracking Error when occluded: {tracking_error_occluded:.3}\n' + str += 'Positive Trackings: ' + ', '.join(f'{val:.3}' for val in num_positive_trackings) + '\n' + str += 'MOTA: ' + ', '.join(f'{val:.3}' for val in mota) + '\n' + str += f'Percept gate openings when visible: {gate_openings_visible:.3}\n' + str += f'Percept gate openings when occluded: {gate_openings_occluded:.3}\n' + + print(str) + + # write tstring to file + with open(os.path.join(path, 'results.txt'), 'w') as f: + f.write(str) + + + +if __name__ == "__main__": + + # use argparse to get the path to the results folder + parser = argparse.ArgumentParser() + parser.add_argument('--path', type=str, default='') + args = parser.parse_args() + + # setting path to results folder + path = args.path + eval_adept(path)
\ No newline at end of file diff --git a/scripts/utils/eval_metrics.py b/scripts/utils/eval_metrics.py index 6f5106d..98af468 100644 --- a/scripts/utils/eval_metrics.py +++ b/scripts/utils/eval_metrics.py @@ -3,6 +3,8 @@ vp_utils.py SCRIPT TAKEN FROM https://github.com/pairlab/SlotFormer ''' +import cv2 +from einops import rearrange import numpy as np from scipy.optimize import linear_sum_assignment from skimage.metrics import structural_similarity, peak_signal_noise_ratio @@ -11,7 +13,7 @@ import torch import torch.nn.functional as F import torchvision.ops as vops -FG_THRE = 0.5 +FG_THRE = 0.3 # 0.5 for the rest, 0.3 for bouncingballs PALETTE = [(0, 255, 0), (0, 0, 255), (0, 255, 255), (255, 255, 0), (255, 0, 255), (100, 100, 255), (200, 200, 100), (170, 120, 200), (255, 0, 0), (200, 100, 100), (10, 200, 100), (200, 200, 200), @@ -272,6 +274,8 @@ def pred_eval_step( gt_bbox=None, pred_bbox=None, eval_traj=True, + gt_mask_hidden=None, + pred_mask_hidden=None, ): """Both of shape [B, T, C, H, W], torch.Tensor. masks of shape [B, T, H, W]. @@ -288,13 +292,14 @@ def pred_eval_step( if eval_traj: assert len(gt_mask.shape) == len(pred_mask.shape) == 4 assert gt_mask.shape == pred_mask.shape - if eval_traj: + if eval_traj and gt_bbox is not None: assert len(gt_pres_mask.shape) == 3 assert len(gt_bbox.shape) == len(pred_bbox.shape) == 4 T = gt.shape[1] # compute perceptual dist & mask metrics before converting to numpy all_percept_dist, all_ari, all_fari, all_miou = [], [], [], [] + all_ari_hidden, all_fari_hidden, all_miou_hidden = [], [], [] for t in range(T): one_gt, one_pred = gt[:, t], pred[:, t] percept_dist = perceptual_dist(one_gt, one_pred, lpips_fn).item() @@ -307,6 +312,29 @@ def pred_eval_step( all_ari.append(ari) all_fari.append(fari) all_miou.append(miou) + + # hidden: + if gt_mask_hidden is not None: + one_gt_mask, one_pred_mask = gt_mask_hidden[:, t], pred_mask_hidden[:, t] + ari = ARI_metric(one_gt_mask, one_pred_mask) + fari = fARI_metric(one_gt_mask, one_pred_mask) + miou = miou_metric(one_gt_mask, one_pred_mask) + all_ari_hidden.append(ari) + all_fari_hidden.append(fari) + all_miou_hidden.append(miou) + + + # dispalay masks with cv2 + if False: + one_gt_mask_, one_pred_mask_ = one_gt_mask.clone(), one_pred_mask.clone() + + # concat one_gt_mask and one_pred_mask horizontally + frame = np.concatenate((one_gt_mask_, one_pred_mask_), axis=1) + frame = rearrange(frame, 'c h w -> h w c') * (255/6) + frame = frame.astype(np.uint8) + cv2.imshow('frame', frame) + cv2.waitKey(0) + else: all_ari.append(0.) all_fari.append(0.) @@ -315,14 +343,14 @@ def pred_eval_step( # compute bbox metrics all_ap, all_ar = [], [] for t in range(T): - if not eval_traj: + if not eval_traj or gt_bbox is None: all_ap.append(0.) all_ar.append(0.) continue one_gt_pres_mask, one_gt_bbox, one_pred_bbox = \ gt_pres_mask[:, t], gt_bbox[:, t], pred_bbox[:, t] ap, ar = batch_bbox_precision_recall(one_gt_pres_mask, one_gt_bbox, - one_pred_bbox) + one_pred_bbox) all_ap.append(ap) all_ar.append(ar) @@ -333,11 +361,13 @@ def pred_eval_step( one_gt, one_pred = gt[:, t], pred[:, t] mse = mse_metric(one_gt, one_pred) psnr = psnr_metric(one_gt, one_pred) - ssim = ssim_metric(one_gt, one_pred) + #ssim = ssim_metric(one_gt, one_pred) + ssim = 0 all_mse.append(mse) all_ssim.append(ssim) all_psnr.append(psnr) - return { + + res = { 'mse': all_mse, 'ssim': all_ssim, 'psnr': all_psnr, @@ -346,5 +376,12 @@ def pred_eval_step( 'fari': all_fari, 'miou': all_miou, 'ap': all_ap, - 'ar': all_ar, + 'ar': all_ar } + + if gt_mask_hidden is not None: + res['ari_hidden'] = all_ari_hidden + res['fari_hidden'] = all_fari_hidden + res['miou_hidden'] = all_miou_hidden + + return res diff --git a/scripts/utils/eval_utils.py b/scripts/utils/eval_utils.py index faab7ec..a01ffd0 100644 --- a/scripts/utils/eval_utils.py +++ b/scripts/utils/eval_utils.py @@ -1,18 +1,92 @@ import os +import shutil +from einops import rearrange import torch as th from model.loci import Loci +def masks_to_boxes(masks: th.Tensor) -> th.Tensor: + """ + Compute the bounding boxes around the provided masks. + + Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with + ``0 <= x1 < x2`` and ``0 <= y1 < y2``. + + Args: + masks (Tensor[N, H, W]): masks to transform where N is the number of masks + and (H, W) are the spatial dimensions. + + Returns: + Tensor[N, 4]: bounding boxes + """ + if masks.numel() == 0: + return th.zeros((0, 4), device=masks.device, dtype=th.float) + + n = masks.shape[0] + + bounding_boxes = th.zeros((n, 4), device=masks.device, dtype=th.float) + + for index, mask in enumerate(masks): + if mask.sum() > 0: + y, x = th.where(mask != 0) + + bounding_boxes[index, 0] = th.min(x) + bounding_boxes[index, 1] = th.min(y) + bounding_boxes[index, 2] = th.max(x) + bounding_boxes[index, 3] = th.max(y) + + return bounding_boxes + +def boxes_to_centroids(boxes): + """Post-process masks instead of directly taking argmax. + + Args: + bboxes: [B, T, N, 4], 4: [x1, y1, x2, y2] + + Returns: + centroids: [B, T, N, 2], 2: [x, y] + """ + + centroids = (boxes[:, :, :, :2] + boxes[:, :, :, 2:]) / 2 + centroids = centroids.squeeze(0) + + # scale to [-1, 1] + centroids[:, :, 0] = centroids[:, :, 0] / 64 * 2 - 1 + centroids[:, :, 1] = centroids[:, :, 1] / 64 * 2 - 1 + + return centroids + +def compute_position_from_mask(mask): + """ + Compute the position of the object from the mask. + + Args: + mask (Tensor[B, N, H, W]): masks to transform where N is the number of masks + and (H, W) are the spatial dimensions. + + Returns: + Tensor[B, N, 2]: position of the object + + """ + masks_binary = (mask > 0.8).float()[:, :-1] + b, o, h, w = masks_binary.shape + masks2 = rearrange(masks_binary, 'b o h w -> (b o) h w') + boxes = masks_to_boxes(masks2.long()) + boxes = rearrange(boxes, '(b o) c -> b 1 o c', b=b, o=o) + centroids = boxes_to_centroids(boxes) + centroids = centroids[:, :, :, [1, 0]].squeeze(1) + return centroids + def setup_result_folders(file, name, set_test, evaluation_mode, object_view, individual_views): net_name = file.split('/')[-1].split('.')[0] #root_path = file.split('nets')[0] - root_path = os.path.join(*file.split('/')[0:-1]) + root_path = os.path.join(*file.split('/')[0:-2]) root_path = os.path.join(root_path, f'results{name}', net_name, set_test['type']) plot_path = os.path.join(root_path, evaluation_mode) # create directories - #if os.path.exists(plot_path): - # shutil.rmtree(plot_path) + if os.path.exists(plot_path): + shutil.rmtree(plot_path) os.makedirs(plot_path, exist_ok = True) if object_view: os.makedirs(os.path.join(plot_path, 'object'), exist_ok = True) @@ -59,13 +133,21 @@ def load_model(cfg, cfg_net, file, device): print(f"load {file} to device {device}") state = th.load(file, map_location=device) - # backward compatibility + # 1. Get keys of current model while ensuring backward compatibility model = {} + allowed_keys = [] + rand_state = net.state_dict() + for key, value in rand_state.items(): + allowed_keys.append(key) + + # 2. Overwrite with values from file for key, value in state["model"].items(): # replace update_module with percept_gate_controller in key string: key = key.replace("update_module", "percept_gate_controller") - model[key.replace(".module.", ".")] = value + if key in allowed_keys: + model[key.replace(".module.", ".")] = value + net.load_state_dict(model) # ??? diff --git a/scripts/utils/io.py b/scripts/utils/io.py index 9bd8158..787c575 100644 --- a/scripts/utils/io.py +++ b/scripts/utils/io.py @@ -65,7 +65,7 @@ def model_path(cfg: Configuration, overwrite=False, move_old=True): :param move_old: Moves old folder with the same name to an old folder, if not overwrite :return: Model path """ - _path = os.path.join('out') + _path = os.path.join('out', cfg.dataset) path = os.path.join(_path, cfg.model_path) if not os.path.exists(_path): @@ -95,14 +95,15 @@ def model_path(cfg: Configuration, overwrite=False, move_old=True): class LossLogger: - def __init__(self): + def __init__(self, writer): self.avgloss = UEMA() self.avg_position_loss = UEMA() self.avg_time_loss = UEMA() - self.avg_encoder_loss = UEMA() - self.avg_mse_object_loss = UEMA() - self.avg_long_mse_object_loss = UEMA(33333) + self.avg_latent_loss = UEMA() + self.avg_encoding_loss = UEMA() + self.avg_prediction_loss = UEMA() + self.avg_prediction_loss_long = UEMA(33333) self.avg_num_objects = UEMA() self.avg_openings = UEMA() self.avg_gestalt = UEMA() @@ -110,28 +111,73 @@ class LossLogger: self.avg_gestalt_mean = UEMA() self.avg_update_gestalt = UEMA() self.avg_update_position = UEMA() + self.avg_num_bounded = UEMA() + + self.writer = writer - def update_complete(self, avg_position_loss, avg_time_loss, avg_encoder_loss, avg_mse_object_loss, avg_long_mse_object_loss, avg_num_objects, avg_openings, avg_gestalt, avg_gestalt2, avg_gestalt_mean, avg_update_gestalt, avg_update_position): + def update_complete(self, avg_position_loss, avg_time_loss, avg_latent_loss, avg_encoding_loss, avg_prediction_loss, avg_num_objects, avg_openings, avg_gestalt, avg_gestalt2, avg_gestalt_mean, avg_update_gestalt, avg_update_position, avg_num_bounded, lr, num_updates): self.avg_position_loss.update(avg_position_loss.item()) self.avg_time_loss.update(avg_time_loss.item()) - self.avg_encoder_loss.update(avg_encoder_loss.item()) - self.avg_mse_object_loss.update(avg_mse_object_loss.item()) - self.avg_long_mse_object_loss.update(avg_long_mse_object_loss.item()) + self.avg_latent_loss.update(avg_latent_loss.item()) + self.avg_encoding_loss.update(avg_encoding_loss.item()) + self.avg_prediction_loss.update(avg_prediction_loss.item()) + self.avg_prediction_loss_long.update(avg_prediction_loss.item()) self.avg_num_objects.update(avg_num_objects) self.avg_openings.update(avg_openings) self.avg_gestalt.update(avg_gestalt.item()) self.avg_gestalt2.update(avg_gestalt2.item()) self.avg_gestalt_mean.update(avg_gestalt_mean.item()) - self.avg_update_gestalt.update(avg_update_gestalt.item()) - self.avg_update_position.update(avg_update_position.item()) + self.avg_update_gestalt.update(avg_update_gestalt) + self.avg_update_position.update(avg_update_position) + self.avg_num_bounded.update(avg_num_bounded) + + self.writer.add_scalar("Train/Position Loss", avg_position_loss.item(), num_updates) + self.writer.add_scalar("Train/Time Loss", avg_time_loss.item(), num_updates) + self.writer.add_scalar("Train/Latent Loss", avg_latent_loss.item(), num_updates) + self.writer.add_scalar("Train/Encoder Loss", avg_encoding_loss.item(), num_updates) + self.writer.add_scalar("Train/Prediction Loss", avg_prediction_loss.item(), num_updates) + self.writer.add_scalar("Train/Number of Objects", avg_num_objects, num_updates) + self.writer.add_scalar("Train/Openings", avg_openings, num_updates) + self.writer.add_scalar("Train/Gestalt", avg_gestalt.item(), num_updates) + self.writer.add_scalar("Train/Gestalt2", avg_gestalt2.item(), num_updates) + self.writer.add_scalar("Train/Gestalt Mean", avg_gestalt_mean.item(), num_updates) + self.writer.add_scalar("Train/Update Gestalt", avg_update_gestalt, num_updates) + self.writer.add_scalar("Train/Update Position", avg_update_position, num_updates) + self.writer.add_scalar("Train/Number Bounded", avg_num_bounded, num_updates) + self.writer.add_scalar("Train/Learning Rate", lr, num_updates) + pass - def update_average_loss(self, avgloss): + def update_average_loss(self, avgloss, num_updates): self.avgloss.update(avgloss) + self.writer.add_scalar("Train/Loss", avgloss, num_updates) pass def get_log(self): - info = f'Loss: {np.abs(float(self.avgloss)):.2e}|{float(self.avg_mse_object_loss):.2e}|{float(self.avg_long_mse_object_loss):.2e}, reg: {float(self.avg_encoder_loss):.2e}|{float(self.avg_time_loss):.2e}|{float(self.avg_position_loss):.2e}, obj: {float(self.avg_num_objects):.1f}, open: {float(self.avg_openings):.2e}|{float(self.avg_gestalt):.2f}, bin: {float(self.avg_gestalt_mean):.2e}|{np.sqrt(float(self.avg_gestalt2) - float(self.avg_gestalt)**2):.2e} closed: {float(self.avg_update_gestalt):.2e}|{float(self.avg_update_position):.2e}' + info = f'Loss: {np.abs(float(self.avgloss)):.2e}|{float(self.avg_prediction_loss):.2e}|{float(self.avg_prediction_loss_long):.2e}, reg: {float(self.avg_encoding_loss):.2e}|{float(self.avg_time_loss):.2e}|{float(self.avg_latent_loss):.2e}|{float(self.avg_position_loss):.2e}, obj: {float(self.avg_num_objects):.1f}, open: {float(self.avg_openings):.2e}|{float(self.avg_gestalt):.2f}, bin: {float(self.avg_gestalt_mean):.2e}|{np.sqrt(float(self.avg_gestalt2) - float(self.avg_gestalt)**2):.2e} closed: {float(self.avg_update_gestalt):.2e}|{float(self.avg_update_position):.2e}' return info + +class WriterWrapper(): + + def __init__(self, use_wandb: bool, cfg: Configuration): + if use_wandb: + from torch.utils.tensorboard import SummaryWriter + import wandb + wandb.init(project=f'Loci_Looped_{cfg.dataset}', name= cfg.model_path, sync_tensorboard=True, config=cfg) + self.writer = SummaryWriter() + else: + self.writer = None + + def add_scalar(self, name, value, step): + if self.writer is not None: + self.writer.add_scalar(name, value, step) + + def add_video(self, name, value, step): + if self.writer is not None: + self.writer.add_video(name, value, step) + + def flush(self): + if self.writer is not None: + self.writer.flush() diff --git a/scripts/utils/plot_utils.py b/scripts/utils/plot_utils.py index 5cb20de..1c83456 100644 --- a/scripts/utils/plot_utils.py +++ b/scripts/utils/plot_utils.py @@ -151,9 +151,10 @@ def to_rgb(tensor: th.Tensor): tensor ), dim=1) -def visualise_gate(gate, h, w): +def visualise_gate(gate, h, w, invert = False): bar = th.ones((1,h,w), device=gate.device) * 0.9 black = int(w*gate.item()) + black = w-black if invert else black if black > 0: bar[:,:, -black:] = 0 return bar @@ -266,28 +267,30 @@ def plot_online_error(error, error_name, target, t, i, sequence_len, root_path, return error_plot -def plot_object_view(error_plot, error_plot2, error_plot_slots, error_plot_slots2, highlighted_input, output_hidden, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object): +def plot_object_view(error_plot, error_plot2, error_plot_slots, error_plot_slots2, highlighted_input, output_hidden, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object, rollout_mode=False, openings=None): # add ground truth positions of objects to image + target = target.clone() if gt_positions_target_next is not None: for o in range(gt_positions_target_next.shape[1]): position = gt_positions_target_next[0, o] position = position/2 + 0.5 - if position[2] > 0.0 and position[0] > 0.0 and position[0] < 1.0 and position[1] > 0.0 and position[1] < 1.0: - width = 5 - w = np.clip(int(position[0]*target.shape[2]), width, target.shape[2]-width).item() + if (len(position.shape) < 3 or position[2] > 0.0) and position[0] > 0.0 and position[0] < 1.0 and position[1] > 0.0 and position[1] < 1.0: + width = int(target.shape[2]*0.05) + w = np.clip(int(position[0]*target.shape[2]), width, target.shape[2]-width).item() # made for bouncing balls h = np.clip(int(position[1]*target.shape[3]), width, target.shape[3]-width).item() col = get_color(o).view(3,1,1) target[0,:,(w-width):(w+width), (h-width):(h+width)] = col # add these positions to the associated slots velocity_next2d ilustration - slots = (association_table[0] == o).nonzero() - for s in slots.flatten(): - velocity_next2d[s,:,(w-width):(w+width), (h-width):(h+width)] = col + if association_table is not None: + slots = (association_table[0] == o).nonzero() + for s in slots.flatten(): + velocity_next2d[s,:,(w-width):(w+width), (h-width):(h+width)] = col - if output_hidden is not None and s != largest_object: - output_hidden[0,:,(w-width):(w+width), (h-width):(h+width)] = col + if output_hidden is not None and s != largest_object: + output_hidden[0,:,(w-width):(w+width), (h-width):(h+width)] = col gateheight = 60 ch = 40 @@ -296,22 +299,28 @@ def plot_object_view(error_plot, error_plot2, error_plot_slots, error_plot_slots gh_margin = int((gh-gh_bar)/2) margin = 20 slots_margin = 10 - height = size[0] * 6 + 18*5 + height = size[0] * 6 + 18*6 width = size[1] * 4 + 18*2 + size[1]*num_objects + 6*(num_objects+1) + slots_margin*(num_objects+1) img = th.ones((3, height, width), device = object_next.device) * 0.4 row = (lambda row_index: [2*size[0]*row_index + (row_index+1)*margin, 2*size[0]*(row_index+1) + (row_index+1)*margin]) col1 = range(margin, margin + size[1]*2) col2 = range(width-(margin+size[1]*2), width-margin) + # add frame around image + if rollout_mode: + img[0,margin-2:margin+size[0]*2+2, margin-2:margin+size[1]*2+2] = 1 + img[:,row(0)[0]:row(0)[1], col1] = preprocess(highlighted_input.to(object_next.device), 2)[0] img[:,row(1)[0]:row(1)[1], col1] = preprocess(output_hidden.to(object_next.device), 2)[0] img[:,row(2)[0]:row(2)[1], col1] = preprocess(target.to(object_next.device), 2)[0] + # add large error plots to image if error_plot is not None: img[:,row(0)[0]+gh+ch+2*margin-gh_margin:row(0)[1]+gh+ch+2*margin-gh_margin, col2] = preprocess(error_plot.to(object_next.device), normalize= True) if error_plot2 is not None: img[:,row(2)[0]:row(2)[1], col2] = preprocess(error_plot2.to(object_next.device), normalize= True) + # fill colunmns with slots for o in range(num_objects): col = 18+size[1]*2+6+o*(6+size[1])+(o+1)*slots_margin @@ -321,23 +330,35 @@ def plot_object_view(error_plot, error_plot2, error_plot_slots, error_plot_slots if (error_plot_slots2 is not None) and len(error_plot_slots2) > o: img[:,margin:margin+ch, col] = get_color(o).view(3,1,1).to(object_next.device) + # gestalt gate img[:,margin+ch+2*margin:2*margin+gh_bar+ch+margin, col] = visualise_gate(slots_closed[:,o, 0].to(object_next.device), h=gh_bar, w=len(col)) offset = gh+margin-gh_margin+ch+2*margin row = (lambda row_index: [offset+(size[0]+6)*row_index, offset+size[0]*(row_index+1)+6*row_index]) img[:,row(0)[0]:row(0)[1], col] = preprocess(rawmask_next[0,o].to(object_next.device)) img[:,row(1)[0]:row(1)[1], col] = preprocess(object_next[:,o].to(object_next.device)) + + # small error plots top row if (error_plot_slots2 is not None) and len(error_plot_slots2) > o: img[:,row(2)[0]:row(2)[1], col] = preprocess(error_plot_slots2[o].to(object_next.device), normalize=True) + # switch to bottom row offset = margin*2-8 row = (lambda row_index: [offset+(size[0]+6)*row_index, offset+size[0]*(row_index+1)+6*row_index]) + + # position gate img[:,row(4)[0]-gh+gh_margin:row(4)[0]-gh_margin, col] = visualise_gate(slots_closed[:,o, 1].to(object_next.device), h=gh_bar, w=len(col)) img[:,row(4)[0]:row(4)[1], col] = preprocess(velocity_next2d[o].to(object_next.device), normalize=True)[0] + + # small error plots bottom row if (error_plot_slots is not None) and len(error_plot_slots) > o: img[:,row(5)[0]:row(5)[1], col] = preprocess(error_plot_slots[o].to(object_next.device), normalize=True) - img = rearrange(img * 255, 'c h w -> h w c').cpu().numpy() + # add gatelord gate visualisation to image + if openings is not None: + img[:,row(5)[1]+gh_margin:row(5)[1]+gh-gh_margin, col] = visualise_gate(openings[:,o].to(object_next.device), h=gh_bar, w=len(col), invert = True) + + img = rearrange(img * 255, 'c h w -> h w c').cpu() return img @@ -347,8 +368,57 @@ def write_image(file, img): pass -def plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i): - +def extract_element(tensor, index): + if tensor is None: + return None + return tensor[index:index+1] + +def plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, sample_i, rollout_mode=False, num_vid=2, att=None, openings=None): + + if len(input) > 1: + img_list = None + for i in range(min(len(input), num_vid)): + _input = extract_element(input, i) + _target = extract_element(target, i) + _mask_cur = extract_element(mask_cur, i) + _mask_next = extract_element(mask_next, i) + _output_next = extract_element(output_next, i) + _position_encoder_cur = extract_element(position_encoder_cur, i) + _position_next = extract_element(position_next, i) + _rawmask_hidden = extract_element(rawmask_hidden, i) + _rawmask_cur = extract_element(rawmask_cur, i) + _rawmask_next = extract_element(rawmask_next, i) + _largest_object = extract_element(largest_object, i) + _object_cur = extract_element(object_cur, i) + _object_next = extract_element(object_next, i) + _object_hidden = extract_element(object_hidden, i) + _slots_bounded = extract_element(slots_bounded, i) + _slots_partially_occluded_cur = extract_element(slots_partially_occluded_cur, i) + _slots_occluded_cur = extract_element(slots_occluded_cur, i) + _slots_partially_occluded_next = extract_element(slots_partially_occluded_next, i) + _slots_occluded_next = extract_element(slots_occluded_next, i) + _slots_closed = extract_element(slots_closed, i) + _gt_positions_target_next = extract_element(gt_positions_target_next, i) + _association_table = extract_element(association_table, i) + _error_next = extract_element(error_next, i) + _output_hidden = extract_element(output_hidden, i) + _att = extract_element(att, i) + _openings = extract_element(openings, i) + + img = plot_timestep_single(cfg, cfg_net, _input, _target, _mask_cur, _mask_next, _output_next, _position_encoder_cur, _position_next, _rawmask_hidden, _rawmask_cur, _rawmask_next, _largest_object, _object_cur, _object_next, _object_hidden, _slots_bounded, _slots_partially_occluded_cur, _slots_occluded_cur, _slots_partially_occluded_next, _slots_occluded_next, _slots_closed, _gt_positions_target_next, _association_table, _error_next, _output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i, rollout_mode, att=_att, openings=_openings) + if img_list is None: + img_list = img.unsqueeze(0) + else: + img_list = th.cat((img_list, img.unsqueeze(0)), dim=0) + + return img_list.permute(0, 3, 1, 2) + + else: + img = plot_timestep_single(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, sample_i, rollout_mode, att=att, openings=openings) + return img + +def plot_timestep_single(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i, rollout_mode=False, att=None, openings=None): + # Create eposition helpers size, gaus2d, vector2d, scale = get_position_helper(cfg_net, mask_cur.device) @@ -364,7 +434,7 @@ def plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, velocity_next2d = color_slots(velocity_next2d, slots_bounded, slots_partially_occluded_next, slots_occluded_next) # compute occlusion - if (cfg.datatype == "adept"): + if (cfg.datatype == "adept") and rawmask_hidden is not None: rawmask_cur_l, rawmask_next_l = compute_occlusion_mask(rawmask_cur, rawmask_next, mask_cur, mask_next, scale) rawmask_cur_h, rawmask_next_h = compute_occlusion_mask(rawmask_cur, rawmask_hidden, mask_cur, mask_next, scale) rawmask_cur_h[:,largest_object] = rawmask_cur_l[:,largest_object] @@ -385,30 +455,39 @@ def plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, mask_next = rearrange(mask_next, 'b (o 1) h w -> b o 1 h w') if object_view: - if (cfg.datatype == "adept"): + if (cfg.datatype == "adept") and statistics_complete_slots is not None: num_objects = 4 error_plot_slots = plot_online_error_slots(statistics_complete_slots['TE'][-cfg_net.num_objects*(t+1):], 'Tracking error', target, sequence_len, root_path, statistics_complete_slots['visible'][-cfg_net.num_objects*(t+1):], slots_bounded) - error_plot_slots2 = plot_online_error_slots(statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], 'Image error', target, sequence_len, root_path, statistics_complete_slots['visible'][-cfg_net.num_objects*(t+1):], slots_bounded, ylim=0.0001) + #error_plot_slots2 = plot_online_error_slots(statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], 'Image error', target, sequence_len, root_path, statistics_complete_slots['visible'][-cfg_net.num_objects*(t+1):], slots_bounded, ylim=0.0001) error_plot = plot_online_error(statistics_batch['image_error'], 'Prediction error', target, t, i, sequence_len, root_path) error_plot2 = plot_online_error(statistics_batch['TE'], 'Tracking error', target, t, i, sequence_len, root_path) - img = plot_object_view(error_plot, error_plot2, error_plot_slots, error_plot_slots2, highlighted_input, output_hidden, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object) - else: + att_histogram = plot_attention_histogram(att, target, root_path) + img = plot_object_view(error_plot, error_plot2, error_plot_slots, att_histogram, highlighted_input, output_hidden, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object, openings=openings) + elif (cfg.datatype == "clevrer") and statistics_complete_slots is not None: num_objects = cfg_net.num_objects error_plot_slots2 = plot_online_error_slots(statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], 'Image error', target, sequence_len, root_path, statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], slots_bounded, ylim=0.0001) error_plot = plot_online_error(statistics_batch['image_error_mse'], 'Prediction error', target, t, i, sequence_len, root_path) - img = plot_object_view(error_plot, None, None, error_plot_slots2, highlighted_input, output_next, object_next, rawmask_next, velocity_next2d, target, slots_closed, None, None, size, num_objects, largest_object) - - cv2.imwrite(f'{plot_path}object/gpnet-objects-{i:04d}-{t_index:03d}.jpg', img) + img = plot_object_view(error_plot, None, None, error_plot_slots2, highlighted_input, output_next, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object, openings=openings) + else: + num_objects = cfg_net.num_objects + att_histogram = plot_attention_histogram(att, target, root_path) + img = plot_object_view(None, None, att_histogram, None, input, output, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object, rollout_mode=rollout_mode, openings=openings) + + if plot_path is not None: + cv2.imwrite(f'{plot_path}/object/{i:04d}-{t_index:03d}.jpg', img.numpy()) if individual_views: # ['error', 'input', 'background', 'prediction', 'position', 'rawmask', 'mask', 'othermask']: write_image(f'{plot_path}/individual/error/error-{i:04d}-{t_index:03d}.jpg', error_next[0]) - write_image(f'{plot_path}/individual/input/input-{i:04d}-{t_index:03d}.jpg', input[0]) + write_image(f'{plot_path}/individual/input/input-{i:04d}-{t_index:03d}.jpg', target[0]) write_image(f'{plot_path}/individual/background/background-{i:04d}-{t_index:03d}.jpg', mask_next[0,-1]) - write_image(f'{plot_path}/individual/imagination/imagination-{i:04d}-{t_index:03d}.jpg', output_hidden[0]) + #write_image(f'{plot_path}/individual/imagination/imagination-{i:04d}-{t_index:03d}.jpg', output_hidden[0]) write_image(f'{plot_path}/individual/prediction/prediction-{i:04d}-{t_index:03d}.jpg', output_next[0]) + for o in range(len(rawmask_next[0])): + write_image(f'{plot_path}/individual/rgb/object-{i:04d}-{o}-{t_index:03d}.jpg', object_next[0][o]) + write_image(f'{plot_path}/individual/rawmask/rawmask-{i:04d}-{o}-{t_index:03d}.jpg', rawmask_next[0][o]) - pass + return img def get_position_helper(cfg_net, device): size = cfg_net.input_size @@ -425,4 +504,35 @@ def reshape_slots(slots_bounded, slots_partially_occluded_cur, slots_occluded_cu slots_partially_occluded_next = th.squeeze(slots_partially_occluded_next)[..., None,None,None] slots_occluded_next = th.squeeze(slots_occluded_next)[..., None,None,None] - return slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next
\ No newline at end of file + return slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next + +def plot_attention_histogram(att, target, root_path): + att_plots = [] + if (att is not None) and (len(att) > 0): + att = att[0] + for object_attention in att: + + fig, ax = plt.subplots(figsize=(round(target.shape[3]/100,2), round(target.shape[2]/100,2))) + + # Plot a bar plot over the 6 objects + num_objects = len(object_attention) + ax.bar(range(num_objects), object_attention.cpu()) + ax.set_ylim([0,1]) + ax.set_xlim([-1,num_objects]) + ax.set_xticks(range(num_objects)) + #ax.set_xticklabels(['1','2','3','4','5','6']) + ax.set_ylabel('attention') + ax.set_xlabel('object') + ax.set_title('Attention histogram') + + # fixed + fig.tight_layout() + plt.savefig(f'{root_path}/tmp.jpg') + plot = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb()) + plot = th.from_numpy(np.array(plot).transpose(2,0,1)) + plt.close(fig) + att_plots.append(plot) + + return att_plots + else: + return None
\ No newline at end of file diff --git a/scripts/validation.py b/scripts/validation.py index 5b1638d..b59cc39 100644 --- a/scripts/validation.py +++ b/scripts/validation.py @@ -1,8 +1,11 @@ +import os import torch as th from torch import nn from torch.utils.data import DataLoader import numpy as np from einops import rearrange, repeat, reduce +from scripts.evaluation_bb import distance_eval_step +from scripts.evaluation_clevrer import compute_statistics_summary from scripts.utils.configuration import Configuration from model.loci import Loci import time @@ -10,12 +13,19 @@ import lpips from skimage.metrics import structural_similarity as ssimloss from skimage.metrics import peak_signal_noise_ratio as psnrloss -def validation_adept(valloader: DataLoader, net: Loci, cfg: Configuration, device): +from scripts.utils.eval_metrics import masks_to_boxes, postproc_mask, pred_eval_step +from scripts.utils.eval_utils import append_statistics, compute_position_from_mask +from scripts.utils.plot_utils import plot_timestep + +def validation_adept(valloader: DataLoader, net: Loci, cfg: Configuration, device, writer, epoch, root_path): # memory mseloss = nn.MSELoss() - avgloss = 0 + loss_next = 0 start_time = time.time() + cfg_net = cfg.model + num_steps = 0 + plot_path = os.path.join(root_path, 'plots', f'epoch_{epoch}') with th.no_grad(): for i, input in enumerate(valloader): @@ -103,7 +113,8 @@ def validation_adept(valloader: DataLoader, net: Loci, cfg: Configuration, devic # 1. Track error if t >= 0: loss = mseloss(output_next, target) - avgloss += loss.item() + loss_next += loss.item() + num_steps += 1 # 2. Remember output mask_last = mask_next.clone() @@ -122,12 +133,19 @@ def validation_adept(valloader: DataLoader, net: Loci, cfg: Configuration, devic error_next = th.sqrt(error_next) * bg_error_next error_last = error_next.clone() - print(f"Validation loss: {avgloss / len(valloader.dataset):.2e}, Time: {time.time() - start_time}") - + # PLotting + if (i == 0) and (t_index % 2 == 0) and (epoch % 3 == 0): + os.makedirs(os.path.join(plot_path, 'object'), exist_ok=True) + openings = net.get_openings() + img_tensor = plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, None, None, error_next, None, True, False, None, None, sequence_len, root_path, plot_path, t_index, t, i, openings=openings) + + print(f"Validation loss: {loss_next / num_steps:.2e}, Time: {time.time() - start_time}") + writer.add_scalar('Val/Prediction Loss', loss_next / num_steps, epoch) + pass -def validation_clevrer(valloader: DataLoader, net: Loci, cfg: Configuration, device): +def validation_clevrer(valloader: DataLoader, net: Loci, cfg: Configuration, device, writer, epoch, root_path): # memory mseloss = nn.MSELoss() @@ -140,6 +158,9 @@ def validation_clevrer(valloader: DataLoader, net: Loci, cfg: Configuration, dev burn_in_length = 6 rollout_length = 42 + + plot_path = os.path.join(root_path, 'plots', f'epoch_{epoch}') + os.makedirs(os.path.join(plot_path, 'object'), exist_ok=True) with th.no_grad(): for i, input in enumerate(valloader): @@ -256,6 +277,243 @@ def validation_clevrer(valloader: DataLoader, net: Loci, cfg: Configuration, dev error_next = th.sqrt(error_next) * bg_error_next error_last = error_next.clone() + # PLotting + if i == 0: + openings = net.get_openings() + img_tensor = plot_timestep(cfg, cfg.model, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, None, None, error_next, None, True, False, None, None, sequence_len, root_path, plot_path, t_index, t, i, openings=openings) + + print(f"MSE loss: {avgloss_mse / len(valloader.dataset):.2e}, LPIPS loss: {avgloss_lpips / len(valloader.dataset):.2e}, PSNR loss: {avgloss_psnr / len(valloader.dataset):.2e}, SSIM loss: {avgloss_ssim / len(valloader.dataset):.2e}, Time: {time.time() - start_time}") + writer.add_scalar('Val/Prediction Loss', avgloss_mse / len(valloader.dataset), epoch) + writer.add_scalar('Val/LPIPS Loss', avgloss_lpips / len(valloader.dataset), epoch) + writer.add_scalar('Val/PSNR Loss', avgloss_psnr / len(valloader.dataset), epoch) + writer.add_scalar('Val/SSIM Loss', avgloss_ssim / len(valloader.dataset), epoch) + + pass + + +def validation_bb(valloader: DataLoader, net: Loci, cfg: Configuration, device, writer, epoch, root_path): + + # memory + start_time = time.time() + net.eval() + evaluation_mode = 'vidpred_black' + use_meds = True + + # Evaluation Specifics + burn_in_length = 10 + rollout_length = 20 + rollout_length_stats = 10 # only consider the first 10 frames for statistics + target_size = (64, 64) + + # Losses + lpipsloss = lpips.LPIPS(net='vgg').to(device) + mseloss = nn.MSELoss() + metric_complete = {'mse': [], 'ssim': [], 'psnr': [], 'percept_dist': [], 'ari': [], 'fari': [], 'miou': [], 'ap': [], 'ar': [], 'meds': [], 'ari_hidden': [], 'fari_hidden': [], 'miou_hidden': []} + loss_next = 0.0 + loss_cur = 0.0 + num_steps = 0 + plot_path = os.path.join(root_path, 'plots', f'epoch_{epoch}') + os.makedirs(os.path.join(plot_path, 'object'), exist_ok=True) + + with th.no_grad(): + for i, input in enumerate(valloader): + + # Load data + tensor = input[0].float().to(device) + background_fix = input[1].to(device) + gt_pos = input[2].to(device) + gt_mask = input[3].to(device) + gt_pres_mask = input[4].to(device) + gt_hidden_mask = input[5].to(device) + sequence_len = tensor.shape[1] + + # placehodlers + mask_cur = None + mask_last = None + rawmask_last = None + position_last = None + gestalt_last = None + priority_last = None + gt_positions_target = None + slots_occlusionfactor = None + error_last = None + + # Memory + cfg_net = cfg.model + num_objects_bb = gt_pos.shape[2] + pred_pos_batch = th.zeros((cfg_net.batch_size, rollout_length, num_objects_bb, 2)).to(device) + gt_pos_batch = th.zeros((cfg_net.batch_size, rollout_length, num_objects_bb, 2)).to(device) + pred_img_batch = th.zeros((cfg_net.batch_size, rollout_length, 3, target_size[0], target_size[1])).to(device) + gt_img_batch = th.zeros((cfg_net.batch_size, rollout_length, 3, target_size[0], target_size[1])).to(device) + pred_mask_batch = th.zeros((cfg_net.batch_size, rollout_length, target_size[0], target_size[1])).to(device) + pred_hidden_mask_batch = th.zeros((cfg_net.batch_size, rollout_length, target_size[0], target_size[1])).to(device) + # Counters + num_rollout = 0 + num_burnin = 0 + + # loop through frames + for t_index,t in enumerate(range(-cfg.defaults.teacher_forcing, burn_in_length+rollout_length)): + + # Move to next frame + t_run = max(t, 0) + input = tensor[:,t_run] + target_cur = tensor[:,t_run] + target = th.clip(tensor[:,t_run+1], 0, 1) + gt_pos_t = gt_pos[:,t_run+1]/32-1 + gt_pos_t = th.concat((gt_pos_t, th.ones_like(gt_pos_t[:,:,:1])), dim=2) + + rollout_index = t_run - burn_in_length + rollout_active = False + if t>=0: + if rollout_index >= 0: + num_rollout += 1 + if (evaluation_mode == 'vidpred_black'): + input = output_next * 0 + error_last = error_last * 0 + rollout_active = True + elif (evaluation_mode == 'vidpred_auto'): + input = output_next + error_last = error_last * 0 + rollout_active = True + else: + num_burnin += 1 + + # obtain prediction + ( + output_next, + position_next, + gestalt_next, + priority_next, + mask_next, + rawmask_next, + object_next, + background, + slots_occlusionfactor, + output_cur, + position_cur, + gestalt_cur, + priority_cur, + mask_cur, + rawmask_cur, + object_cur, + position_encoder_cur, + slots_bounded, + slots_partially_occluded_cur, + slots_occluded_cur, + slots_partially_occluded_next, + slots_occluded_next, + slots_closed, + output_hidden, + largest_object, + rawmask_hidden, + object_hidden + ) = net( + input, + error_last, + mask_last, + rawmask_last, + position_last, + gestalt_last, + priority_last, + background_fix, + slots_occlusionfactor, + reset = (t == -cfg.defaults.teacher_forcing), + evaluate=True, + warmup = (t < 0), + shuffleslots = True, + reset_mask = (t <= 0), + allow_spawn = True, + show_hidden = False, + clean_slots = False, + ) + + # 1. Track error + if t >= 0: + + if (rollout_index >= 0): + # store positions per batch + if use_meds: + if False: + pred_pos_batch[:,rollout_index] = rearrange(position_next, 'b (o c) -> b o c', o=cfg_net.num_objects)[:,:,:2] + else: + pred_pos_batch[:,rollout_index] = compute_position_from_mask(rawmask_next) + + gt_pos_batch[:,rollout_index] = gt_pos_t[:,:,:2] + + pred_img_batch[:,rollout_index] = output_next + gt_img_batch[:,rollout_index] = target + + # Here we compute only the foreground segmentation mask + pred_mask_batch[:,rollout_index] = postproc_mask(mask_next[:,None,:,None])[:, 0] + + # Here we compute the hidden segmentation + occluded_cur = th.clip(rawmask_next - mask_next, 0, 1)[:,:-1] + occluded_sum_cur = 1-(reduce(occluded_cur, 'b c h w -> b h w', 'max') > 0.5).float() + occluded_cur = th.cat((occluded_cur, occluded_sum_cur[:,None]), dim=1) + pred_hidden_mask_batch[:,rollout_index] = postproc_mask(occluded_cur[:,None,:,None])[:, 0] + + # 2. Remember output + mask_last = mask_next.clone() + rawmask_last = rawmask_next.clone() + position_last = position_next.clone() + gestalt_last = gestalt_next.clone() + priority_last = priority_next.clone() + + # 3. Error for next frame + # background error + bg_error_cur = th.sqrt(reduce((input - background)**2, 'b c h w -> b 1 h w', 'mean')).detach() + bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach() + + # prediction error + error_next = th.sqrt(reduce((target - output_next)**2, 'b c h w -> b 1 h w', 'mean')).detach() + error_next = th.sqrt(error_next) * bg_error_next + error_last = error_next.clone() + + # Prediction and encoder Loss + loss_next += mseloss(output_next * bg_error_next, target * bg_error_next) + loss_cur += mseloss(output_cur * bg_error_cur, input * bg_error_cur) + num_steps += 1 + + # PLotting + if i == 0: + openings = net.get_openings() + img_tensor = plot_timestep(cfg, cfg_net, input, target_cur, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_pos_t, None, error_next, None, True, False, None, None, sequence_len, root_path, plot_path, t_index, t, i, rollout_mode=rollout_active, openings=openings) + + for b in range(cfg_net.batch_size): + + # perceptual similarity from slotformer paper + metric_dict = pred_eval_step( + gt = gt_img_batch[b:b+1], + pred = pred_img_batch[b:b+1], + pred_mask = pred_mask_batch.long()[b:b+1], + pred_mask_hidden = pred_hidden_mask_batch.long()[b:b+1], + pred_bbox = None, + gt_mask = gt_mask.long()[b:b+1, burn_in_length+1:burn_in_length+rollout_length+1], + gt_mask_hidden = gt_hidden_mask.long()[b:b+1, burn_in_length+1:burn_in_length+rollout_length+1], + gt_pres_mask = gt_pres_mask[b:b+1, burn_in_length+1:burn_in_length+rollout_length+1], + gt_bbox = None, + lpips_fn = lpipsloss, + eval_traj = True, + ) + + metric_dict['meds'] = distance_eval_step(gt_pos_batch[b], pred_pos_batch[b]) + metric_complete = append_statistics(metric_dict, metric_complete) + + # sanity check + if (num_rollout != rollout_length) and (num_burnin != burn_in_length): + raise ValueError('Number of rollout steps and burnin steps must be equal to the sequence length.') + + dic = compute_statistics_summary(metric_complete, evaluation_mode, consider_first_n_frames=rollout_length_stats) + writer.add_scalar('Val/Meds', dic['meds_complete_sum'], epoch) + + writer.add_scalar('Val/ARI_hidden', dic['ari_hidden_complete_average'], epoch) + writer.add_scalar('Val/ARI', dic['ari_complete_average'], epoch) + writer.add_scalar('Val/LPIPS', dic['percept_dist_complete_average'], epoch) + + writer.add_scalar('Val/Prediction Loss', loss_next / num_steps, epoch) + writer.add_scalar('Val/Encoding Loss', loss_cur / num_steps, epoch) + + net.train() pass
\ No newline at end of file |