aboutsummaryrefslogtreecommitdiff
path: root/data/datasets/CLEVRER
diff options
context:
space:
mode:
authorfredeee2023-11-02 10:47:21 +0100
committerfredeee2023-11-02 10:47:21 +0100
commitf8302ee886ef9b631f11a52900dac964a61350e1 (patch)
tree87288be6f851ab69405e524b81940c501c52789a /data/datasets/CLEVRER
parentf16fef1ab9371e1c81a2e0b2fbea59dee285a9f8 (diff)
initiaƶ commit
Diffstat (limited to 'data/datasets/CLEVRER')
-rw-r--r--data/datasets/CLEVRER/dataset.py205
1 files changed, 205 insertions, 0 deletions
diff --git a/data/datasets/CLEVRER/dataset.py b/data/datasets/CLEVRER/dataset.py
new file mode 100644
index 0000000..f08c974
--- /dev/null
+++ b/data/datasets/CLEVRER/dataset.py
@@ -0,0 +1,205 @@
+from torch.utils import data
+from typing import Tuple, Union, List
+import numpy as np
+import json
+import math
+import cv2
+import h5py
+import os
+import pickle
+
+
+class RamImage():
+ def __init__(self, path):
+
+ fd = open(path, 'rb')
+ img_str = fd.read()
+ fd.close()
+
+ self.img_raw = np.frombuffer(img_str, np.uint8)
+
+ def to_numpy(self):
+ return cv2.imdecode(self.img_raw, cv2.IMREAD_COLOR)
+
+class ClevrerSample(data.Dataset):
+ def __init__(self, root_path: str, data_path: str, size: Tuple[int, int]):
+
+ self.size = size
+ self.data_path = root_path
+ self.video_id = int(data_path.split('_')[1])
+
+ frames = []
+ for frame in os.listdir(os.path.join(root_path, data_path)):
+ if os.path.isfile(os.path.join(root_path, data_path, frame)) and frame.endswith('.jpg'):
+ frames.append(os.path.join(root_path, data_path, frame))
+ frames.sort()
+
+ self.imgs = []
+ for path in frames:
+ self.imgs.append(RamImage(path))
+
+ def get_data(self):
+
+ frames = np.zeros((128,3,self.size[1], self.size[0]),dtype=np.float32)
+ for i in range(len(self.imgs)):
+ img = self.imgs[i].to_numpy()
+ frames[i] = img.transpose(2, 0, 1).astype(np.float32) / 255.0
+
+ return frames
+
+ def downsample(self, size):
+ self.size = size
+ imgs = []
+ path = os.path.join(self.data_path, 'tmp.jpg')
+ for image_large in self.imgs:
+ img_small = cv2.resize(image_large.to_numpy(), dsize=(self.size[0], self.size[1]), interpolation=cv2.INTER_CUBIC)
+ cv2.imwrite(path, img_small)
+ imgs.append(RamImage(path))
+ self.imgs = imgs
+
+ # remove tmp.jpg
+ os.remove(path)
+
+ return self
+
+class ClevrerDataset(data.Dataset):
+
+ def save(self):
+ with open(self.file, "wb") as outfile:
+ pickle.dump(self.samples, outfile)
+
+ def load(self):
+ with open(self.file, "rb") as infile:
+ self.samples = pickle.load(infile)
+
+ def __init__(self, root_path: str, dataset_name: str, type: str, size: Tuple[int, int], full_size: Tuple[int, int] = None, use_slotformer: bool = False, evaluation: bool = False):
+
+ data_path = f'data/data/video/{dataset_name}'
+ data_path = os.path.join(root_path, data_path)
+ self.file = os.path.join(data_path, f'dataset-{size[0]}x{size[1]}-{type}.pickle')
+
+ self.samples = []
+
+ if os.path.exists(self.file):
+ self.load()
+ else:
+
+ if (full_size is None) or (size == full_size):
+ sample_path = os.path.join(data_path, type, 'images')
+ samples = []
+ print(f'sample_path:', sample_path)
+
+ for directory in os.listdir(sample_path):
+ path = os.path.join(sample_path, directory)
+ if os.path.isdir(path):
+ samples.append(directory)
+
+ samples.sort()
+ num_samples = len(samples)
+ print(f'num_samples:', num_samples)
+
+ for i, dir in enumerate(samples):
+ self.samples.append(ClevrerSample(sample_path, dir, size))
+
+ print(f"Loading CLEVRER [{i * 100 / num_samples:.2f}]", flush=True)
+
+ else:
+ # load full size dataset
+ full_dataset = ClevrerDataset(root_path, dataset_name, type, full_size, full_size)
+
+ # downsample
+ for i, sample in enumerate(full_dataset.samples):
+ self.samples.append(sample.downsample(size))
+
+ print(f"Loading CLEVRER {type} [{i * 100 / len(full_dataset.samples):.2f}]", flush=True)
+
+ self.save()
+
+ self.length = len(self.samples)
+ self.background = None
+ if "background.jpg" in os.listdir(data_path):
+ self.background = cv2.imread(os.path.join(data_path, "background.jpg"))
+ self.background = cv2.resize(self.background, dsize=size, interpolation=cv2.INTER_CUBIC)
+ self.background = self.background.transpose(2, 0, 1).astype(np.float32) / 255.0
+ self.background = self.background.reshape(self.background.shape[0], self.background.shape[1], self.background.shape[2])
+
+ self.use_slotformer = use_slotformer
+ self.eval = evaluation
+ if self.use_slotformer:
+ with open(f'{data_path}/slotformer/valid_idx_{type}.pt', 'rb') as f:
+ self.slotformer_idx = pickle.load(f)
+ self.length = len(self.slotformer_idx)
+ self.video_ids = [sample.video_id for sample in self.samples]
+ self.burn_in_length = 6
+ self.rollout_length = 10
+ self.skip_length = 2
+
+ if self.eval:
+ self.gt_mask = np.load(f'{data_path}/slotformer/gt_mask.npy').astype(np.int8)
+ self.gt_bbox = np.load(f'{data_path}/slotformer/gt_bbox.npy').astype(np.int16)
+ self.gt_pres_mask = np.load(f'{data_path}/slotformer/gt_pres_mask.npy').astype(bool)
+ #self.gt = np.load(f'{data_path}/slotformer/gt.npy').astype(np.float32)
+
+ print(f"ClevrerDataset: {self.length}")
+
+ if len(self) == 0:
+ raise FileNotFoundError(f'Found no dataset at {self.data_path}')
+
+ if False:
+ for sample in self.samples:
+ frames = sample.get_data()
+ i = 0
+ for frame in frames:
+ i += 1
+ if i == 29:
+ print('test')
+ frame = np.einsum('chw->hwc', frame)
+ cv2.imshow('frame', frame)
+ cv2.waitKey(0)
+
+ def __len__(self):
+ return self.length
+
+ def __getitem__(self, index: int):
+
+ if self.use_slotformer:
+ index_org = index
+ _, start_idx, video_id = self.slotformer_idx[index]
+ index = self.video_ids.index(video_id)
+
+ #print(f"Loading CLEVRER {index} {start_idx} {video_id}")
+
+ if video_id != self.samples[index].video_id:
+ raise ValueError(f"Video ID mismatch {video_id} != {self.samples[index].video_id}")
+
+ selec = range(start_idx, start_idx+(self.burn_in_length+self.rollout_length)*self.skip_length, self.skip_length)
+ frames = self.samples[index].get_data()[selec]
+
+ if self.eval:
+ gt_mask = self.gt_mask[index_org]
+ gt_bbox = self.gt_bbox[index_org]
+ gt_pres_mask = self.gt_pres_mask[index_org]
+ #gt = self.gt[index_org]
+
+ return (
+ frames,
+ self.background,
+ gt_mask,
+ gt_bbox,
+ gt_pres_mask,
+ #gt
+ )
+
+ else:
+ frames = self.samples[index].get_data()
+
+ return (
+ frames,
+ self.background
+ )
+
+#if __name__ == "__main__":
+ #dataset = ClevrerDataset('./', "CLEVRER", "train", [320, 240])
+
+ #dataset = ClevrerDataset('./', "CLEVRER", "train", [480, 320])
+ #dataset = ClevrerDataset('./', "CLEVRER", "train", [120, 80], [480, 320])