diff options
Diffstat (limited to 'scripts/utils/plot_utils.py')
-rw-r--r-- | scripts/utils/plot_utils.py | 428 |
1 files changed, 428 insertions, 0 deletions
diff --git a/scripts/utils/plot_utils.py b/scripts/utils/plot_utils.py new file mode 100644 index 0000000..5cb20de --- /dev/null +++ b/scripts/utils/plot_utils.py @@ -0,0 +1,428 @@ +import torch as th +from torch import nn +import numpy as np +import cv2 +from einops import rearrange, repeat +import matplotlib.pyplot as plt +import PIL +from model.utils.nn_utils import Gaus2D, Vector2D + +def preprocess(tensor, scale=1, normalize=False, mean_std_normalize=False): + + if tensor is None: + return None + + if normalize: + min_ = th.min(tensor) + max_ = th.max(tensor) + tensor = (tensor - min_) / (max_ - min_) + + if mean_std_normalize: + mean = th.mean(tensor) + std = th.std(tensor) + tensor = th.clip((tensor - mean) / (2 * std), -1, 1) * 0.5 + 0.5 + + if scale > 1: + upsample = nn.Upsample(scale_factor=scale).to(tensor[0].device) + tensor = upsample(tensor) + + return tensor + +def preprocess_multi(*args, scale): + return [preprocess(a, scale) for a in args] + +def color_mask(mask): + + colors = th.tensor([ + [ 255, 0, 0 ], + [ 0, 0, 255 ], + [ 255, 255, 0 ], + [ 255, 0, 255 ], + [ 0, 255, 255 ], + [ 0, 255, 0 ], + [ 255, 128, 0 ], + [ 128, 255, 0 ], + [ 128, 0, 255 ], + [ 255, 0, 128 ], + [ 0, 255, 128 ], + [ 0, 128, 255 ], + [ 255, 128, 128 ], + [ 128, 255, 128 ], + [ 128, 128, 255 ], + [ 255, 128, 128 ], + [ 128, 255, 128 ], + [ 128, 128, 255 ], + [ 255, 128, 255 ], + [ 128, 255, 255 ], + [ 128, 255, 255 ], + [ 255, 255, 128 ], + [ 255, 255, 128 ], + [ 255, 128, 255 ], + [ 128, 0, 0 ], + [ 0, 0, 128 ], + [ 128, 128, 0 ], + [ 128, 0, 128 ], + [ 0, 128, 128 ], + [ 0, 128, 0 ], + [ 128, 128, 0 ], + [ 128, 128, 0 ], + [ 128, 0, 128 ], + [ 128, 0, 128 ], + [ 0, 128, 128 ], + [ 0, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + ], device = mask.device) / 255.0 + + colors = colors.view(1, -1, 3, 1, 1) + mask = mask.unsqueeze(dim=2) + + return th.sum(colors[:,:mask.shape[1]] * mask, dim=1) + +def get_color(o): + colors = th.tensor([ + [ 255, 0, 0 ], + [ 0, 0, 255 ], + [ 255, 255, 0 ], + [ 255, 0, 255 ], + [ 0, 255, 255 ], + [ 0, 255, 0 ], + [ 255, 128, 0 ], + [ 128, 255, 0 ], + [ 128, 0, 255 ], + [ 255, 0, 128 ], + [ 0, 255, 128 ], + [ 0, 128, 255 ], + [ 255, 128, 128 ], + [ 128, 255, 128 ], + [ 128, 128, 255 ], + [ 255, 128, 128 ], + [ 128, 255, 128 ], + [ 128, 128, 255 ], + [ 255, 128, 255 ], + [ 128, 255, 255 ], + [ 128, 255, 255 ], + [ 255, 255, 128 ], + [ 255, 255, 128 ], + [ 255, 128, 255 ], + [ 128, 0, 0 ], + [ 0, 0, 128 ], + [ 128, 128, 0 ], + [ 128, 0, 128 ], + [ 0, 128, 128 ], + [ 0, 128, 0 ], + [ 128, 128, 0 ], + [ 128, 128, 0 ], + [ 128, 0, 128 ], + [ 128, 0, 128 ], + [ 0, 128, 128 ], + [ 0, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + [ 128, 128, 128 ], + ]) / 255.0 + + colors = colors.view(48,3) + return colors[o] + +def to_rgb(tensor: th.Tensor): + return th.cat(( + tensor * 0.6 + 0.4, + tensor, + tensor + ), dim=1) + +def visualise_gate(gate, h, w): + bar = th.ones((1,h,w), device=gate.device) * 0.9 + black = int(w*gate.item()) + if black > 0: + bar[:,:, -black:] = 0 + return bar + +def get_highlighted_input(input, mask_cur): + + # highlight error + highlighted_input = input + if mask_cur is not None: + grayscale = input[:,0:1] * 0.299 + input[:,1:2] * 0.587 + input[:,2:3] * 0.114 + object_mask_cur = th.sum(mask_cur[:,:-1], dim=1).unsqueeze(dim=1) + highlighted_input = grayscale * (1 - object_mask_cur) + highlighted_input += grayscale * object_mask_cur * 0.3333333 + cmask = color_mask(mask_cur[:,:-1]) + highlighted_input = highlighted_input + cmask * 0.6666666 + + return highlighted_input + +def color_slots(image, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur): + + image = (1-image) * slots_bounded + image * (1-slots_bounded) + image = th.clip(image - 0.3, 0,1) * slots_partially_occluded_cur + image * (1-slots_partially_occluded_cur) + image = th.clip(image - 0.3, 0,1) * slots_occluded_cur + image * (1-slots_occluded_cur) + + return image + +def compute_occlusion_mask(rawmask_cur, rawmask_next, mask_cur, mask_next, scale): + + # compute occlusion mask + occluded_cur = th.clip(rawmask_cur - mask_cur, 0, 1)[:,:-1] + occluded_next = th.clip(rawmask_next - mask_next, 0, 1)[:,:-1] + + # to rgb + rawmask_cur = repeat(rawmask_cur[:,:-1], 'b o h w -> b (o 3) h w') + rawmask_next = repeat(rawmask_next[:,:-1], 'b o h w -> b (o 3) h w') + + # scale + occluded_next = preprocess(occluded_next, scale) + occluded_cur = preprocess(occluded_cur, scale) + rawmask_cur = preprocess(rawmask_cur, scale) + rawmask_next = preprocess(rawmask_next, scale) + + # set occlusion to red + rawmask_cur = rearrange(rawmask_cur, 'b (o c) h w -> b o c h w', c = 3) + rawmask_cur[:,:,0] = rawmask_cur[:,:,0] * (1 - occluded_next) + rawmask_cur[:,:,1] = rawmask_cur[:,:,1] * (1 - occluded_next) + + rawmask_next = rearrange(rawmask_next, 'b (o c) h w -> b o c h w', c = 3) + rawmask_next[:,:,0] = rawmask_next[:,:,0] * (1 - occluded_next) + rawmask_next[:,:,1] = rawmask_next[:,:,1] * (1 - occluded_next) + + return rawmask_cur, rawmask_next + +def plot_online_error_slots(errors, error_name, target, sequence_len, root_path, visibilty_memory, slots_bounded, ylim=0.3): + error_plots = [] + if len(errors) > 0: + num_slots = int(th.sum(slots_bounded).item()) + errors = rearrange(np.array(errors), '(l o) -> o l', o=len(slots_bounded))[:num_slots] + visibilty_memory = rearrange(np.array(visibilty_memory), '(l o) -> o l', o=len(slots_bounded))[:num_slots] + for error,visibility in zip(errors, visibilty_memory): + + if len(error) < sequence_len: + fig, ax = plt.subplots(figsize=(round(target.shape[3]/100,2), round(target.shape[2]/100,2))) + plt.plot(error, label=error_name) + + visibility = np.concatenate((visibility, np.ones(sequence_len-len(error)))) + ax.fill_between(range(sequence_len), 0, 1, where=visibility==0, color='orange', alpha=0.3, transform=ax.get_xaxis_transform()) + + plt.xlim((0,sequence_len)) + plt.ylim((0,ylim)) + fig.tight_layout() + plt.savefig(f'{root_path}/tmp.jpg') + + error_plot = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb()) + error_plot = th.from_numpy(np.array(error_plot).transpose(2,0,1)) + plt.close(fig) + + error_plots.append(error_plot) + + return error_plots + +def plot_online_error(error, error_name, target, t, i, sequence_len, root_path, online_surprise = False): + + fig = plt.figure(figsize=( round(target.shape[3]/50,2), round(target.shape[2]/50,2) )) + plt.plot(error, label=error_name) + + if online_surprise: + # compute moving average of error + moving_average_length = 10 + if t > moving_average_length: + moving_average_length += 1 + average_error = np.mean(error[-moving_average_length:-1]) + current_sd = np.std(error[-moving_average_length:-1]) + current_error = error[-1] + + if current_error > average_error + 2 * current_sd: + fig.set_facecolor('orange') + + plt.xlim((0,sequence_len)) + plt.legend() + # increase title size + plt.title(f'{error_name}', fontsize=20) + plt.xlabel('timestep') + plt.ylabel('error') + plt.savefig(f'{root_path}/tmp.jpg') + + error_plot = PIL.Image.frombytes('RGB', fig.canvas.get_width_height(),fig.canvas.tostring_rgb()) + error_plot = th.from_numpy(np.array(error_plot).transpose(2,0,1)) + plt.close(fig) + + return error_plot + +def plot_object_view(error_plot, error_plot2, error_plot_slots, error_plot_slots2, highlighted_input, output_hidden, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object): + + # add ground truth positions of objects to image + if gt_positions_target_next is not None: + for o in range(gt_positions_target_next.shape[1]): + position = gt_positions_target_next[0, o] + position = position/2 + 0.5 + + if position[2] > 0.0 and position[0] > 0.0 and position[0] < 1.0 and position[1] > 0.0 and position[1] < 1.0: + width = 5 + w = np.clip(int(position[0]*target.shape[2]), width, target.shape[2]-width).item() + h = np.clip(int(position[1]*target.shape[3]), width, target.shape[3]-width).item() + col = get_color(o).view(3,1,1) + target[0,:,(w-width):(w+width), (h-width):(h+width)] = col + + # add these positions to the associated slots velocity_next2d ilustration + slots = (association_table[0] == o).nonzero() + for s in slots.flatten(): + velocity_next2d[s,:,(w-width):(w+width), (h-width):(h+width)] = col + + if output_hidden is not None and s != largest_object: + output_hidden[0,:,(w-width):(w+width), (h-width):(h+width)] = col + + gateheight = 60 + ch = 40 + gh = 40 + gh_bar = gh-20 + gh_margin = int((gh-gh_bar)/2) + margin = 20 + slots_margin = 10 + height = size[0] * 6 + 18*5 + width = size[1] * 4 + 18*2 + size[1]*num_objects + 6*(num_objects+1) + slots_margin*(num_objects+1) + img = th.ones((3, height, width), device = object_next.device) * 0.4 + row = (lambda row_index: [2*size[0]*row_index + (row_index+1)*margin, 2*size[0]*(row_index+1) + (row_index+1)*margin]) + col1 = range(margin, margin + size[1]*2) + col2 = range(width-(margin+size[1]*2), width-margin) + + img[:,row(0)[0]:row(0)[1], col1] = preprocess(highlighted_input.to(object_next.device), 2)[0] + img[:,row(1)[0]:row(1)[1], col1] = preprocess(output_hidden.to(object_next.device), 2)[0] + img[:,row(2)[0]:row(2)[1], col1] = preprocess(target.to(object_next.device), 2)[0] + + if error_plot is not None: + img[:,row(0)[0]+gh+ch+2*margin-gh_margin:row(0)[1]+gh+ch+2*margin-gh_margin, col2] = preprocess(error_plot.to(object_next.device), normalize= True) + if error_plot2 is not None: + img[:,row(2)[0]:row(2)[1], col2] = preprocess(error_plot2.to(object_next.device), normalize= True) + + for o in range(num_objects): + + col = 18+size[1]*2+6+o*(6+size[1])+(o+1)*slots_margin + col = range(col, col + size[1]) + + # color bar for the gate + if (error_plot_slots2 is not None) and len(error_plot_slots2) > o: + img[:,margin:margin+ch, col] = get_color(o).view(3,1,1).to(object_next.device) + + img[:,margin+ch+2*margin:2*margin+gh_bar+ch+margin, col] = visualise_gate(slots_closed[:,o, 0].to(object_next.device), h=gh_bar, w=len(col)) + offset = gh+margin-gh_margin+ch+2*margin + row = (lambda row_index: [offset+(size[0]+6)*row_index, offset+size[0]*(row_index+1)+6*row_index]) + + img[:,row(0)[0]:row(0)[1], col] = preprocess(rawmask_next[0,o].to(object_next.device)) + img[:,row(1)[0]:row(1)[1], col] = preprocess(object_next[:,o].to(object_next.device)) + if (error_plot_slots2 is not None) and len(error_plot_slots2) > o: + img[:,row(2)[0]:row(2)[1], col] = preprocess(error_plot_slots2[o].to(object_next.device), normalize=True) + + offset = margin*2-8 + row = (lambda row_index: [offset+(size[0]+6)*row_index, offset+size[0]*(row_index+1)+6*row_index]) + img[:,row(4)[0]-gh+gh_margin:row(4)[0]-gh_margin, col] = visualise_gate(slots_closed[:,o, 1].to(object_next.device), h=gh_bar, w=len(col)) + img[:,row(4)[0]:row(4)[1], col] = preprocess(velocity_next2d[o].to(object_next.device), normalize=True)[0] + if (error_plot_slots is not None) and len(error_plot_slots) > o: + img[:,row(5)[0]:row(5)[1], col] = preprocess(error_plot_slots[o].to(object_next.device), normalize=True) + + img = rearrange(img * 255, 'c h w -> h w c').cpu().numpy() + + return img + +def write_image(file, img): + img = rearrange(img * 255, 'c h w -> h w c').cpu().numpy() + cv2.imwrite(file, img) + + pass + +def plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i): + + # Create eposition helpers + size, gaus2d, vector2d, scale = get_position_helper(cfg_net, mask_cur.device) + + # Compute plot content + highlighted_input = get_highlighted_input(input, mask_cur) + output = th.clip(output_next, 0, 1) + position_cur2d = gaus2d(rearrange(position_encoder_cur, 'b (o c) -> (b o) c', o=cfg_net.num_objects)) + velocity_next2d = vector2d(rearrange(position_next, 'b (o c) -> (b o) c', o=cfg_net.num_objects)) + + # color slots + slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next = reshape_slots(slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next) + position_cur2d = color_slots(position_cur2d, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur) + velocity_next2d = color_slots(velocity_next2d, slots_bounded, slots_partially_occluded_next, slots_occluded_next) + + # compute occlusion + if (cfg.datatype == "adept"): + rawmask_cur_l, rawmask_next_l = compute_occlusion_mask(rawmask_cur, rawmask_next, mask_cur, mask_next, scale) + rawmask_cur_h, rawmask_next_h = compute_occlusion_mask(rawmask_cur, rawmask_hidden, mask_cur, mask_next, scale) + rawmask_cur_h[:,largest_object] = rawmask_cur_l[:,largest_object] + rawmask_next_h[:,largest_object] = rawmask_next_l[:,largest_object] + rawmask_cur = rawmask_cur_h + rawmask_next = rawmask_next_h + object_hidden[:, largest_object] = object_next[:, largest_object] + object_next = object_hidden + else: + rawmask_cur, rawmask_next = compute_occlusion_mask(rawmask_cur, rawmask_next, mask_cur, mask_next, scale) + + # scale plot content + input, target, output, highlighted_input, object_next, object_cur, mask_next, error_next, output_hidden, output_next = preprocess_multi(input, target, output, highlighted_input, object_next, object_cur, mask_next, error_next, output_hidden, output_next, scale=scale) + + # reshape + object_next = rearrange(object_next, 'b (o c) h w -> b o c h w', c = cfg_net.img_channels) + object_cur = rearrange(object_cur, 'b (o c) h w -> b o c h w', c = cfg_net.img_channels) + mask_next = rearrange(mask_next, 'b (o 1) h w -> b o 1 h w') + + if object_view: + if (cfg.datatype == "adept"): + num_objects = 4 + error_plot_slots = plot_online_error_slots(statistics_complete_slots['TE'][-cfg_net.num_objects*(t+1):], 'Tracking error', target, sequence_len, root_path, statistics_complete_slots['visible'][-cfg_net.num_objects*(t+1):], slots_bounded) + error_plot_slots2 = plot_online_error_slots(statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], 'Image error', target, sequence_len, root_path, statistics_complete_slots['visible'][-cfg_net.num_objects*(t+1):], slots_bounded, ylim=0.0001) + error_plot = plot_online_error(statistics_batch['image_error'], 'Prediction error', target, t, i, sequence_len, root_path) + error_plot2 = plot_online_error(statistics_batch['TE'], 'Tracking error', target, t, i, sequence_len, root_path) + img = plot_object_view(error_plot, error_plot2, error_plot_slots, error_plot_slots2, highlighted_input, output_hidden, object_next, rawmask_next, velocity_next2d, target, slots_closed, gt_positions_target_next, association_table, size, num_objects, largest_object) + else: + num_objects = cfg_net.num_objects + error_plot_slots2 = plot_online_error_slots(statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], 'Image error', target, sequence_len, root_path, statistics_complete_slots['slot_error'][-cfg_net.num_objects*(t+1):], slots_bounded, ylim=0.0001) + error_plot = plot_online_error(statistics_batch['image_error_mse'], 'Prediction error', target, t, i, sequence_len, root_path) + img = plot_object_view(error_plot, None, None, error_plot_slots2, highlighted_input, output_next, object_next, rawmask_next, velocity_next2d, target, slots_closed, None, None, size, num_objects, largest_object) + + cv2.imwrite(f'{plot_path}object/gpnet-objects-{i:04d}-{t_index:03d}.jpg', img) + + if individual_views: + # ['error', 'input', 'background', 'prediction', 'position', 'rawmask', 'mask', 'othermask']: + write_image(f'{plot_path}/individual/error/error-{i:04d}-{t_index:03d}.jpg', error_next[0]) + write_image(f'{plot_path}/individual/input/input-{i:04d}-{t_index:03d}.jpg', input[0]) + write_image(f'{plot_path}/individual/background/background-{i:04d}-{t_index:03d}.jpg', mask_next[0,-1]) + write_image(f'{plot_path}/individual/imagination/imagination-{i:04d}-{t_index:03d}.jpg', output_hidden[0]) + write_image(f'{plot_path}/individual/prediction/prediction-{i:04d}-{t_index:03d}.jpg', output_next[0]) + + pass + +def get_position_helper(cfg_net, device): + size = cfg_net.input_size + gaus2d = Gaus2D(size).to(device) + vector2d = Vector2D(size).to(device) + scale = size[0] // (cfg_net.latent_size[0] * 2**(cfg_net.level*2)) + return size,gaus2d,vector2d,scale + +def reshape_slots(slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next): + + slots_bounded = th.squeeze(slots_bounded)[..., None,None,None] + slots_partially_occluded_cur = th.squeeze(slots_partially_occluded_cur)[..., None,None,None] + slots_occluded_cur = th.squeeze(slots_occluded_cur)[..., None,None,None] + slots_partially_occluded_next = th.squeeze(slots_partially_occluded_next)[..., None,None,None] + slots_occluded_next = th.squeeze(slots_occluded_next)[..., None,None,None] + + return slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next
\ No newline at end of file |