from torch.utils import data from typing import Tuple, Union, List import numpy as np import json import math import cv2 import h5py import os import pickle class RamImage(): def __init__(self, path): fd = open(path, 'rb') img_str = fd.read() fd.close() self.img_raw = np.frombuffer(img_str, np.uint8) def to_numpy(self): return cv2.imdecode(self.img_raw, cv2.IMREAD_COLOR) class ClevrerSample(data.Dataset): def __init__(self, root_path: str, data_path: str, size: Tuple[int, int]): self.size = size self.data_path = root_path self.video_id = int(data_path.split('_')[1]) frames = [] for frame in os.listdir(os.path.join(root_path, data_path)): if os.path.isfile(os.path.join(root_path, data_path, frame)) and frame.endswith('.jpg'): frames.append(os.path.join(root_path, data_path, frame)) frames.sort() self.imgs = [] for path in frames: self.imgs.append(RamImage(path)) def get_data(self): frames = np.zeros((128,3,self.size[1], self.size[0]),dtype=np.float32) for i in range(len(self.imgs)): img = self.imgs[i].to_numpy() frames[i] = img.transpose(2, 0, 1).astype(np.float32) / 255.0 return frames def downsample(self, size): self.size = size imgs = [] path = os.path.join(self.data_path, 'tmp.jpg') for image_large in self.imgs: img_small = cv2.resize(image_large.to_numpy(), dsize=(self.size[0], self.size[1]), interpolation=cv2.INTER_CUBIC) cv2.imwrite(path, img_small) imgs.append(RamImage(path)) self.imgs = imgs # remove tmp.jpg os.remove(path) return self class ClevrerDataset(data.Dataset): def save(self): with open(self.file, "wb") as outfile: pickle.dump(self.samples, outfile) def load(self): with open(self.file, "rb") as infile: self.samples = pickle.load(infile) def __init__(self, root_path: str, dataset_name: str, type: str, size: Tuple[int, int], full_size: Tuple[int, int] = None, use_slotformer: bool = False, evaluation: bool = False): data_path = f'data/data/video/{dataset_name}' data_path = os.path.join(root_path, data_path) self.file = os.path.join(data_path, f'dataset-{size[0]}x{size[1]}-{type}.pickle') self.samples = [] if os.path.exists(self.file): self.load() else: if (full_size is None) or (size == full_size): sample_path = os.path.join(data_path, type, 'images') samples = [] print(f'sample_path:', sample_path) for directory in os.listdir(sample_path): path = os.path.join(sample_path, directory) if os.path.isdir(path): samples.append(directory) samples.sort() num_samples = len(samples) print(f'num_samples:', num_samples) for i, dir in enumerate(samples): self.samples.append(ClevrerSample(sample_path, dir, size)) print(f"Loading CLEVRER [{i * 100 / num_samples:.2f}]", flush=True) else: # load full size dataset full_dataset = ClevrerDataset(root_path, dataset_name, type, full_size, full_size) # downsample for i, sample in enumerate(full_dataset.samples): self.samples.append(sample.downsample(size)) print(f"Loading CLEVRER {type} [{i * 100 / len(full_dataset.samples):.2f}]", flush=True) self.save() self.length = len(self.samples) self.background = None if "background.jpg" in os.listdir(data_path): self.background = cv2.imread(os.path.join(data_path, "background.jpg")) self.background = cv2.resize(self.background, dsize=size, interpolation=cv2.INTER_CUBIC) self.background = self.background.transpose(2, 0, 1).astype(np.float32) / 255.0 self.background = self.background.reshape(self.background.shape[0], self.background.shape[1], self.background.shape[2]) self.use_slotformer = use_slotformer self.eval = evaluation if self.use_slotformer: with open(f'{data_path}/slotformer/valid_idx_{type}.pt', 'rb') as f: self.slotformer_idx = pickle.load(f) self.length = len(self.slotformer_idx) self.video_ids = [sample.video_id for sample in self.samples] self.burn_in_length = 6 self.rollout_length = 10 self.skip_length = 2 if self.eval: self.gt_mask = np.load(f'{data_path}/slotformer/gt_mask.npy').astype(np.int8) self.gt_bbox = np.load(f'{data_path}/slotformer/gt_bbox.npy').astype(np.int16) self.gt_pres_mask = np.load(f'{data_path}/slotformer/gt_pres_mask.npy').astype(bool) #self.gt = np.load(f'{data_path}/slotformer/gt.npy').astype(np.float32) print(f"ClevrerDataset: {self.length}") if len(self) == 0: raise FileNotFoundError(f'Found no dataset at {self.data_path}') if False: for sample in self.samples: frames = sample.get_data() i = 0 for frame in frames: i += 1 if i == 29: print('test') frame = np.einsum('chw->hwc', frame) cv2.imshow('frame', frame) cv2.waitKey(0) def __len__(self): return self.length def __getitem__(self, index: int): if self.use_slotformer: index_org = index _, start_idx, video_id = self.slotformer_idx[index] index = self.video_ids.index(video_id) #print(f"Loading CLEVRER {index} {start_idx} {video_id}") if video_id != self.samples[index].video_id: raise ValueError(f"Video ID mismatch {video_id} != {self.samples[index].video_id}") selec = range(start_idx, start_idx+(self.burn_in_length+self.rollout_length)*self.skip_length, self.skip_length) frames = self.samples[index].get_data()[selec] if self.eval: gt_mask = self.gt_mask[index_org] gt_bbox = self.gt_bbox[index_org] gt_pres_mask = self.gt_pres_mask[index_org] #gt = self.gt[index_org] return ( frames, self.background, gt_mask, gt_bbox, gt_pres_mask, #gt ) else: frames = self.samples[index].get_data() return ( frames, self.background ) #if __name__ == "__main__": #dataset = ClevrerDataset('./', "CLEVRER", "train", [320, 240]) #dataset = ClevrerDataset('./', "CLEVRER", "train", [480, 320]) #dataset = ClevrerDataset('./', "CLEVRER", "train", [120, 80], [480, 320])