diff options
Diffstat (limited to 'data/datasets/BOUNCINGBALLS/dataset.py')
-rw-r--r-- | data/datasets/BOUNCINGBALLS/dataset.py | 195 |
1 files changed, 195 insertions, 0 deletions
diff --git a/data/datasets/BOUNCINGBALLS/dataset.py b/data/datasets/BOUNCINGBALLS/dataset.py new file mode 100644 index 0000000..3460d02 --- /dev/null +++ b/data/datasets/BOUNCINGBALLS/dataset.py @@ -0,0 +1,195 @@ +from pickletools import int4 +from torch.utils import data +from typing import Tuple, Union, List +import numpy as np +import json +import math +import cv2 +import h5py +import os +import pickle +import sys +import yaml +import warnings +from PIL import Image +from einops import reduce, rearrange, repeat +import torch as th + + +class BouncingBallDataset(data.Dataset): + + def __init__(self, root_path: str, dataset_name: str, type: str, size: Tuple[int, int], type_name: str = None, full_size: Tuple[int, int] = None, create_dataset: bool = False): + + assert type in ["train", "test", "val"] + assert type_name in ["interaction", "occlusion", "twolayer", "twolayerdense", "twolayer_ood", "threelayer_ood", "twolayer_ood_3balls"] + + data_path = f'data/data/video/{dataset_name}' + data_path = os.path.join(root_path, data_path) + self.file = os.path.join(data_path, f'balls_{type_name}-{type}-{size[0]}x{size[1]}-v1.hdf5') + self.train = (type == "train") + self.samples = [] + + if os.path.exists(self.file): + self.hdf5_file = h5py.File(self.file, "r") + + # load dataset + self.length = self.hdf5_file['sequence_indices'].shape[0] + self.background = np.zeros((3, size[0], size[1]), dtype=np.uint8) + + # set number of objects + if (type_name == "twolayer") or (type_name == "threelayer_ood" and type != "test"): + self.num_objects = 6 + elif (type_name == "twolayer_ood" and type == "test") or (type_name == "twolayer_ood_3balls" and type == "test"): + self.num_objects = 4 + elif type_name == "twolayer_ood": + self.num_objects = 2 + elif type_name == "twolayer_ood_3balls": + self.num_objects = 3 + elif (type_name == "threelayer_ood" and type == "test"): + self.num_objects = 9 + else: + self.num_objects = 3 + + if len(self) == 0: + raise FileNotFoundError(f'Found no dataset at {data_path}') + + # loop trough own dataset by calling __getitem__ + if False: + for i in range(len(self)): + self[i] + + def add_one_timestep(self, x): + return np.concatenate((x, np.zeros_like(x[:1])), axis=0) + + def __len__(self): + return self.length + + def __getitem__(self, index: int): + + index_start, length = self.hdf5_file['sequence_indices'][index] + rgb_images = self.hdf5_file["rgb_images"][index_start:index_start+length] + + if rgb_images[0].dtype == np.uint8: + images = [] + for i in range(len(rgb_images)): + img = cv2.imdecode(rgb_images[i], 1) + images.append(img.transpose(2, 0, 1).astype(np.float32) / 255.0) + + rgb_images = np.stack(images) + + rgb_images = th.from_numpy(rgb_images) + + if self.train: + return ( + rgb_images, + self.background + ) + + # EVALUATION + num_objects = self.num_objects + instance_positions = self.hdf5_file['instance_positions'][index_start*num_objects:(index_start+length)*num_objects] + instance_positions = rearrange(instance_positions, '(t o) c -> t o c', o=num_objects) + instance_positions = instance_positions[:, :, ::-1] # IMPORTANT: flip x and y axis + + instance_pres = self.hdf5_file['instance_incamera'][index_start*num_objects:(index_start+length)*num_objects] + instance_pres = rearrange(instance_pres, '(t o) c -> t o c', o=num_objects).squeeze(-1) + + instance_bounding_boxes = self.hdf5_file['instance_mask_bboxes'][index_start*num_objects:(index_start+length)*num_objects] + instance_bounding_boxes = rearrange(instance_bounding_boxes, '(t o) c -> t o c', o=num_objects) + instance_bounding_boxes = instance_bounding_boxes[:, :, [1, 0, 3, 2]] + + foreground_mask = self.hdf5_file['foreground_mask'][index_start:(index_start+length)] + foreground_mask = rearrange(foreground_mask, 't 1 h w -> t h w')/255 + + instance_masks = self.hdf5_file['instance_masks'][index_start*num_objects:(index_start+length)*num_objects] + instance_masks = rearrange(instance_masks, '(t o) 1 h w -> t o 1 h w', o=num_objects).squeeze()/255 + + # CUSTOM + # use instance masks to to create hidden masks + hidden_mask = reduce(instance_masks, 't o h w -> t 1 h w', 'sum').squeeze() + hidden_mask = (hidden_mask > 1).astype(np.uint8) + + # segmentation_masks: index gives which object is visible at that pixel + segmentation_mask = np.argmax(instance_masks[:, ::-1], axis=1) + 1 + segmentation_mask = foreground_mask * segmentation_mask + + # segmentation mask but only for hidden objects + segementation_mask_hidden = np.argmax(instance_masks[:, :3], axis=1) + 1 # TODO only works for 6 objects + segementation_mask_hidden = hidden_mask * segementation_mask_hidden + + # add one dummy timestep at the end + instance_positions = self.add_one_timestep(instance_positions) + rgb_images = self.add_one_timestep(rgb_images) + foreground_mask = self.add_one_timestep(foreground_mask) + hidden_mask = self.add_one_timestep(hidden_mask) + instance_pres = self.add_one_timestep(instance_pres) + instance_bounding_boxes = self.add_one_timestep(instance_bounding_boxes) + instance_masks = self.add_one_timestep(instance_masks) + segmentation_mask = self.add_one_timestep(segmentation_mask) + segementation_mask_hidden = self.add_one_timestep(segementation_mask_hidden) + + if False: + video = np.array(rgb_images) + locations = np.array(instance_positions) + fg_masks = np.array(foreground_mask) + bb = np.array(instance_bounding_boxes) + h_masks = np.array(hidden_mask) + + # loop through video frames and show them using cv2 + for t in range(video.shape[0]): + frame = rearrange(video[t], 'c h w -> h w c') * 255 + + for loc in locations[t]: + x, y = loc + #cv2.circle(frame, (int(x), int(y)), 2, (255, 0, 0), -1) # did not work properly + x_max = int(min(int(x + 2), frame.shape[1])) + x_min = int(max(int(x - 2), 0)) + y_max = int(min(int(y + 2), frame.shape[0])) + y_min = int(max(int(y - 2), 0)) + frame[x_min:x_max, y_min:y_max] = [255, 0, 0] + + # draw the bounding boxes into the frame + for i, b in enumerate(bb[t]): + x_min, y_min, x_max, y_max = b + + x_min = int(max(x_min, 0)) + y_min = int(max(y_min, 0)) + x_max = int(min(x_max, frame.shape[0]-1)) + y_max = int(min(y_max, frame.shape[0]-1)) + + # dont't use c2 rectangeel function here but draw it manually + for pixel in range(x_min, x_max): + frame[pixel, y_min, 0] = 255 + frame[pixel, y_max, 0] = 255 + for pixel in range(y_min, y_max): + frame[x_min, pixel, 0] = 255 + frame[x_max, pixel, 0] = 255 + + fg_mask = repeat(fg_masks[t], 'h w -> h w 3') * 255 + h_mask = repeat(h_masks[t], 'h w -> h w 3') * 255 + s_mask = repeat(segmentation_mask[t], 'h w -> h w 3') * (255/6) + s_mask_hidden = repeat(segementation_mask_hidden[t], 'h w -> h w 3') * (255/3) + frame = np.concatenate((frame, fg_mask, s_mask, h_mask, s_mask_hidden), axis=1) + + # add instance masks to visualisation + for i, mask in enumerate(instance_masks[t]): + mask = repeat(mask, 'h w -> h w 3') * 255 + # add border to the right side + mask[:, -1, 0] = 255 + frame = np.concatenate((frame, mask), axis=1) + + frame = frame.astype(np.uint8) + cv2.imshow('frame', frame) + cv2.waitKey(0) + + return ( + rgb_images, + self.background, + instance_positions, + segmentation_mask, + instance_pres, + segementation_mask_hidden + ) + + +#a = BouncingBallDataset("./", 'BOUNCINGBALLS', "train", (64,64))
\ No newline at end of file |