aboutsummaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
Diffstat (limited to 'scripts')
-rw-r--r--scripts/evaluation_adept.py398
-rw-r--r--scripts/evaluation_adept_savi.py233
-rw-r--r--scripts/evaluation_clevrer.py311
-rw-r--r--scripts/exec/eval.py28
-rw-r--r--scripts/exec/eval_savi.py17
-rw-r--r--scripts/exec/train.py33
-rw-r--r--scripts/training.py523
-rw-r--r--scripts/utils/configuration.py83
-rw-r--r--scripts/utils/eval_metrics.py350
-rw-r--r--scripts/utils/eval_utils.py78
-rw-r--r--scripts/utils/io.py137
-rw-r--r--scripts/utils/optimizers.py100
-rw-r--r--scripts/utils/plot_utils.py428
-rw-r--r--scripts/validation.py261
14 files changed, 2980 insertions, 0 deletions
diff --git a/scripts/evaluation_adept.py b/scripts/evaluation_adept.py
new file mode 100644
index 0000000..eab3ad9
--- /dev/null
+++ b/scripts/evaluation_adept.py
@@ -0,0 +1,398 @@
+import torch as th
+from torch.utils.data import Dataset, DataLoader, Subset
+from torch import nn
+import os
+from scripts.utils.plot_utils import plot_timestep
+from scripts.utils.configuration import Configuration
+from scripts.utils.io import init_device
+import numpy as np
+from einops import rearrange, repeat, reduce
+import motmetrics as mm
+from copy import deepcopy
+import pandas as pd
+from scripts.utils.eval_utils import append_statistics, load_model, setup_result_folders, store_statistics
+
+
+def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_frequency= 1, plot_first_samples = 2):
+
+ # Set up cpu or gpu training
+ device, verbose = init_device(cfg)
+
+ # Config
+ cfg_net = cfg.model
+ cfg_net.batch_size = 1
+
+ # Load model
+ net = load_model(cfg, cfg_net, file, device)
+ net.eval()
+
+ # Plot config
+ object_view = True
+ individual_views = False
+ root_path = None
+
+ # get evaluation sets for control and surprise condition
+ set_test_array, evaluation_modes = get_evaluation_sets(dataset)
+
+ # memory
+ statistics_template = {'set': [], 'evalmode': [], 'scene': [], 'frame': [], 'image_error': [], 'TE': []}
+ statistics_complete = deepcopy(statistics_template)
+ statistics_complete_slots = {'set': [], 'evalmode': [], 'scene': [], 'frame': [], 'slot':[], 'TE': [], 'visible': [], 'bound': [], 'occluder': [], 'inimage': [], 'slot_error': [], 'mask_size': [], 'rawmask_size': [], 'rawmask_size_hidden': [], 'alpha_pos': [], 'alpha_ges': [], 'object_id': [], 'vanishing': []}
+ acc_memory_complete = None
+
+ for set_test in set_test_array:
+
+ for evaluation_mode in evaluation_modes:
+ print(f'Start evaluation loop: {evaluation_mode} - {set_test["type"]}')
+
+ # Load data
+ dataloader = DataLoader(
+ Subset(dataset, set_test['samples']),
+ num_workers = 1,
+ pin_memory = False,
+ batch_size = 1,
+ shuffle = False
+ )
+
+ # memory
+ mseloss = nn.MSELoss()
+ root_path, plot_path = setup_result_folders(file, n, set_test, evaluation_mode, object_view, individual_views)
+ acc_memory_eval = []
+
+ with th.no_grad():
+ for i, input in enumerate(dataloader):
+ print(f'Processing sample {i+1}/{len(dataloader)}', flush=True)
+
+ # Data
+ tensor = input[0].float().to(device)
+ background_fix = input[1].to(device)
+ gt_object_positions = input[3].to(device)
+ gt_object_visibility = input[4].to(device)
+ gt_occluder_mask = input[5].to(device)
+
+ # Apply skip frames
+ gt_object_positions = gt_object_positions[:,range(0, tensor.shape[1], cfg.defaults.skip_frames)]
+ gt_object_visibility = gt_object_visibility[:,range(0, tensor.shape[1], cfg.defaults.skip_frames)]
+ tensor = tensor[:,range(0, tensor.shape[1], cfg.defaults.skip_frames)]
+ sequence_len = tensor.shape[1]
+
+ # Placehodlers
+ mask_cur = None
+ mask_last = None
+ rawmask_last = None
+ position_last = None
+ gestalt_last = None
+ priority_last = None
+ gt_positions_target = None
+ slots_occlusionfactor = None
+ error_last = None
+
+ # Memory
+ association_table = th.ones(cfg_net.num_objects).to(device) * -1
+ acc = mm.MOTAccumulator(auto_id=True)
+ statistics_batch = deepcopy(statistics_template)
+ slots_vanishing_memory = np.zeros(cfg_net.num_objects)
+
+ # loop through frames
+ for t_index,t in enumerate(range(-cfg.defaults.teacher_forcing, sequence_len-1)):
+
+ # Move to next frame
+ t_run = max(t, 0)
+ input = tensor[:,t_run]
+ target = th.clip(tensor[:,t_run+1], 0, 1)
+ gt_positions_target = gt_object_positions[:,t_run]
+ gt_positions_target_next = gt_object_positions[:,t_run+1]
+ gt_visibility_target = gt_object_visibility[:,t_run]
+
+ # Forward Pass
+ (
+ output_next,
+ position_next,
+ gestalt_next,
+ priority_next,
+ mask_next,
+ rawmask_next,
+ object_next,
+ background,
+ slots_occlusionfactor,
+ output_cur,
+ position_cur,
+ gestalt_cur,
+ priority_cur,
+ mask_cur,
+ rawmask_cur,
+ object_cur,
+ position_encoder_cur,
+ slots_bounded,
+ slots_partially_occluded_cur,
+ slots_occluded_cur,
+ slots_partially_occluded_next,
+ slots_occluded_next,
+ slots_closed,
+ output_hidden,
+ largest_object,
+ rawmask_hidden,
+ object_hidden
+ ) = net(
+ input,
+ error_last,
+ mask_last,
+ rawmask_last,
+ position_last,
+ gestalt_last,
+ priority_last,
+ background_fix,
+ slots_occlusionfactor,
+ reset = (t == -cfg.defaults.teacher_forcing),
+ evaluate=True,
+ warmup = (t < 0),
+ shuffleslots = False,
+ reset_mask = (t <= 0),
+ allow_spawn = True,
+ show_hidden = True,
+ clean_slots = (t <= 0),
+ )
+
+ # 1. Track error
+ if t >= 0:
+
+ # Position error: MSE between predicted position and target position
+ tracking_error, tracking_error_perslot, association_table, slots_visible, slots_in_image, slots_occluder = calculate_tracking_error(gt_positions_target, gt_visibility_target, position_cur, cfg_net.num_objects, slots_bounded, slots_occluded_cur, association_table, gt_occluder_mask)
+
+ statistics_batch = store_statistics(statistics_batch,
+ set_test['type'],
+ evaluation_mode,
+ set_test['samples'][i],
+ t,
+ mseloss(output_next, target).item(),
+ tracking_error)
+
+ # Compute rawmask_size
+ rawmask_size, rawmask_size_hidden, mask_size = compute_mask_sizes(mask_next, rawmask_next, rawmask_hidden)
+
+ # Compute slot-wise prediction error
+ slot_error = compute_slot_error(cfg_net, target, output_next, mask_next, mask_size)
+
+ # Check if objects vanishes: they leave the scene suprsingly in the suprrise condition
+ slots_vanishing = compute_vanishing_slots(gt_positions_target, association_table, gt_positions_target_next)
+ slots_vanishing_memory = slots_vanishing + slots_vanishing_memory
+
+ # Store slot statistics
+ statistics_complete_slots = store_statistics(statistics_complete_slots,
+ [set_test['type']] * cfg_net.num_objects,
+ [evaluation_mode] * cfg_net.num_objects,
+ [set_test['samples'][i]] * cfg_net.num_objects,
+ [t] * cfg_net.num_objects,
+ range(cfg_net.num_objects),
+ tracking_error_perslot.cpu().numpy().flatten(),
+ slots_visible.cpu().numpy().flatten().astype(int),
+ slots_bounded.cpu().numpy().flatten().astype(int),
+ slots_occluder.cpu().numpy().flatten().astype(int),
+ slots_in_image.cpu().numpy().flatten().astype(int),
+ slot_error.cpu().numpy().flatten(),
+ mask_size.cpu().numpy().flatten(),
+ rawmask_size.cpu().numpy().flatten(),
+ rawmask_size_hidden.cpu().numpy().flatten(),
+ slots_closed[:, :, 1].cpu().numpy().flatten(),
+ slots_closed[:, :, 0].cpu().numpy().flatten(),
+ association_table[0].cpu().numpy().flatten().astype(int),
+ extend = True)
+
+ # Compute MOTA
+ acc = update_mota_acc(acc, gt_positions_target, position_cur, slots_bounded, cfg_net.num_objects, gt_occluder_mask, slots_occluder, rawmask_next)
+
+ # 2. Remember output
+ mask_last = mask_next.clone()
+ rawmask_last = rawmask_next.clone()
+ position_last = position_next.clone()
+ gestalt_last = gestalt_next.clone()
+ priority_last = priority_next.clone()
+
+ # 3. Error for next frame
+ bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+ error_next = th.sqrt(reduce((target - output_next)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+ error_next = th.sqrt(error_next) * bg_error_next
+ error_last = error_next.clone()
+
+ # 4. Plot
+ if (t % plot_frequency == 0) and (i < plot_first_samples) and (t >= 0):
+ plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i)
+
+ # fill jumping statistics
+ statistics_complete_slots['vanishing'].extend(np.tile(slots_vanishing_memory.astype(int), t+1))
+
+ # store batch statistics in complete statistics
+ acc_memory_eval.append(acc)
+ statistics_complete = append_statistics(statistics_complete, statistics_batch, extend = True)
+
+ summary = mm.metrics.create().compute_many(acc_memory_eval, metrics=mm.metrics.motchallenge_metrics, generate_overall=True)
+ summary['set'] = set_test['type']
+ summary['evalmode'] = evaluation_mode
+ acc_memory_complete = summary.copy() if acc_memory_complete is None else pd.concat([acc_memory_complete, summary])
+
+
+ print('-- Evaluation Done --')
+ pd.DataFrame(statistics_complete).to_csv(f'{root_path}/statistics/trialframe.csv')
+ pd.DataFrame(statistics_complete_slots).to_csv(f'{root_path}/statistics/slotframe.csv')
+ pd.DataFrame(acc_memory_complete).to_csv(f'{root_path}/statistics/accframe.csv')
+ if object_view and os.path.exists(f'{root_path}/tmp.jpg'):
+ os.remove(f'{root_path}/tmp.jpg')
+ pass
+
+
+def compute_vanishing_slots(gt_positions_target, association_table, gt_positions_target_next):
+ objects_vanishing = th.abs(gt_positions_target[:,:,2] - gt_positions_target_next[:,:,2]) > 0.2
+ objects_vanishing = th.where(objects_vanishing.flatten())[0]
+ slots_vanishing = [(obj.item() in objects_vanishing) for obj in association_table[0]]
+ return slots_vanishing
+
+def compute_slot_error(cfg_net, target, output_next, mask_next, mask_size):
+ output_slot = repeat(mask_next[:,:-1], 'b o h w -> b o 3 h w') * repeat(output_next, 'b c h w -> b o c h w', o=cfg_net.num_objects)
+ target_slot = repeat(mask_next[:,:-1], 'b o h w -> b o 3 h w') * repeat(target, 'b c h w -> b o c h w', o=cfg_net.num_objects)
+ slot_error = reduce((output_slot - target_slot)**2, 'b o c h w -> b o', 'mean')/mask_size
+ return slot_error
+
+def compute_mask_sizes(mask_next, rawmask_next, rawmask_hidden):
+ rawmask_size = reduce(rawmask_next[:, :-1], 'b o h w-> b o', 'sum')
+ rawmask_size_hidden = reduce(rawmask_hidden[:, :-1], 'b o h w-> b o', 'sum')
+ mask_size = reduce(mask_next[:, :-1], 'b o h w-> b o', 'sum')
+ return rawmask_size,rawmask_size_hidden,mask_size
+
+def update_mota_acc(acc, gt_positions, estimated_positions, slots_bounded, cfg_num_objects, gt_occluder_mask, slots_occluder, rawmask, ignore_occluder = False):
+
+ # num objects
+ num_objects = len(gt_positions[0])
+
+ # get rid of batch dimension and priority dimension
+ pos = rearrange(estimated_positions.detach()[0], '(o c) -> o c', o=cfg_num_objects)[:, :2]
+ targets = gt_positions[0, :, :2]
+
+ # stretch positions to account for frame ratio, Specific for ADEPT!
+ pos = th.clip(pos, -1, 1)
+ pos[:, 0] = pos[:, 0] * 1.5
+ targets[:, 0] = targets[:, 0] * 1.5
+
+ # remove objects that are not in the image
+ edge = 1
+ in_image = th.cat([targets[:, 0] < (1.5 * edge), targets[:, 0] > (-1.5 * edge), targets[:, 1] < (1 * edge), targets[:, 1] > (-1 * edge)])
+ in_image = th.all(rearrange(in_image, '(c o) -> o c', o=num_objects), dim=1)
+
+ if ignore_occluder:
+ in_image = (gt_occluder_mask[0] == 0) * in_image
+ targets = targets[in_image]
+
+ # test if position estimates in image
+ in_image_pos = th.cat([pos[:, 0] < (1.5 * edge), pos[:, 0] > (-1.5 * edge), pos[:, 1] < (1 * edge), pos[:, 1] > (-1 * edge)])
+ in_image_pos = th.all(rearrange(in_image_pos, '(c o) -> c o', o=cfg_num_objects), dim=0, keepdim=True)
+
+ # only position estimates that are in image and bound
+ if rawmask is not None:
+ rawmask_size = reduce(rawmask[:, :-1], 'b o h w-> b o', 'sum')
+ m = (slots_bounded * in_image_pos * (rawmask_size > 100)).bool()
+ else:
+ m = (slots_bounded * in_image_pos).bool()
+ if ignore_occluder:
+ m = (m * (1 - slots_occluder)).bool()
+
+ pos = pos[repeat(m, '1 o -> o 2')]
+ pos = rearrange(pos, '(o c) -> o c', c = 2)
+
+ # compute pairwise distances
+ diagonal_length = th.sqrt(th.sum(th.tensor([2,3])**2)).item()
+ C = mm.distances.norm2squared_matrix(targets.cpu().numpy(), pos.cpu().numpy(), max_d2=diagonal_length*0.1)
+
+ # upadate accumulator
+ acc.update( (th.where(in_image)[0]).cpu(), (th.where(m)[1]).cpu(), C)
+
+ return acc
+
+def calculate_tracking_error(gt_positions_target, gt_visibility_target, position_cur, cfg_num_slots, slots_bounded, slots_occluded_cur, association_table, gt_occluder_mask):
+
+ # tracking utils
+ pdist = nn.PairwiseDistance(p=2).to(position_cur.device)
+
+ # 1. association of newly bounded slots to ground truth objects
+ # num objects
+ num_objects = len(gt_positions_target[0])
+
+ # get rid of batch dimension and priority dimension
+ pos = rearrange(position_cur.clone()[0], '(o c) -> o c', o=cfg_num_slots)[:, :2]
+ targets = gt_positions_target[0, :, :2]
+
+ # stretch positions to account for frame ratio, Specific for ADEPT!
+ pos = th.clip(pos, -1, 1)
+ pos[:, 0] = pos[:, 0] * 1.5
+ targets[:, 0] = targets[:, 0] * 1.5
+ diagonal_length = th.sqrt(th.sum(th.tensor([2,3])**2))
+
+ # reshape and repeat for comparison
+ pos = repeat(pos, 'o c -> (o r) c', r=num_objects)
+ targets = repeat(targets, 'o c -> (r o) c', r=cfg_num_slots)
+
+ # comparison
+ distance = pdist(pos, targets)
+ distance = rearrange(distance, '(o r) -> o r', r=num_objects)
+
+ # find closest target for each slot
+ distance = th.min(distance, dim=1, keepdim=True)
+
+ # update association table
+ slots_newly_bounded = slots_bounded * (association_table == -1)
+ if slots_occluded_cur is not None:
+ slots_newly_bounded = slots_newly_bounded * (1-slots_occluded_cur)
+ association_table = association_table * (1-slots_newly_bounded) + slots_newly_bounded * distance[1].T
+
+ # 2. position error
+ # get rid of batch dimension and priority dimension
+ pos = rearrange(position_cur.clone()[0], '(o c) -> o c', o=cfg_num_slots)[:, :2]
+ targets = gt_positions_target[0, :, :3]
+
+ # stretch positions to account for frame ratio, Specific for ADEPT!
+ pos[:, 0] = pos[:, 0] * 1.5
+ targets[:, 0] = targets[:, 0] * 1.5
+
+ # gather targets according to association table
+ targets = targets[association_table.long()][0]
+
+ # determine which slosts are within the image
+ slots_in_image = th.cat([targets[:, 0] < 1.5, targets[:, 0] > -1.5, targets[:, 1] < 1, targets[:, 1] > -1, targets[:, 2] > 0])
+ slots_in_image = rearrange(slots_in_image, '(c o) -> o c', o=cfg_num_slots)
+ slots_in_image = th.all(slots_in_image, dim=1)
+
+ # define which slots to consider for tracking error
+ slots_to_track = slots_bounded * slots_in_image
+
+ # compute position error
+ targets = targets[:, :2]
+ tracking_error_perslot = th.sqrt(th.sum((pos - targets)**2, dim=1))/diagonal_length
+ tracking_error_perslot = tracking_error_perslot[None, :] * slots_to_track
+ tracking_error = th.sum(tracking_error_perslot).item()/max(th.sum(slots_to_track).item(), 1)
+
+ # compute which slots are visible
+ visible_objects = th.where(gt_visibility_target[0] == 1)[0]
+ slots_visible = th.tensor([[int(obj.item()) in visible_objects for obj in association_table[0]]]).float().to(slots_to_track.device)
+ slots_visible = slots_visible * slots_to_track
+
+ # determine which objects are bound to the occluder
+ occluder_objects = th.where(gt_occluder_mask[0] == 1)[0]
+ slots_occluder = th.tensor([[int(obj.item()) in occluder_objects for obj in association_table[0]]]).float().to(slots_to_track.device)
+ slots_occluder = slots_occluder * slots_to_track
+
+ return tracking_error, tracking_error_perslot, association_table, slots_visible, slots_in_image, slots_occluder
+
+def get_evaluation_sets(dataset):
+
+ # Standad evaluation
+ evaluation_modes = ['open']
+
+ # Important!
+ # filter different scenarios: 1 as control and 0,3 as surprise (see Smith et al. 2020)
+ suprise_mask = [(sample.case in [1]) for i,sample in enumerate(dataset.samples)]
+ control_mask = [(sample.case in [0,3]) for i,sample in enumerate(dataset.samples)]
+
+ # Create test sets
+ set_surprise = {"samples": np.where(suprise_mask)[0].tolist(), "type": 'surprise'}
+ set_control = {"samples": np.where(control_mask)[0].tolist(), "type": 'control'}
+ set_test_array = [set_control, set_surprise]
+
+ return set_test_array, evaluation_modes \ No newline at end of file
diff --git a/scripts/evaluation_adept_savi.py b/scripts/evaluation_adept_savi.py
new file mode 100644
index 0000000..6a2d5c7
--- /dev/null
+++ b/scripts/evaluation_adept_savi.py
@@ -0,0 +1,233 @@
+from einops import rearrange, reduce, repeat
+import torch as th
+from torch.utils.data import Dataset, DataLoader, Subset
+import cv2
+import numpy as np
+import pandas as pd
+import os
+from data.datasets.ADEPT.dataset import AdeptDataset
+import motmetrics as mm
+from scripts.evaluation_adept import calculate_tracking_error, get_evaluation_sets, update_mota_acc
+from scripts.utils.eval_utils import setup_result_folders, store_statistics
+from scripts.utils.plot_utils import write_image
+
+FG_THRE = 0.95
+
+def evaluate(dataset: Dataset, file, n, plot_frequency= 1, plot_first_samples = 2):
+
+ # read pkl file
+ masks_complete = pd.read_pickle(file)
+
+ # plot config
+ color_list = [[255,0,0], [0,255,0], [0,0,255], [255,255,0], [0,255,255], [255,0,255], [255,255,255]]
+ dot_size = 2
+ skip_frames = 2
+ offset = 15
+
+ # memory
+ statistics_complete_slots = {'set': [], 'evalmode': [], 'scene': [], 'frame': [], 'slot':[], 'TE': [], 'visible': [], 'bound': [], 'occluder': [], 'inimage': [], 'slot_error': [], 'mask_size': [], 'rawmask_size': [], 'rawmask_size_hidden': [], 'alpha_pos': [], 'alpha_ges': [], 'object_id': []}
+ acc_memory_eval = []
+
+ # load adept dataset
+ set_test_array, evaluation_modes = get_evaluation_sets(dataset)
+ control_samples = set_test_array[0]['samples'] # only consider control set
+ evalset = Subset(dataset, control_samples)
+ root_path, plot_path = setup_result_folders(file, n, set_test_array[0], evaluation_modes[0], True, False)
+
+ for i in range(len(evalset)):
+ print(f'Processing sample {i+1}/{len(evalset)}', flush=True)
+ input = evalset[i]
+ acc = mm.MOTAccumulator(auto_id=True)
+
+ # get input frame and target frame
+ tensor = th.tensor(input[0]).float().unsqueeze(0)
+ background_fix = th.tensor(input[1]).unsqueeze(0)
+ gt_object_positions = th.tensor(input[3]).unsqueeze(0)
+ gt_object_visibility = th.tensor(input[4]).unsqueeze(0)
+ gt_occluder_mask = th.tensor(input[5]).unsqueeze(0)
+
+ # apply skip frames
+ gt_object_positions = gt_object_positions[:,range(0, tensor.shape[1], skip_frames)]
+ gt_object_visibility = gt_object_visibility[:,range(0, tensor.shape[1], skip_frames)]
+ tensor = tensor[:,range(0, tensor.shape[1], skip_frames)]
+ sequence_len = tensor.shape[1]
+
+ # load data
+ masks = th.tensor(masks_complete['test'][f'control_{i}.mp4'])
+ masks_before_softmax = th.tensor(masks_complete['test_raw'][f'control_{i}.mp4'])
+
+ # calculate rawmasks
+ bg_mask = masks_before_softmax.mean(dim=1)
+ masks_raw = compute_maskraw(masks_before_softmax, bg_mask)
+ slots_bound = compute_slots_bound(masks_raw)
+
+ # threshold masks and calculate centroids
+ masks_binary = (masks_raw > FG_THRE).float()
+ masks2 = rearrange(masks_binary, 't o 1 h w -> (t o) h w')
+ boxes = masks_to_boxes(masks2.long())
+ boxes = boxes.reshape(1, masks.shape[0], 7, 4)
+ centroids = boxes_to_centroids(boxes)
+
+ # get rid of batch dimension
+ association_table = th.ones(7) * -1
+
+ # iterate over frames
+ for t_index in range(offset,min(sequence_len,masks.shape[0])):
+
+ # move to next frame
+ input = tensor[:,t_index]
+ target = th.clip(tensor[:,t_index+1], 0, 1)
+ gt_positions_target = gt_object_positions[:,t_index]
+ gt_positions_target_next = gt_object_positions[:,t_index+1]
+ gt_visibility_target = gt_object_visibility[:,t_index]
+
+ position_cur = centroids[t_index]
+ position_cur = rearrange(position_cur, 'o c -> 1 (o c)')
+ slots_bound_cur = slots_bound[t_index]
+ slots_bound_cur = rearrange(slots_bound_cur, 'o c -> 1 (o c)')
+
+ # calculate tracking error
+ tracking_error, tracking_error_perslot, association_table, slots_visible, slots_in_image, slots_occluder = calculate_tracking_error(gt_positions_target, gt_visibility_target, position_cur, 7, slots_bound_cur, None, association_table, gt_occluder_mask)
+
+ rawmask_size = reduce(masks_raw[t_index], 'o 1 h w-> 1 o', 'sum')
+ mask_size = reduce(masks[t_index], 'o 1 h w-> 1 o', 'sum')
+
+ statistics_complete_slots = store_statistics(statistics_complete_slots,
+ ['control'] * 7,
+ ['control'] * 7,
+ [control_samples[i]] * 7,
+ [t_index] * 7,
+ range(7),
+ tracking_error_perslot.cpu().numpy().flatten(),
+ slots_visible.cpu().numpy().flatten().astype(int),
+ slots_bound_cur.cpu().numpy().flatten().astype(int),
+ slots_occluder.cpu().numpy().flatten().astype(int),
+ slots_in_image.cpu().numpy().flatten().astype(int),
+ [0] * 7,
+ mask_size.cpu().numpy().flatten(),
+ rawmask_size.cpu().numpy().flatten(),
+ [0] * 7,
+ [0] * 7,
+ [0] * 7,
+ association_table[0].cpu().numpy().flatten().astype(int),
+ extend = True)
+
+ acc = update_mota_acc(acc, gt_positions_target, position_cur, slots_bound_cur, 7, gt_occluder_mask, slots_occluder, None)
+
+ # plot_option
+ if (t_index % plot_frequency == 0) and (i < plot_first_samples) and (t_index >= 0):
+ masks_to_display = masks_binary.numpy() # masks_binary.numpy()
+
+ frame = tensor[0, t_index]
+ frame = frame.numpy().transpose(1,2,0)
+ frame = cv2.resize(frame, (64,64))
+
+ centroids_frame = centroids[t_index]
+ centroids_frame[:,0] = (centroids_frame[:,0] + 1) * 64 / 2
+ centroids_frame[:,1] = (centroids_frame[:,1] + 1) * 64 / 2
+
+ bound_frame = slots_bound[t_index]
+ for c_index,centroid_slot in enumerate(centroids_frame):
+ if bound_frame[c_index] == 1:
+ frame[int(centroid_slot[1]-dot_size):int(centroid_slot[1]+dot_size), int(centroid_slot[0]-dot_size):int(centroid_slot[0]+dot_size)] = color_list[c_index]
+
+ # slot images
+ slot_frame = masks_to_display[t_index].max(axis=0)
+ slot_frame = slot_frame.reshape((64,64,1)).repeat(3, axis=2)
+
+ if True:
+ for mask in masks_to_display[t_index]:
+ #slot_frame_single = mask.reshape((64,64,1)).repeat(3, axis=2)
+ slot_frame_single = mask.transpose((1,2,0)).repeat(3, axis=2)
+ slot_frame = np.concatenate((slot_frame, slot_frame_single), axis=1)
+
+ frame = np.concatenate((frame, slot_frame), axis=1)
+ cv2.imwrite(f'{plot_path}object/objects-{i:04d}-{t_index:03d}.jpg', frame*255)
+
+ acc_memory_eval.append(acc)
+
+ mh = mm.metrics.create()
+ summary = mh.compute_many(acc_memory_eval, metrics=mm.metrics.motchallenge_metrics, generate_overall=True)
+ summary['set'] = 'control'
+ summary['evalmode'] = 'control'
+ pd.DataFrame(summary).to_csv(os.path.join(root_path, 'statistics' , 'accframe.csv'))
+ pd.DataFrame(statistics_complete_slots).to_csv(os.path.join(root_path, 'statistics' , 'slotframe.csv'))
+
+def masks_to_boxes(masks: th.Tensor) -> th.Tensor:
+ """
+ Compute the bounding boxes around the provided masks.
+
+ Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with
+ ``0 <= x1 < x2`` and ``0 <= y1 < y2``.
+
+ Args:
+ masks (Tensor[N, H, W]): masks to transform where N is the number of masks
+ and (H, W) are the spatial dimensions.
+
+ Returns:
+ Tensor[N, 4]: bounding boxes
+ """
+ if masks.numel() == 0:
+ return th.zeros((0, 4), device=masks.device, dtype=th.float)
+
+ n = masks.shape[0]
+
+ bounding_boxes = th.zeros((n, 4), device=masks.device, dtype=th.float)
+
+ for index, mask in enumerate(masks):
+ if mask.sum() > 0:
+ y, x = th.where(mask != 0)
+
+ bounding_boxes[index, 0] = th.min(x)
+ bounding_boxes[index, 1] = th.min(y)
+ bounding_boxes[index, 2] = th.max(x)
+ bounding_boxes[index, 3] = th.max(y)
+
+ return bounding_boxes
+
+def boxes_to_centroids(boxes):
+ """Post-process masks instead of directly taking argmax.
+
+ Args:
+ bboxes: [B, T, N, 4], 4: [x1, y1, x2, y2]
+
+ Returns:
+ centroids: [B, T, N, 2], 2: [x, y]
+ """
+
+ centroids = (boxes[:, :, :, :2] + boxes[:, :, :, 2:]) / 2
+ centroids = centroids.squeeze(0)
+
+ # scale to [-1, 1]
+ centroids[:, :, 0] = centroids[:, :, 0] / 64 * 2 - 1
+ centroids[:, :, 1] = centroids[:, :, 1] / 64 * 2 - 1
+
+ return centroids
+
+def compute_slots_bound(masks):
+
+ # take sum over axis 3,4 with th
+ masks_sum = masks.amax(dim=(3,4))
+ slots_bound = (masks_sum > FG_THRE).float()
+ return slots_bound
+
+def compute_maskraw(mask, bg_mask):
+
+ # d is a diagonal matrix which defines what to take the softmax over
+ d_mask = th.diag(th.ones(8))
+ d_mask[:,-1] = 1
+ d_mask[-1,-1] = 0
+
+ mask = mask.squeeze(2)
+
+ # take subset of maskraw with the diagonal matrix
+ maskraw = th.cat((mask, bg_mask), dim=1)
+ maskraw = repeat(maskraw, 'b o h w -> b r o h w', r = 8)
+ maskraw = maskraw[:,d_mask.bool()]
+ maskraw = rearrange(maskraw, 'b (o r) h w -> b o r h w', o = 7)
+
+ # take softmax between each object mask and the background mask
+ maskraw = th.squeeze(th.softmax(maskraw, dim=2)[:,:,0], dim=2)
+ maskraw = maskraw.unsqueeze(2)
+
+ return maskraw \ No newline at end of file
diff --git a/scripts/evaluation_clevrer.py b/scripts/evaluation_clevrer.py
new file mode 100644
index 0000000..a43d50c
--- /dev/null
+++ b/scripts/evaluation_clevrer.py
@@ -0,0 +1,311 @@
+import pickle
+import torch as th
+from torch.utils.data import Dataset, DataLoader, Subset
+from torch import nn
+import os
+from scripts.utils.plot_utils import plot_timestep
+from scripts.utils.eval_metrics import masks_to_boxes, pred_eval_step, postproc_mask
+from scripts.utils.eval_utils import append_statistics, load_model, setup_result_folders, store_statistics
+from scripts.utils.configuration import Configuration
+from scripts.utils.io import init_device
+import numpy as np
+from einops import rearrange, repeat, reduce
+from copy import deepcopy
+import lpips
+import torchvision.transforms as transforms
+
+def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_frequency= 1, plot_first_samples = 0):
+
+ # Set up cpu or gpu training
+ device, verbose = init_device(cfg)
+
+ # Config
+ cfg_net = cfg.model
+ cfg_net.batch_size = 2 if verbose else 32
+ cfg_net.batch_size = 1 if plot_first_samples > 0 else cfg_net.batch_size # only plotting first samples
+
+ # Load model
+ net = load_model(cfg, cfg_net, file, device)
+ net.eval()
+
+ # config
+ object_view = True
+ individual_views = False
+ root_path = None
+ plotting_mode = (cfg_net.batch_size == 1) and (plot_first_samples > 0)
+
+ # get evaluation sets
+ set_test_array, evaluation_modes = get_evaluation_sets(dataset)
+
+ # memory
+ statistics_template = {'set': [], 'evalmode': [], 'scene': [], 'frame': [], 'image_error_mse': []}
+ statistics_complete_slots = {'set': [], 'evalmode': [], 'scene': [], 'frame': [], 'slot':[], 'bound': [], 'slot_error': [], 'rawmask_size': [], 'alpha_pos': [], 'alpha_ges': []}
+ metric_complete = None
+
+ # Evaluation Specifics
+ burn_in_length = 6
+ rollout_length = 42
+ cfg.defaults.skip_frames = 2
+ blackout_p = 0.2
+ target_size = (64,64)
+ dataset.burn_in_length = burn_in_length
+ dataset.rollout_length = rollout_length
+ dataset.skip_length = cfg.defaults.skip_frames
+
+ # Transformation utils
+ to_small = transforms.Resize(target_size)
+ to_normalize = transforms.Normalize((0.5, ), (0.5, ))
+ to_smallnorm = transforms.Compose([to_small, to_normalize])
+
+ # Losses
+ lpipsloss = lpips.LPIPS(net='vgg').to(device)
+ mseloss = nn.MSELoss()
+
+ for set_test in set_test_array:
+
+ for evaluation_mode in evaluation_modes:
+ print(f'Start evaluation loop: {evaluation_mode}')
+
+ # load data
+ dataloader = DataLoader(
+ Subset(dataset, range(plot_first_samples)) if plotting_mode else dataset,
+ num_workers = 1,
+ pin_memory = False,
+ batch_size = cfg_net.batch_size,
+ shuffle = False,
+ drop_last = True,
+ )
+
+ # memory
+ root_path, plot_path = setup_result_folders(file, n, set_test, evaluation_mode, object_view, individual_views)
+ metric_complete = {'mse': [], 'ssim': [], 'psnr': [], 'percept_dist': [], 'ari': [], 'fari': [], 'miou': [], 'ap': [], 'ar': [], 'blackout': []}
+
+ with th.no_grad():
+ for i, input in enumerate(dataloader):
+ print(f'Processing sample {i+1}/{len(dataloader)}', flush=True)
+
+ # Load data
+ tensor = input[0].float().to(device)
+ background_fix = input[1].to(device)
+ gt_mask = input[2].to(device)
+ gt_bbox = input[3].to(device)
+ gt_pres_mask = input[4].to(device)
+ sequence_len = tensor.shape[1]
+
+ # Placehodlers
+ mask_cur = None
+ mask_last = None
+ rawmask_last = None
+ position_last = None
+ gestalt_last = None
+ priority_last = None
+ slots_occlusionfactor = None
+ error_last = None
+
+ # Memory
+ pred = th.zeros((cfg_net.batch_size, rollout_length, 3, target_size[0], target_size[1])).to(device)
+ gt = th.zeros((cfg_net.batch_size, rollout_length, 3, target_size[0], target_size[1])).to(device)
+ pred_mask = th.zeros((cfg_net.batch_size, rollout_length, target_size[0], target_size[1])).to(device)
+ statistics_batch = deepcopy(statistics_template)
+
+ # Counters
+ num_rollout = 0
+ num_burnin = 0
+ blackout_mem = [0]
+
+ # Loop through frames
+ for t_index,t in enumerate(range(-cfg.defaults.teacher_forcing, sequence_len-1)):
+
+ # Move to next frame
+ t_run = max(t, 0)
+ input = tensor[:,t_run]
+ target = th.clip(tensor[:,t_run+1], 0, 1)
+
+ rollout_index = t_run - burn_in_length
+ if (rollout_index >= 0) and (evaluation_mode == 'blackout'):
+ blackout = (th.rand(1) < blackout_p).float().to(device)
+ input = blackout * (input * 0) + (1-blackout) * input
+ error_last = blackout * (error_last * 0) + (1-blackout) * error_last
+ blackout_mem.append(blackout.int().cpu().item())
+
+ elif t>=0:
+ num_burnin += 1
+
+ if (rollout_index >= 0) and (evaluation_mode != 'blackout'):
+ blackout_mem.append(0)
+
+ # obtain prediction
+ (
+ output_next,
+ position_next,
+ gestalt_next,
+ priority_next,
+ mask_next,
+ rawmask_next,
+ object_next,
+ background,
+ slots_occlusionfactor,
+ output_cur,
+ position_cur,
+ gestalt_cur,
+ priority_cur,
+ mask_cur,
+ rawmask_cur,
+ object_cur,
+ position_encoder_cur,
+ slots_bounded,
+ slots_partially_occluded_cur,
+ slots_occluded_cur,
+ slots_partially_occluded_next,
+ slots_occluded_next,
+ slots_closed,
+ output_hidden,
+ largest_object,
+ rawmask_hidden,
+ object_hidden
+ ) = net(
+ input,
+ error_last,
+ mask_last,
+ rawmask_last,
+ position_last,
+ gestalt_last,
+ priority_last,
+ background_fix,
+ slots_occlusionfactor,
+ reset = (t == -cfg.defaults.teacher_forcing),
+ evaluate=True,
+ warmup = (t < 0),
+ shuffleslots = False,
+ reset_mask = (t <= 0),
+ allow_spawn = True,
+ show_hidden = plotting_mode,
+ clean_slots = (t <= 0),
+ )
+
+ # 1. Track error for plots
+ if t >= 0:
+
+ # save prediction
+ if rollout_index >= -1:
+ pred[:,rollout_index+1] = to_smallnorm(output_next)
+ gt[:,rollout_index+1] = to_smallnorm(target)
+ pred_mask[:,rollout_index+1] = postproc_mask(to_small(mask_next)[:,None,:,None])[:, 0]
+ num_rollout += 1
+
+ if plotting_mode and object_view:
+ statistics_complete_slots, statistics_batch = compute_plot_statistics(cfg_net, statistics_complete_slots, mseloss, set_test, evaluation_mode, i, statistics_batch, t, target, output_next, mask_next, slots_bounded, slots_closed, rawmask_hidden)
+
+ # 2. Remember output
+ mask_last = mask_next.clone()
+ rawmask_last = rawmask_next.clone()
+ position_last = position_next.clone()
+ gestalt_last = gestalt_next.clone()
+ priority_last = priority_next.clone()
+
+ # 3. Error for next frame
+ bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+
+ # prediction error
+ error_next = th.sqrt(reduce((target - output_next)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+ error_next = th.sqrt(error_next) * bg_error_next
+ error_last = error_next.clone()
+
+ # 4. plot preparation
+ if plotting_mode and (t % plot_frequency == 0):
+ plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, None, None, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i)
+
+ # Compute prediction accuracy based on Slotformer metrics (ARI, FARI, mIoU, AP, AR)
+ if not plotting_mode:
+ for b in range(cfg_net.batch_size):
+ metric_dict = pred_eval_step(
+ gt=gt[b:b+1],
+ pred=pred[b:b+1],
+ pred_mask=pred_mask.long()[b:b+1],
+ pred_bbox=masks_to_boxes(pred_mask.long()[b:b+1], cfg_net.num_objects+1),
+ gt_mask=gt_mask.long()[b:b+1],
+ gt_pres_mask=gt_pres_mask[b:b+1],
+ gt_bbox=gt_bbox[b:b+1],
+ lpips_fn=lpipsloss,
+ eval_traj=True,
+ )
+ metric_dict['blackout'] = blackout_mem
+ metric_complete = append_statistics(metric_dict, metric_complete)
+
+ # sanity check
+ if (num_rollout != rollout_length) and (num_burnin != burn_in_length) and (evaluation_mode == 'rollout'):
+ raise ValueError('Number of rollout steps and burnin steps must be equal to the sequence length.')
+
+ if not plotting_mode:
+ average_dic = compute_statistics_summary(metric_complete, evaluation_mode)
+
+ # Store statistics
+ with open(os.path.join(f'{root_path}/statistics', f'{evaluation_mode}_metric_complete.pkl'), 'wb') as f:
+ pickle.dump(metric_complete, f)
+ with open(os.path.join(f'{root_path}/statistics', f'{evaluation_mode}_metric_average.pkl'), 'wb') as f:
+ pickle.dump(average_dic, f)
+
+ print('-- Evaluation Done --')
+ if object_view and os.path.exists(f'{root_path}/tmp.jpg'):
+ os.remove(f'{root_path}/tmp.jpg')
+ pass
+
+def compute_statistics_summary(metric_complete, evaluation_mode):
+ average_dic = {}
+ for key in metric_complete:
+ # take average over all frames
+ average_dic[key + 'complete_average'] = np.mean(metric_complete[key])
+ average_dic[key + 'complete_std'] = np.std(metric_complete[key])
+ print(f'{key} complete average: {average_dic[key + "complete_average"]:.4f} +/- {average_dic[key + "complete_std"]:.4f}')
+
+ if evaluation_mode == 'blackout':
+ # take average only for frames where blackout occurs
+ blackout_mask = np.array(metric_complete['blackout']) > 0
+ average_dic[key + 'blackout_average'] = np.mean(np.array(metric_complete[key])[blackout_mask])
+ average_dic[key + 'blackout_std'] = np.std(np.array(metric_complete[key])[blackout_mask])
+ average_dic[key + 'visible_average'] = np.mean(np.array(metric_complete[key])[blackout_mask == False])
+ average_dic[key + 'visible_std'] = np.std(np.array(metric_complete[key])[blackout_mask == False])
+
+ print(f'{key} blackout average: {average_dic[key + "blackout_average"]:.4f} +/- {average_dic[key + "blackout_std"]:.4f}')
+ print(f'{key} visible average: {average_dic[key + "visible_average"]:.4f} +/- {average_dic[key + "visible_std"]:.4f}')
+ return average_dic
+
+def compute_plot_statistics(cfg_net, statistics_complete_slots, mseloss, set_test, evaluation_mode, i, statistics_batch, t, target, output_next, mask_next, slots_bounded, slots_closed, rawmask_hidden):
+ statistics_batch = store_statistics(statistics_batch,
+ set_test['type'],
+ evaluation_mode,
+ set_test['samples'][i],
+ t,
+ mseloss(output_next, target).item()
+ )
+
+ # compute slot-wise prediction error
+ output_slot = repeat(mask_next[:,:-1], 'b o h w -> b o 3 h w') * repeat(output_next, 'b c h w -> b o c h w', o=cfg_net.num_objects)
+ target_slot = repeat(mask_next[:,:-1], 'b o h w -> b o 3 h w') * repeat(target, 'b c h w -> b o c h w', o=cfg_net.num_objects)
+ slot_error = reduce((output_slot - target_slot)**2, 'b o c h w -> b o', 'mean')
+
+ # compute rawmask_size
+ rawmask_size = reduce(rawmask_hidden[:, :-1], 'b o h w-> b o', 'sum')
+
+ statistics_complete_slots = store_statistics(statistics_complete_slots,
+ [set_test['type']] * cfg_net.num_objects,
+ [evaluation_mode] * cfg_net.num_objects,
+ [set_test['samples'][i]] * cfg_net.num_objects,
+ [t] * cfg_net.num_objects,
+ range(cfg_net.num_objects),
+ slots_bounded.cpu().numpy().flatten().astype(int),
+ slot_error.cpu().numpy().flatten(),
+ rawmask_size.cpu().numpy().flatten(),
+ slots_closed[:, :, 1].cpu().numpy().flatten(),
+ slots_closed[:, :, 0].cpu().numpy().flatten(),
+ extend = True)
+
+ return statistics_complete_slots,statistics_batch
+
+def get_evaluation_sets(dataset):
+
+ set = {"samples": np.arange(len(dataset), dtype=int), "type": "test"}
+ evaluation_modes = ['blackout', 'open'] # use 'open' for no blackouts
+ set_test_array = [set]
+
+ return set_test_array, evaluation_modes
diff --git a/scripts/exec/eval.py b/scripts/exec/eval.py
new file mode 100644
index 0000000..96ed32d
--- /dev/null
+++ b/scripts/exec/eval.py
@@ -0,0 +1,28 @@
+import argparse
+import sys
+from data.datasets.ADEPT.dataset import AdeptDataset
+from data.datasets.CLEVRER.dataset import ClevrerDataset, ClevrerSample, RamImage
+from scripts.utils.configuration import Configuration
+from scripts import evaluation_adept, evaluation_clevrer
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-cfg", default="", help='path to the configuration file')
+ parser.add_argument("-load", default="", type=str, help='path to model')
+ parser.add_argument("-n", default="", type=str, help='results name')
+
+ # Load configuration
+ args = parser.parse_args(sys.argv[1:])
+ cfg = Configuration(args.cfg)
+ print(f'Evaluating model {args.load}')
+
+ # Load dataset
+ if cfg.datatype == "clevrer":
+ testset = ClevrerDataset("./", cfg.dataset, 'val', (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=True, evaluation=True)
+ evaluation_clevrer.evaluate(cfg, testset, args.load, args.n, plot_frequency= 2, plot_first_samples = 3) # only plotting
+ evaluation_clevrer.evaluate(cfg, testset, args.load, args.n, plot_first_samples = 0) # evaluation
+ elif cfg.datatype == "adept":
+ testset = AdeptDataset("./", cfg.dataset, 'createdown', (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
+ evaluation_adept.evaluate(cfg, testset, args.load, args.n, plot_frequency= 1, plot_first_samples = 2)
+ else:
+ raise Exception("Dataset not supported") \ No newline at end of file
diff --git a/scripts/exec/eval_savi.py b/scripts/exec/eval_savi.py
new file mode 100644
index 0000000..87d4e59
--- /dev/null
+++ b/scripts/exec/eval_savi.py
@@ -0,0 +1,17 @@
+import argparse
+import sys
+from data.datasets.ADEPT.dataset import AdeptDataset
+from scripts.evaluation_adept_savi import evaluate
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-load", default="", type=str, help='path to savi slots')
+ parser.add_argument("-n", default="", type=str, help='results name')
+
+ # Load configuration
+ args = parser.parse_args(sys.argv[1:])
+ print(f'Evaluating savi slots {args.load}')
+
+ # Load dataset
+ testset = AdeptDataset("./", 'adept', 'createdown', (30 * 2**(2*2), 20 * 2**(2*2)))
+ evaluate(testset, args.load, args.n, plot_frequency= 1, plot_first_samples = 2) \ No newline at end of file
diff --git a/scripts/exec/train.py b/scripts/exec/train.py
new file mode 100644
index 0000000..f6e78fd
--- /dev/null
+++ b/scripts/exec/train.py
@@ -0,0 +1,33 @@
+import argparse
+import sys
+from scripts.utils.configuration import Configuration
+from scripts import training
+from data.datasets.CLEVRER.dataset import ClevrerDataset, ClevrerSample, RamImage
+from data.datasets.ADEPT.dataset import AdeptDataset
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-cfg", default="", help='path to the configuration file')
+ parser.add_argument("-n", default=-1, type=int, help='optional run number')
+ parser.add_argument("-load", default="", type=str, help='path to pretrained model or checkpoint')
+
+ # Load configuration
+ args = parser.parse_args(sys.argv[1:])
+ cfg = Configuration(args.cfg)
+ cfg.model_path = f"{cfg.model_path}"
+ if args.n >= 0:
+ cfg.model_path = f"{cfg.model_path}.run{args.n}"
+ print(f'Training model {cfg.model_path}')
+
+ # Load dataset
+ if cfg.datatype == "clevrer":
+ trainset = ClevrerDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=False)
+ valset = ClevrerDataset("./", cfg.dataset, "val", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=True)
+ elif cfg.datatype == "adept":
+ trainset = AdeptDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
+ valset = AdeptDataset("./", cfg.dataset, "test", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
+ else:
+ raise Exception("Dataset not supported")
+
+ # Final call
+ training.train_loci(cfg, trainset, valset, args.load) \ No newline at end of file
diff --git a/scripts/training.py b/scripts/training.py
new file mode 100644
index 0000000..b807e54
--- /dev/null
+++ b/scripts/training.py
@@ -0,0 +1,523 @@
+import torch as th
+from torch import nn
+from torch.utils.data import Dataset, DataLoader, Subset
+import cv2
+import numpy as np
+import os
+from einops import rearrange, repeat, reduce
+from scripts.utils.configuration import Configuration
+from scripts.utils.io import init_device, model_path, LossLogger
+from scripts.utils.optimizers import RAdam
+from model.loci import Loci
+import random
+from scripts.utils.io import Timer
+from scripts.utils.plot_utils import color_mask
+from scripts.validation import validation_clevrer, validation_adept
+
+
+def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
+
+ # Set up cpu or gpu training
+ device, verbose = init_device(cfg)
+ if verbose:
+ valset = Subset(valset, range(0, 8))
+
+ # Define model path
+ path = model_path(cfg, overwrite=False)
+ cfg.save(path)
+ os.makedirs(os.path.join(path, 'nets'), exist_ok=True)
+
+ # Configure model
+ cfg_net = cfg.model
+ net = Loci(
+ cfg = cfg_net,
+ teacher_forcing = cfg.defaults.teacher_forcing
+ )
+ net = net.to(device=device)
+
+ # Log model size
+ log_modelsize(net)
+
+ # Init Optimizers
+ optimizer_init, optimizer_encoder, optimizer_decoder, optimizer_predictor, optimizer_background, optimizer_update = init_optimizer(cfg, net)
+
+ # Option to load model
+ if file != "":
+ load_model(
+ file,
+ cfg,
+ net,
+ optimizer_init,
+ optimizer_encoder,
+ optimizer_decoder,
+ optimizer_predictor,
+ optimizer_background,
+ cfg.defaults.load_optimizers,
+ only_encoder_decoder = (cfg.num_updates == 0) # only load encoder and decoder for initial training
+ )
+ print(f'loaded {file}', flush=True)
+
+ # Set up data loaders
+ trainloader = get_loader(cfg, trainset, cfg_net, shuffle=True)
+ valset.train = True #valset.dataset.train = True
+ valloader = get_loader(cfg, valset, cfg_net, shuffle=False)
+
+ # initial save
+ save_model(
+ os.path.join(path, 'nets', 'net0.pt'),
+ net,
+ optimizer_init,
+ optimizer_encoder,
+ optimizer_decoder,
+ optimizer_predictor,
+ optimizer_background
+ )
+
+ # Start training at num_updates
+ num_updates = cfg.num_updates
+ if num_updates > 0:
+ print('!!! Start training at num_updates: ', num_updates)
+ print('!!! Net init status: ', net.get_init_status())
+
+ # Set up statistics
+ loss_tracker = LossLogger()
+
+ # Set up training variables
+ num_time_steps = 0
+ bptt_steps = cfg.bptt.bptt_steps
+ increase_bptt_steps = False
+ background_blendin_factor = 0.0
+ th.backends.cudnn.benchmark = True
+ timer = Timer()
+ bceloss = nn.BCELoss()
+
+ # Init net to current num_updates
+ if num_updates >= cfg.phases.background_pretraining_end and net.get_init_status() < 1:
+ net.inc_init_level()
+
+ if num_updates >= cfg.phases.entity_pretraining_phase1_end and net.get_init_status() < 2:
+ net.inc_init_level()
+
+ if num_updates >= cfg.phases.entity_pretraining_phase2_end and net.get_init_status() < 3:
+ net.inc_init_level()
+ for param in optimizer_init.param_groups:
+ param['lr'] = cfg.learning_rate
+
+ if num_updates > cfg.phases.start_inner_loop:
+ net.cfg.inner_loop_enabled = True
+
+ if num_updates >= cfg.phases.entity_pretraining_phase1_end:
+ background_blendin_factor = max(min((num_updates - cfg.phases.entity_pretraining_phase1_end)/30000, 1.0), 0.0)
+
+
+ # --- Start Training
+ print('Start training')
+ for epoch in range(cfg.max_epochs):
+
+ # Validation every epoch
+ if epoch >= 0:
+ if cfg.datatype == 'adept':
+ validation_adept(valloader, net, cfg, device)
+ elif cfg.datatype == 'clevrer':
+ validation_clevrer(valloader, net, cfg, device)
+
+ # Start epoch training
+ print('Start epoch:', epoch)
+
+ # Backprop through time steps
+ if increase_bptt_steps:
+ bptt_steps = max(bptt_steps + 1, cfg.bptt.bptt_steps_max)
+ print('Increase closed loop steps to', bptt_steps)
+ increase_bptt_steps = False
+
+ for batch_index, input in enumerate(trainloader):
+
+ # Extract input and background
+ tensor = input[0]
+ background = input[1].to(device)
+ shuffleslots = (num_updates <= cfg.phases.shufleslots_end)
+
+ # Placeholders
+ position = None
+ gestalt = None
+ priority = None
+ mask = None
+ object = None
+ rawmask = None
+ loss = th.tensor(0)
+ summed_loss = None
+ slots_occlusionfactor = None
+
+ # Apply skip frames to sequence
+ selec = range(random.randrange(cfg.defaults.skip_frames), tensor.shape[1], cfg.defaults.skip_frames)
+ tensor = tensor[:,selec]
+ sequence_len = tensor.shape[1]
+
+ # Initial frame
+ input = tensor[:,0].to(device)
+ input_next = input
+ target = th.clip(input, 0, 1).detach()
+ error_last = None
+
+ # First apply teacher forcing for the first x frames
+ for t in range(-cfg.defaults.teacher_forcing, sequence_len-1):
+
+ # Set update scheme for backprop through time
+ if t >= cfg.bptt.bptt_start_timestep:
+ t_run = (t - cfg.bptt.bptt_start_timestep)
+ run_optimizers = t_run % bptt_steps == bptt_steps - 1
+ detach = (t_run % bptt_steps == 0) or t == -cfg.defaults.teacher_forcing
+ else:
+ run_optimizers = True
+ detach = True
+
+ if verbose:
+ print(f't: {t}, run_optimizers: {run_optimizers}, detach: {detach}')
+
+ if t >= 0:
+ # Skip to next frame
+ num_time_steps += 1
+ input = input_next
+ input_next = tensor[:,t+1].to(device)
+ target = th.clip(input_next, 0, 1)
+
+ # Apply error dropout
+ if net.get_init_status() > 2 and cfg.defaults.error_dropout > 0 and np.random.rand() < cfg.defaults.error_dropout:
+ error_last = th.zeros_like(error_last)
+
+ # Apply sensation blackout when training clevrere
+ if net.cfg.inner_loop_enabled and cfg.datatype == 'clevrer':
+ if t >= 10:
+ blackout = th.tensor((np.random.rand(cfg_net.batch_size) < 0.2)[:,None,None,None]).float().to(device)
+ input = blackout * (input * 0) + (1-blackout) * input
+ error_last = blackout * (error_last * 0) + (1-blackout) * error_last
+
+ # Forward Pass
+ (
+ output_next,
+ output_cur,
+ position,
+ gestalt,
+ priority,
+ mask,
+ rawmask,
+ object,
+ background,
+ slots_occlusionfactor,
+ position_loss,
+ time_loss,
+ slots_closed
+ ) = net(
+ input, # current frame
+ error_last, # error of last frame --> missing object
+ mask, # masks of current frame
+ rawmask, # raw masks of current frame
+ position, # positions of objects of next frame
+ gestalt, # gestalt of objects of next frame
+ priority, # priority of objects of next frame
+ background,
+ slots_occlusionfactor,
+ reset = (t == -cfg.defaults.teacher_forcing), # new sequence
+ warmup = (t < 0), # teacher forcing
+ detach = detach,
+ shuffleslots = shuffleslots or ((cfg.datatype == 'clevrer') and (t<=0)),
+ reset_mask = (t <= 0),
+ clean_slots = (t <= 0 and not shuffleslots),
+ )
+
+ # Loss weighting
+ position_loss = position_loss * cfg_net.position_regularizer
+ time_loss = time_loss * cfg_net.time_regularizer
+
+ # Compute background error
+ bg_error_cur = th.sqrt(reduce((input - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+ bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+
+ # Compute next-frame prediction error
+ error_next = th.sqrt(reduce((target - output_next)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+ error_next = th.sqrt(error_next) * bg_error_next
+ error_last = error_next
+
+ # Initially focus on foreground learning
+ if background_blendin_factor < 1:
+ fg_mask_next = th.gt(bg_error_next, 0.1).float().detach()
+ fg_mask_next[fg_mask_next == 0] = background_blendin_factor
+ target = th.clip(target * fg_mask_next, 0, 1)
+
+ fg_mask_cur = th.gt(bg_error_cur, 0.1).float().detach()
+ fg_mask_cur[fg_mask_cur == 0] = background_blendin_factor
+ input = th.clip(input * fg_mask_cur, 0, 1)
+
+ # Gradually blend in background for more stable training
+ if num_updates % 30 == 0 and num_updates >= cfg.phases.entity_pretraining_phase1_end:
+ background_blendin_factor = min(1, background_blendin_factor + 0.001)
+
+ # Final Loss computation
+ encoder_loss = th.mean((output_cur - input)**2) * cfg_net.encoder_regularizer
+ cliped_output_next = th.clip(output_next, 0, 1)
+ loss = bceloss(cliped_output_next, target) + encoder_loss + position_loss + time_loss
+
+ # Accumulate loss over BPP steps
+ summed_loss = loss if summed_loss is None else summed_loss + loss
+ mask = mask.detach()
+
+ if run_optimizers:
+
+ # detach gradients for next step
+ position = position.detach()
+ gestalt = gestalt.detach()
+ rawmask = rawmask.detach()
+ object = object.detach()
+ priority = priority.detach()
+
+ # zero grad
+ optimizer_init.zero_grad()
+ optimizer_encoder.zero_grad()
+ optimizer_decoder.zero_grad()
+ optimizer_predictor.zero_grad()
+ optimizer_background.zero_grad()
+ if net.cfg.inner_loop_enabled:
+ optimizer_update.zero_grad()
+
+ # optimize
+ summed_loss.backward()
+ optimizer_init.step()
+ optimizer_encoder.step()
+ optimizer_decoder.step()
+ optimizer_predictor.step()
+ optimizer_background.step()
+ if net.cfg.inner_loop_enabled:
+ optimizer_update.step()
+
+ # Reset loss
+ num_updates += 1
+ summed_loss = None
+
+ # Update net status
+ update_net_status(num_updates, net, cfg, optimizer_init)
+
+ if num_updates == cfg.phases.start_inner_loop:
+ print('Start inner loop')
+ net.cfg.inner_loop_enabled = True
+
+ if (cfg.bptt.increase_bptt_steps_every > 0) and ((num_updates-cfg.num_updates) % cfg.bptt.increase_bptt_steps_every == 0) and ((num_updates-cfg.num_updates) > 0):
+ increase_bptt_steps = True
+
+ # Plots for online evaluation
+ if num_updates % 20000 == 0:
+ plot_online(cfg, path, num_updates, input, background, mask, sequence_len, t, output_next, bg_error_next)
+
+
+ # Track statisitcs
+ if t >= cfg.defaults.statistics_offset:
+ track_statistics(cfg_net, net, loss_tracker, input, gestalt, mask, target, output_next, output_cur, position_loss, time_loss, slots_closed, bg_error_cur, bg_error_next)
+ loss_tracker.update_average_loss(loss.item())
+
+ # Logging
+ if num_updates % 100 == 0 and run_optimizers:
+ print(f'Epoch[{num_updates}/{num_time_steps}/{sequence_len}]: {str(timer)}, {epoch + 1}, Blendin:{float(background_blendin_factor)}, i: {net.get_init_status() + net.initial_states.init.get():.2f},' + loss_tracker.get_log(), flush=True)
+
+ # Training finished
+ if num_updates > cfg.max_updates:
+ save_model(
+ os.path.join(path, 'nets', 'net_final.pt'),
+ net,
+ optimizer_init,
+ optimizer_encoder,
+ optimizer_decoder,
+ optimizer_predictor,
+ optimizer_background
+ )
+ print("Training finished")
+ return
+
+ # Checkpointing
+ if num_updates % 50000 == 0 and run_optimizers:
+ save_model(
+ os.path.join(path, 'nets', f'net_{num_updates}.pt'),
+ net,
+ optimizer_init,
+ optimizer_encoder,
+ optimizer_decoder,
+ optimizer_predictor,
+ optimizer_background
+ )
+ pass
+
+def track_statistics(cfg_net, net, loss_tracker, input, gestalt, mask, target, output_next, output_cur, position_loss, time_loss, slots_closed, bg_error_cur, bg_error_next):
+
+ # Prediction Loss
+ mseloss = nn.MSELoss()
+ loss_next = mseloss(output_next * bg_error_next, target * bg_error_next)
+
+ # Encoder Loss (only for stats)
+ loss_cur = mseloss(output_cur * bg_error_cur, input * bg_error_cur)
+
+ # area of foreground mask
+ num_objects = th.mean(reduce((reduce(mask[:,:-1], 'b c h w -> b c', 'max') > 0.5).float(), 'b c -> b', 'sum')).item()
+
+ # difference in shape
+ _gestalt = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt)), 'b (o c) -> (b o)', 'mean', o = cfg_net.num_objects)
+ _gestalt2 = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt))**2, 'b (o c) -> (b o)', 'mean', o = cfg_net.num_objects)
+ max_mask = (reduce(mask[:,:-1], 'b c h w -> (b c)', 'max') > 0.5).float()
+ avg_gestalt = (th.sum(_gestalt * max_mask) / (1e-16 + th.sum(max_mask)))
+ avg_gestalt2 = (th.sum(_gestalt2 * max_mask) / (1e-16 + th.sum(max_mask)))
+ avg_gestalt_mean = th.mean(th.clip(gestalt, 0, 1))
+
+ # udpdate gates
+ avg_update_gestalt = slots_closed[:,:,0].mean()
+ avg_update_position = slots_closed[:,:,1].mean()
+
+ loss_tracker.update_complete(position_loss, time_loss, loss_cur, loss_next, loss_next, num_objects, net.get_openings(), avg_gestalt, avg_gestalt2, avg_gestalt_mean, avg_update_gestalt, avg_update_position)
+ pass
+
+def log_modelsize(net):
+ print(f'Loaded model with {sum([param.numel() for param in net.parameters()]):7d} parameters', flush=True)
+ print(f' States: {sum([param.numel() for param in net.initial_states.parameters()]):7d} parameters', flush=True)
+ print(f' Encoder: {sum([param.numel() for param in net.encoder.parameters()]):7d} parameters', flush=True)
+ print(f' Decoder: {sum([param.numel() for param in net.decoder.parameters()]):7d} parameters', flush=True)
+ print(f' predictor: {sum([param.numel() for param in net.predictor.parameters()]):7d} parameters', flush=True)
+ print(f' background: {sum([param.numel() for param in net.background.parameters()]):7d} parameters', flush=True)
+ print("\n")
+ pass
+
+def init_optimizer(cfg, net):
+ optimizer_init = RAdam(net.initial_states.parameters(), lr = cfg.learning_rate * 30)
+ optimizer_encoder = RAdam(net.encoder.parameters(), lr = cfg.learning_rate)
+ optimizer_decoder = RAdam(net.decoder.parameters(), lr = cfg.learning_rate)
+ optimizer_predictor = RAdam(net.predictor.parameters(), lr = cfg.learning_rate)
+ optimizer_background = RAdam([net.background.mask], lr = cfg.learning_rate)
+ optimizer_update = RAdam(net.percept_gate_controller.parameters(), lr = cfg.learning_rate)
+ return optimizer_init,optimizer_encoder,optimizer_decoder,optimizer_predictor,optimizer_background,optimizer_update
+
+def save_model(
+ file,
+ net,
+ optimizer_init,
+ optimizer_encoder,
+ optimizer_decoder,
+ optimizer_predictor,
+ optimizer_background
+):
+
+ state = { }
+
+ state['optimizer_init'] = optimizer_init.state_dict()
+ state['optimizer_encoder'] = optimizer_encoder.state_dict()
+ state['optimizer_decoder'] = optimizer_decoder.state_dict()
+ state['optimizer_predictor'] = optimizer_predictor.state_dict()
+ state['optimizer_background'] = optimizer_background.state_dict()
+
+ state["model"] = net.state_dict()
+ th.save(state, file)
+ pass
+
+def load_model(
+ file,
+ cfg,
+ net,
+ optimizer_init,
+ optimizer_encoder,
+ optimizer_decoder,
+ optimizer_predictor,
+ optimizer_background,
+ load_optimizers = True,
+ only_encoder_decoder = False
+):
+ device = th.device(cfg.device)
+ state = th.load(file, map_location=device)
+ print(f"load {file} to device {device}, only encoder/decoder: {only_encoder_decoder}")
+ print(f"load optimizers: {load_optimizers}")
+
+ if load_optimizers:
+ optimizer_init.load_state_dict(state[f'optimizer_init'])
+ for n in range(len(optimizer_init.param_groups)):
+ optimizer_init.param_groups[n]['lr'] = cfg.learning_rate
+
+ optimizer_encoder.load_state_dict(state[f'optimizer_encoder'])
+ for n in range(len(optimizer_encoder.param_groups)):
+ optimizer_encoder.param_groups[n]['lr'] = cfg.learning_rate
+
+ optimizer_decoder.load_state_dict(state[f'optimizer_decoder'])
+ for n in range(len(optimizer_decoder.param_groups)):
+ optimizer_decoder.param_groups[n]['lr'] = cfg.learning_rate
+
+ optimizer_predictor.load_state_dict(state['optimizer_predictor'])
+ for n in range(len(optimizer_predictor.param_groups)):
+ optimizer_predictor.param_groups[n]['lr'] = cfg.learning_rate
+
+ optimizer_background.load_state_dict(state['optimizer_background'])
+ for n in range(len(optimizer_background.param_groups)):
+ optimizer_background.param_groups[n]['lr'] = cfg.model.background.learning_rate
+
+ # 1. Fill model with values of net
+ model = {}
+ allowed_keys = []
+ rand_state = net.state_dict()
+ for key, value in rand_state.items():
+ allowed_keys.append(key)
+ model[key.replace(".module.", ".")] = value
+
+ # 2. Overwrite with values from file
+ for key, value in state["model"].items():
+ # replace update_module with percept_gate_controller in key string:
+ key = key.replace("update_module", "percept_gate_controller")
+
+ if key in allowed_keys:
+ if only_encoder_decoder:
+ if ('encoder' in key) or ('decoder' in key):
+ model[key.replace(".module.", ".")] = value
+ else:
+ model[key.replace(".module.", ".")] = value
+
+ net.load_state_dict(model)
+
+ pass
+
+def update_net_status(num_updates, net, cfg, optimizer_init):
+ if num_updates == cfg.phases.background_pretraining_end and net.get_init_status() < 1:
+ net.inc_init_level()
+
+ if num_updates == cfg.phases.entity_pretraining_phase1_end and net.get_init_status() < 2:
+ net.inc_init_level()
+
+ if num_updates == cfg.phases.entity_pretraining_phase2_end and net.get_init_status() < 3:
+ net.inc_init_level()
+ for param in optimizer_init.param_groups:
+ param['lr'] = cfg.learning_rate
+
+ pass
+
+def plot_online(cfg, path, num_updates, input, background, mask, sequence_len, t, output_next, bg_error_next):
+ plot_path = os.path.join(path, 'plots', f'net_{num_updates}')
+ if not os.path.exists(plot_path):
+ os.makedirs(plot_path, exist_ok=True)
+
+ # highlight error
+ grayscale = input[:,0:1] * 0.299 + input[:,1:2] * 0.587 + input[:,2:3] * 0.114
+ object_mask_cur = th.sum(mask[:,:-1], dim=1).unsqueeze(dim=1)
+ highlited_input = grayscale * (1 - object_mask_cur)
+ highlited_input += grayscale * object_mask_cur * 0.3333333
+ cmask = color_mask(mask[:,:-1])
+ highlited_input = highlited_input + cmask * 0.6666666
+
+ cv2.imwrite(os.path.join(plot_path, f'input-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(input[0], 'c h w -> h w c').detach().cpu().numpy() * 255)
+ cv2.imwrite(os.path.join(plot_path, f'background-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(background[0], 'c h w -> h w c').detach().cpu().numpy() * 255)
+ cv2.imwrite(os.path.join(plot_path, f'error_mask-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(bg_error_next[0], 'c h w -> h w c').detach().cpu().numpy() * 255)
+ cv2.imwrite(os.path.join(plot_path, f'background_mask-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(mask[0,-1:], 'c h w -> h w c').detach().cpu().numpy() * 255)
+ cv2.imwrite(os.path.join(plot_path, f'output_next-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(output_next[0], 'c h w -> h w c').detach().cpu().numpy() * 255)
+ cv2.imwrite(os.path.join(plot_path, f'output_highlight-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(highlited_input[0], 'c h w -> h w c').detach().cpu().numpy() * 255)
+ pass
+
+def get_loader(cfg, set, cfg_net, shuffle = True):
+ loader = DataLoader(
+ set,
+ pin_memory = True,
+ num_workers = cfg.defaults.num_workers,
+ batch_size = cfg_net.batch_size,
+ shuffle = shuffle,
+ drop_last = True,
+ prefetch_factor = cfg.defaults.prefetch_factor,
+ persistent_workers = True
+ )
+ return loader \ No newline at end of file
diff --git a/scripts/utils/configuration.py b/scripts/utils/configuration.py
new file mode 100644
index 0000000..7c38095
--- /dev/null
+++ b/scripts/utils/configuration.py
@@ -0,0 +1,83 @@
+import json
+import jsmin
+import os
+
+class Dict(dict):
+ """
+ Dictionary that allows to access per attributes and to except names from being loaded
+ """
+ def __init__(self, dictionary: dict = None):
+ super(Dict, self).__init__()
+
+ if dictionary is not None:
+ self.load(dictionary)
+
+ def __getattr__(self, item):
+ try:
+ return self[item] if item in self else getattr(super(Dict, self), item)
+ except AttributeError:
+ raise AttributeError(f'This dictionary has no attribute "{item}"')
+
+ def load(self, dictionary: dict, name_list: list = None):
+ """
+ Loads a dictionary
+ :param dictionary: Dictionary to be loaded
+ :param name_list: List of names to be updated
+ """
+ for name in dictionary:
+ data = dictionary[name]
+ if name_list is None or name in name_list:
+ if isinstance(data, dict):
+ if name in self:
+ self[name].load(data)
+ else:
+ self[name] = Dict(data)
+ elif isinstance(data, list):
+ self[name] = list()
+ for item in data:
+ if isinstance(item, dict):
+ self[name].append(Dict(item))
+ else:
+ self[name].append(item)
+ else:
+ self[name] = data
+
+ def save(self, path):
+ """
+ Saves the dictionary into a json file
+ :param path: Path of the json file
+ """
+ os.makedirs(path, exist_ok=True)
+
+ path = os.path.join(path, 'cfg.json')
+
+ with open(path, 'w') as file:
+ json.dump(self, file, indent=True)
+
+
+class Configuration(Dict):
+ """
+ Configuration loaded from a json file
+ """
+ def __init__(self, path: str, default_path=None):
+ super(Configuration, self).__init__()
+
+ if default_path is not None:
+ self.load(default_path)
+
+ self.load(path)
+
+ def load_model(self, path: str):
+ self.load(path, name_list=["model"])
+
+ def load(self, path: str, name_list: list = None):
+ """
+ Loads attributes from a json file
+ :param path: Path of the json file
+ :param name_list: List of names to be updated
+ :return:
+ """
+ with open(path) as file:
+ data = json.loads(jsmin.jsmin(file.read()))
+
+ super(Configuration, self).load(data, name_list)
diff --git a/scripts/utils/eval_metrics.py b/scripts/utils/eval_metrics.py
new file mode 100644
index 0000000..6f5106d
--- /dev/null
+++ b/scripts/utils/eval_metrics.py
@@ -0,0 +1,350 @@
+'''
+vp_utils.py
+SCRIPT TAKEN FROM https://github.com/pairlab/SlotFormer
+'''
+
+import numpy as np
+from scipy.optimize import linear_sum_assignment
+from skimage.metrics import structural_similarity, peak_signal_noise_ratio
+
+import torch
+import torch.nn.functional as F
+import torchvision.ops as vops
+
+FG_THRE = 0.5
+PALETTE = [(0, 255, 0), (0, 0, 255), (0, 255, 255), (255, 255, 0),
+ (255, 0, 255), (100, 100, 255), (200, 200, 100), (170, 120, 200),
+ (255, 0, 0), (200, 100, 100), (10, 200, 100), (200, 200, 200),
+ (50, 50, 50)]
+PALETTE_np = np.array(PALETTE, dtype=np.uint8)
+PALETTE_torch = torch.from_numpy(PALETTE_np).float() / 255. * 2. - 1.
+
+def to_rgb_from_tensor(x):
+ """Reverse the Normalize operation in torchvision."""
+ return (x * 0.5 + 0.5).clamp(0, 1)
+
+def postproc_mask(batch_masks):
+ """Post-process masks instead of directly taking argmax.
+
+ Args:
+ batch_masks: [B, T, N, 1, H, W]
+
+ Returns:
+ masks: [B, T, H, W]
+ """
+ batch_masks = batch_masks.clone()
+ B, T, N, _, H, W = batch_masks.shape
+ batch_masks = batch_masks.reshape(B * T, N, H * W)
+ slots_max = batch_masks.max(-1)[0] # [B*T, N]
+ bg_idx = slots_max.argmin(-1) # [B*T]
+ spatial_max = batch_masks.max(1)[0] # [B*T, H*W]
+ bg_mask = (spatial_max < FG_THRE) # [B*T, H*W]
+ idx_mask = torch.zeros((B * T, N)).type_as(bg_mask)
+ idx_mask[torch.arange(B * T), bg_idx] = True
+ # set the background mask score to 1
+ batch_masks[idx_mask.unsqueeze(-1) * bg_mask.unsqueeze(1)] = 1.
+ masks = batch_masks.argmax(1) # [B*T, H*W]
+ return masks.reshape(B, T, H, W)
+
+
+def masks_to_boxes_w_empty_mask(binary_masks):
+ """binary_masks: [B, H, W]."""
+ B = binary_masks.shape[0]
+ obj_mask = (binary_masks.sum([-1, -2]) > 0) # [B]
+ bboxes = torch.ones((B, 4)).float().to(binary_masks.device) * -1.
+ bboxes[obj_mask] = vops.masks_to_boxes(binary_masks[obj_mask])
+ return bboxes
+
+
+def masks_to_boxes(masks, num_boxes=7):
+ """Convert seg_masks to bboxes.
+
+ Args:
+ masks: [B, T, H, W], output after taking argmax
+ num_boxes: number of boxes to generate (num_slots)
+
+ Returns:
+ bboxes: [B, T, N, 4], 4: [x1, y1, x2, y2]
+ """
+ B, T, H, W = masks.shape
+ binary_masks = F.one_hot(masks, num_classes=num_boxes) # [B, T, H, W, N]
+ binary_masks = binary_masks.permute(0, 1, 4, 2, 3) # [B, T, N, H, W]
+ binary_masks = binary_masks.contiguous().flatten(0, 2)
+ bboxes = masks_to_boxes_w_empty_mask(binary_masks) # [B*T*N, 4]
+ bboxes = bboxes.reshape(B, T, num_boxes, 4) # [B, T, N, 4]
+ return bboxes
+
+
+def mse_metric(x, y):
+ """x/y: [B, C, H, W]"""
+ # people often sum over spatial dimensions in video prediction MSE
+ # see e.g. https://github.com/Yunbo426/predrnn-pp/blob/40764c52f433290aa02414b5f25c09c72b98e0af/train.py#L244
+ return ((x - y)**2).sum(-1).sum(-1).mean()
+
+
+def psnr_metric(x, y):
+ """x/y: [B, C, H, W]"""
+ psnrs = [
+ peak_signal_noise_ratio(
+ x[i],
+ y[i],
+ data_range=1.,
+ ) for i in range(x.shape[0])
+ ]
+ return np.mean(psnrs)
+
+
+def ssim_metric(x, y):
+ """x/y: [B, C, H, W]"""
+ x = x * 255.
+ y = y * 255.
+ ssims = [
+ structural_similarity(
+ x[i],
+ y[i],
+ channel_axis=0,
+ gaussian_weights=True,
+ sigma=1.5,
+ use_sample_covariance=False,
+ data_range=255,
+ ) for i in range(x.shape[0])
+ ]
+ return np.mean(ssims)
+
+
+def perceptual_dist(x, y, loss_fn):
+ """x/y: [B, C, H, W]"""
+ return loss_fn(x, y).mean()
+
+
+def adjusted_rand_index(true_ids, pred_ids, ignore_background=False):
+ """Computes the adjusted Rand index (ARI), a clustering similarity score.
+
+ Code borrowed from https://github.com/google-research/slot-attention-video/blob/e8ab54620d0f1934b332ddc09f1dba7bc07ff601/savi/lib/metrics.py#L111
+
+ Args:
+ true_ids: An integer-valued array of shape
+ [batch_size, seq_len, H, W]. The true cluster assignment encoded
+ as integer ids.
+ pred_ids: An integer-valued array of shape
+ [batch_size, seq_len, H, W]. The predicted cluster assignment
+ encoded as integer ids.
+ ignore_background: Boolean, if True, then ignore all pixels where
+ true_ids == 0 (default: False).
+
+ Returns:
+ ARI scores as a float32 array of shape [batch_size].
+ """
+ if len(true_ids.shape) == 3:
+ true_ids = true_ids.unsqueeze(1)
+ if len(pred_ids.shape) == 3:
+ pred_ids = pred_ids.unsqueeze(1)
+
+ true_oh = F.one_hot(true_ids).float()
+ pred_oh = F.one_hot(pred_ids).float()
+
+ if ignore_background:
+ true_oh = true_oh[..., 1:] # Remove the background row.
+
+ N = torch.einsum("bthwc,bthwk->bck", true_oh, pred_oh)
+ A = torch.sum(N, dim=-1) # row-sum (batch_size, c)
+ B = torch.sum(N, dim=-2) # col-sum (batch_size, k)
+ num_points = torch.sum(A, dim=1)
+
+ rindex = torch.sum(N * (N - 1), dim=[1, 2])
+ aindex = torch.sum(A * (A - 1), dim=1)
+ bindex = torch.sum(B * (B - 1), dim=1)
+ expected_rindex = aindex * bindex / torch.clamp(
+ num_points * (num_points - 1), min=1)
+ max_rindex = (aindex + bindex) / 2
+ denominator = max_rindex - expected_rindex
+ ari = (rindex - expected_rindex) / denominator
+
+ # There are two cases for which the denominator can be zero:
+ # 1. If both label_pred and label_true assign all pixels to a single cluster.
+ # (max_rindex == expected_rindex == rindex == num_points * (num_points-1))
+ # 2. If both label_pred and label_true assign max 1 point to each cluster.
+ # (max_rindex == expected_rindex == rindex == 0)
+ # In both cases, we want the ARI score to be 1.0:
+ return torch.where(denominator != 0, ari, torch.tensor(1.).type_as(ari))
+
+
+def ARI_metric(x, y):
+ """x/y: [B, H, W], both are seg_masks after argmax."""
+ assert 'int' in str(x.dtype)
+ assert 'int' in str(y.dtype)
+ return adjusted_rand_index(x, y).mean().item()
+
+
+def fARI_metric(x, y):
+ """x/y: [B, H, W], both are seg_masks after argmax."""
+ assert 'int' in str(x.dtype)
+ assert 'int' in str(y.dtype)
+ return adjusted_rand_index(x, y, ignore_background=True).mean().item()
+
+
+def bbox_precision_recall(gt_pres_mask, gt_bbox, pred_bbox, ovthresh=0.5):
+ """Compute the precision of predicted bounding boxes.
+
+ Args:
+ gt_pres_mask: A boolean tensor of shape [N]
+ gt_bbox: A tensor of shape [N, 4]
+ pred_bbox: A tensor of shape [M, 4]
+ """
+ gt_bbox, pred_bbox = gt_bbox.clone(), pred_bbox.clone()
+ gt_bbox = gt_bbox[gt_pres_mask.bool()]
+ pred_bbox = pred_bbox[pred_bbox[:, 0] >= 0.]
+ N, M = gt_bbox.shape[0], pred_bbox.shape[0]
+ assert gt_bbox.shape[1] == pred_bbox.shape[1] == 4
+ # assert M >= N
+ tp, fp = 0, 0
+ bbox_used = [False] * pred_bbox.shape[0]
+ bbox_ious = vops.box_iou(gt_bbox, pred_bbox) # [N, M]
+
+ # Find the best iou match for each ground truth bbox.
+ for i in range(N):
+ best_iou_idx = bbox_ious[i].argmax().item()
+ best_iou = bbox_ious[i, best_iou_idx].item()
+ if best_iou >= ovthresh and not bbox_used[best_iou_idx]:
+ tp += 1
+ bbox_used[best_iou_idx] = True
+ else:
+ fp += 1
+
+ # compute precision and recall
+ precision = tp / float(M)
+ recall = tp / float(N)
+ return precision, recall
+
+
+def batch_bbox_precision_recall(gt_pres_mask, gt_bbox, pred_bbox):
+ """Compute the precision of predicted bounding boxes over batch."""
+ aps, ars = [], []
+ for i in range(gt_pres_mask.shape[0]):
+ ap, ar = bbox_precision_recall(gt_pres_mask[i], gt_bbox[i],
+ pred_bbox[i])
+ aps.append(ap)
+ ars.append(ar)
+ return np.mean(aps), np.mean(ars)
+
+
+def hungarian_miou(gt_mask, pred_mask):
+ """both mask: [H*W] after argmax, 0 is gt background index."""
+ true_oh = F.one_hot(gt_mask).float()[..., 1:] # only foreground, [HW, N]
+ pred_oh = F.one_hot(pred_mask).float() # [HW, M]
+ N, M = true_oh.shape[-1], pred_oh.shape[-1]
+ # compute all pairwise IoU
+ intersect = (true_oh[:, :, None] * pred_oh[:, None, :]).sum(0) # [N, M]
+ union = true_oh.sum(0)[:, None] + pred_oh.sum(0)[None, :] # [N, M]
+ iou = intersect / (union + 1e-8) # [N, M]
+ iou = iou.detach().cpu().numpy()
+ # find the best match for each gt
+ row_ind, col_ind = linear_sum_assignment(iou, maximize=True)
+ # there are two possibilities here
+ # 1. M >= N, just take the best match mean
+ # 2. M < N, some objects are not detected, their iou is 0
+ if M >= N:
+ assert (row_ind == np.arange(N)).all()
+ return iou[row_ind, col_ind].mean()
+ return iou[row_ind, col_ind].sum() / float(N)
+
+
+def miou_metric(gt_mask, pred_mask):
+ """both mask: [B, H, W], both are seg_masks after argmax."""
+ assert 'int' in str(gt_mask.dtype)
+ assert 'int' in str(pred_mask.dtype)
+ gt_mask, pred_mask = gt_mask.flatten(1, 2), pred_mask.flatten(1, 2)
+ ious = [
+ hungarian_miou(gt_mask[i], pred_mask[i])
+ for i in range(gt_mask.shape[0])
+ ]
+ return np.mean(ious)
+
+
+@torch.no_grad()
+def pred_eval_step(
+ gt,
+ pred,
+ lpips_fn,
+ gt_mask=None,
+ pred_mask=None,
+ gt_pres_mask=None,
+ gt_bbox=None,
+ pred_bbox=None,
+ eval_traj=True,
+):
+ """Both of shape [B, T, C, H, W], torch.Tensor.
+ masks of shape [B, T, H, W].
+ pres_mask of shape [B, T, N].
+ bboxes of shape [B, T, N/M, 4].
+
+ eval_traj: whether to evaluate the trajectory (measured by bbox and mask).
+
+ Compute metrics for every timestep.
+ """
+ assert len(gt.shape) == len(pred.shape) == 5
+ assert gt.shape == pred.shape
+ assert gt.shape[2] == 3
+ if eval_traj:
+ assert len(gt_mask.shape) == len(pred_mask.shape) == 4
+ assert gt_mask.shape == pred_mask.shape
+ if eval_traj:
+ assert len(gt_pres_mask.shape) == 3
+ assert len(gt_bbox.shape) == len(pred_bbox.shape) == 4
+ T = gt.shape[1]
+
+ # compute perceptual dist & mask metrics before converting to numpy
+ all_percept_dist, all_ari, all_fari, all_miou = [], [], [], []
+ for t in range(T):
+ one_gt, one_pred = gt[:, t], pred[:, t]
+ percept_dist = perceptual_dist(one_gt, one_pred, lpips_fn).item()
+ all_percept_dist.append(percept_dist)
+ if eval_traj:
+ one_gt_mask, one_pred_mask = gt_mask[:, t], pred_mask[:, t]
+ ari = ARI_metric(one_gt_mask, one_pred_mask)
+ fari = fARI_metric(one_gt_mask, one_pred_mask)
+ miou = miou_metric(one_gt_mask, one_pred_mask)
+ all_ari.append(ari)
+ all_fari.append(fari)
+ all_miou.append(miou)
+ else:
+ all_ari.append(0.)
+ all_fari.append(0.)
+ all_miou.append(0.)
+
+ # compute bbox metrics
+ all_ap, all_ar = [], []
+ for t in range(T):
+ if not eval_traj:
+ all_ap.append(0.)
+ all_ar.append(0.)
+ continue
+ one_gt_pres_mask, one_gt_bbox, one_pred_bbox = \
+ gt_pres_mask[:, t], gt_bbox[:, t], pred_bbox[:, t]
+ ap, ar = batch_bbox_precision_recall(one_gt_pres_mask, one_gt_bbox,
+ one_pred_bbox)
+ all_ap.append(ap)
+ all_ar.append(ar)
+
+ gt = to_rgb_from_tensor(gt).cpu().numpy()
+ pred = to_rgb_from_tensor(pred).cpu().numpy()
+ all_mse, all_ssim, all_psnr = [], [], []
+ for t in range(T):
+ one_gt, one_pred = gt[:, t], pred[:, t]
+ mse = mse_metric(one_gt, one_pred)
+ psnr = psnr_metric(one_gt, one_pred)
+ ssim = ssim_metric(one_gt, one_pred)
+ all_mse.append(mse)
+ all_ssim.append(ssim)
+ all_psnr.append(psnr)
+ return {
+ 'mse': all_mse,
+ 'ssim': all_ssim,
+ 'psnr': all_psnr,
+ 'percept_dist': all_percept_dist,
+ 'ari': all_ari,
+ 'fari': all_fari,
+ 'miou': all_miou,
+ 'ap': all_ap,
+ 'ar': all_ar,
+ }
diff --git a/scripts/utils/eval_utils.py b/scripts/utils/eval_utils.py
new file mode 100644
index 0000000..faab7ec
--- /dev/null
+++ b/scripts/utils/eval_utils.py
@@ -0,0 +1,78 @@
+import os
+import torch as th
+from model.loci import Loci
+
+def setup_result_folders(file, name, set_test, evaluation_mode, object_view, individual_views):
+
+ net_name = file.split('/')[-1].split('.')[0]
+ #root_path = file.split('nets')[0]
+ root_path = os.path.join(*file.split('/')[0:-1])
+ root_path = os.path.join(root_path, f'results{name}', net_name, set_test['type'])
+ plot_path = os.path.join(root_path, evaluation_mode)
+
+ # create directories
+ #if os.path.exists(plot_path):
+ # shutil.rmtree(plot_path)
+ os.makedirs(plot_path, exist_ok = True)
+ if object_view:
+ os.makedirs(os.path.join(plot_path, 'object'), exist_ok = True)
+ if individual_views:
+ os.makedirs(os.path.join(plot_path, 'individual'), exist_ok = True)
+ for group in ['error', 'input', 'background', 'prediction', 'position', 'rawmask', 'mask', 'othermask', 'imagination']:
+ os.makedirs(os.path.join(plot_path, 'individual', group), exist_ok = True)
+ os.makedirs(os.path.join(root_path, 'statistics'), exist_ok = True)
+
+ # final directory
+ plot_path = plot_path + '/'
+ print(f"save plots to {plot_path}")
+
+ return root_path, plot_path
+
+def store_statistics(memory, *args, extend=False):
+ for i,key in enumerate(memory.keys()):
+ if i >= len(args):
+ break
+ if extend:
+ memory[key].extend(args[i])
+ else:
+ memory[key].append(args[i])
+ return memory
+
+def append_statistics(memory1, memory2, ignore=[], extend=False):
+ for key in memory1:
+ if key not in ignore:
+ if extend:
+ memory2[key] = memory2[key] + memory1[key]
+ else:
+ memory2[key].append(memory1[key])
+ return memory2
+
+def load_model(cfg, cfg_net, file, device):
+
+ net = Loci(
+ cfg_net,
+ teacher_forcing = cfg.defaults.teacher_forcing
+ )
+
+ # load model
+ if file != '':
+ print(f"load {file} to device {device}")
+ state = th.load(file, map_location=device)
+
+ # backward compatibility
+ model = {}
+ for key, value in state["model"].items():
+ # replace update_module with percept_gate_controller in key string:
+ key = key.replace("update_module", "percept_gate_controller")
+ model[key.replace(".module.", ".")] = value
+
+ net.load_state_dict(model)
+
+ # ???
+ if net.get_init_status() < 1:
+ net.inc_init_level()
+
+ # set network to evaluation mode
+ net = net.to(device=device)
+
+ return net \ No newline at end of file
diff --git a/scripts/utils/io.py b/scripts/utils/io.py
new file mode 100644
index 0000000..9bd8158
--- /dev/null
+++ b/scripts/utils/io.py
@@ -0,0 +1,137 @@
+import os
+from scripts.utils.configuration import Configuration
+import time
+import torch as th
+import numpy as np
+from einops import rearrange, repeat, reduce
+
+
+def init_device(cfg):
+ print(f'Cuda available: {th.cuda.is_available()} Cuda count: {th.cuda.device_count()}')
+ if th.cuda.is_available():
+ device = th.device("cuda:0")
+ verbose = False
+ cfg.device = "cuda:0"
+ cfg.model.device = "cuda:0"
+ print('!!! USING GPU !!!')
+ else:
+ device = th.device("cpu")
+ verbose = True
+ cfg.device = "cpu"
+ cfg.model.device = "cpu"
+ cfg.model.batch_size = 2
+ cfg.defaults.teacher_forcing = 4
+ print('!!! USING CPU !!!')
+ return device,verbose
+
+class Timer:
+
+ def __init__(self):
+ self.last = time.time()
+ self.passed = 0
+ self.sum = 0
+
+ def __str__(self):
+ self.passed = self.passed * 0.99 + time.time() - self.last
+ self.sum = self.sum * 0.99 + 1
+ passed = self.passed / self.sum
+ self.last = time.time()
+
+ if passed > 1:
+ return f"{passed:.2f}s/it"
+
+ return f"{1.0/passed:.2f}it/s"
+
+class UEMA:
+
+ def __init__(self, memory = 100):
+ self.value = 0
+ self.sum = 1e-30
+ self.decay = np.exp(-1 / memory)
+
+ def update(self, value):
+ self.value = self.value * self.decay + value
+ self.sum = self.sum * self.decay + 1
+
+ def __float__(self):
+ return self.value / self.sum
+
+
+def model_path(cfg: Configuration, overwrite=False, move_old=True):
+ """
+ Makes the model path, option to not overwrite
+ :param cfg: Configuration file with the model path
+ :param overwrite: Overwrites the files in the directory, else makes a new directory
+ :param move_old: Moves old folder with the same name to an old folder, if not overwrite
+ :return: Model path
+ """
+ _path = os.path.join('out')
+ path = os.path.join(_path, cfg.model_path)
+
+ if not os.path.exists(_path):
+ os.makedirs(_path)
+
+ if not overwrite:
+ if move_old:
+ # Moves existing directory to an old folder
+ if os.path.exists(path):
+ old_path = os.path.join(_path, f'{cfg.model_path}_old')
+ if not os.path.exists(old_path):
+ os.makedirs(old_path)
+ _old_path = os.path.join(old_path, cfg.model_path)
+ i = 0
+ while os.path.exists(_old_path):
+ i = i + 1
+ _old_path = os.path.join(old_path, f'{cfg.model_path}_{i}')
+ os.renames(path, _old_path)
+ else:
+ # Increases number after directory name for each new path
+ i = 0
+ while os.path.exists(path):
+ i = i + 1
+ path = os.path.join(_path, f'{cfg.model_path}_{i}')
+
+ return path
+
+class LossLogger:
+
+ def __init__(self):
+
+ self.avgloss = UEMA()
+ self.avg_position_loss = UEMA()
+ self.avg_time_loss = UEMA()
+ self.avg_encoder_loss = UEMA()
+ self.avg_mse_object_loss = UEMA()
+ self.avg_long_mse_object_loss = UEMA(33333)
+ self.avg_num_objects = UEMA()
+ self.avg_openings = UEMA()
+ self.avg_gestalt = UEMA()
+ self.avg_gestalt2 = UEMA()
+ self.avg_gestalt_mean = UEMA()
+ self.avg_update_gestalt = UEMA()
+ self.avg_update_position = UEMA()
+
+
+ def update_complete(self, avg_position_loss, avg_time_loss, avg_encoder_loss, avg_mse_object_loss, avg_long_mse_object_loss, avg_num_objects, avg_openings, avg_gestalt, avg_gestalt2, avg_gestalt_mean, avg_update_gestalt, avg_update_position):
+
+ self.avg_position_loss.update(avg_position_loss.item())
+ self.avg_time_loss.update(avg_time_loss.item())
+ self.avg_encoder_loss.update(avg_encoder_loss.item())
+ self.avg_mse_object_loss.update(avg_mse_object_loss.item())
+ self.avg_long_mse_object_loss.update(avg_long_mse_object_loss.item())
+ self.avg_num_objects.update(avg_num_objects)
+ self.avg_openings.update(avg_openings)
+ self.avg_gestalt.update(avg_gestalt.item())
+ self.avg_gestalt2.update(avg_gestalt2.item())
+ self.avg_gestalt_mean.update(avg_gestalt_mean.item())
+ self.avg_update_gestalt.update(avg_update_gestalt.item())
+ self.avg_update_position.update(avg_update_position.item())
+ pass
+
+ def update_average_loss(self, avgloss):
+ self.avgloss.update(avgloss)
+ pass
+
+ def get_log(self):
+ info = f'Loss: {np.abs(float(self.avgloss)):.2e}|{float(self.avg_mse_object_loss):.2e}|{float(self.avg_long_mse_object_loss):.2e}, reg: {float(self.avg_encoder_loss):.2e}|{float(self.avg_time_loss):.2e}|{float(self.avg_position_loss):.2e}, obj: {float(self.avg_num_objects):.1f}, open: {float(self.avg_openings):.2e}|{float(self.avg_gestalt):.2f}, bin: {float(self.avg_gestalt_mean):.2e}|{np.sqrt(float(self.avg_gestalt2) - float(self.avg_gestalt)**2):.2e} closed: {float(self.avg_update_gestalt):.2e}|{float(self.avg_update_position):.2e}'
+ return info
diff --git a/scripts/utils/optimizers.py b/scripts/utils/optimizers.py
new file mode 100644
index 0000000..49283b3
--- /dev/null
+++ b/scripts/utils/optimizers.py
@@ -0,0 +1,100 @@
+import math
+import torch as th
+import numpy as np
+from torch.optim.optimizer import Optimizer
+
+"""
+Liyuan Liu , Haoming Jiang, Pengcheng He, Weizhu Chen, Xiaodong Liu, Jianfeng Gao, and Jiawei Han (2020).
+On the Variance of the Adaptive Learning Rate and Beyond. the Eighth International Conference on Learning
+Representations.
+"""
+class RAdam(Optimizer):
+
+ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-16, weight_decay=0, degenerated_to_sgd=False):
+ if not 0.0 <= lr:
+ raise ValueError("Invalid learning rate: {}".format(lr))
+ if not 0.0 <= eps:
+ raise ValueError("Invalid epsilon value: {}".format(eps))
+ if not 0.0 <= betas[0] < 1.0:
+ raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
+ if not 0.0 <= betas[1] < 1.0:
+ raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
+
+ self.degenerated_to_sgd = degenerated_to_sgd
+ if isinstance(params, (list, tuple)) and len(params) > 0 and isinstance(params[0], dict):
+ for param in params:
+ if 'betas' in param and (param['betas'][0] != betas[0] or param['betas'][1] != betas[1]):
+ param['buffer'] = [[None, None, None] for _ in range(10)]
+ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)])
+ super(RAdam, self).__init__(params, defaults)
+
+ def __setstate__(self, state):
+ super(RAdam, self).__setstate__(state)
+
+ def step(self, closure=None):
+
+ loss = None
+ if closure is not None:
+ loss = closure()
+
+ for group in self.param_groups:
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad.data.float()
+ if grad.is_sparse:
+ raise RuntimeError('RAdam does not support sparse gradients')
+
+ p_data_fp32 = p.data.float()
+
+ state = self.state[p]
+
+ if len(state) == 0:
+ state['step'] = 0
+ state['exp_avg'] = th.zeros_like(p_data_fp32)
+ state['exp_avg_sq'] = th.zeros_like(p_data_fp32)
+ else:
+ state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
+ state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
+
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+ beta1, beta2 = group['betas']
+
+ exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value = 1 - beta2)
+ exp_avg.mul_(beta1).add_(grad, alpha = 1 - beta1)
+
+ state['step'] += 1
+ buffered = group['buffer'][int(state['step'] % 10)]
+ if state['step'] == buffered[0]:
+ N_sma, step_size = buffered[1], buffered[2]
+ else:
+ buffered[0] = state['step']
+ beta2_t = beta2 ** state['step']
+ N_sma_max = 2 / (1 - beta2) - 1
+ N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
+ buffered[1] = N_sma
+
+ # more conservative since it's an approximated value
+ if N_sma >= 5:
+ step_size = math.sqrt((1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (N_sma_max - 2)) / (1 - beta1 ** state['step'])
+ elif self.degenerated_to_sgd:
+ step_size = 1.0 / (1 - beta1 ** state['step'])
+ else:
+ step_size = -1
+ buffered[2] = step_size
+
+ # more conservative since it's an approximated value
+ if N_sma >= 5:
+ if group['weight_decay'] != 0:
+ p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
+ p_data_fp32.addcdiv_(exp_avg, denom, value = -step_size * group['lr'])
+ p.data.copy_(p_data_fp32)
+ elif step_size > 0:
+ if group['weight_decay'] != 0:
+ p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
+ p_data_fp32.add_(-step_size * group['lr'], exp_avg)
+ p.data.copy_(p_data_fp32)
+
+ return loss \ No newline at end of file
diff --git a/scripts/utils/plot_utils.py b/scripts/utils/plot_utils.py
new file mode 100644
index 0000000..5cb20de
--- /dev/null
+++ b/scripts/utils/plot_utils.py
@@ -0,0 +1,428 @@
+import torch as th
+from torch import nn
+import numpy as np
+import cv2
+from einops import rearrange, repeat
+import matplotlib.pyplot as plt
+import PIL
+from model.utils.nn_utils import Gaus2D, Vector2D
+
+def preprocess(tensor, scale=1, normalize=False, mean_std_normalize=False):
+
+ if tensor is None:
+ return None
+
+ if normalize:
+ min_ = th.min(tensor)
+ max_ = th.max(tensor)
+ tensor = (tensor - min_) / (max_ - min_)
+
+ if mean_std_normalize:
+ mean = th.mean(tensor)
+ std = th.std(tensor)
+ tensor = th.clip((tensor - mean) / (2 * std), -1, 1) * 0.5 + 0.5
+
+ if scale > 1:
+ upsample = nn.Upsample(scale_factor=scale).to(tensor[0].device)
+ tensor = upsample(tensor)
+
+ return tensor
+
+def preprocess_multi(*args, scale):
+ return [preprocess(a, scale) for a in args]
+
+def color_mask(mask):
+
+ colors = th.tensor([
+ [ 255, 0, 0 ],
+ [ 0, 0, 255 ],
+ [ 255, 255, 0 ],
+ [ 255, 0, 255 ],
+ [ 0, 255, 255 ],
+ [ 0, 255, 0 ],
+ [ 255, 128, 0 ],
+ [ 128, 255, 0 ],
+ [ 128, 0, 255 ],
+ [ 255, 0, 128 ],
+ [ 0, 255, 128 ],
+ [ 0, 128, 255 ],
+ [ 255, 128, 128 ],
+ [ 128, 255, 128 ],
+ [ 128, 128, 255 ],
+ [ 255, 128, 128 ],
+ [ 128, 255, 128 ],
+ [ 128, 128, 255 ],
+ [ 255, 128, 255 ],
+ [ 128, 255, 255 ],
+ [ 128, 255, 255 ],
+ [ 255, 255, 128 ],
+ [ 255, 255, 128 ],
+ [ 255, 128, 255 ],
+ [ 128, 0, 0 ],
+ [ 0, 0, 128 ],
+ [ 128, 128, 0 ],
+ [ 128, 0, 128 ],
+ [ 0, 128, 128 ],
+ [ 0, 128, 0 ],
+ [ 128, 128, 0 ],
+ [ 128, 128, 0 ],
+ [ 128, 0, 128 ],
+ [ 128, 0, 128 ],
+ [ 0, 128, 128 ],
+ [ 0, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ ], device = mask.device) / 255.0
+
+ colors = colors.view(1, -1, 3, 1, 1)
+ mask = mask.unsqueeze(dim=2)
+
+ return th.sum(colors[:,:mask.shape[1]] * mask, dim=1)
+
+def get_color(o):
+ colors = th.tensor([
+ [ 255, 0, 0 ],
+ [ 0, 0, 255 ],
+ [ 255, 255, 0 ],
+ [ 255, 0, 255 ],
+ [ 0, 255, 255 ],
+ [ 0, 255, 0 ],
+ [ 255, 128, 0 ],
+ [ 128, 255, 0 ],
+ [ 128, 0, 255 ],
+ [ 255, 0, 128 ],
+ [ 0, 255, 128 ],
+ [ 0, 128, 255 ],
+ [ 255, 128, 128 ],
+ [ 128, 255, 128 ],
+ [ 128, 128, 255 ],
+ [ 255, 128, 128 ],
+ [ 128, 255, 128 ],
+ [ 128, 128, 255 ],
+ [ 255, 128, 255 ],
+ [ 128, 255, 255 ],
+ [ 128, 255, 255 ],
+ [ 255, 255, 128 ],
+ [ 255, 255, 128 ],
+ [ 255, 128, 255 ],
+ [ 128, 0, 0 ],
+ [ 0, 0, 128 ],
+ [ 128, 128, 0 ],
+ [ 128, 0, 128 ],
+ [ 0, 128, 128 ],
+ [ 0, 128, 0 ],
+ [ 128, 128, 0 ],
+ [ 128, 128, 0 ],
+ [ 128, 0, 128 ],
+ [ 128, 0, 128 ],
+ [ 0, 128, 128 ],
+ [ 0, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ [ 128, 128, 128 ],
+ ]) / 255.0
+
+ colors = colors.view(48,3)
+ return colors[o]
+
+def to_rgb(tensor: th.Tensor):
+ return th.cat((
+ tensor * 0.6 + 0.4,
+ tensor,
+ tensor
+ ), dim=1)
+
+def visualise_gate(gate, h, w):
+ bar = th.ones((1,h,w), device=gate.device) * 0.9
+ black = int(w*gate.item())
+ if black > 0:
+ bar[:,:, -black:] = 0
+ return bar
+
+def get_highlighted_input(input, mask_cur):
+
+ # highlight error
+ highlighted_input = input
+ if mask_cur is not None:
+ grayscale = input[:,0:1] * 0.299 + input[:,1:2] * 0.587 + input[:,2:3] * 0.114
+ object_mask_cur = th.sum(mask_cur[:,:-1], dim=1).unsqueeze(dim=1)
+ highlighted_input = grayscale * (1 - object_mask_cur)
+ highlighted_input += grayscale * object_mask_cur * 0.3333333
+ cmask = color_mask(mask_cur[:,:-1])
+ highlighted_input = highlighted_input + cmask * 0.6666666
+
+ return highlighted_input
+
+def color_slots(image, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur):
+
+ image = (1-image) * slots_bounded + image * (1-slots_bounded)
+ image = th.clip(image - 0.3, 0,1) * slots_partially_occluded_cur + image * (1-slots_partially_occluded_cur)
+ image = th.clip(image - 0.3, 0,1) * slots_occluded_cur + image * (1-slots_occluded_cur)
+
+ return image
+
+def compute_occlusion_mask(rawmask_cur, rawmask_next, mask_cur, mask_next, scale):
+
+ # compute occlusion mask
+ occluded_cur = th.clip(rawmask_cur - mask_cur, 0, 1)[:,:-1]
+ occluded_next = th.clip(rawmask_next - mask_next, 0, 1)[:,:-1]
+
+ # to rgb
+ rawmask_cur = repeat(rawmask_cur[:,:-1], 'b o h w -> b (o 3) h w')
+ rawmask_next = repeat(rawmask_next[:,:-1], 'b o h w -> b (o 3) h w')
+
+ # scale
+ occluded_next = preprocess(occluded_next, scale)
+ occluded_cur = preprocess(occluded_cur, scale)
+ rawmask_cur = preprocess(rawmask_cur, scale)
+ rawmask_next = preprocess(rawmask_next, scale)
+
+ # set occlusion to red
+ rawmask_cur = rearrange(rawmask_cur, 'b (o c) h w -> b o c h w', c = 3)
+ rawmask_cur[:,:,0] = rawmask_cur[:,:,0] * (1 - occluded_next)
+ rawmask_cur[:,:,1] = rawmask_cur[:,:,1] * (1 - occluded_next)
+
+ rawmask_next = rearrange(rawmask_next, 'b (o c) h w -> b o c h w', c = 3)
+ rawmask_next[:,:,0] = rawmask_next[:,:,0] * (1 - occluded_next)
+ rawmask_next[:,:,1] = rawmask_next[:,:,1] * (1 - occluded_next)
+
+ return rawmask_cur, rawmask_next
+
+def plot_online_error_slots(errors, error_name, target, sequence_len, root_path, visibilty_memory, slots_bounded, ylim=0.3):
+ error_plots = []
+ if len(errors) > 0:
+ num_slots = int(th.sum(slots_bounded).item())
+ errors = rearrange(np.array(errors), '(l o) -> o l', o=len(slots_bounded))[:num_slots]
+ visibilty_memory = rearrange(np.array(visibilty_memory), '(l o) -> o l', o=len(slots_bounded))[:num_slots]
+ for error,visibility in zip(errors, visibilty_memory):
+
+ if len(error) < sequence_len:
+ fig, ax = plt.subplots(figsize=(round(target.shape[3]/100,2), round(target.shape[2]/100,2)))
+ plt.plot(error, label=error_name)
+
+ visibility = np.concatenate((visibility, np.ones(sequence_len-len(error))))
+ ax.fill_between(range(sequence_len), 0, 1, where=visibility==0, color='orange', alpha=0.3, transform=ax.get_xaxis_transform())
+
+ plt.xlim((0,sequence_len))
+ plt.ylim((0,ylim))
+ fig.tight_layout()
+ plt.savefig(f'{root_path}/tmp.jpg')
+
+ error_plot = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb())
+ error_plot = th.from_numpy(np.array(error_plot).transpose(2,0,1))
+ plt.close(fig)
+
+ error_plots.append(error_plot)
+
+ return error_plots
+
+def plot_online_error(error, error_name, target, t, i, sequence_len, root_path, online_surprise = False):
+
+ fig = plt.figure(figsize=( round(target.shape[3]/50,2), round(target.shape[2]/50,2) ))
+ plt.plot(error, label=error_name)
+
+ if online_surprise:
+ # compute moving average of error
+ moving_average_length = 10
+ if t > moving_average_length:
+ moving_average_length += 1
+ average_error = np.mean(error[-moving_average_length:-1])
+ current_sd = np.std(error[-moving_average_length:-1])
+ current_error = error[-1]
+
+ if current_error > average_error + 2 * current_sd:
+ fig.set_facecolor('orange')
+
+ plt.xlim((0,sequence_len))
+ plt.legend()
+ # increase title size
+ plt.title(f'{error_name}', fontsize=20)
+ plt.xlabel('timestep')
+ plt.ylabel('error')
+ plt.savefig(f'{root_path}/tmp.jpg')
+
+ error_plot = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb())
+ error_plot = th.from_numpy(np.array(error_plot).transpose(2,0,1))
+ plt.close(fig)
+
+ return error_plot
+
+def plot_object_view(error_plot, error_plot2, error_plot_slots, error_plot_slots2, highlighted_input, output_hidden, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object):
+
+ # add ground truth positions of objects to image
+ if gt_positions_target_next is not None:
+ for o in range(gt_positions_target_next.shape[1]):
+ position = gt_positions_target_next[0, o]
+ position = position/2 + 0.5
+
+ if position[2] > 0.0 and position[0] > 0.0 and position[0] < 1.0 and position[1] > 0.0 and position[1] < 1.0:
+ width = 5
+ w = np.clip(int(position[0]*target.shape[2]), width, target.shape[2]-width).item()
+ h = np.clip(int(position[1]*target.shape[3]), width, target.shape[3]-width).item()
+ col = get_color(o).view(3,1,1)
+ target[0,:,(w-width):(w+width), (h-width):(h+width)] = col
+
+ # add these positions to the associated slots velocity_next2d ilustration
+ slots = (association_table[0] == o).nonzero()
+ for s in slots.flatten():
+ velocity_next2d[s,:,(w-width):(w+width), (h-width):(h+width)] = col
+
+ if output_hidden is not None and s != largest_object:
+ output_hidden[0,:,(w-width):(w+width), (h-width):(h+width)] = col
+
+ gateheight = 60
+ ch = 40
+ gh = 40
+ gh_bar = gh-20
+ gh_margin = int((gh-gh_bar)/2)
+ margin = 20
+ slots_margin = 10
+ height = size[0] * 6 + 18*5
+ width = size[1] * 4 + 18*2 + size[1]*num_objects + 6*(num_objects+1) + slots_margin*(num_objects+1)
+ img = th.ones((3, height, width), device = object_next.device) * 0.4
+ row = (lambda row_index: [2*size[0]*row_index + (row_index+1)*margin, 2*size[0]*(row_index+1) + (row_index+1)*margin])
+ col1 = range(margin, margin + size[1]*2)
+ col2 = range(width-(margin+size[1]*2), width-margin)
+
+ img[:,row(0)[0]:row(0)[1], col1] = preprocess(highlighted_input.to(object_next.device), 2)[0]
+ img[:,row(1)[0]:row(1)[1], col1] = preprocess(output_hidden.to(object_next.device), 2)[0]
+ img[:,row(2)[0]:row(2)[1], col1] = preprocess(target.to(object_next.device), 2)[0]
+
+ if error_plot is not None:
+ img[:,row(0)[0]+gh+ch+2*margin-gh_margin:row(0)[1]+gh+ch+2*margin-gh_margin, col2] = preprocess(error_plot.to(object_next.device), normalize= True)
+ if error_plot2 is not None:
+ img[:,row(2)[0]:row(2)[1], col2] = preprocess(error_plot2.to(object_next.device), normalize= True)
+
+ for o in range(num_objects):
+
+ col = 18+size[1]*2+6+o*(6+size[1])+(o+1)*slots_margin
+ col = range(col, col + size[1])
+
+ # color bar for the gate
+ if (error_plot_slots2 is not None) and len(error_plot_slots2) > o:
+ img[:,margin:margin+ch, col] = get_color(o).view(3,1,1).to(object_next.device)
+
+ img[:,margin+ch+2*margin:2*margin+gh_bar+ch+margin, col] = visualise_gate(slots_closed[:,o, 0].to(object_next.device), h=gh_bar, w=len(col))
+ offset = gh+margin-gh_margin+ch+2*margin
+ row = (lambda row_index: [offset+(size[0]+6)*row_index, offset+size[0]*(row_index+1)+6*row_index])
+
+ img[:,row(0)[0]:row(0)[1], col] = preprocess(rawmask_next[0,o].to(object_next.device))
+ img[:,row(1)[0]:row(1)[1], col] = preprocess(object_next[:,o].to(object_next.device))
+ if (error_plot_slots2 is not None) and len(error_plot_slots2) > o:
+ img[:,row(2)[0]:row(2)[1], col] = preprocess(error_plot_slots2[o].to(object_next.device), normalize=True)
+
+ offset = margin*2-8
+ row = (lambda row_index: [offset+(size[0]+6)*row_index, offset+size[0]*(row_index+1)+6*row_index])
+ img[:,row(4)[0]-gh+gh_margin:row(4)[0]-gh_margin, col] = visualise_gate(slots_closed[:,o, 1].to(object_next.device), h=gh_bar, w=len(col))
+ img[:,row(4)[0]:row(4)[1], col] = preprocess(velocity_next2d[o].to(object_next.device), normalize=True)[0]
+ if (error_plot_slots is not None) and len(error_plot_slots) > o:
+ img[:,row(5)[0]:row(5)[1], col] = preprocess(error_plot_slots[o].to(object_next.device), normalize=True)
+
+ img = rearrange(img * 255, 'c h w -> h w c').cpu().numpy()
+
+ return img
+
+def write_image(file, img):
+ img = rearrange(img * 255, 'c h w -> h w c').cpu().numpy()
+ cv2.imwrite(file, img)
+
+ pass
+
+def plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i):
+
+ # Create eposition helpers
+ size, gaus2d, vector2d, scale = get_position_helper(cfg_net, mask_cur.device)
+
+ # Compute plot content
+ highlighted_input = get_highlighted_input(input, mask_cur)
+ output = th.clip(output_next, 0, 1)
+ position_cur2d = gaus2d(rearrange(position_encoder_cur, 'b (o c) -> (b o) c', o=cfg_net.num_objects))
+ velocity_next2d = vector2d(rearrange(position_next, 'b (o c) -> (b o) c', o=cfg_net.num_objects))
+
+ # color slots
+ slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next = reshape_slots(slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next)
+ position_cur2d = color_slots(position_cur2d, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur)
+ velocity_next2d = color_slots(velocity_next2d, slots_bounded, slots_partially_occluded_next, slots_occluded_next)
+
+ # compute occlusion
+ if (cfg.datatype == "adept"):
+ rawmask_cur_l, rawmask_next_l = compute_occlusion_mask(rawmask_cur, rawmask_next, mask_cur, mask_next, scale)
+ rawmask_cur_h, rawmask_next_h = compute_occlusion_mask(rawmask_cur, rawmask_hidden, mask_cur, mask_next, scale)
+ rawmask_cur_h[:,largest_object] = rawmask_cur_l[:,largest_object]
+ rawmask_next_h[:,largest_object] = rawmask_next_l[:,largest_object]
+ rawmask_cur = rawmask_cur_h
+ rawmask_next = rawmask_next_h
+ object_hidden[:, largest_object] = object_next[:, largest_object]
+ object_next = object_hidden
+ else:
+ rawmask_cur, rawmask_next = compute_occlusion_mask(rawmask_cur, rawmask_next, mask_cur, mask_next, scale)
+
+ # scale plot content
+ input, target, output, highlighted_input, object_next, object_cur, mask_next, error_next, output_hidden, output_next = preprocess_multi(input, target, output, highlighted_input, object_next, object_cur, mask_next, error_next, output_hidden, output_next, scale=scale)
+
+ # reshape
+ object_next = rearrange(object_next, 'b (o c) h w -> b o c h w', c = cfg_net.img_channels)
+ object_cur = rearrange(object_cur, 'b (o c) h w -> b o c h w', c = cfg_net.img_channels)
+ mask_next = rearrange(mask_next, 'b (o 1) h w -> b o 1 h w')
+
+ if object_view:
+ if (cfg.datatype == "adept"):
+ num_objects = 4
+ error_plot_slots = plot_online_error_slots(statistics_complete_slots['TE'][-cfg_net.num_objects*(t+1):], 'Tracking error', target, sequence_len, root_path, statistics_complete_slots['visible'][-cfg_net.num_objects*(t+1):], slots_bounded)
+ error_plot_slots2 = plot_online_error_slots(statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], 'Image error', target, sequence_len, root_path, statistics_complete_slots['visible'][-cfg_net.num_objects*(t+1):], slots_bounded, ylim=0.0001)
+ error_plot = plot_online_error(statistics_batch['image_error'], 'Prediction error', target, t, i, sequence_len, root_path)
+ error_plot2 = plot_online_error(statistics_batch['TE'], 'Tracking error', target, t, i, sequence_len, root_path)
+ img = plot_object_view(error_plot, error_plot2, error_plot_slots, error_plot_slots2, highlighted_input, output_hidden, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object)
+ else:
+ num_objects = cfg_net.num_objects
+ error_plot_slots2 = plot_online_error_slots(statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], 'Image error', target, sequence_len, root_path, statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], slots_bounded, ylim=0.0001)
+ error_plot = plot_online_error(statistics_batch['image_error_mse'], 'Prediction error', target, t, i, sequence_len, root_path)
+ img = plot_object_view(error_plot, None, None, error_plot_slots2, highlighted_input, output_next, object_next, rawmask_next, velocity_next2d, target, slots_closed, None, None, size, num_objects, largest_object)
+
+ cv2.imwrite(f'{plot_path}object/gpnet-objects-{i:04d}-{t_index:03d}.jpg', img)
+
+ if individual_views:
+ # ['error', 'input', 'background', 'prediction', 'position', 'rawmask', 'mask', 'othermask']:
+ write_image(f'{plot_path}/individual/error/error-{i:04d}-{t_index:03d}.jpg', error_next[0])
+ write_image(f'{plot_path}/individual/input/input-{i:04d}-{t_index:03d}.jpg', input[0])
+ write_image(f'{plot_path}/individual/background/background-{i:04d}-{t_index:03d}.jpg', mask_next[0,-1])
+ write_image(f'{plot_path}/individual/imagination/imagination-{i:04d}-{t_index:03d}.jpg', output_hidden[0])
+ write_image(f'{plot_path}/individual/prediction/prediction-{i:04d}-{t_index:03d}.jpg', output_next[0])
+
+ pass
+
+def get_position_helper(cfg_net, device):
+ size = cfg_net.input_size
+ gaus2d = Gaus2D(size).to(device)
+ vector2d = Vector2D(size).to(device)
+ scale = size[0] // (cfg_net.latent_size[0] * 2**(cfg_net.level*2))
+ return size,gaus2d,vector2d,scale
+
+def reshape_slots(slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next):
+
+ slots_bounded = th.squeeze(slots_bounded)[..., None,None,None]
+ slots_partially_occluded_cur = th.squeeze(slots_partially_occluded_cur)[..., None,None,None]
+ slots_occluded_cur = th.squeeze(slots_occluded_cur)[..., None,None,None]
+ slots_partially_occluded_next = th.squeeze(slots_partially_occluded_next)[..., None,None,None]
+ slots_occluded_next = th.squeeze(slots_occluded_next)[..., None,None,None]
+
+ return slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next \ No newline at end of file
diff --git a/scripts/validation.py b/scripts/validation.py
new file mode 100644
index 0000000..5b1638d
--- /dev/null
+++ b/scripts/validation.py
@@ -0,0 +1,261 @@
+import torch as th
+from torch import nn
+from torch.utils.data import DataLoader
+import numpy as np
+from einops import rearrange, repeat, reduce
+from scripts.utils.configuration import Configuration
+from model.loci import Loci
+import time
+import lpips
+from skimage.metrics import structural_similarity as ssimloss
+from skimage.metrics import peak_signal_noise_ratio as psnrloss
+
+def validation_adept(valloader: DataLoader, net: Loci, cfg: Configuration, device):
+
+ # memory
+ mseloss = nn.MSELoss()
+ avgloss = 0
+ start_time = time.time()
+
+ with th.no_grad():
+ for i, input in enumerate(valloader):
+
+ # get input frame and target frame
+ tensor = input[0].float().to(device)
+ background_fix = input[1].to(device)
+
+ # apply skip frames
+ tensor = tensor[:,range(0, tensor.shape[1], cfg.defaults.skip_frames)]
+ sequence_len = tensor.shape[1]
+
+ # initial frame
+ input = tensor[:,0]
+ target = th.clip(tensor[:,0], 0, 1)
+ error_last = None
+
+ # placehodlers
+ mask_cur = None
+ mask_last = None
+ rawmask_last = None
+ position_last = None
+ gestalt_last = None
+ priority_last = None
+ gt_positions_target = None
+ slots_occlusionfactor = None
+
+ # loop through frames
+ for t_index,t in enumerate(range(-cfg.defaults.teacher_forcing, sequence_len-1)):
+
+ # move to next frame
+ t_run = max(t, 0)
+ input = tensor[:,t_run]
+ target = th.clip(tensor[:,t_run+1], 0, 1)
+
+ # obtain prediction
+ (
+ output_next,
+ position_next,
+ gestalt_next,
+ priority_next,
+ mask_next,
+ rawmask_next,
+ object_next,
+ background,
+ slots_occlusionfactor,
+ output_cur,
+ position_cur,
+ gestalt_cur,
+ priority_cur,
+ mask_cur,
+ rawmask_cur,
+ object_cur,
+ position_encoder_cur,
+ slots_bounded,
+ slots_partially_occluded_cur,
+ slots_occluded_cur,
+ slots_partially_occluded_next,
+ slots_occluded_next,
+ slots_closed,
+ output_hidden,
+ largest_object,
+ rawmask_hidden,
+ object_hidden
+ ) = net(
+ input,
+ error_last,
+ mask_last,
+ rawmask_last,
+ position_last,
+ gestalt_last,
+ priority_last,
+ background_fix,
+ slots_occlusionfactor,
+ reset = (t == -cfg.defaults.teacher_forcing),
+ evaluate=True,
+ warmup = (t < 0),
+ shuffleslots = False,
+ reset_mask = (t <= 0),
+ allow_spawn = True,
+ show_hidden = False,
+ clean_slots = (t <= 0),
+ )
+
+ # 1. Track error
+ if t >= 0:
+ loss = mseloss(output_next, target)
+ avgloss += loss.item()
+
+ # 2. Remember output
+ mask_last = mask_next.clone()
+ rawmask_last = rawmask_next.clone()
+ position_last = position_next.clone()
+ gestalt_last = gestalt_next.clone()
+ priority_last = priority_next.clone()
+
+ # 3. Error for next frame
+ # background error
+ bg_error_cur = th.sqrt(reduce((input - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+ bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+
+ # prediction error
+ error_next = th.sqrt(reduce((target - output_next)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+ error_next = th.sqrt(error_next) * bg_error_next
+ error_last = error_next.clone()
+
+ print(f"Validation loss: {avgloss / len(valloader.dataset):.2e}, Time: {time.time() - start_time}")
+
+ pass
+
+
+def validation_clevrer(valloader: DataLoader, net: Loci, cfg: Configuration, device):
+
+ # memory
+ mseloss = nn.MSELoss()
+ lpipsloss = lpips.LPIPS(net='vgg').to(device)
+ avgloss_mse = 0
+ avgloss_lpips = 0
+ avgloss_psnr = 0
+ avgloss_ssim = 0
+ start_time = time.time()
+
+ burn_in_length = 6
+ rollout_length = 42
+
+ with th.no_grad():
+ for i, input in enumerate(valloader):
+
+ # get input frame and target frame
+ tensor = input[0].float().to(device)
+ background_fix = input[1].to(device)
+
+ # apply skip frames
+ tensor = tensor[:,range(0, tensor.shape[1], cfg.defaults.skip_frames)]
+ sequence_len = tensor.shape[1]
+
+ # initial frame
+ input = tensor[:,0]
+ target = th.clip(tensor[:,0], 0, 1)
+ error_last = None
+
+ # placehodlers
+ mask_cur = None
+ mask_last = None
+ rawmask_last = None
+ position_last = None
+ gestalt_last = None
+ priority_last = None
+ gt_positions_target = None
+ slots_occlusionfactor = None
+
+ # loop through frames
+ for t_index,t in enumerate(range(-cfg.defaults.teacher_forcing, min(burn_in_length + rollout_length-1, sequence_len-1))):
+
+ # move to next frame
+ t_run = max(t, 0)
+ input = tensor[:,t_run]
+ target = th.clip(tensor[:,t_run+1], 0, 1)
+ if t_run >= burn_in_length:
+ blackout = th.tensor((np.random.rand(valloader.batch_size) < 0.2)[:,None,None,None]).float().to(device)
+ input = blackout * (input * 0) + (1-blackout) * input
+ error_last = blackout * (error_last * 0) + (1-blackout) * error_last
+
+ # obtain prediction
+ (
+ output_next,
+ position_next,
+ gestalt_next,
+ priority_next,
+ mask_next,
+ rawmask_next,
+ object_next,
+ background,
+ slots_occlusionfactor,
+ output_cur,
+ position_cur,
+ gestalt_cur,
+ priority_cur,
+ mask_cur,
+ rawmask_cur,
+ object_cur,
+ position_encoder_cur,
+ slots_bounded,
+ slots_partially_occluded_cur,
+ slots_occluded_cur,
+ slots_partially_occluded_next,
+ slots_occluded_next,
+ slots_closed,
+ output_hidden,
+ largest_object,
+ rawmask_hidden,
+ object_hidden
+ ) = net(
+ input,
+ error_last,
+ mask_last,
+ rawmask_last,
+ position_last,
+ gestalt_last,
+ priority_last,
+ background_fix,
+ slots_occlusionfactor,
+ reset = (t == -cfg.defaults.teacher_forcing),
+ evaluate=True,
+ warmup = (t < 0),
+ shuffleslots = False,
+ reset_mask = (t <= 0),
+ allow_spawn = True,
+ show_hidden = False,
+ clean_slots = (t <= 0),
+ )
+
+ # 1. Track error
+ if t >= 0:
+ loss_mse = mseloss(output_next, target)
+ loss_ssim = np.sum([ssimloss(output_next[i].cpu().numpy(), target[i].cpu().numpy(), channel_axis=0,gaussian_weights=True,sigma=1.5,use_sample_covariance=False,data_range=1) for i in range(output_next.shape[0])]),
+ loss_psnr = np.sum([psnrloss(output_next[i].cpu().numpy(), target[i].cpu().numpy(), data_range=1) for i in range(output_next.shape[0])]),
+ loss_lpips = th.sum(lpipsloss(output_next*2-1, target*2-1))
+
+ avgloss_mse += loss_mse.item()
+ avgloss_ssim += loss_ssim[0].item()
+ avgloss_psnr += loss_psnr[0].item()
+ avgloss_lpips += loss_lpips.item()
+
+ # 2. Remember output
+ mask_last = mask_next.clone()
+ rawmask_last = rawmask_next.clone()
+ position_last = position_next.clone()
+ gestalt_last = gestalt_next.clone()
+ priority_last = priority_next.clone()
+
+ # 3. Error for next frame
+ # background error
+ bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+
+ # prediction error
+ error_next = th.sqrt(reduce((target - output_next)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+ error_next = th.sqrt(error_next) * bg_error_next
+ error_last = error_next.clone()
+
+ print(f"MSE loss: {avgloss_mse / len(valloader.dataset):.2e}, LPIPS loss: {avgloss_lpips / len(valloader.dataset):.2e}, PSNR loss: {avgloss_psnr / len(valloader.dataset):.2e}, SSIM loss: {avgloss_ssim / len(valloader.dataset):.2e}, Time: {time.time() - start_time}")
+
+ pass \ No newline at end of file