diff options
author | fredeee | 2023-11-02 10:47:21 +0100 |
---|---|---|
committer | fredeee | 2023-11-02 10:47:21 +0100 |
commit | f8302ee886ef9b631f11a52900dac964a61350e1 (patch) | |
tree | 87288be6f851ab69405e524b81940c501c52789a /data/datasets/CLEVRER | |
parent | f16fef1ab9371e1c81a2e0b2fbea59dee285a9f8 (diff) |
initiaƶ commit
Diffstat (limited to 'data/datasets/CLEVRER')
-rw-r--r-- | data/datasets/CLEVRER/dataset.py | 205 |
1 files changed, 205 insertions, 0 deletions
diff --git a/data/datasets/CLEVRER/dataset.py b/data/datasets/CLEVRER/dataset.py new file mode 100644 index 0000000..f08c974 --- /dev/null +++ b/data/datasets/CLEVRER/dataset.py @@ -0,0 +1,205 @@ +from torch.utils import data +from typing import Tuple, Union, List +import numpy as np +import json +import math +import cv2 +import h5py +import os +import pickle + + +class RamImage(): + def __init__(self, path): + + fd = open(path, 'rb') + img_str = fd.read() + fd.close() + + self.img_raw = np.frombuffer(img_str, np.uint8) + + def to_numpy(self): + return cv2.imdecode(self.img_raw, cv2.IMREAD_COLOR) + +class ClevrerSample(data.Dataset): + def __init__(self, root_path: str, data_path: str, size: Tuple[int, int]): + + self.size = size + self.data_path = root_path + self.video_id = int(data_path.split('_')[1]) + + frames = [] + for frame in os.listdir(os.path.join(root_path, data_path)): + if os.path.isfile(os.path.join(root_path, data_path, frame)) and frame.endswith('.jpg'): + frames.append(os.path.join(root_path, data_path, frame)) + frames.sort() + + self.imgs = [] + for path in frames: + self.imgs.append(RamImage(path)) + + def get_data(self): + + frames = np.zeros((128,3,self.size[1], self.size[0]),dtype=np.float32) + for i in range(len(self.imgs)): + img = self.imgs[i].to_numpy() + frames[i] = img.transpose(2, 0, 1).astype(np.float32) / 255.0 + + return frames + + def downsample(self, size): + self.size = size + imgs = [] + path = os.path.join(self.data_path, 'tmp.jpg') + for image_large in self.imgs: + img_small = cv2.resize(image_large.to_numpy(), dsize=(self.size[0], self.size[1]), interpolation=cv2.INTER_CUBIC) + cv2.imwrite(path, img_small) + imgs.append(RamImage(path)) + self.imgs = imgs + + # remove tmp.jpg + os.remove(path) + + return self + +class ClevrerDataset(data.Dataset): + + def save(self): + with open(self.file, "wb") as outfile: + pickle.dump(self.samples, outfile) + + def load(self): + with open(self.file, "rb") as infile: + self.samples = pickle.load(infile) + + def __init__(self, root_path: str, dataset_name: str, type: str, size: Tuple[int, int], full_size: Tuple[int, int] = None, use_slotformer: bool = False, evaluation: bool = False): + + data_path = f'data/data/video/{dataset_name}' + data_path = os.path.join(root_path, data_path) + self.file = os.path.join(data_path, f'dataset-{size[0]}x{size[1]}-{type}.pickle') + + self.samples = [] + + if os.path.exists(self.file): + self.load() + else: + + if (full_size is None) or (size == full_size): + sample_path = os.path.join(data_path, type, 'images') + samples = [] + print(f'sample_path:', sample_path) + + for directory in os.listdir(sample_path): + path = os.path.join(sample_path, directory) + if os.path.isdir(path): + samples.append(directory) + + samples.sort() + num_samples = len(samples) + print(f'num_samples:', num_samples) + + for i, dir in enumerate(samples): + self.samples.append(ClevrerSample(sample_path, dir, size)) + + print(f"Loading CLEVRER [{i * 100 / num_samples:.2f}]", flush=True) + + else: + # load full size dataset + full_dataset = ClevrerDataset(root_path, dataset_name, type, full_size, full_size) + + # downsample + for i, sample in enumerate(full_dataset.samples): + self.samples.append(sample.downsample(size)) + + print(f"Loading CLEVRER {type} [{i * 100 / len(full_dataset.samples):.2f}]", flush=True) + + self.save() + + self.length = len(self.samples) + self.background = None + if "background.jpg" in os.listdir(data_path): + self.background = cv2.imread(os.path.join(data_path, "background.jpg")) + self.background = cv2.resize(self.background, dsize=size, interpolation=cv2.INTER_CUBIC) + self.background = self.background.transpose(2, 0, 1).astype(np.float32) / 255.0 + self.background = self.background.reshape(self.background.shape[0], self.background.shape[1], self.background.shape[2]) + + self.use_slotformer = use_slotformer + self.eval = evaluation + if self.use_slotformer: + with open(f'{data_path}/slotformer/valid_idx_{type}.pt', 'rb') as f: + self.slotformer_idx = pickle.load(f) + self.length = len(self.slotformer_idx) + self.video_ids = [sample.video_id for sample in self.samples] + self.burn_in_length = 6 + self.rollout_length = 10 + self.skip_length = 2 + + if self.eval: + self.gt_mask = np.load(f'{data_path}/slotformer/gt_mask.npy').astype(np.int8) + self.gt_bbox = np.load(f'{data_path}/slotformer/gt_bbox.npy').astype(np.int16) + self.gt_pres_mask = np.load(f'{data_path}/slotformer/gt_pres_mask.npy').astype(bool) + #self.gt = np.load(f'{data_path}/slotformer/gt.npy').astype(np.float32) + + print(f"ClevrerDataset: {self.length}") + + if len(self) == 0: + raise FileNotFoundError(f'Found no dataset at {self.data_path}') + + if False: + for sample in self.samples: + frames = sample.get_data() + i = 0 + for frame in frames: + i += 1 + if i == 29: + print('test') + frame = np.einsum('chw->hwc', frame) + cv2.imshow('frame', frame) + cv2.waitKey(0) + + def __len__(self): + return self.length + + def __getitem__(self, index: int): + + if self.use_slotformer: + index_org = index + _, start_idx, video_id = self.slotformer_idx[index] + index = self.video_ids.index(video_id) + + #print(f"Loading CLEVRER {index} {start_idx} {video_id}") + + if video_id != self.samples[index].video_id: + raise ValueError(f"Video ID mismatch {video_id} != {self.samples[index].video_id}") + + selec = range(start_idx, start_idx+(self.burn_in_length+self.rollout_length)*self.skip_length, self.skip_length) + frames = self.samples[index].get_data()[selec] + + if self.eval: + gt_mask = self.gt_mask[index_org] + gt_bbox = self.gt_bbox[index_org] + gt_pres_mask = self.gt_pres_mask[index_org] + #gt = self.gt[index_org] + + return ( + frames, + self.background, + gt_mask, + gt_bbox, + gt_pres_mask, + #gt + ) + + else: + frames = self.samples[index].get_data() + + return ( + frames, + self.background + ) + +#if __name__ == "__main__": + #dataset = ClevrerDataset('./', "CLEVRER", "train", [320, 240]) + + #dataset = ClevrerDataset('./', "CLEVRER", "train", [480, 320]) + #dataset = ClevrerDataset('./', "CLEVRER", "train", [120, 80], [480, 320]) |