diff options
Diffstat (limited to 'scripts')
-rw-r--r-- | scripts/evaluation_adept.py | 398 | ||||
-rw-r--r-- | scripts/evaluation_adept_savi.py | 233 | ||||
-rw-r--r-- | scripts/evaluation_clevrer.py | 311 | ||||
-rw-r--r-- | scripts/exec/eval.py | 28 | ||||
-rw-r--r-- | scripts/exec/eval_savi.py | 17 | ||||
-rw-r--r-- | scripts/exec/train.py | 33 | ||||
-rw-r--r-- | scripts/training.py | 523 | ||||
-rw-r--r-- | scripts/utils/configuration.py | 83 | ||||
-rw-r--r-- | scripts/utils/eval_metrics.py | 350 | ||||
-rw-r--r-- | scripts/utils/eval_utils.py | 78 | ||||
-rw-r--r-- | scripts/utils/io.py | 137 | ||||
-rw-r--r-- | scripts/utils/optimizers.py | 100 | ||||
-rw-r--r-- | scripts/utils/plot_utils.py | 428 | ||||
-rw-r--r-- | scripts/validation.py | 261 |
14 files changed, 2980 insertions, 0 deletions
diff --git a/scripts/evaluation_adept.py b/scripts/evaluation_adept.py new file mode 100644 index 0000000..eab3ad9 --- /dev/null +++ b/scripts/evaluation_adept.py @@ -0,0 +1,398 @@ +import torch as th +from torch.utils.data import Dataset, DataLoader, Subset +from torch import nn +import os +from scripts.utils.plot_utils import plot_timestep +from scripts.utils.configuration import Configuration +from scripts.utils.io import init_device +import numpy as np +from einops import rearrange, repeat, reduce +import motmetrics as mm +from copy import deepcopy +import pandas as pd +from scripts.utils.eval_utils import append_statistics, load_model, setup_result_folders, store_statistics + + +def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_frequency= 1, plot_first_samples = 2): + + # Set up cpu or gpu training + device, verbose = init_device(cfg) + + # Config + cfg_net = cfg.model + cfg_net.batch_size = 1 + + # Load model + net = load_model(cfg, cfg_net, file, device) + net.eval() + + # Plot config + object_view = True + individual_views = False + root_path = None + + # get evaluation sets for control and surprise condition + set_test_array, evaluation_modes = get_evaluation_sets(dataset) + + # memory + statistics_template = {'set': [], 'evalmode': [], 'scene': [], 'frame': [], 'image_error': [], 'TE': []} + statistics_complete = deepcopy(statistics_template) + statistics_complete_slots = {'set': [], 'evalmode': [], 'scene': [], 'frame': [], 'slot':[], 'TE': [], 'visible': [], 'bound': [], 'occluder': [], 'inimage': [], 'slot_error': [], 'mask_size': [], 'rawmask_size': [], 'rawmask_size_hidden': [], 'alpha_pos': [], 'alpha_ges': [], 'object_id': [], 'vanishing': []} + acc_memory_complete = None + + for set_test in set_test_array: + + for evaluation_mode in evaluation_modes: + print(f'Start evaluation loop: {evaluation_mode} - {set_test["type"]}') + + # Load data + dataloader = DataLoader( + Subset(dataset, set_test['samples']), + num_workers = 1, + pin_memory = False, + batch_size = 1, + shuffle = False + ) + + # memory + mseloss = nn.MSELoss() + root_path, plot_path = setup_result_folders(file, n, set_test, evaluation_mode, object_view, individual_views) + acc_memory_eval = [] + + with th.no_grad(): + for i, input in enumerate(dataloader): + print(f'Processing sample {i+1}/{len(dataloader)}', flush=True) + + # Data + tensor = input[0].float().to(device) + background_fix = input[1].to(device) + gt_object_positions = input[3].to(device) + gt_object_visibility = input[4].to(device) + gt_occluder_mask = input[5].to(device) + + # Apply skip frames + gt_object_positions = gt_object_positions[:,range(0, tensor.shape[1], cfg.defaults.skip_frames)] + gt_object_visibility = gt_object_visibility[:,range(0, tensor.shape[1], cfg.defaults.skip_frames)] + tensor = tensor[:,range(0, tensor.shape[1], cfg.defaults.skip_frames)] + sequence_len = tensor.shape[1] + + # Placehodlers + mask_cur = None + mask_last = None + rawmask_last = None + position_last = None + gestalt_last = None + priority_last = None + gt_positions_target = None + slots_occlusionfactor = None + error_last = None + + # Memory + association_table = th.ones(cfg_net.num_objects).to(device) * -1 + acc = mm.MOTAccumulator(auto_id=True) + statistics_batch = deepcopy(statistics_template) + slots_vanishing_memory = np.zeros(cfg_net.num_objects) + + # loop through frames + for t_index,t in enumerate(range(-cfg.defaults.teacher_forcing, sequence_len-1)): + + # Move to next frame + t_run = max(t, 0) + input = tensor[:,t_run] + target = th.clip(tensor[:,t_run+1], 0, 1) + gt_positions_target = gt_object_positions[:,t_run] + gt_positions_target_next = gt_object_positions[:,t_run+1] + gt_visibility_target = gt_object_visibility[:,t_run] + + # Forward Pass + ( + output_next, + position_next, + gestalt_next, + priority_next, + mask_next, + rawmask_next, + object_next, + background, + slots_occlusionfactor, + output_cur, + position_cur, + gestalt_cur, + priority_cur, + mask_cur, + rawmask_cur, + object_cur, + position_encoder_cur, + slots_bounded, + slots_partially_occluded_cur, + slots_occluded_cur, + slots_partially_occluded_next, + slots_occluded_next, + slots_closed, + output_hidden, + largest_object, + rawmask_hidden, + object_hidden + ) = net( + input, + error_last, + mask_last, + rawmask_last, + position_last, + gestalt_last, + priority_last, + background_fix, + slots_occlusionfactor, + reset = (t == -cfg.defaults.teacher_forcing), + evaluate=True, + warmup = (t < 0), + shuffleslots = False, + reset_mask = (t <= 0), + allow_spawn = True, + show_hidden = True, + clean_slots = (t <= 0), + ) + + # 1. Track error + if t >= 0: + + # Position error: MSE between predicted position and target position + tracking_error, tracking_error_perslot, association_table, slots_visible, slots_in_image, slots_occluder = calculate_tracking_error(gt_positions_target, gt_visibility_target, position_cur, cfg_net.num_objects, slots_bounded, slots_occluded_cur, association_table, gt_occluder_mask) + + statistics_batch = store_statistics(statistics_batch, + set_test['type'], + evaluation_mode, + set_test['samples'][i], + t, + mseloss(output_next, target).item(), + tracking_error) + + # Compute rawmask_size + rawmask_size, rawmask_size_hidden, mask_size = compute_mask_sizes(mask_next, rawmask_next, rawmask_hidden) + + # Compute slot-wise prediction error + slot_error = compute_slot_error(cfg_net, target, output_next, mask_next, mask_size) + + # Check if objects vanishes: they leave the scene suprsingly in the suprrise condition + slots_vanishing = compute_vanishing_slots(gt_positions_target, association_table, gt_positions_target_next) + slots_vanishing_memory = slots_vanishing + slots_vanishing_memory + + # Store slot statistics + statistics_complete_slots = store_statistics(statistics_complete_slots, + [set_test['type']] * cfg_net.num_objects, + [evaluation_mode] * cfg_net.num_objects, + [set_test['samples'][i]] * cfg_net.num_objects, + [t] * cfg_net.num_objects, + range(cfg_net.num_objects), + tracking_error_perslot.cpu().numpy().flatten(), + slots_visible.cpu().numpy().flatten().astype(int), + slots_bounded.cpu().numpy().flatten().astype(int), + slots_occluder.cpu().numpy().flatten().astype(int), + slots_in_image.cpu().numpy().flatten().astype(int), + slot_error.cpu().numpy().flatten(), + mask_size.cpu().numpy().flatten(), + rawmask_size.cpu().numpy().flatten(), + rawmask_size_hidden.cpu().numpy().flatten(), + slots_closed[:, :, 1].cpu().numpy().flatten(), + slots_closed[:, :, 0].cpu().numpy().flatten(), + association_table[0].cpu().numpy().flatten().astype(int), + extend = True) + + # Compute MOTA + acc = update_mota_acc(acc, gt_positions_target, position_cur, slots_bounded, cfg_net.num_objects, gt_occluder_mask, slots_occluder, rawmask_next) + + # 2. Remember output + mask_last = mask_next.clone() + rawmask_last = rawmask_next.clone() + position_last = position_next.clone() + gestalt_last = gestalt_next.clone() + priority_last = priority_next.clone() + + # 3. Error for next frame + bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach() + error_next = th.sqrt(reduce((target - output_next)**2, 'b c h w -> b 1 h w', 'mean')).detach() + error_next = th.sqrt(error_next) * bg_error_next + error_last = error_next.clone() + + # 4. Plot + if (t % plot_frequency == 0) and (i < plot_first_samples) and (t >= 0): + plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i) + + # fill jumping statistics + statistics_complete_slots['vanishing'].extend(np.tile(slots_vanishing_memory.astype(int), t+1)) + + # store batch statistics in complete statistics + acc_memory_eval.append(acc) + statistics_complete = append_statistics(statistics_complete, statistics_batch, extend = True) + + summary = mm.metrics.create().compute_many(acc_memory_eval, metrics=mm.metrics.motchallenge_metrics, generate_overall=True) + summary['set'] = set_test['type'] + summary['evalmode'] = evaluation_mode + acc_memory_complete = summary.copy() if acc_memory_complete is None else pd.concat([acc_memory_complete, summary]) + + + print('-- Evaluation Done --') + pd.DataFrame(statistics_complete).to_csv(f'{root_path}/statistics/trialframe.csv') + pd.DataFrame(statistics_complete_slots).to_csv(f'{root_path}/statistics/slotframe.csv') + pd.DataFrame(acc_memory_complete).to_csv(f'{root_path}/statistics/accframe.csv') + if object_view and os.path.exists(f'{root_path}/tmp.jpg'): + os.remove(f'{root_path}/tmp.jpg') + pass + + +def compute_vanishing_slots(gt_positions_target, association_table, gt_positions_target_next): + objects_vanishing = th.abs(gt_positions_target[:,:,2] - gt_positions_target_next[:,:,2]) > 0.2 + objects_vanishing = th.where(objects_vanishing.flatten())[0] + slots_vanishing = [(obj.item() in objects_vanishing) for obj in association_table[0]] + return slots_vanishing + +def compute_slot_error(cfg_net, target, output_next, mask_next, mask_size): + output_slot = repeat(mask_next[:,:-1], 'b o h w -> b o 3 h w') * repeat(output_next, 'b c h w -> b o c h w', o=cfg_net.num_objects) + target_slot = repeat(mask_next[:,:-1], 'b o h w -> b o 3 h w') * repeat(target, 'b c h w -> b o c h w', o=cfg_net.num_objects) + slot_error = reduce((output_slot - target_slot)**2, 'b o c h w -> b o', 'mean')/mask_size + return slot_error + +def compute_mask_sizes(mask_next, rawmask_next, rawmask_hidden): + rawmask_size = reduce(rawmask_next[:, :-1], 'b o h w-> b o', 'sum') + rawmask_size_hidden = reduce(rawmask_hidden[:, :-1], 'b o h w-> b o', 'sum') + mask_size = reduce(mask_next[:, :-1], 'b o h w-> b o', 'sum') + return rawmask_size,rawmask_size_hidden,mask_size + +def update_mota_acc(acc, gt_positions, estimated_positions, slots_bounded, cfg_num_objects, gt_occluder_mask, slots_occluder, rawmask, ignore_occluder = False): + + # num objects + num_objects = len(gt_positions[0]) + + # get rid of batch dimension and priority dimension + pos = rearrange(estimated_positions.detach()[0], '(o c) -> o c', o=cfg_num_objects)[:, :2] + targets = gt_positions[0, :, :2] + + # stretch positions to account for frame ratio, Specific for ADEPT! + pos = th.clip(pos, -1, 1) + pos[:, 0] = pos[:, 0] * 1.5 + targets[:, 0] = targets[:, 0] * 1.5 + + # remove objects that are not in the image + edge = 1 + in_image = th.cat([targets[:, 0] < (1.5 * edge), targets[:, 0] > (-1.5 * edge), targets[:, 1] < (1 * edge), targets[:, 1] > (-1 * edge)]) + in_image = th.all(rearrange(in_image, '(c o) -> o c', o=num_objects), dim=1) + + if ignore_occluder: + in_image = (gt_occluder_mask[0] == 0) * in_image + targets = targets[in_image] + + # test if position estimates in image + in_image_pos = th.cat([pos[:, 0] < (1.5 * edge), pos[:, 0] > (-1.5 * edge), pos[:, 1] < (1 * edge), pos[:, 1] > (-1 * edge)]) + in_image_pos = th.all(rearrange(in_image_pos, '(c o) -> c o', o=cfg_num_objects), dim=0, keepdim=True) + + # only position estimates that are in image and bound + if rawmask is not None: + rawmask_size = reduce(rawmask[:, :-1], 'b o h w-> b o', 'sum') + m = (slots_bounded * in_image_pos * (rawmask_size > 100)).bool() + else: + m = (slots_bounded * in_image_pos).bool() + if ignore_occluder: + m = (m * (1 - slots_occluder)).bool() + + pos = pos[repeat(m, '1 o -> o 2')] + pos = rearrange(pos, '(o c) -> o c', c = 2) + + # compute pairwise distances + diagonal_length = th.sqrt(th.sum(th.tensor([2,3])**2)).item() + C = mm.distances.norm2squared_matrix(targets.cpu().numpy(), pos.cpu().numpy(), max_d2=diagonal_length*0.1) + + # upadate accumulator + acc.update( (th.where(in_image)[0]).cpu(), (th.where(m)[1]).cpu(), C) + + return acc + +def calculate_tracking_error(gt_positions_target, gt_visibility_target, position_cur, cfg_num_slots, slots_bounded, slots_occluded_cur, association_table, gt_occluder_mask): + + # tracking utils + pdist = nn.PairwiseDistance(p=2).to(position_cur.device) + + # 1. association of newly bounded slots to ground truth objects + # num objects + num_objects = len(gt_positions_target[0]) + + # get rid of batch dimension and priority dimension + pos = rearrange(position_cur.clone()[0], '(o c) -> o c', o=cfg_num_slots)[:, :2] + targets = gt_positions_target[0, :, :2] + + # stretch positions to account for frame ratio, Specific for ADEPT! + pos = th.clip(pos, -1, 1) + pos[:, 0] = pos[:, 0] * 1.5 + targets[:, 0] = targets[:, 0] * 1.5 + diagonal_length = th.sqrt(th.sum(th.tensor([2,3])**2)) + + # reshape and repeat for comparison + pos = repeat(pos, 'o c -> (o r) c', r=num_objects) + targets = repeat(targets, 'o c -> (r o) c', r=cfg_num_slots) + + # comparison + distance = pdist(pos, targets) + distance = rearrange(distance, '(o r) -> o r', r=num_objects) + + # find closest target for each slot + distance = th.min(distance, dim=1, keepdim=True) + + # update association table + slots_newly_bounded = slots_bounded * (association_table == -1) + if slots_occluded_cur is not None: + slots_newly_bounded = slots_newly_bounded * (1-slots_occluded_cur) + association_table = association_table * (1-slots_newly_bounded) + slots_newly_bounded * distance[1].T + + # 2. position error + # get rid of batch dimension and priority dimension + pos = rearrange(position_cur.clone()[0], '(o c) -> o c', o=cfg_num_slots)[:, :2] + targets = gt_positions_target[0, :, :3] + + # stretch positions to account for frame ratio, Specific for ADEPT! + pos[:, 0] = pos[:, 0] * 1.5 + targets[:, 0] = targets[:, 0] * 1.5 + + # gather targets according to association table + targets = targets[association_table.long()][0] + + # determine which slosts are within the image + slots_in_image = th.cat([targets[:, 0] < 1.5, targets[:, 0] > -1.5, targets[:, 1] < 1, targets[:, 1] > -1, targets[:, 2] > 0]) + slots_in_image = rearrange(slots_in_image, '(c o) -> o c', o=cfg_num_slots) + slots_in_image = th.all(slots_in_image, dim=1) + + # define which slots to consider for tracking error + slots_to_track = slots_bounded * slots_in_image + + # compute position error + targets = targets[:, :2] + tracking_error_perslot = th.sqrt(th.sum((pos - targets)**2, dim=1))/diagonal_length + tracking_error_perslot = tracking_error_perslot[None, :] * slots_to_track + tracking_error = th.sum(tracking_error_perslot).item()/max(th.sum(slots_to_track).item(), 1) + + # compute which slots are visible + visible_objects = th.where(gt_visibility_target[0] == 1)[0] + slots_visible = th.tensor([[int(obj.item()) in visible_objects for obj in association_table[0]]]).float().to(slots_to_track.device) + slots_visible = slots_visible * slots_to_track + + # determine which objects are bound to the occluder + occluder_objects = th.where(gt_occluder_mask[0] == 1)[0] + slots_occluder = th.tensor([[int(obj.item()) in occluder_objects for obj in association_table[0]]]).float().to(slots_to_track.device) + slots_occluder = slots_occluder * slots_to_track + + return tracking_error, tracking_error_perslot, association_table, slots_visible, slots_in_image, slots_occluder + +def get_evaluation_sets(dataset): + + # Standad evaluation + evaluation_modes = ['open'] + + # Important! + # filter different scenarios: 1 as control and 0,3 as surprise (see Smith et al. 2020) + suprise_mask = [(sample.case in [1]) for i,sample in enumerate(dataset.samples)] + control_mask = [(sample.case in [0,3]) for i,sample in enumerate(dataset.samples)] + + # Create test sets + set_surprise = {"samples": np.where(suprise_mask)[0].tolist(), "type": 'surprise'} + set_control = {"samples": np.where(control_mask)[0].tolist(), "type": 'control'} + set_test_array = [set_control, set_surprise] + + return set_test_array, evaluation_modes
\ No newline at end of file diff --git a/scripts/evaluation_adept_savi.py b/scripts/evaluation_adept_savi.py new file mode 100644 index 0000000..6a2d5c7 --- /dev/null +++ b/scripts/evaluation_adept_savi.py @@ -0,0 +1,233 @@ +from einops import rearrange, reduce, repeat +import torch as th +from torch.utils.data import Dataset, DataLoader, Subset +import cv2 +import numpy as np +import pandas as pd +import os +from data.datasets.ADEPT.dataset import AdeptDataset +import motmetrics as mm +from scripts.evaluation_adept import calculate_tracking_error, get_evaluation_sets, update_mota_acc +from scripts.utils.eval_utils import setup_result_folders, store_statistics +from scripts.utils.plot_utils import write_image + +FG_THRE = 0.95 + +def evaluate(dataset: Dataset, file, n, plot_frequency= 1, plot_first_samples = 2): + + # read pkl file + masks_complete = pd.read_pickle(file) + + # plot config + color_list = [[255,0,0], [0,255,0], [0,0,255], [255,255,0], [0,255,255], [255,0,255], [255,255,255]] + dot_size = 2 + skip_frames = 2 + offset = 15 + + # memory + statistics_complete_slots = {'set': [], 'evalmode': [], 'scene': [], 'frame': [], 'slot':[], 'TE': [], 'visible': [], 'bound': [], 'occluder': [], 'inimage': [], 'slot_error': [], 'mask_size': [], 'rawmask_size': [], 'rawmask_size_hidden': [], 'alpha_pos': [], 'alpha_ges': [], 'object_id': []} + acc_memory_eval = [] + + # load adept dataset + set_test_array, evaluation_modes = get_evaluation_sets(dataset) + control_samples = set_test_array[0]['samples'] # only consider control set + evalset = Subset(dataset, control_samples) + root_path, plot_path = setup_result_folders(file, n, set_test_array[0], evaluation_modes[0], True, False) + + for i in range(len(evalset)): + print(f'Processing sample {i+1}/{len(evalset)}', flush=True) + input = evalset[i] + acc = mm.MOTAccumulator(auto_id=True) + + # get input frame and target frame + tensor = th.tensor(input[0]).float().unsqueeze(0) + background_fix = th.tensor(input[1]).unsqueeze(0) + gt_object_positions = th.tensor(input[3]).unsqueeze(0) + gt_object_visibility = th.tensor(input[4]).unsqueeze(0) + gt_occluder_mask = th.tensor(input[5]).unsqueeze(0) + + # apply skip frames + gt_object_positions = gt_object_positions[:,range(0, tensor.shape[1], skip_frames)] + gt_object_visibility = gt_object_visibility[:,range(0, tensor.shape[1], skip_frames)] + tensor = tensor[:,range(0, tensor.shape[1], skip_frames)] + sequence_len = tensor.shape[1] + + # load data + masks = th.tensor(masks_complete['test'][f'control_{i}.mp4']) + masks_before_softmax = th.tensor(masks_complete['test_raw'][f'control_{i}.mp4']) + + # calculate rawmasks + bg_mask = masks_before_softmax.mean(dim=1) + masks_raw = compute_maskraw(masks_before_softmax, bg_mask) + slots_bound = compute_slots_bound(masks_raw) + + # threshold masks and calculate centroids + masks_binary = (masks_raw > FG_THRE).float() + masks2 = rearrange(masks_binary, 't o 1 h w -> (t o) h w') + boxes = masks_to_boxes(masks2.long()) + boxes = boxes.reshape(1, masks.shape[0], 7, 4) + centroids = boxes_to_centroids(boxes) + + # get rid of batch dimension + association_table = th.ones(7) * -1 + + # iterate over frames + for t_index in range(offset,min(sequence_len,masks.shape[0])): + + # move to next frame + input = tensor[:,t_index] + target = th.clip(tensor[:,t_index+1], 0, 1) + gt_positions_target = gt_object_positions[:,t_index] + gt_positions_target_next = gt_object_positions[:,t_index+1] + gt_visibility_target = gt_object_visibility[:,t_index] + + position_cur = centroids[t_index] + position_cur = rearrange(position_cur, 'o c -> 1 (o c)') + slots_bound_cur = slots_bound[t_index] + slots_bound_cur = rearrange(slots_bound_cur, 'o c -> 1 (o c)') + + # calculate tracking error + tracking_error, tracking_error_perslot, association_table, slots_visible, slots_in_image, slots_occluder = calculate_tracking_error(gt_positions_target, gt_visibility_target, position_cur, 7, slots_bound_cur, None, association_table, gt_occluder_mask) + + rawmask_size = reduce(masks_raw[t_index], 'o 1 h w-> 1 o', 'sum') + mask_size = reduce(masks[t_index], 'o 1 h w-> 1 o', 'sum') + + statistics_complete_slots = store_statistics(statistics_complete_slots, + ['control'] * 7, + ['control'] * 7, + [control_samples[i]] * 7, + [t_index] * 7, + range(7), + tracking_error_perslot.cpu().numpy().flatten(), + slots_visible.cpu().numpy().flatten().astype(int), + slots_bound_cur.cpu().numpy().flatten().astype(int), + slots_occluder.cpu().numpy().flatten().astype(int), + slots_in_image.cpu().numpy().flatten().astype(int), + [0] * 7, + mask_size.cpu().numpy().flatten(), + rawmask_size.cpu().numpy().flatten(), + [0] * 7, + [0] * 7, + [0] * 7, + association_table[0].cpu().numpy().flatten().astype(int), + extend = True) + + acc = update_mota_acc(acc, gt_positions_target, position_cur, slots_bound_cur, 7, gt_occluder_mask, slots_occluder, None) + + # plot_option + if (t_index % plot_frequency == 0) and (i < plot_first_samples) and (t_index >= 0): + masks_to_display = masks_binary.numpy() # masks_binary.numpy() + + frame = tensor[0, t_index] + frame = frame.numpy().transpose(1,2,0) + frame = cv2.resize(frame, (64,64)) + + centroids_frame = centroids[t_index] + centroids_frame[:,0] = (centroids_frame[:,0] + 1) * 64 / 2 + centroids_frame[:,1] = (centroids_frame[:,1] + 1) * 64 / 2 + + bound_frame = slots_bound[t_index] + for c_index,centroid_slot in enumerate(centroids_frame): + if bound_frame[c_index] == 1: + frame[int(centroid_slot[1]-dot_size):int(centroid_slot[1]+dot_size), int(centroid_slot[0]-dot_size):int(centroid_slot[0]+dot_size)] = color_list[c_index] + + # slot images + slot_frame = masks_to_display[t_index].max(axis=0) + slot_frame = slot_frame.reshape((64,64,1)).repeat(3, axis=2) + + if True: + for mask in masks_to_display[t_index]: + #slot_frame_single = mask.reshape((64,64,1)).repeat(3, axis=2) + slot_frame_single = mask.transpose((1,2,0)).repeat(3, axis=2) + slot_frame = np.concatenate((slot_frame, slot_frame_single), axis=1) + + frame = np.concatenate((frame, slot_frame), axis=1) + cv2.imwrite(f'{plot_path}object/objects-{i:04d}-{t_index:03d}.jpg', frame*255) + + acc_memory_eval.append(acc) + + mh = mm.metrics.create() + summary = mh.compute_many(acc_memory_eval, metrics=mm.metrics.motchallenge_metrics, generate_overall=True) + summary['set'] = 'control' + summary['evalmode'] = 'control' + pd.DataFrame(summary).to_csv(os.path.join(root_path, 'statistics' , 'accframe.csv')) + pd.DataFrame(statistics_complete_slots).to_csv(os.path.join(root_path, 'statistics' , 'slotframe.csv')) + +def masks_to_boxes(masks: th.Tensor) -> th.Tensor: + """ + Compute the bounding boxes around the provided masks. + + Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with + ``0 <= x1 < x2`` and ``0 <= y1 < y2``. + + Args: + masks (Tensor[N, H, W]): masks to transform where N is the number of masks + and (H, W) are the spatial dimensions. + + Returns: + Tensor[N, 4]: bounding boxes + """ + if masks.numel() == 0: + return th.zeros((0, 4), device=masks.device, dtype=th.float) + + n = masks.shape[0] + + bounding_boxes = th.zeros((n, 4), device=masks.device, dtype=th.float) + + for index, mask in enumerate(masks): + if mask.sum() > 0: + y, x = th.where(mask != 0) + + bounding_boxes[index, 0] = th.min(x) + bounding_boxes[index, 1] = th.min(y) + bounding_boxes[index, 2] = th.max(x) + bounding_boxes[index, 3] = th.max(y) + + return bounding_boxes + +def boxes_to_centroids(boxes): + """Post-process masks instead of directly taking argmax. + + Args: + bboxes: [B, T, N, 4], 4: [x1, y1, x2, y2] + + Returns: + centroids: [B, T, N, 2], 2: [x, y] + """ + + centroids = (boxes[:, :, :, :2] + boxes[:, :, :, 2:]) / 2 + centroids = centroids.squeeze(0) + + # scale to [-1, 1] + centroids[:, :, 0] = centroids[:, :, 0] / 64 * 2 - 1 + centroids[:, :, 1] = centroids[:, :, 1] / 64 * 2 - 1 + + return centroids + +def compute_slots_bound(masks): + + # take sum over axis 3,4 with th + masks_sum = masks.amax(dim=(3,4)) + slots_bound = (masks_sum > FG_THRE).float() + return slots_bound + +def compute_maskraw(mask, bg_mask): + + # d is a diagonal matrix which defines what to take the softmax over + d_mask = th.diag(th.ones(8)) + d_mask[:,-1] = 1 + d_mask[-1,-1] = 0 + + mask = mask.squeeze(2) + + # take subset of maskraw with the diagonal matrix + maskraw = th.cat((mask, bg_mask), dim=1) + maskraw = repeat(maskraw, 'b o h w -> b r o h w', r = 8) + maskraw = maskraw[:,d_mask.bool()] + maskraw = rearrange(maskraw, 'b (o r) h w -> b o r h w', o = 7) + + # take softmax between each object mask and the background mask + maskraw = th.squeeze(th.softmax(maskraw, dim=2)[:,:,0], dim=2) + maskraw = maskraw.unsqueeze(2) + + return maskraw
\ No newline at end of file diff --git a/scripts/evaluation_clevrer.py b/scripts/evaluation_clevrer.py new file mode 100644 index 0000000..a43d50c --- /dev/null +++ b/scripts/evaluation_clevrer.py @@ -0,0 +1,311 @@ +import pickle +import torch as th +from torch.utils.data import Dataset, DataLoader, Subset +from torch import nn +import os +from scripts.utils.plot_utils import plot_timestep +from scripts.utils.eval_metrics import masks_to_boxes, pred_eval_step, postproc_mask +from scripts.utils.eval_utils import append_statistics, load_model, setup_result_folders, store_statistics +from scripts.utils.configuration import Configuration +from scripts.utils.io import init_device +import numpy as np +from einops import rearrange, repeat, reduce +from copy import deepcopy +import lpips +import torchvision.transforms as transforms + +def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_frequency= 1, plot_first_samples = 0): + + # Set up cpu or gpu training + device, verbose = init_device(cfg) + + # Config + cfg_net = cfg.model + cfg_net.batch_size = 2 if verbose else 32 + cfg_net.batch_size = 1 if plot_first_samples > 0 else cfg_net.batch_size # only plotting first samples + + # Load model + net = load_model(cfg, cfg_net, file, device) + net.eval() + + # config + object_view = True + individual_views = False + root_path = None + plotting_mode = (cfg_net.batch_size == 1) and (plot_first_samples > 0) + + # get evaluation sets + set_test_array, evaluation_modes = get_evaluation_sets(dataset) + + # memory + statistics_template = {'set': [], 'evalmode': [], 'scene': [], 'frame': [], 'image_error_mse': []} + statistics_complete_slots = {'set': [], 'evalmode': [], 'scene': [], 'frame': [], 'slot':[], 'bound': [], 'slot_error': [], 'rawmask_size': [], 'alpha_pos': [], 'alpha_ges': []} + metric_complete = None + + # Evaluation Specifics + burn_in_length = 6 + rollout_length = 42 + cfg.defaults.skip_frames = 2 + blackout_p = 0.2 + target_size = (64,64) + dataset.burn_in_length = burn_in_length + dataset.rollout_length = rollout_length + dataset.skip_length = cfg.defaults.skip_frames + + # Transformation utils + to_small = transforms.Resize(target_size) + to_normalize = transforms.Normalize((0.5, ), (0.5, )) + to_smallnorm = transforms.Compose([to_small, to_normalize]) + + # Losses + lpipsloss = lpips.LPIPS(net='vgg').to(device) + mseloss = nn.MSELoss() + + for set_test in set_test_array: + + for evaluation_mode in evaluation_modes: + print(f'Start evaluation loop: {evaluation_mode}') + + # load data + dataloader = DataLoader( + Subset(dataset, range(plot_first_samples)) if plotting_mode else dataset, + num_workers = 1, + pin_memory = False, + batch_size = cfg_net.batch_size, + shuffle = False, + drop_last = True, + ) + + # memory + root_path, plot_path = setup_result_folders(file, n, set_test, evaluation_mode, object_view, individual_views) + metric_complete = {'mse': [], 'ssim': [], 'psnr': [], 'percept_dist': [], 'ari': [], 'fari': [], 'miou': [], 'ap': [], 'ar': [], 'blackout': []} + + with th.no_grad(): + for i, input in enumerate(dataloader): + print(f'Processing sample {i+1}/{len(dataloader)}', flush=True) + + # Load data + tensor = input[0].float().to(device) + background_fix = input[1].to(device) + gt_mask = input[2].to(device) + gt_bbox = input[3].to(device) + gt_pres_mask = input[4].to(device) + sequence_len = tensor.shape[1] + + # Placehodlers + mask_cur = None + mask_last = None + rawmask_last = None + position_last = None + gestalt_last = None + priority_last = None + slots_occlusionfactor = None + error_last = None + + # Memory + pred = th.zeros((cfg_net.batch_size, rollout_length, 3, target_size[0], target_size[1])).to(device) + gt = th.zeros((cfg_net.batch_size, rollout_length, 3, target_size[0], target_size[1])).to(device) + pred_mask = th.zeros((cfg_net.batch_size, rollout_length, target_size[0], target_size[1])).to(device) + statistics_batch = deepcopy(statistics_template) + + # Counters + num_rollout = 0 + num_burnin = 0 + blackout_mem = [0] + + # Loop through frames + for t_index,t in enumerate(range(-cfg.defaults.teacher_forcing, sequence_len-1)): + + # Move to next frame + t_run = max(t, 0) + input = tensor[:,t_run] + target = th.clip(tensor[:,t_run+1], 0, 1) + + rollout_index = t_run - burn_in_length + if (rollout_index >= 0) and (evaluation_mode == 'blackout'): + blackout = (th.rand(1) < blackout_p).float().to(device) + input = blackout * (input * 0) + (1-blackout) * input + error_last = blackout * (error_last * 0) + (1-blackout) * error_last + blackout_mem.append(blackout.int().cpu().item()) + + elif t>=0: + num_burnin += 1 + + if (rollout_index >= 0) and (evaluation_mode != 'blackout'): + blackout_mem.append(0) + + # obtain prediction + ( + output_next, + position_next, + gestalt_next, + priority_next, + mask_next, + rawmask_next, + object_next, + background, + slots_occlusionfactor, + output_cur, + position_cur, + gestalt_cur, + priority_cur, + mask_cur, + rawmask_cur, + object_cur, + position_encoder_cur, + slots_bounded, + slots_partially_occluded_cur, + slots_occluded_cur, + slots_partially_occluded_next, + slots_occluded_next, + slots_closed, + output_hidden, + largest_object, + rawmask_hidden, + object_hidden + ) = net( + input, + error_last, + mask_last, + rawmask_last, + position_last, + gestalt_last, + priority_last, + background_fix, + slots_occlusionfactor, + reset = (t == -cfg.defaults.teacher_forcing), + evaluate=True, + warmup = (t < 0), + shuffleslots = False, + reset_mask = (t <= 0), + allow_spawn = True, + show_hidden = plotting_mode, + clean_slots = (t <= 0), + ) + + # 1. Track error for plots + if t >= 0: + + # save prediction + if rollout_index >= -1: + pred[:,rollout_index+1] = to_smallnorm(output_next) + gt[:,rollout_index+1] = to_smallnorm(target) + pred_mask[:,rollout_index+1] = postproc_mask(to_small(mask_next)[:,None,:,None])[:, 0] + num_rollout += 1 + + if plotting_mode and object_view: + statistics_complete_slots, statistics_batch = compute_plot_statistics(cfg_net, statistics_complete_slots, mseloss, set_test, evaluation_mode, i, statistics_batch, t, target, output_next, mask_next, slots_bounded, slots_closed, rawmask_hidden) + + # 2. Remember output + mask_last = mask_next.clone() + rawmask_last = rawmask_next.clone() + position_last = position_next.clone() + gestalt_last = gestalt_next.clone() + priority_last = priority_next.clone() + + # 3. Error for next frame + bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach() + + # prediction error + error_next = th.sqrt(reduce((target - output_next)**2, 'b c h w -> b 1 h w', 'mean')).detach() + error_next = th.sqrt(error_next) * bg_error_next + error_last = error_next.clone() + + # 4. plot preparation + if plotting_mode and (t % plot_frequency == 0): + plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, None, None, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i) + + # Compute prediction accuracy based on Slotformer metrics (ARI, FARI, mIoU, AP, AR) + if not plotting_mode: + for b in range(cfg_net.batch_size): + metric_dict = pred_eval_step( + gt=gt[b:b+1], + pred=pred[b:b+1], + pred_mask=pred_mask.long()[b:b+1], + pred_bbox=masks_to_boxes(pred_mask.long()[b:b+1], cfg_net.num_objects+1), + gt_mask=gt_mask.long()[b:b+1], + gt_pres_mask=gt_pres_mask[b:b+1], + gt_bbox=gt_bbox[b:b+1], + lpips_fn=lpipsloss, + eval_traj=True, + ) + metric_dict['blackout'] = blackout_mem + metric_complete = append_statistics(metric_dict, metric_complete) + + # sanity check + if (num_rollout != rollout_length) and (num_burnin != burn_in_length) and (evaluation_mode == 'rollout'): + raise ValueError('Number of rollout steps and burnin steps must be equal to the sequence length.') + + if not plotting_mode: + average_dic = compute_statistics_summary(metric_complete, evaluation_mode) + + # Store statistics + with open(os.path.join(f'{root_path}/statistics', f'{evaluation_mode}_metric_complete.pkl'), 'wb') as f: + pickle.dump(metric_complete, f) + with open(os.path.join(f'{root_path}/statistics', f'{evaluation_mode}_metric_average.pkl'), 'wb') as f: + pickle.dump(average_dic, f) + + print('-- Evaluation Done --') + if object_view and os.path.exists(f'{root_path}/tmp.jpg'): + os.remove(f'{root_path}/tmp.jpg') + pass + +def compute_statistics_summary(metric_complete, evaluation_mode): + average_dic = {} + for key in metric_complete: + # take average over all frames + average_dic[key + 'complete_average'] = np.mean(metric_complete[key]) + average_dic[key + 'complete_std'] = np.std(metric_complete[key]) + print(f'{key} complete average: {average_dic[key + "complete_average"]:.4f} +/- {average_dic[key + "complete_std"]:.4f}') + + if evaluation_mode == 'blackout': + # take average only for frames where blackout occurs + blackout_mask = np.array(metric_complete['blackout']) > 0 + average_dic[key + 'blackout_average'] = np.mean(np.array(metric_complete[key])[blackout_mask]) + average_dic[key + 'blackout_std'] = np.std(np.array(metric_complete[key])[blackout_mask]) + average_dic[key + 'visible_average'] = np.mean(np.array(metric_complete[key])[blackout_mask == False]) + average_dic[key + 'visible_std'] = np.std(np.array(metric_complete[key])[blackout_mask == False]) + + print(f'{key} blackout average: {average_dic[key + "blackout_average"]:.4f} +/- {average_dic[key + "blackout_std"]:.4f}') + print(f'{key} visible average: {average_dic[key + "visible_average"]:.4f} +/- {average_dic[key + "visible_std"]:.4f}') + return average_dic + +def compute_plot_statistics(cfg_net, statistics_complete_slots, mseloss, set_test, evaluation_mode, i, statistics_batch, t, target, output_next, mask_next, slots_bounded, slots_closed, rawmask_hidden): + statistics_batch = store_statistics(statistics_batch, + set_test['type'], + evaluation_mode, + set_test['samples'][i], + t, + mseloss(output_next, target).item() + ) + + # compute slot-wise prediction error + output_slot = repeat(mask_next[:,:-1], 'b o h w -> b o 3 h w') * repeat(output_next, 'b c h w -> b o c h w', o=cfg_net.num_objects) + target_slot = repeat(mask_next[:,:-1], 'b o h w -> b o 3 h w') * repeat(target, 'b c h w -> b o c h w', o=cfg_net.num_objects) + slot_error = reduce((output_slot - target_slot)**2, 'b o c h w -> b o', 'mean') + + # compute rawmask_size + rawmask_size = reduce(rawmask_hidden[:, :-1], 'b o h w-> b o', 'sum') + + statistics_complete_slots = store_statistics(statistics_complete_slots, + [set_test['type']] * cfg_net.num_objects, + [evaluation_mode] * cfg_net.num_objects, + [set_test['samples'][i]] * cfg_net.num_objects, + [t] * cfg_net.num_objects, + range(cfg_net.num_objects), + slots_bounded.cpu().numpy().flatten().astype(int), + slot_error.cpu().numpy().flatten(), + rawmask_size.cpu().numpy().flatten(), + slots_closed[:, :, 1].cpu().numpy().flatten(), + slots_closed[:, :, 0].cpu().numpy().flatten(), + extend = True) + + return statistics_complete_slots,statistics_batch + +def get_evaluation_sets(dataset): + + set = {"samples": np.arange(len(dataset), dtype=int), "type": "test"} + evaluation_modes = ['blackout', 'open'] # use 'open' for no blackouts + set_test_array = [set] + + return set_test_array, evaluation_modes diff --git a/scripts/exec/eval.py b/scripts/exec/eval.py new file mode 100644 index 0000000..96ed32d --- /dev/null +++ b/scripts/exec/eval.py @@ -0,0 +1,28 @@ +import argparse +import sys +from data.datasets.ADEPT.dataset import AdeptDataset +from data.datasets.CLEVRER.dataset import ClevrerDataset, ClevrerSample, RamImage +from scripts.utils.configuration import Configuration +from scripts import evaluation_adept, evaluation_clevrer + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-cfg", default="", help='path to the configuration file') + parser.add_argument("-load", default="", type=str, help='path to model') + parser.add_argument("-n", default="", type=str, help='results name') + + # Load configuration + args = parser.parse_args(sys.argv[1:]) + cfg = Configuration(args.cfg) + print(f'Evaluating model {args.load}') + + # Load dataset + if cfg.datatype == "clevrer": + testset = ClevrerDataset("./", cfg.dataset, 'val', (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=True, evaluation=True) + evaluation_clevrer.evaluate(cfg, testset, args.load, args.n, plot_frequency= 2, plot_first_samples = 3) # only plotting + evaluation_clevrer.evaluate(cfg, testset, args.load, args.n, plot_first_samples = 0) # evaluation + elif cfg.datatype == "adept": + testset = AdeptDataset("./", cfg.dataset, 'createdown', (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))) + evaluation_adept.evaluate(cfg, testset, args.load, args.n, plot_frequency= 1, plot_first_samples = 2) + else: + raise Exception("Dataset not supported")
\ No newline at end of file diff --git a/scripts/exec/eval_savi.py b/scripts/exec/eval_savi.py new file mode 100644 index 0000000..87d4e59 --- /dev/null +++ b/scripts/exec/eval_savi.py @@ -0,0 +1,17 @@ +import argparse +import sys +from data.datasets.ADEPT.dataset import AdeptDataset +from scripts.evaluation_adept_savi import evaluate + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-load", default="", type=str, help='path to savi slots') + parser.add_argument("-n", default="", type=str, help='results name') + + # Load configuration + args = parser.parse_args(sys.argv[1:]) + print(f'Evaluating savi slots {args.load}') + + # Load dataset + testset = AdeptDataset("./", 'adept', 'createdown', (30 * 2**(2*2), 20 * 2**(2*2))) + evaluate(testset, args.load, args.n, plot_frequency= 1, plot_first_samples = 2)
\ No newline at end of file diff --git a/scripts/exec/train.py b/scripts/exec/train.py new file mode 100644 index 0000000..f6e78fd --- /dev/null +++ b/scripts/exec/train.py @@ -0,0 +1,33 @@ +import argparse +import sys +from scripts.utils.configuration import Configuration +from scripts import training +from data.datasets.CLEVRER.dataset import ClevrerDataset, ClevrerSample, RamImage +from data.datasets.ADEPT.dataset import AdeptDataset + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-cfg", default="", help='path to the configuration file') + parser.add_argument("-n", default=-1, type=int, help='optional run number') + parser.add_argument("-load", default="", type=str, help='path to pretrained model or checkpoint') + + # Load configuration + args = parser.parse_args(sys.argv[1:]) + cfg = Configuration(args.cfg) + cfg.model_path = f"{cfg.model_path}" + if args.n >= 0: + cfg.model_path = f"{cfg.model_path}.run{args.n}" + print(f'Training model {cfg.model_path}') + + # Load dataset + if cfg.datatype == "clevrer": + trainset = ClevrerDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=False) + valset = ClevrerDataset("./", cfg.dataset, "val", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=True) + elif cfg.datatype == "adept": + trainset = AdeptDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))) + valset = AdeptDataset("./", cfg.dataset, "test", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))) + else: + raise Exception("Dataset not supported") + + # Final call + training.train_loci(cfg, trainset, valset, args.load)
\ No newline at end of file diff --git a/scripts/training.py b/scripts/training.py new file mode 100644 index 0000000..b807e54 --- /dev/null +++ b/scripts/training.py @@ -0,0 +1,523 @@ +import torch as th +from torch import nn +from torch.utils.data import Dataset, DataLoader, Subset +import cv2 +import numpy as np +import os +from einops import rearrange, repeat, reduce +from scripts.utils.configuration import Configuration +from scripts.utils.io import init_device, model_path, LossLogger +from scripts.utils.optimizers import RAdam +from model.loci import Loci +import random +from scripts.utils.io import Timer +from scripts.utils.plot_utils import color_mask +from scripts.validation import validation_clevrer, validation_adept + + +def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): + + # Set up cpu or gpu training + device, verbose = init_device(cfg) + if verbose: + valset = Subset(valset, range(0, 8)) + + # Define model path + path = model_path(cfg, overwrite=False) + cfg.save(path) + os.makedirs(os.path.join(path, 'nets'), exist_ok=True) + + # Configure model + cfg_net = cfg.model + net = Loci( + cfg = cfg_net, + teacher_forcing = cfg.defaults.teacher_forcing + ) + net = net.to(device=device) + + # Log model size + log_modelsize(net) + + # Init Optimizers + optimizer_init, optimizer_encoder, optimizer_decoder, optimizer_predictor, optimizer_background, optimizer_update = init_optimizer(cfg, net) + + # Option to load model + if file != "": + load_model( + file, + cfg, + net, + optimizer_init, + optimizer_encoder, + optimizer_decoder, + optimizer_predictor, + optimizer_background, + cfg.defaults.load_optimizers, + only_encoder_decoder = (cfg.num_updates == 0) # only load encoder and decoder for initial training + ) + print(f'loaded {file}', flush=True) + + # Set up data loaders + trainloader = get_loader(cfg, trainset, cfg_net, shuffle=True) + valset.train = True #valset.dataset.train = True + valloader = get_loader(cfg, valset, cfg_net, shuffle=False) + + # initial save + save_model( + os.path.join(path, 'nets', 'net0.pt'), + net, + optimizer_init, + optimizer_encoder, + optimizer_decoder, + optimizer_predictor, + optimizer_background + ) + + # Start training at num_updates + num_updates = cfg.num_updates + if num_updates > 0: + print('!!! Start training at num_updates: ', num_updates) + print('!!! Net init status: ', net.get_init_status()) + + # Set up statistics + loss_tracker = LossLogger() + + # Set up training variables + num_time_steps = 0 + bptt_steps = cfg.bptt.bptt_steps + increase_bptt_steps = False + background_blendin_factor = 0.0 + th.backends.cudnn.benchmark = True + timer = Timer() + bceloss = nn.BCELoss() + + # Init net to current num_updates + if num_updates >= cfg.phases.background_pretraining_end and net.get_init_status() < 1: + net.inc_init_level() + + if num_updates >= cfg.phases.entity_pretraining_phase1_end and net.get_init_status() < 2: + net.inc_init_level() + + if num_updates >= cfg.phases.entity_pretraining_phase2_end and net.get_init_status() < 3: + net.inc_init_level() + for param in optimizer_init.param_groups: + param['lr'] = cfg.learning_rate + + if num_updates > cfg.phases.start_inner_loop: + net.cfg.inner_loop_enabled = True + + if num_updates >= cfg.phases.entity_pretraining_phase1_end: + background_blendin_factor = max(min((num_updates - cfg.phases.entity_pretraining_phase1_end)/30000, 1.0), 0.0) + + + # --- Start Training + print('Start training') + for epoch in range(cfg.max_epochs): + + # Validation every epoch + if epoch >= 0: + if cfg.datatype == 'adept': + validation_adept(valloader, net, cfg, device) + elif cfg.datatype == 'clevrer': + validation_clevrer(valloader, net, cfg, device) + + # Start epoch training + print('Start epoch:', epoch) + + # Backprop through time steps + if increase_bptt_steps: + bptt_steps = max(bptt_steps + 1, cfg.bptt.bptt_steps_max) + print('Increase closed loop steps to', bptt_steps) + increase_bptt_steps = False + + for batch_index, input in enumerate(trainloader): + + # Extract input and background + tensor = input[0] + background = input[1].to(device) + shuffleslots = (num_updates <= cfg.phases.shufleslots_end) + + # Placeholders + position = None + gestalt = None + priority = None + mask = None + object = None + rawmask = None + loss = th.tensor(0) + summed_loss = None + slots_occlusionfactor = None + + # Apply skip frames to sequence + selec = range(random.randrange(cfg.defaults.skip_frames), tensor.shape[1], cfg.defaults.skip_frames) + tensor = tensor[:,selec] + sequence_len = tensor.shape[1] + + # Initial frame + input = tensor[:,0].to(device) + input_next = input + target = th.clip(input, 0, 1).detach() + error_last = None + + # First apply teacher forcing for the first x frames + for t in range(-cfg.defaults.teacher_forcing, sequence_len-1): + + # Set update scheme for backprop through time + if t >= cfg.bptt.bptt_start_timestep: + t_run = (t - cfg.bptt.bptt_start_timestep) + run_optimizers = t_run % bptt_steps == bptt_steps - 1 + detach = (t_run % bptt_steps == 0) or t == -cfg.defaults.teacher_forcing + else: + run_optimizers = True + detach = True + + if verbose: + print(f't: {t}, run_optimizers: {run_optimizers}, detach: {detach}') + + if t >= 0: + # Skip to next frame + num_time_steps += 1 + input = input_next + input_next = tensor[:,t+1].to(device) + target = th.clip(input_next, 0, 1) + + # Apply error dropout + if net.get_init_status() > 2 and cfg.defaults.error_dropout > 0 and np.random.rand() < cfg.defaults.error_dropout: + error_last = th.zeros_like(error_last) + + # Apply sensation blackout when training clevrere + if net.cfg.inner_loop_enabled and cfg.datatype == 'clevrer': + if t >= 10: + blackout = th.tensor((np.random.rand(cfg_net.batch_size) < 0.2)[:,None,None,None]).float().to(device) + input = blackout * (input * 0) + (1-blackout) * input + error_last = blackout * (error_last * 0) + (1-blackout) * error_last + + # Forward Pass + ( + output_next, + output_cur, + position, + gestalt, + priority, + mask, + rawmask, + object, + background, + slots_occlusionfactor, + position_loss, + time_loss, + slots_closed + ) = net( + input, # current frame + error_last, # error of last frame --> missing object + mask, # masks of current frame + rawmask, # raw masks of current frame + position, # positions of objects of next frame + gestalt, # gestalt of objects of next frame + priority, # priority of objects of next frame + background, + slots_occlusionfactor, + reset = (t == -cfg.defaults.teacher_forcing), # new sequence + warmup = (t < 0), # teacher forcing + detach = detach, + shuffleslots = shuffleslots or ((cfg.datatype == 'clevrer') and (t<=0)), + reset_mask = (t <= 0), + clean_slots = (t <= 0 and not shuffleslots), + ) + + # Loss weighting + position_loss = position_loss * cfg_net.position_regularizer + time_loss = time_loss * cfg_net.time_regularizer + + # Compute background error + bg_error_cur = th.sqrt(reduce((input - background)**2, 'b c h w -> b 1 h w', 'mean')).detach() + bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach() + + # Compute next-frame prediction error + error_next = th.sqrt(reduce((target - output_next)**2, 'b c h w -> b 1 h w', 'mean')).detach() + error_next = th.sqrt(error_next) * bg_error_next + error_last = error_next + + # Initially focus on foreground learning + if background_blendin_factor < 1: + fg_mask_next = th.gt(bg_error_next, 0.1).float().detach() + fg_mask_next[fg_mask_next == 0] = background_blendin_factor + target = th.clip(target * fg_mask_next, 0, 1) + + fg_mask_cur = th.gt(bg_error_cur, 0.1).float().detach() + fg_mask_cur[fg_mask_cur == 0] = background_blendin_factor + input = th.clip(input * fg_mask_cur, 0, 1) + + # Gradually blend in background for more stable training + if num_updates % 30 == 0 and num_updates >= cfg.phases.entity_pretraining_phase1_end: + background_blendin_factor = min(1, background_blendin_factor + 0.001) + + # Final Loss computation + encoder_loss = th.mean((output_cur - input)**2) * cfg_net.encoder_regularizer + cliped_output_next = th.clip(output_next, 0, 1) + loss = bceloss(cliped_output_next, target) + encoder_loss + position_loss + time_loss + + # Accumulate loss over BPP steps + summed_loss = loss if summed_loss is None else summed_loss + loss + mask = mask.detach() + + if run_optimizers: + + # detach gradients for next step + position = position.detach() + gestalt = gestalt.detach() + rawmask = rawmask.detach() + object = object.detach() + priority = priority.detach() + + # zero grad + optimizer_init.zero_grad() + optimizer_encoder.zero_grad() + optimizer_decoder.zero_grad() + optimizer_predictor.zero_grad() + optimizer_background.zero_grad() + if net.cfg.inner_loop_enabled: + optimizer_update.zero_grad() + + # optimize + summed_loss.backward() + optimizer_init.step() + optimizer_encoder.step() + optimizer_decoder.step() + optimizer_predictor.step() + optimizer_background.step() + if net.cfg.inner_loop_enabled: + optimizer_update.step() + + # Reset loss + num_updates += 1 + summed_loss = None + + # Update net status + update_net_status(num_updates, net, cfg, optimizer_init) + + if num_updates == cfg.phases.start_inner_loop: + print('Start inner loop') + net.cfg.inner_loop_enabled = True + + if (cfg.bptt.increase_bptt_steps_every > 0) and ((num_updates-cfg.num_updates) % cfg.bptt.increase_bptt_steps_every == 0) and ((num_updates-cfg.num_updates) > 0): + increase_bptt_steps = True + + # Plots for online evaluation + if num_updates % 20000 == 0: + plot_online(cfg, path, num_updates, input, background, mask, sequence_len, t, output_next, bg_error_next) + + + # Track statisitcs + if t >= cfg.defaults.statistics_offset: + track_statistics(cfg_net, net, loss_tracker, input, gestalt, mask, target, output_next, output_cur, position_loss, time_loss, slots_closed, bg_error_cur, bg_error_next) + loss_tracker.update_average_loss(loss.item()) + + # Logging + if num_updates % 100 == 0 and run_optimizers: + print(f'Epoch[{num_updates}/{num_time_steps}/{sequence_len}]: {str(timer)}, {epoch + 1}, Blendin:{float(background_blendin_factor)}, i: {net.get_init_status() + net.initial_states.init.get():.2f},' + loss_tracker.get_log(), flush=True) + + # Training finished + if num_updates > cfg.max_updates: + save_model( + os.path.join(path, 'nets', 'net_final.pt'), + net, + optimizer_init, + optimizer_encoder, + optimizer_decoder, + optimizer_predictor, + optimizer_background + ) + print("Training finished") + return + + # Checkpointing + if num_updates % 50000 == 0 and run_optimizers: + save_model( + os.path.join(path, 'nets', f'net_{num_updates}.pt'), + net, + optimizer_init, + optimizer_encoder, + optimizer_decoder, + optimizer_predictor, + optimizer_background + ) + pass + +def track_statistics(cfg_net, net, loss_tracker, input, gestalt, mask, target, output_next, output_cur, position_loss, time_loss, slots_closed, bg_error_cur, bg_error_next): + + # Prediction Loss + mseloss = nn.MSELoss() + loss_next = mseloss(output_next * bg_error_next, target * bg_error_next) + + # Encoder Loss (only for stats) + loss_cur = mseloss(output_cur * bg_error_cur, input * bg_error_cur) + + # area of foreground mask + num_objects = th.mean(reduce((reduce(mask[:,:-1], 'b c h w -> b c', 'max') > 0.5).float(), 'b c -> b', 'sum')).item() + + # difference in shape + _gestalt = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt)), 'b (o c) -> (b o)', 'mean', o = cfg_net.num_objects) + _gestalt2 = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt))**2, 'b (o c) -> (b o)', 'mean', o = cfg_net.num_objects) + max_mask = (reduce(mask[:,:-1], 'b c h w -> (b c)', 'max') > 0.5).float() + avg_gestalt = (th.sum(_gestalt * max_mask) / (1e-16 + th.sum(max_mask))) + avg_gestalt2 = (th.sum(_gestalt2 * max_mask) / (1e-16 + th.sum(max_mask))) + avg_gestalt_mean = th.mean(th.clip(gestalt, 0, 1)) + + # udpdate gates + avg_update_gestalt = slots_closed[:,:,0].mean() + avg_update_position = slots_closed[:,:,1].mean() + + loss_tracker.update_complete(position_loss, time_loss, loss_cur, loss_next, loss_next, num_objects, net.get_openings(), avg_gestalt, avg_gestalt2, avg_gestalt_mean, avg_update_gestalt, avg_update_position) + pass + +def log_modelsize(net): + print(f'Loaded model with {sum([param.numel() for param in net.parameters()]):7d} parameters', flush=True) + print(f' States: {sum([param.numel() for param in net.initial_states.parameters()]):7d} parameters', flush=True) + print(f' Encoder: {sum([param.numel() for param in net.encoder.parameters()]):7d} parameters', flush=True) + print(f' Decoder: {sum([param.numel() for param in net.decoder.parameters()]):7d} parameters', flush=True) + print(f' predictor: {sum([param.numel() for param in net.predictor.parameters()]):7d} parameters', flush=True) + print(f' background: {sum([param.numel() for param in net.background.parameters()]):7d} parameters', flush=True) + print("\n") + pass + +def init_optimizer(cfg, net): + optimizer_init = RAdam(net.initial_states.parameters(), lr = cfg.learning_rate * 30) + optimizer_encoder = RAdam(net.encoder.parameters(), lr = cfg.learning_rate) + optimizer_decoder = RAdam(net.decoder.parameters(), lr = cfg.learning_rate) + optimizer_predictor = RAdam(net.predictor.parameters(), lr = cfg.learning_rate) + optimizer_background = RAdam([net.background.mask], lr = cfg.learning_rate) + optimizer_update = RAdam(net.percept_gate_controller.parameters(), lr = cfg.learning_rate) + return optimizer_init,optimizer_encoder,optimizer_decoder,optimizer_predictor,optimizer_background,optimizer_update + +def save_model( + file, + net, + optimizer_init, + optimizer_encoder, + optimizer_decoder, + optimizer_predictor, + optimizer_background +): + + state = { } + + state['optimizer_init'] = optimizer_init.state_dict() + state['optimizer_encoder'] = optimizer_encoder.state_dict() + state['optimizer_decoder'] = optimizer_decoder.state_dict() + state['optimizer_predictor'] = optimizer_predictor.state_dict() + state['optimizer_background'] = optimizer_background.state_dict() + + state["model"] = net.state_dict() + th.save(state, file) + pass + +def load_model( + file, + cfg, + net, + optimizer_init, + optimizer_encoder, + optimizer_decoder, + optimizer_predictor, + optimizer_background, + load_optimizers = True, + only_encoder_decoder = False +): + device = th.device(cfg.device) + state = th.load(file, map_location=device) + print(f"load {file} to device {device}, only encoder/decoder: {only_encoder_decoder}") + print(f"load optimizers: {load_optimizers}") + + if load_optimizers: + optimizer_init.load_state_dict(state[f'optimizer_init']) + for n in range(len(optimizer_init.param_groups)): + optimizer_init.param_groups[n]['lr'] = cfg.learning_rate + + optimizer_encoder.load_state_dict(state[f'optimizer_encoder']) + for n in range(len(optimizer_encoder.param_groups)): + optimizer_encoder.param_groups[n]['lr'] = cfg.learning_rate + + optimizer_decoder.load_state_dict(state[f'optimizer_decoder']) + for n in range(len(optimizer_decoder.param_groups)): + optimizer_decoder.param_groups[n]['lr'] = cfg.learning_rate + + optimizer_predictor.load_state_dict(state['optimizer_predictor']) + for n in range(len(optimizer_predictor.param_groups)): + optimizer_predictor.param_groups[n]['lr'] = cfg.learning_rate + + optimizer_background.load_state_dict(state['optimizer_background']) + for n in range(len(optimizer_background.param_groups)): + optimizer_background.param_groups[n]['lr'] = cfg.model.background.learning_rate + + # 1. Fill model with values of net + model = {} + allowed_keys = [] + rand_state = net.state_dict() + for key, value in rand_state.items(): + allowed_keys.append(key) + model[key.replace(".module.", ".")] = value + + # 2. Overwrite with values from file + for key, value in state["model"].items(): + # replace update_module with percept_gate_controller in key string: + key = key.replace("update_module", "percept_gate_controller") + + if key in allowed_keys: + if only_encoder_decoder: + if ('encoder' in key) or ('decoder' in key): + model[key.replace(".module.", ".")] = value + else: + model[key.replace(".module.", ".")] = value + + net.load_state_dict(model) + + pass + +def update_net_status(num_updates, net, cfg, optimizer_init): + if num_updates == cfg.phases.background_pretraining_end and net.get_init_status() < 1: + net.inc_init_level() + + if num_updates == cfg.phases.entity_pretraining_phase1_end and net.get_init_status() < 2: + net.inc_init_level() + + if num_updates == cfg.phases.entity_pretraining_phase2_end and net.get_init_status() < 3: + net.inc_init_level() + for param in optimizer_init.param_groups: + param['lr'] = cfg.learning_rate + + pass + +def plot_online(cfg, path, num_updates, input, background, mask, sequence_len, t, output_next, bg_error_next): + plot_path = os.path.join(path, 'plots', f'net_{num_updates}') + if not os.path.exists(plot_path): + os.makedirs(plot_path, exist_ok=True) + + # highlight error + grayscale = input[:,0:1] * 0.299 + input[:,1:2] * 0.587 + input[:,2:3] * 0.114 + object_mask_cur = th.sum(mask[:,:-1], dim=1).unsqueeze(dim=1) + highlited_input = grayscale * (1 - object_mask_cur) + highlited_input += grayscale * object_mask_cur * 0.3333333 + cmask = color_mask(mask[:,:-1]) + highlited_input = highlited_input + cmask * 0.6666666 + + cv2.imwrite(os.path.join(plot_path, f'input-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(input[0], 'c h w -> h w c').detach().cpu().numpy() * 255) + cv2.imwrite(os.path.join(plot_path, f'background-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(background[0], 'c h w -> h w c').detach().cpu().numpy() * 255) + cv2.imwrite(os.path.join(plot_path, f'error_mask-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(bg_error_next[0], 'c h w -> h w c').detach().cpu().numpy() * 255) + cv2.imwrite(os.path.join(plot_path, f'background_mask-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(mask[0,-1:], 'c h w -> h w c').detach().cpu().numpy() * 255) + cv2.imwrite(os.path.join(plot_path, f'output_next-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(output_next[0], 'c h w -> h w c').detach().cpu().numpy() * 255) + cv2.imwrite(os.path.join(plot_path, f'output_highlight-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(highlited_input[0], 'c h w -> h w c').detach().cpu().numpy() * 255) + pass + +def get_loader(cfg, set, cfg_net, shuffle = True): + loader = DataLoader( + set, + pin_memory = True, + num_workers = cfg.defaults.num_workers, + batch_size = cfg_net.batch_size, + shuffle = shuffle, + drop_last = True, + prefetch_factor = cfg.defaults.prefetch_factor, + persistent_workers = True + ) + return loader
\ No newline at end of file diff --git a/scripts/utils/configuration.py b/scripts/utils/configuration.py new file mode 100644 index 0000000..7c38095 --- /dev/null +++ b/scripts/utils/configuration.py @@ -0,0 +1,83 @@ +import json +import jsmin +import os + +class Dict(dict): + """ + Dictionary that allows to access per attributes and to except names from being loaded + """ + def __init__(self, dictionary: dict = None): + super(Dict, self).__init__() + + if dictionary is not None: + self.load(dictionary) + + def __getattr__(self, item): + try: + return self[item] if item in self else getattr(super(Dict, self), item) + except AttributeError: + raise AttributeError(f'This dictionary has no attribute "{item}"') + + def load(self, dictionary: dict, name_list: list = None): + """ + Loads a dictionary + :param dictionary: Dictionary to be loaded + :param name_list: List of names to be updated + """ + for name in dictionary: + data = dictionary[name] + if name_list is None or name in name_list: + if isinstance(data, dict): + if name in self: + self[name].load(data) + else: + self[name] = Dict(data) + elif isinstance(data, list): + self[name] = list() + for item in data: + if isinstance(item, dict): + self[name].append(Dict(item)) + else: + self[name].append(item) + else: + self[name] = data + + def save(self, path): + """ + Saves the dictionary into a json file + :param path: Path of the json file + """ + os.makedirs(path, exist_ok=True) + + path = os.path.join(path, 'cfg.json') + + with open(path, 'w') as file: + json.dump(self, file, indent=True) + + +class Configuration(Dict): + """ + Configuration loaded from a json file + """ + def __init__(self, path: str, default_path=None): + super(Configuration, self).__init__() + + if default_path is not None: + self.load(default_path) + + self.load(path) + + def load_model(self, path: str): + self.load(path, name_list=["model"]) + + def load(self, path: str, name_list: list = None): + """ + Loads attributes from a json file + :param path: Path of the json file + :param name_list: List of names to be updated + :return: + """ + with open(path) as file: + data = json.loads(jsmin.jsmin(file.read())) + + super(Configuration, self).load(data, name_list) diff --git a/scripts/utils/eval_metrics.py b/scripts/utils/eval_metrics.py new file mode 100644 index 0000000..6f5106d --- /dev/null +++ b/scripts/utils/eval_metrics.py @@ -0,0 +1,350 @@ +''' +vp_utils.py +SCRIPT TAKEN FROM https://github.com/pairlab/SlotFormer +''' + +import numpy as np +from scipy.optimize import linear_sum_assignment +from skimage.metrics import structural_similarity, peak_signal_noise_ratio + +import torch +import torch.nn.functional as F +import torchvision.ops as vops + +FG_THRE = 0.5 +PALETTE = [(0, 255, 0), (0, 0, 255), (0, 255, 255), (255, 255, 0), + (255, 0, 255), (100, 100, 255), (200, 200, 100), (170, 120, 200), + (255, 0, 0), (200, 100, 100), (10, 200, 100), (200, 200, 200), + (50, 50, 50)] +PALETTE_np = np.array(PALETTE, dtype=np.uint8) +PALETTE_torch = torch.from_numpy(PALETTE_np).float() / 255. * 2. - 1. + +def to_rgb_from_tensor(x): + """Reverse the Normalize operation in torchvision.""" + return (x * 0.5 + 0.5).clamp(0, 1) + +def postproc_mask(batch_masks): + """Post-process masks instead of directly taking argmax. + + Args: + batch_masks: [B, T, N, 1, H, W] + + Returns: + masks: [B, T, H, W] + """ + batch_masks = batch_masks.clone() + B, T, N, _, H, W = batch_masks.shape + batch_masks = batch_masks.reshape(B * T, N, H * W) + slots_max = batch_masks.max(-1)[0] # [B*T, N] + bg_idx = slots_max.argmin(-1) # [B*T] + spatial_max = batch_masks.max(1)[0] # [B*T, H*W] + bg_mask = (spatial_max < FG_THRE) # [B*T, H*W] + idx_mask = torch.zeros((B * T, N)).type_as(bg_mask) + idx_mask[torch.arange(B * T), bg_idx] = True + # set the background mask score to 1 + batch_masks[idx_mask.unsqueeze(-1) * bg_mask.unsqueeze(1)] = 1. + masks = batch_masks.argmax(1) # [B*T, H*W] + return masks.reshape(B, T, H, W) + + +def masks_to_boxes_w_empty_mask(binary_masks): + """binary_masks: [B, H, W].""" + B = binary_masks.shape[0] + obj_mask = (binary_masks.sum([-1, -2]) > 0) # [B] + bboxes = torch.ones((B, 4)).float().to(binary_masks.device) * -1. + bboxes[obj_mask] = vops.masks_to_boxes(binary_masks[obj_mask]) + return bboxes + + +def masks_to_boxes(masks, num_boxes=7): + """Convert seg_masks to bboxes. + + Args: + masks: [B, T, H, W], output after taking argmax + num_boxes: number of boxes to generate (num_slots) + + Returns: + bboxes: [B, T, N, 4], 4: [x1, y1, x2, y2] + """ + B, T, H, W = masks.shape + binary_masks = F.one_hot(masks, num_classes=num_boxes) # [B, T, H, W, N] + binary_masks = binary_masks.permute(0, 1, 4, 2, 3) # [B, T, N, H, W] + binary_masks = binary_masks.contiguous().flatten(0, 2) + bboxes = masks_to_boxes_w_empty_mask(binary_masks) # [B*T*N, 4] + bboxes = bboxes.reshape(B, T, num_boxes, 4) # [B, T, N, 4] + return bboxes + + +def mse_metric(x, y): + """x/y: [B, C, H, W]""" + # people often sum over spatial dimensions in video prediction MSE + # see e.g. https://github.com/Yunbo426/predrnn-pp/blob/40764c52f433290aa02414b5f25c09c72b98e0af/train.py#L244 + return ((x - y)**2).sum(-1).sum(-1).mean() + + +def psnr_metric(x, y): + """x/y: [B, C, H, W]""" + psnrs = [ + peak_signal_noise_ratio( + x[i], + y[i], + data_range=1., + ) for i in range(x.shape[0]) + ] + return np.mean(psnrs) + + +def ssim_metric(x, y): + """x/y: [B, C, H, W]""" + x = x * 255. + y = y * 255. + ssims = [ + structural_similarity( + x[i], + y[i], + channel_axis=0, + gaussian_weights=True, + sigma=1.5, + use_sample_covariance=False, + data_range=255, + ) for i in range(x.shape[0]) + ] + return np.mean(ssims) + + +def perceptual_dist(x, y, loss_fn): + """x/y: [B, C, H, W]""" + return loss_fn(x, y).mean() + + +def adjusted_rand_index(true_ids, pred_ids, ignore_background=False): + """Computes the adjusted Rand index (ARI), a clustering similarity score. + + Code borrowed from https://github.com/google-research/slot-attention-video/blob/e8ab54620d0f1934b332ddc09f1dba7bc07ff601/savi/lib/metrics.py#L111 + + Args: + true_ids: An integer-valued array of shape + [batch_size, seq_len, H, W]. The true cluster assignment encoded + as integer ids. + pred_ids: An integer-valued array of shape + [batch_size, seq_len, H, W]. The predicted cluster assignment + encoded as integer ids. + ignore_background: Boolean, if True, then ignore all pixels where + true_ids == 0 (default: False). + + Returns: + ARI scores as a float32 array of shape [batch_size]. + """ + if len(true_ids.shape) == 3: + true_ids = true_ids.unsqueeze(1) + if len(pred_ids.shape) == 3: + pred_ids = pred_ids.unsqueeze(1) + + true_oh = F.one_hot(true_ids).float() + pred_oh = F.one_hot(pred_ids).float() + + if ignore_background: + true_oh = true_oh[..., 1:] # Remove the background row. + + N = torch.einsum("bthwc,bthwk->bck", true_oh, pred_oh) + A = torch.sum(N, dim=-1) # row-sum (batch_size, c) + B = torch.sum(N, dim=-2) # col-sum (batch_size, k) + num_points = torch.sum(A, dim=1) + + rindex = torch.sum(N * (N - 1), dim=[1, 2]) + aindex = torch.sum(A * (A - 1), dim=1) + bindex = torch.sum(B * (B - 1), dim=1) + expected_rindex = aindex * bindex / torch.clamp( + num_points * (num_points - 1), min=1) + max_rindex = (aindex + bindex) / 2 + denominator = max_rindex - expected_rindex + ari = (rindex - expected_rindex) / denominator + + # There are two cases for which the denominator can be zero: + # 1. If both label_pred and label_true assign all pixels to a single cluster. + # (max_rindex == expected_rindex == rindex == num_points * (num_points-1)) + # 2. If both label_pred and label_true assign max 1 point to each cluster. + # (max_rindex == expected_rindex == rindex == 0) + # In both cases, we want the ARI score to be 1.0: + return torch.where(denominator != 0, ari, torch.tensor(1.).type_as(ari)) + + +def ARI_metric(x, y): + """x/y: [B, H, W], both are seg_masks after argmax.""" + assert 'int' in str(x.dtype) + assert 'int' in str(y.dtype) + return adjusted_rand_index(x, y).mean().item() + + +def fARI_metric(x, y): + """x/y: [B, H, W], both are seg_masks after argmax.""" + assert 'int' in str(x.dtype) + assert 'int' in str(y.dtype) + return adjusted_rand_index(x, y, ignore_background=True).mean().item() + + +def bbox_precision_recall(gt_pres_mask, gt_bbox, pred_bbox, ovthresh=0.5): + """Compute the precision of predicted bounding boxes. + + Args: + gt_pres_mask: A boolean tensor of shape [N] + gt_bbox: A tensor of shape [N, 4] + pred_bbox: A tensor of shape [M, 4] + """ + gt_bbox, pred_bbox = gt_bbox.clone(), pred_bbox.clone() + gt_bbox = gt_bbox[gt_pres_mask.bool()] + pred_bbox = pred_bbox[pred_bbox[:, 0] >= 0.] + N, M = gt_bbox.shape[0], pred_bbox.shape[0] + assert gt_bbox.shape[1] == pred_bbox.shape[1] == 4 + # assert M >= N + tp, fp = 0, 0 + bbox_used = [False] * pred_bbox.shape[0] + bbox_ious = vops.box_iou(gt_bbox, pred_bbox) # [N, M] + + # Find the best iou match for each ground truth bbox. + for i in range(N): + best_iou_idx = bbox_ious[i].argmax().item() + best_iou = bbox_ious[i, best_iou_idx].item() + if best_iou >= ovthresh and not bbox_used[best_iou_idx]: + tp += 1 + bbox_used[best_iou_idx] = True + else: + fp += 1 + + # compute precision and recall + precision = tp / float(M) + recall = tp / float(N) + return precision, recall + + +def batch_bbox_precision_recall(gt_pres_mask, gt_bbox, pred_bbox): + """Compute the precision of predicted bounding boxes over batch.""" + aps, ars = [], [] + for i in range(gt_pres_mask.shape[0]): + ap, ar = bbox_precision_recall(gt_pres_mask[i], gt_bbox[i], + pred_bbox[i]) + aps.append(ap) + ars.append(ar) + return np.mean(aps), np.mean(ars) + + +def hungarian_miou(gt_mask, pred_mask): + """both mask: [H*W] after argmax, 0 is gt background index.""" + true_oh = F.one_hot(gt_mask).float()[..., 1:] # only foreground, [HW, N] + pred_oh = F.one_hot(pred_mask).float() # [HW, M] + N, M = true_oh.shape[-1], pred_oh.shape[-1] + # compute all pairwise IoU + intersect = (true_oh[:, :, None] * pred_oh[:, None, :]).sum(0) # [N, M] + union = true_oh.sum(0)[:, None] + pred_oh.sum(0)[None, :] # [N, M] + iou = intersect / (union + 1e-8) # [N, M] + iou = iou.detach().cpu().numpy() + # find the best match for each gt + row_ind, col_ind = linear_sum_assignment(iou, maximize=True) + # there are two possibilities here + # 1. M >= N, just take the best match mean + # 2. M < N, some objects are not detected, their iou is 0 + if M >= N: + assert (row_ind == np.arange(N)).all() + return iou[row_ind, col_ind].mean() + return iou[row_ind, col_ind].sum() / float(N) + + +def miou_metric(gt_mask, pred_mask): + """both mask: [B, H, W], both are seg_masks after argmax.""" + assert 'int' in str(gt_mask.dtype) + assert 'int' in str(pred_mask.dtype) + gt_mask, pred_mask = gt_mask.flatten(1, 2), pred_mask.flatten(1, 2) + ious = [ + hungarian_miou(gt_mask[i], pred_mask[i]) + for i in range(gt_mask.shape[0]) + ] + return np.mean(ious) + + +@torch.no_grad() +def pred_eval_step( + gt, + pred, + lpips_fn, + gt_mask=None, + pred_mask=None, + gt_pres_mask=None, + gt_bbox=None, + pred_bbox=None, + eval_traj=True, +): + """Both of shape [B, T, C, H, W], torch.Tensor. + masks of shape [B, T, H, W]. + pres_mask of shape [B, T, N]. + bboxes of shape [B, T, N/M, 4]. + + eval_traj: whether to evaluate the trajectory (measured by bbox and mask). + + Compute metrics for every timestep. + """ + assert len(gt.shape) == len(pred.shape) == 5 + assert gt.shape == pred.shape + assert gt.shape[2] == 3 + if eval_traj: + assert len(gt_mask.shape) == len(pred_mask.shape) == 4 + assert gt_mask.shape == pred_mask.shape + if eval_traj: + assert len(gt_pres_mask.shape) == 3 + assert len(gt_bbox.shape) == len(pred_bbox.shape) == 4 + T = gt.shape[1] + + # compute perceptual dist & mask metrics before converting to numpy + all_percept_dist, all_ari, all_fari, all_miou = [], [], [], [] + for t in range(T): + one_gt, one_pred = gt[:, t], pred[:, t] + percept_dist = perceptual_dist(one_gt, one_pred, lpips_fn).item() + all_percept_dist.append(percept_dist) + if eval_traj: + one_gt_mask, one_pred_mask = gt_mask[:, t], pred_mask[:, t] + ari = ARI_metric(one_gt_mask, one_pred_mask) + fari = fARI_metric(one_gt_mask, one_pred_mask) + miou = miou_metric(one_gt_mask, one_pred_mask) + all_ari.append(ari) + all_fari.append(fari) + all_miou.append(miou) + else: + all_ari.append(0.) + all_fari.append(0.) + all_miou.append(0.) + + # compute bbox metrics + all_ap, all_ar = [], [] + for t in range(T): + if not eval_traj: + all_ap.append(0.) + all_ar.append(0.) + continue + one_gt_pres_mask, one_gt_bbox, one_pred_bbox = \ + gt_pres_mask[:, t], gt_bbox[:, t], pred_bbox[:, t] + ap, ar = batch_bbox_precision_recall(one_gt_pres_mask, one_gt_bbox, + one_pred_bbox) + all_ap.append(ap) + all_ar.append(ar) + + gt = to_rgb_from_tensor(gt).cpu().numpy() + pred = to_rgb_from_tensor(pred).cpu().numpy() + all_mse, all_ssim, all_psnr = [], [], [] + for t in range(T): + one_gt, one_pred = gt[:, t], pred[:, t] + mse = mse_metric(one_gt, one_pred) + psnr = psnr_metric(one_gt, one_pred) + ssim = ssim_metric(one_gt, one_pred) + all_mse.append(mse) + all_ssim.append(ssim) + all_psnr.append(psnr) + return { + 'mse': all_mse, + 'ssim': all_ssim, + 'psnr': all_psnr, + 'percept_dist': all_percept_dist, + 'ari': all_ari, + 'fari': all_fari, + 'miou': all_miou, + 'ap': all_ap, + 'ar': all_ar, + } diff --git a/scripts/utils/eval_utils.py b/scripts/utils/eval_utils.py new file mode 100644 index 0000000..faab7ec --- /dev/null +++ b/scripts/utils/eval_utils.py @@ -0,0 +1,78 @@ +import os +import torch as th +from model.loci import Loci + +def setup_result_folders(file, name, set_test, evaluation_mode, object_view, individual_views): + + net_name = file.split('/')[-1].split('.')[0] + #root_path = file.split('nets')[0] + root_path = os.path.join(*file.split('/')[0:-1]) + root_path = os.path.join(root_path, f'results{name}', net_name, set_test['type']) + plot_path = os.path.join(root_path, evaluation_mode) + + # create directories + #if os.path.exists(plot_path): + # shutil.rmtree(plot_path) + os.makedirs(plot_path, exist_ok = True) + if object_view: + os.makedirs(os.path.join(plot_path, 'object'), exist_ok = True) + if individual_views: + os.makedirs(os.path.join(plot_path, 'individual'), exist_ok = True) + for group in ['error', 'input', 'background', 'prediction', 'position', 'rawmask', 'mask', 'othermask', 'imagination']: + os.makedirs(os.path.join(plot_path, 'individual', group), exist_ok = True) + os.makedirs(os.path.join(root_path, 'statistics'), exist_ok = True) + + # final directory + plot_path = plot_path + '/' + print(f"save plots to {plot_path}") + + return root_path, plot_path + +def store_statistics(memory, *args, extend=False): + for i,key in enumerate(memory.keys()): + if i >= len(args): + break + if extend: + memory[key].extend(args[i]) + else: + memory[key].append(args[i]) + return memory + +def append_statistics(memory1, memory2, ignore=[], extend=False): + for key in memory1: + if key not in ignore: + if extend: + memory2[key] = memory2[key] + memory1[key] + else: + memory2[key].append(memory1[key]) + return memory2 + +def load_model(cfg, cfg_net, file, device): + + net = Loci( + cfg_net, + teacher_forcing = cfg.defaults.teacher_forcing + ) + + # load model + if file != '': + print(f"load {file} to device {device}") + state = th.load(file, map_location=device) + + # backward compatibility + model = {} + for key, value in state["model"].items(): + # replace update_module with percept_gate_controller in key string: + key = key.replace("update_module", "percept_gate_controller") + model[key.replace(".module.", ".")] = value + + net.load_state_dict(model) + + # ??? + if net.get_init_status() < 1: + net.inc_init_level() + + # set network to evaluation mode + net = net.to(device=device) + + return net
\ No newline at end of file diff --git a/scripts/utils/io.py b/scripts/utils/io.py new file mode 100644 index 0000000..9bd8158 --- /dev/null +++ b/scripts/utils/io.py @@ -0,0 +1,137 @@ +import os +from scripts.utils.configuration import Configuration +import time +import torch as th +import numpy as np +from einops import rearrange, repeat, reduce + + +def init_device(cfg): + print(f'Cuda available: {th.cuda.is_available()} Cuda count: {th.cuda.device_count()}') + if th.cuda.is_available(): + device = th.device("cuda:0") + verbose = False + cfg.device = "cuda:0" + cfg.model.device = "cuda:0" + print('!!! USING GPU !!!') + else: + device = th.device("cpu") + verbose = True + cfg.device = "cpu" + cfg.model.device = "cpu" + cfg.model.batch_size = 2 + cfg.defaults.teacher_forcing = 4 + print('!!! USING CPU !!!') + return device,verbose + +class Timer: + + def __init__(self): + self.last = time.time() + self.passed = 0 + self.sum = 0 + + def __str__(self): + self.passed = self.passed * 0.99 + time.time() - self.last + self.sum = self.sum * 0.99 + 1 + passed = self.passed / self.sum + self.last = time.time() + + if passed > 1: + return f"{passed:.2f}s/it" + + return f"{1.0/passed:.2f}it/s" + +class UEMA: + + def __init__(self, memory = 100): + self.value = 0 + self.sum = 1e-30 + self.decay = np.exp(-1 / memory) + + def update(self, value): + self.value = self.value * self.decay + value + self.sum = self.sum * self.decay + 1 + + def __float__(self): + return self.value / self.sum + + +def model_path(cfg: Configuration, overwrite=False, move_old=True): + """ + Makes the model path, option to not overwrite + :param cfg: Configuration file with the model path + :param overwrite: Overwrites the files in the directory, else makes a new directory + :param move_old: Moves old folder with the same name to an old folder, if not overwrite + :return: Model path + """ + _path = os.path.join('out') + path = os.path.join(_path, cfg.model_path) + + if not os.path.exists(_path): + os.makedirs(_path) + + if not overwrite: + if move_old: + # Moves existing directory to an old folder + if os.path.exists(path): + old_path = os.path.join(_path, f'{cfg.model_path}_old') + if not os.path.exists(old_path): + os.makedirs(old_path) + _old_path = os.path.join(old_path, cfg.model_path) + i = 0 + while os.path.exists(_old_path): + i = i + 1 + _old_path = os.path.join(old_path, f'{cfg.model_path}_{i}') + os.renames(path, _old_path) + else: + # Increases number after directory name for each new path + i = 0 + while os.path.exists(path): + i = i + 1 + path = os.path.join(_path, f'{cfg.model_path}_{i}') + + return path + +class LossLogger: + + def __init__(self): + + self.avgloss = UEMA() + self.avg_position_loss = UEMA() + self.avg_time_loss = UEMA() + self.avg_encoder_loss = UEMA() + self.avg_mse_object_loss = UEMA() + self.avg_long_mse_object_loss = UEMA(33333) + self.avg_num_objects = UEMA() + self.avg_openings = UEMA() + self.avg_gestalt = UEMA() + self.avg_gestalt2 = UEMA() + self.avg_gestalt_mean = UEMA() + self.avg_update_gestalt = UEMA() + self.avg_update_position = UEMA() + + + def update_complete(self, avg_position_loss, avg_time_loss, avg_encoder_loss, avg_mse_object_loss, avg_long_mse_object_loss, avg_num_objects, avg_openings, avg_gestalt, avg_gestalt2, avg_gestalt_mean, avg_update_gestalt, avg_update_position): + + self.avg_position_loss.update(avg_position_loss.item()) + self.avg_time_loss.update(avg_time_loss.item()) + self.avg_encoder_loss.update(avg_encoder_loss.item()) + self.avg_mse_object_loss.update(avg_mse_object_loss.item()) + self.avg_long_mse_object_loss.update(avg_long_mse_object_loss.item()) + self.avg_num_objects.update(avg_num_objects) + self.avg_openings.update(avg_openings) + self.avg_gestalt.update(avg_gestalt.item()) + self.avg_gestalt2.update(avg_gestalt2.item()) + self.avg_gestalt_mean.update(avg_gestalt_mean.item()) + self.avg_update_gestalt.update(avg_update_gestalt.item()) + self.avg_update_position.update(avg_update_position.item()) + pass + + def update_average_loss(self, avgloss): + self.avgloss.update(avgloss) + pass + + def get_log(self): + info = f'Loss: {np.abs(float(self.avgloss)):.2e}|{float(self.avg_mse_object_loss):.2e}|{float(self.avg_long_mse_object_loss):.2e}, reg: {float(self.avg_encoder_loss):.2e}|{float(self.avg_time_loss):.2e}|{float(self.avg_position_loss):.2e}, obj: {float(self.avg_num_objects):.1f}, open: {float(self.avg_openings):.2e}|{float(self.avg_gestalt):.2f}, bin: {float(self.avg_gestalt_mean):.2e}|{np.sqrt(float(self.avg_gestalt2) - float(self.avg_gestalt)**2):.2e} closed: {float(self.avg_update_gestalt):.2e}|{float(self.avg_update_position):.2e}' + return info diff --git a/scripts/utils/optimizers.py b/scripts/utils/optimizers.py new file mode 100644 index 0000000..49283b3 --- /dev/null +++ b/scripts/utils/optimizers.py @@ -0,0 +1,100 @@ +import math +import torch as th +import numpy as np +from torch.optim.optimizer import Optimizer + +""" +Liyuan Liu , Haoming Jiang, Pengcheng He, Weizhu Chen, Xiaodong Liu, Jianfeng Gao, and Jiawei Han (2020). +On the Variance of the Adaptive Learning Rate and Beyond. the Eighth International Conference on Learning +Representations. +""" +class RAdam(Optimizer): + + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, weight_decay=0, degenerated_to_sgd=False): + if not 0.0 <= lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 <= eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + self.degenerated_to_sgd = degenerated_to_sgd + if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict): + for param in params: + if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]): + param['buffer'] = [[None, None, None] for _ in range(10)] + defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]) + super(RAdam, self).__init__(params, defaults) + + def __setstate__(self, state): + super(RAdam, self).__setstate__(state) + + def step(self, closure=None): + + loss = None + if closure is not None: + loss = closure() + + for group in self.param_groups: + + for p in group['params']: + if p.grad is None: + continue + grad = p.grad.data.float() + if grad.is_sparse: + raise RuntimeError('RAdam does not support sparse gradients') + + p_data_fp32 = p.data.float() + + state = self.state[p] + + if len(state) == 0: + state['step'] = 0 + state['exp_avg'] = th.zeros_like(p_data_fp32) + state['exp_avg_sq'] = th.zeros_like(p_data_fp32) + else: + state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32) + state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32) + + exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] + beta1, beta2 = group['betas'] + + exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value = 1 - beta2) + exp_avg.mul_(beta1).add_(grad, alpha = 1 - beta1) + + state['step'] += 1 + buffered = group['buffer'][int(state['step'] % 10)] + if state['step'] == buffered[0]: + N_sma, step_size = buffered[1], buffered[2] + else: + buffered[0] = state['step'] + beta2_t = beta2 ** state['step'] + N_sma_max = 2 / (1 - beta2) - 1 + N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t) + buffered[1] = N_sma + + # more conservative since it's an approximated value + if N_sma >= 5: + step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step']) + elif self.degenerated_to_sgd: + step_size = 1.0 / (1 - beta1 ** state['step']) + else: + step_size = -1 + buffered[2] = step_size + + # more conservative since it's an approximated value + if N_sma >= 5: + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + denom = exp_avg_sq.sqrt().add_(group['eps']) + p_data_fp32.addcdiv_(exp_avg, denom, value = -step_size * group['lr']) + p.data.copy_(p_data_fp32) + elif step_size > 0: + if group['weight_decay'] != 0: + p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32) + p_data_fp32.add_(-step_size * group['lr'], exp_avg) + p.data.copy_(p_data_fp32) + + return loss
\ No newline at end of file diff --git a/scripts/utils/plot_utils.py b/scripts/utils/plot_utils.py new file mode 100644 index 0000000..5cb20de --- /dev/null +++ b/scripts/utils/plot_utils.py @@ -0,0 +1,428 @@ +import torch as th +from torch import nn +import numpy as np +import cv2 +from einops import rearrange, repeat +import matplotlib.pyplot as plt +import PIL +from model.utils.nn_utils import Gaus2D, Vector2D + +def preprocess(tensor, scale=1, normalize=False, mean_std_normalize=False): + + if tensor is None: + return None + + if normalize: + min_ = th.min(tensor) + max_ = th.max(tensor) + tensor = (tensor - min_) / (max_ - min_) + + if mean_std_normalize: + mean = th.mean(tensor) + std = th.std(tensor) + tensor = th.clip((tensor - mean) / (2 * std), -1, 1) * 0.5 + 0.5 + + if scale > 1: + upsample = nn.Upsample(scale_factor=scale).to(tensor[0].device) + tensor = upsample(tensor) + + return tensor + +def preprocess_multi(*args, scale): + return [preprocess(a, scale) for a in args] + +def color_mask(mask): + + colors = th.tensor([ + [ 255, 0, 0 ], + [ 0, 0, 255 ], + [ 255, 255, 0 ], + [ 255, 0, 255 ], + [ 0, 255, 255 ], + [ 0, 255, 0 ], + [ 255, 128, 0 ], + [ 128, 255, 0 ], + [ 128, 0, 255 ], + [ 255, 0, 128 ], + [ 0, 255, 128 ], + [ 0, 128, 255 ], + [ 255, 128, 128 ], + [ 128, 255, 128 ], + [ 128, 128, 255 ], + [ 255, 128, 128 ], + [ 128, 255, 128 ], + [ 128, 128, 255 ], + [ 255, 128, 255 ], + [ 128, 255, 255 ], + [ 128, 255, 255 ], + [ 255, 255, 128 ], + [ 255, 255, 128 ], + [ 255, 128, 255 ], + [ 128, 0, 0 ], + [ 0, 0, 128 ], + [ 128, 128, 0 ], + [ 128, 0, 128 ], + [ 0, 128, 128 ], + [ 0, 128, 0 ], + [ 128, 128, 0 ], + [ 128, 128, 0 ], + [ 128, 0, 128 ], + [ 128, 0, 128 ], + [ 0, 128, 128 ], + [ 0, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + ], device = mask.device) / 255.0 + + colors = colors.view(1, -1, 3, 1, 1) + mask = mask.unsqueeze(dim=2) + + return th.sum(colors[:,:mask.shape[1]] * mask, dim=1) + +def get_color(o): + colors = th.tensor([ + [ 255, 0, 0 ], + [ 0, 0, 255 ], + [ 255, 255, 0 ], + [ 255, 0, 255 ], + [ 0, 255, 255 ], + [ 0, 255, 0 ], + [ 255, 128, 0 ], + [ 128, 255, 0 ], + [ 128, 0, 255 ], + [ 255, 0, 128 ], + [ 0, 255, 128 ], + [ 0, 128, 255 ], + [ 255, 128, 128 ], + [ 128, 255, 128 ], + [ 128, 128, 255 ], + [ 255, 128, 128 ], + [ 128, 255, 128 ], + [ 128, 128, 255 ], + [ 255, 128, 255 ], + [ 128, 255, 255 ], + [ 128, 255, 255 ], + [ 255, 255, 128 ], + [ 255, 255, 128 ], + [ 255, 128, 255 ], + [ 128, 0, 0 ], + [ 0, 0, 128 ], + [ 128, 128, 0 ], + [ 128, 0, 128 ], + [ 0, 128, 128 ], + [ 0, 128, 0 ], + [ 128, 128, 0 ], + [ 128, 128, 0 ], + [ 128, 0, 128 ], + [ 128, 0, 128 ], + [ 0, 128, 128 ], + [ 0, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + ]) / 255.0 + + colors = colors.view(48,3) + return colors[o] + +def to_rgb(tensor: th.Tensor): + return th.cat(( + tensor * 0.6 + 0.4, + tensor, + tensor + ), dim=1) + +def visualise_gate(gate, h, w): + bar = th.ones((1,h,w), device=gate.device) * 0.9 + black = int(w*gate.item()) + if black > 0: + bar[:,:, -black:] = 0 + return bar + +def get_highlighted_input(input, mask_cur): + + # highlight error + highlighted_input = input + if mask_cur is not None: + grayscale = input[:,0:1] * 0.299 + input[:,1:2] * 0.587 + input[:,2:3] * 0.114 + object_mask_cur = th.sum(mask_cur[:,:-1], dim=1).unsqueeze(dim=1) + highlighted_input = grayscale * (1 - object_mask_cur) + highlighted_input += grayscale * object_mask_cur * 0.3333333 + cmask = color_mask(mask_cur[:,:-1]) + highlighted_input = highlighted_input + cmask * 0.6666666 + + return highlighted_input + +def color_slots(image, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur): + + image = (1-image) * slots_bounded + image * (1-slots_bounded) + image = th.clip(image - 0.3, 0,1) * slots_partially_occluded_cur + image * (1-slots_partially_occluded_cur) + image = th.clip(image - 0.3, 0,1) * slots_occluded_cur + image * (1-slots_occluded_cur) + + return image + +def compute_occlusion_mask(rawmask_cur, rawmask_next, mask_cur, mask_next, scale): + + # compute occlusion mask + occluded_cur = th.clip(rawmask_cur - mask_cur, 0, 1)[:,:-1] + occluded_next = th.clip(rawmask_next - mask_next, 0, 1)[:,:-1] + + # to rgb + rawmask_cur = repeat(rawmask_cur[:,:-1], 'b o h w -> b (o 3) h w') + rawmask_next = repeat(rawmask_next[:,:-1], 'b o h w -> b (o 3) h w') + + # scale + occluded_next = preprocess(occluded_next, scale) + occluded_cur = preprocess(occluded_cur, scale) + rawmask_cur = preprocess(rawmask_cur, scale) + rawmask_next = preprocess(rawmask_next, scale) + + # set occlusion to red + rawmask_cur = rearrange(rawmask_cur, 'b (o c) h w -> b o c h w', c = 3) + rawmask_cur[:,:,0] = rawmask_cur[:,:,0] * (1 - occluded_next) + rawmask_cur[:,:,1] = rawmask_cur[:,:,1] * (1 - occluded_next) + + rawmask_next = rearrange(rawmask_next, 'b (o c) h w -> b o c h w', c = 3) + rawmask_next[:,:,0] = rawmask_next[:,:,0] * (1 - occluded_next) + rawmask_next[:,:,1] = rawmask_next[:,:,1] * (1 - occluded_next) + + return rawmask_cur, rawmask_next + +def plot_online_error_slots(errors, error_name, target, sequence_len, root_path, visibilty_memory, slots_bounded, ylim=0.3): + error_plots = [] + if len(errors) > 0: + num_slots = int(th.sum(slots_bounded).item()) + errors = rearrange(np.array(errors), '(l o) -> o l', o=len(slots_bounded))[:num_slots] + visibilty_memory = rearrange(np.array(visibilty_memory), '(l o) -> o l', o=len(slots_bounded))[:num_slots] + for error,visibility in zip(errors, visibilty_memory): + + if len(error) < sequence_len: + fig, ax = plt.subplots(figsize=(round(target.shape[3]/100,2), round(target.shape[2]/100,2))) + plt.plot(error, label=error_name) + + visibility = np.concatenate((visibility, np.ones(sequence_len-len(error)))) + ax.fill_between(range(sequence_len), 0, 1, where=visibility==0, color='orange', alpha=0.3, transform=ax.get_xaxis_transform()) + + plt.xlim((0,sequence_len)) + plt.ylim((0,ylim)) + fig.tight_layout() + plt.savefig(f'{root_path}/tmp.jpg') + + error_plot = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb()) + error_plot = th.from_numpy(np.array(error_plot).transpose(2,0,1)) + plt.close(fig) + + error_plots.append(error_plot) + + return error_plots + +def plot_online_error(error, error_name, target, t, i, sequence_len, root_path, online_surprise = False): + + fig = plt.figure(figsize=( round(target.shape[3]/50,2), round(target.shape[2]/50,2) )) + plt.plot(error, label=error_name) + + if online_surprise: + # compute moving average of error + moving_average_length = 10 + if t > moving_average_length: + moving_average_length += 1 + average_error = np.mean(error[-moving_average_length:-1]) + current_sd = np.std(error[-moving_average_length:-1]) + current_error = error[-1] + + if current_error > average_error + 2 * current_sd: + fig.set_facecolor('orange') + + plt.xlim((0,sequence_len)) + plt.legend() + # increase title size + plt.title(f'{error_name}', fontsize=20) + plt.xlabel('timestep') + plt.ylabel('error') + plt.savefig(f'{root_path}/tmp.jpg') + + error_plot = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb()) + error_plot = th.from_numpy(np.array(error_plot).transpose(2,0,1)) + plt.close(fig) + + return error_plot + +def plot_object_view(error_plot, error_plot2, error_plot_slots, error_plot_slots2, highlighted_input, output_hidden, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object): + + # add ground truth positions of objects to image + if gt_positions_target_next is not None: + for o in range(gt_positions_target_next.shape[1]): + position = gt_positions_target_next[0, o] + position = position/2 + 0.5 + + if position[2] > 0.0 and position[0] > 0.0 and position[0] < 1.0 and position[1] > 0.0 and position[1] < 1.0: + width = 5 + w = np.clip(int(position[0]*target.shape[2]), width, target.shape[2]-width).item() + h = np.clip(int(position[1]*target.shape[3]), width, target.shape[3]-width).item() + col = get_color(o).view(3,1,1) + target[0,:,(w-width):(w+width), (h-width):(h+width)] = col + + # add these positions to the associated slots velocity_next2d ilustration + slots = (association_table[0] == o).nonzero() + for s in slots.flatten(): + velocity_next2d[s,:,(w-width):(w+width), (h-width):(h+width)] = col + + if output_hidden is not None and s != largest_object: + output_hidden[0,:,(w-width):(w+width), (h-width):(h+width)] = col + + gateheight = 60 + ch = 40 + gh = 40 + gh_bar = gh-20 + gh_margin = int((gh-gh_bar)/2) + margin = 20 + slots_margin = 10 + height = size[0] * 6 + 18*5 + width = size[1] * 4 + 18*2 + size[1]*num_objects + 6*(num_objects+1) + slots_margin*(num_objects+1) + img = th.ones((3, height, width), device = object_next.device) * 0.4 + row = (lambda row_index: [2*size[0]*row_index + (row_index+1)*margin, 2*size[0]*(row_index+1) + (row_index+1)*margin]) + col1 = range(margin, margin + size[1]*2) + col2 = range(width-(margin+size[1]*2), width-margin) + + img[:,row(0)[0]:row(0)[1], col1] = preprocess(highlighted_input.to(object_next.device), 2)[0] + img[:,row(1)[0]:row(1)[1], col1] = preprocess(output_hidden.to(object_next.device), 2)[0] + img[:,row(2)[0]:row(2)[1], col1] = preprocess(target.to(object_next.device), 2)[0] + + if error_plot is not None: + img[:,row(0)[0]+gh+ch+2*margin-gh_margin:row(0)[1]+gh+ch+2*margin-gh_margin, col2] = preprocess(error_plot.to(object_next.device), normalize= True) + if error_plot2 is not None: + img[:,row(2)[0]:row(2)[1], col2] = preprocess(error_plot2.to(object_next.device), normalize= True) + + for o in range(num_objects): + + col = 18+size[1]*2+6+o*(6+size[1])+(o+1)*slots_margin + col = range(col, col + size[1]) + + # color bar for the gate + if (error_plot_slots2 is not None) and len(error_plot_slots2) > o: + img[:,margin:margin+ch, col] = get_color(o).view(3,1,1).to(object_next.device) + + img[:,margin+ch+2*margin:2*margin+gh_bar+ch+margin, col] = visualise_gate(slots_closed[:,o, 0].to(object_next.device), h=gh_bar, w=len(col)) + offset = gh+margin-gh_margin+ch+2*margin + row = (lambda row_index: [offset+(size[0]+6)*row_index, offset+size[0]*(row_index+1)+6*row_index]) + + img[:,row(0)[0]:row(0)[1], col] = preprocess(rawmask_next[0,o].to(object_next.device)) + img[:,row(1)[0]:row(1)[1], col] = preprocess(object_next[:,o].to(object_next.device)) + if (error_plot_slots2 is not None) and len(error_plot_slots2) > o: + img[:,row(2)[0]:row(2)[1], col] = preprocess(error_plot_slots2[o].to(object_next.device), normalize=True) + + offset = margin*2-8 + row = (lambda row_index: [offset+(size[0]+6)*row_index, offset+size[0]*(row_index+1)+6*row_index]) + img[:,row(4)[0]-gh+gh_margin:row(4)[0]-gh_margin, col] = visualise_gate(slots_closed[:,o, 1].to(object_next.device), h=gh_bar, w=len(col)) + img[:,row(4)[0]:row(4)[1], col] = preprocess(velocity_next2d[o].to(object_next.device), normalize=True)[0] + if (error_plot_slots is not None) and len(error_plot_slots) > o: + img[:,row(5)[0]:row(5)[1], col] = preprocess(error_plot_slots[o].to(object_next.device), normalize=True) + + img = rearrange(img * 255, 'c h w -> h w c').cpu().numpy() + + return img + +def write_image(file, img): + img = rearrange(img * 255, 'c h w -> h w c').cpu().numpy() + cv2.imwrite(file, img) + + pass + +def plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i): + + # Create eposition helpers + size, gaus2d, vector2d, scale = get_position_helper(cfg_net, mask_cur.device) + + # Compute plot content + highlighted_input = get_highlighted_input(input, mask_cur) + output = th.clip(output_next, 0, 1) + position_cur2d = gaus2d(rearrange(position_encoder_cur, 'b (o c) -> (b o) c', o=cfg_net.num_objects)) + velocity_next2d = vector2d(rearrange(position_next, 'b (o c) -> (b o) c', o=cfg_net.num_objects)) + + # color slots + slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next = reshape_slots(slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next) + position_cur2d = color_slots(position_cur2d, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur) + velocity_next2d = color_slots(velocity_next2d, slots_bounded, slots_partially_occluded_next, slots_occluded_next) + + # compute occlusion + if (cfg.datatype == "adept"): + rawmask_cur_l, rawmask_next_l = compute_occlusion_mask(rawmask_cur, rawmask_next, mask_cur, mask_next, scale) + rawmask_cur_h, rawmask_next_h = compute_occlusion_mask(rawmask_cur, rawmask_hidden, mask_cur, mask_next, scale) + rawmask_cur_h[:,largest_object] = rawmask_cur_l[:,largest_object] + rawmask_next_h[:,largest_object] = rawmask_next_l[:,largest_object] + rawmask_cur = rawmask_cur_h + rawmask_next = rawmask_next_h + object_hidden[:, largest_object] = object_next[:, largest_object] + object_next = object_hidden + else: + rawmask_cur, rawmask_next = compute_occlusion_mask(rawmask_cur, rawmask_next, mask_cur, mask_next, scale) + + # scale plot content + input, target, output, highlighted_input, object_next, object_cur, mask_next, error_next, output_hidden, output_next = preprocess_multi(input, target, output, highlighted_input, object_next, object_cur, mask_next, error_next, output_hidden, output_next, scale=scale) + + # reshape + object_next = rearrange(object_next, 'b (o c) h w -> b o c h w', c = cfg_net.img_channels) + object_cur = rearrange(object_cur, 'b (o c) h w -> b o c h w', c = cfg_net.img_channels) + mask_next = rearrange(mask_next, 'b (o 1) h w -> b o 1 h w') + + if object_view: + if (cfg.datatype == "adept"): + num_objects = 4 + error_plot_slots = plot_online_error_slots(statistics_complete_slots['TE'][-cfg_net.num_objects*(t+1):], 'Tracking error', target, sequence_len, root_path, statistics_complete_slots['visible'][-cfg_net.num_objects*(t+1):], slots_bounded) + error_plot_slots2 = plot_online_error_slots(statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], 'Image error', target, sequence_len, root_path, statistics_complete_slots['visible'][-cfg_net.num_objects*(t+1):], slots_bounded, ylim=0.0001) + error_plot = plot_online_error(statistics_batch['image_error'], 'Prediction error', target, t, i, sequence_len, root_path) + error_plot2 = plot_online_error(statistics_batch['TE'], 'Tracking error', target, t, i, sequence_len, root_path) + img = plot_object_view(error_plot, error_plot2, error_plot_slots, error_plot_slots2, highlighted_input, output_hidden, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object) + else: + num_objects = cfg_net.num_objects + error_plot_slots2 = plot_online_error_slots(statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], 'Image error', target, sequence_len, root_path, statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], slots_bounded, ylim=0.0001) + error_plot = plot_online_error(statistics_batch['image_error_mse'], 'Prediction error', target, t, i, sequence_len, root_path) + img = plot_object_view(error_plot, None, None, error_plot_slots2, highlighted_input, output_next, object_next, rawmask_next, velocity_next2d, target, slots_closed, None, None, size, num_objects, largest_object) + + cv2.imwrite(f'{plot_path}object/gpnet-objects-{i:04d}-{t_index:03d}.jpg', img) + + if individual_views: + # ['error', 'input', 'background', 'prediction', 'position', 'rawmask', 'mask', 'othermask']: + write_image(f'{plot_path}/individual/error/error-{i:04d}-{t_index:03d}.jpg', error_next[0]) + write_image(f'{plot_path}/individual/input/input-{i:04d}-{t_index:03d}.jpg', input[0]) + write_image(f'{plot_path}/individual/background/background-{i:04d}-{t_index:03d}.jpg', mask_next[0,-1]) + write_image(f'{plot_path}/individual/imagination/imagination-{i:04d}-{t_index:03d}.jpg', output_hidden[0]) + write_image(f'{plot_path}/individual/prediction/prediction-{i:04d}-{t_index:03d}.jpg', output_next[0]) + + pass + +def get_position_helper(cfg_net, device): + size = cfg_net.input_size + gaus2d = Gaus2D(size).to(device) + vector2d = Vector2D(size).to(device) + scale = size[0] // (cfg_net.latent_size[0] * 2**(cfg_net.level*2)) + return size,gaus2d,vector2d,scale + +def reshape_slots(slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next): + + slots_bounded = th.squeeze(slots_bounded)[..., None,None,None] + slots_partially_occluded_cur = th.squeeze(slots_partially_occluded_cur)[..., None,None,None] + slots_occluded_cur = th.squeeze(slots_occluded_cur)[..., None,None,None] + slots_partially_occluded_next = th.squeeze(slots_partially_occluded_next)[..., None,None,None] + slots_occluded_next = th.squeeze(slots_occluded_next)[..., None,None,None] + + return slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next
\ No newline at end of file diff --git a/scripts/validation.py b/scripts/validation.py new file mode 100644 index 0000000..5b1638d --- /dev/null +++ b/scripts/validation.py @@ -0,0 +1,261 @@ +import torch as th +from torch import nn +from torch.utils.data import DataLoader +import numpy as np +from einops import rearrange, repeat, reduce +from scripts.utils.configuration import Configuration +from model.loci import Loci +import time +import lpips +from skimage.metrics import structural_similarity as ssimloss +from skimage.metrics import peak_signal_noise_ratio as psnrloss + +def validation_adept(valloader: DataLoader, net: Loci, cfg: Configuration, device): + + # memory + mseloss = nn.MSELoss() + avgloss = 0 + start_time = time.time() + + with th.no_grad(): + for i, input in enumerate(valloader): + + # get input frame and target frame + tensor = input[0].float().to(device) + background_fix = input[1].to(device) + + # apply skip frames + tensor = tensor[:,range(0, tensor.shape[1], cfg.defaults.skip_frames)] + sequence_len = tensor.shape[1] + + # initial frame + input = tensor[:,0] + target = th.clip(tensor[:,0], 0, 1) + error_last = None + + # placehodlers + mask_cur = None + mask_last = None + rawmask_last = None + position_last = None + gestalt_last = None + priority_last = None + gt_positions_target = None + slots_occlusionfactor = None + + # loop through frames + for t_index,t in enumerate(range(-cfg.defaults.teacher_forcing, sequence_len-1)): + + # move to next frame + t_run = max(t, 0) + input = tensor[:,t_run] + target = th.clip(tensor[:,t_run+1], 0, 1) + + # obtain prediction + ( + output_next, + position_next, + gestalt_next, + priority_next, + mask_next, + rawmask_next, + object_next, + background, + slots_occlusionfactor, + output_cur, + position_cur, + gestalt_cur, + priority_cur, + mask_cur, + rawmask_cur, + object_cur, + position_encoder_cur, + slots_bounded, + slots_partially_occluded_cur, + slots_occluded_cur, + slots_partially_occluded_next, + slots_occluded_next, + slots_closed, + output_hidden, + largest_object, + rawmask_hidden, + object_hidden + ) = net( + input, + error_last, + mask_last, + rawmask_last, + position_last, + gestalt_last, + priority_last, + background_fix, + slots_occlusionfactor, + reset = (t == -cfg.defaults.teacher_forcing), + evaluate=True, + warmup = (t < 0), + shuffleslots = False, + reset_mask = (t <= 0), + allow_spawn = True, + show_hidden = False, + clean_slots = (t <= 0), + ) + + # 1. Track error + if t >= 0: + loss = mseloss(output_next, target) + avgloss += loss.item() + + # 2. Remember output + mask_last = mask_next.clone() + rawmask_last = rawmask_next.clone() + position_last = position_next.clone() + gestalt_last = gestalt_next.clone() + priority_last = priority_next.clone() + + # 3. Error for next frame + # background error + bg_error_cur = th.sqrt(reduce((input - background)**2, 'b c h w -> b 1 h w', 'mean')).detach() + bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach() + + # prediction error + error_next = th.sqrt(reduce((target - output_next)**2, 'b c h w -> b 1 h w', 'mean')).detach() + error_next = th.sqrt(error_next) * bg_error_next + error_last = error_next.clone() + + print(f"Validation loss: {avgloss / len(valloader.dataset):.2e}, Time: {time.time() - start_time}") + + pass + + +def validation_clevrer(valloader: DataLoader, net: Loci, cfg: Configuration, device): + + # memory + mseloss = nn.MSELoss() + lpipsloss = lpips.LPIPS(net='vgg').to(device) + avgloss_mse = 0 + avgloss_lpips = 0 + avgloss_psnr = 0 + avgloss_ssim = 0 + start_time = time.time() + + burn_in_length = 6 + rollout_length = 42 + + with th.no_grad(): + for i, input in enumerate(valloader): + + # get input frame and target frame + tensor = input[0].float().to(device) + background_fix = input[1].to(device) + + # apply skip frames + tensor = tensor[:,range(0, tensor.shape[1], cfg.defaults.skip_frames)] + sequence_len = tensor.shape[1] + + # initial frame + input = tensor[:,0] + target = th.clip(tensor[:,0], 0, 1) + error_last = None + + # placehodlers + mask_cur = None + mask_last = None + rawmask_last = None + position_last = None + gestalt_last = None + priority_last = None + gt_positions_target = None + slots_occlusionfactor = None + + # loop through frames + for t_index,t in enumerate(range(-cfg.defaults.teacher_forcing, min(burn_in_length + rollout_length-1, sequence_len-1))): + + # move to next frame + t_run = max(t, 0) + input = tensor[:,t_run] + target = th.clip(tensor[:,t_run+1], 0, 1) + if t_run >= burn_in_length: + blackout = th.tensor((np.random.rand(valloader.batch_size) < 0.2)[:,None,None,None]).float().to(device) + input = blackout * (input * 0) + (1-blackout) * input + error_last = blackout * (error_last * 0) + (1-blackout) * error_last + + # obtain prediction + ( + output_next, + position_next, + gestalt_next, + priority_next, + mask_next, + rawmask_next, + object_next, + background, + slots_occlusionfactor, + output_cur, + position_cur, + gestalt_cur, + priority_cur, + mask_cur, + rawmask_cur, + object_cur, + position_encoder_cur, + slots_bounded, + slots_partially_occluded_cur, + slots_occluded_cur, + slots_partially_occluded_next, + slots_occluded_next, + slots_closed, + output_hidden, + largest_object, + rawmask_hidden, + object_hidden + ) = net( + input, + error_last, + mask_last, + rawmask_last, + position_last, + gestalt_last, + priority_last, + background_fix, + slots_occlusionfactor, + reset = (t == -cfg.defaults.teacher_forcing), + evaluate=True, + warmup = (t < 0), + shuffleslots = False, + reset_mask = (t <= 0), + allow_spawn = True, + show_hidden = False, + clean_slots = (t <= 0), + ) + + # 1. Track error + if t >= 0: + loss_mse = mseloss(output_next, target) + loss_ssim = np.sum([ssimloss(output_next[i].cpu().numpy(), target[i].cpu().numpy(), channel_axis=0,gaussian_weights=True,sigma=1.5,use_sample_covariance=False,data_range=1) for i in range(output_next.shape[0])]), + loss_psnr = np.sum([psnrloss(output_next[i].cpu().numpy(), target[i].cpu().numpy(), data_range=1) for i in range(output_next.shape[0])]), + loss_lpips = th.sum(lpipsloss(output_next*2-1, target*2-1)) + + avgloss_mse += loss_mse.item() + avgloss_ssim += loss_ssim[0].item() + avgloss_psnr += loss_psnr[0].item() + avgloss_lpips += loss_lpips.item() + + # 2. Remember output + mask_last = mask_next.clone() + rawmask_last = rawmask_next.clone() + position_last = position_next.clone() + gestalt_last = gestalt_next.clone() + priority_last = priority_next.clone() + + # 3. Error for next frame + # background error + bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach() + + # prediction error + error_next = th.sqrt(reduce((target - output_next)**2, 'b c h w -> b 1 h w', 'mean')).detach() + error_next = th.sqrt(error_next) * bg_error_next + error_last = error_next.clone() + + print(f"MSE loss: {avgloss_mse / len(valloader.dataset):.2e}, LPIPS loss: {avgloss_lpips / len(valloader.dataset):.2e}, PSNR loss: {avgloss_psnr / len(valloader.dataset):.2e}, SSIM loss: {avgloss_ssim / len(valloader.dataset):.2e}, Time: {time.time() - start_time}") + + pass
\ No newline at end of file |