aboutsummaryrefslogtreecommitdiff
path: root/data/datasets/BOUNCINGBALLS/dataset.py
blob: 3460d0294adb5cfc1dfff20b5daf977d6b9e26cb (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
from pickletools import int4
from torch.utils import data
from typing import Tuple, Union, List
import numpy as np
import json
import math
import cv2
import h5py
import os
import pickle
import sys
import yaml
import warnings
from PIL import Image
from einops import reduce, rearrange, repeat
import torch as th


class BouncingBallDataset(data.Dataset):

    def __init__(self, root_path: str, dataset_name: str, type: str, size: Tuple[int, int], type_name: str = None, full_size: Tuple[int, int] = None, create_dataset: bool = False):

        assert type in ["train", "test", "val"]
        assert type_name in ["interaction", "occlusion", "twolayer", "twolayerdense", "twolayer_ood", "threelayer_ood", "twolayer_ood_3balls"]

        data_path  = f'data/data/video/{dataset_name}'
        data_path  = os.path.join(root_path, data_path)
        self.file  = os.path.join(data_path, f'balls_{type_name}-{type}-{size[0]}x{size[1]}-v1.hdf5')
        self.train = (type == "train")
        self.samples    = []

        if os.path.exists(self.file):
            self.hdf5_file = h5py.File(self.file, "r")
        
        # load dataset
        self.length     = self.hdf5_file['sequence_indices'].shape[0]
        self.background = np.zeros((3, size[0], size[1]), dtype=np.uint8)

        # set number of objects
        if (type_name == "twolayer") or (type_name == "threelayer_ood" and type != "test"):
            self.num_objects = 6
        elif (type_name == "twolayer_ood" and type == "test") or (type_name == "twolayer_ood_3balls" and type == "test"):
            self.num_objects = 4
        elif type_name == "twolayer_ood":
            self.num_objects = 2
        elif type_name == "twolayer_ood_3balls":
            self.num_objects = 3
        elif (type_name == "threelayer_ood" and type == "test"):
            self.num_objects = 9
        else:
            self.num_objects = 3

        if len(self) == 0:
            raise FileNotFoundError(f'Found no dataset at {data_path}')
        
        # loop trough own dataset by calling __getitem__
        if False:
            for i in range(len(self)):
                self[i]

    def add_one_timestep(self, x):
        return np.concatenate((x, np.zeros_like(x[:1])), axis=0)
        
    def __len__(self):
        return self.length

    def __getitem__(self, index: int):

        index_start, length = self.hdf5_file['sequence_indices'][index]
        rgb_images   = self.hdf5_file["rgb_images"][index_start:index_start+length]

        if rgb_images[0].dtype == np.uint8:
            images = []
            for i in range(len(rgb_images)):
                img = cv2.imdecode(rgb_images[i], 1)
                images.append(img.transpose(2, 0, 1).astype(np.float32) / 255.0)

            rgb_images   = np.stack(images)

        rgb_images   = th.from_numpy(rgb_images)

        if self.train:
            return (
                rgb_images,
                self.background
            )

        # EVALUATION
        num_objects = self.num_objects
        instance_positions = self.hdf5_file['instance_positions'][index_start*num_objects:(index_start+length)*num_objects]
        instance_positions = rearrange(instance_positions, '(t o) c -> t o c', o=num_objects)
        instance_positions = instance_positions[:, :, ::-1] # IMPORTANT: flip x and y axis

        instance_pres = self.hdf5_file['instance_incamera'][index_start*num_objects:(index_start+length)*num_objects]
        instance_pres = rearrange(instance_pres, '(t o) c -> t o c', o=num_objects).squeeze(-1)

        instance_bounding_boxes = self.hdf5_file['instance_mask_bboxes'][index_start*num_objects:(index_start+length)*num_objects]
        instance_bounding_boxes = rearrange(instance_bounding_boxes, '(t o) c -> t o c', o=num_objects)
        instance_bounding_boxes = instance_bounding_boxes[:, :, [1, 0, 3, 2]]

        foreground_mask   = self.hdf5_file['foreground_mask'][index_start:(index_start+length)]
        foreground_mask   = rearrange(foreground_mask, 't 1 h w -> t h w')/255

        instance_masks = self.hdf5_file['instance_masks'][index_start*num_objects:(index_start+length)*num_objects]
        instance_masks = rearrange(instance_masks, '(t o) 1 h w -> t o 1 h w', o=num_objects).squeeze()/255

        # CUSTOM
        # use instance masks to to create hidden masks
        hidden_mask = reduce(instance_masks, 't o h w -> t 1 h w', 'sum').squeeze()
        hidden_mask = (hidden_mask > 1).astype(np.uint8)

        # segmentation_masks: index gives which object is visible at that pixel
        segmentation_mask = np.argmax(instance_masks[:, ::-1], axis=1) + 1
        segmentation_mask = foreground_mask * segmentation_mask

        # segmentation mask but only for hidden objects
        segementation_mask_hidden = np.argmax(instance_masks[:, :3], axis=1) + 1 # TODO only works for 6 objects
        segementation_mask_hidden = hidden_mask * segementation_mask_hidden
                
        # add one dummy timestep at the end
        instance_positions = self.add_one_timestep(instance_positions)
        rgb_images = self.add_one_timestep(rgb_images)
        foreground_mask = self.add_one_timestep(foreground_mask)
        hidden_mask = self.add_one_timestep(hidden_mask)
        instance_pres = self.add_one_timestep(instance_pres)
        instance_bounding_boxes = self.add_one_timestep(instance_bounding_boxes)
        instance_masks = self.add_one_timestep(instance_masks)
        segmentation_mask = self.add_one_timestep(segmentation_mask)
        segementation_mask_hidden = self.add_one_timestep(segementation_mask_hidden)

        if False:
            video       = np.array(rgb_images)
            locations   =  np.array(instance_positions)
            fg_masks    = np.array(foreground_mask)
            bb          = np.array(instance_bounding_boxes)
            h_masks     = np.array(hidden_mask)

            # loop through video frames and show them using cv2
            for t in range(video.shape[0]):
                frame = rearrange(video[t], 'c h w -> h w c') * 255

                for loc in locations[t]:
                    x, y = loc
                    #cv2.circle(frame, (int(x), int(y)), 2, (255, 0, 0), -1) # did not work properly
                    x_max = int(min(int(x + 2), frame.shape[1]))
                    x_min = int(max(int(x - 2), 0))
                    y_max = int(min(int(y + 2), frame.shape[0]))
                    y_min = int(max(int(y - 2), 0))
                    frame[x_min:x_max, y_min:y_max] = [255, 0, 0]

                # draw the bounding boxes into the frame
                for i, b in enumerate(bb[t]):
                    x_min, y_min, x_max, y_max = b
                    
                    x_min = int(max(x_min, 0))
                    y_min = int(max(y_min, 0))
                    x_max = int(min(x_max, frame.shape[0]-1))
                    y_max = int(min(y_max, frame.shape[0]-1))

                    # dont't use c2 rectangeel function here but draw it manually
                    for pixel in range(x_min, x_max):
                        frame[pixel, y_min, 0] = 255
                        frame[pixel, y_max, 0] = 255
                    for pixel in range(y_min, y_max):
                        frame[x_min, pixel, 0] = 255
                        frame[x_max, pixel, 0] = 255
                
                fg_mask = repeat(fg_masks[t], 'h w -> h w 3') * 255
                h_mask = repeat(h_masks[t], 'h w -> h w 3') * 255
                s_mask = repeat(segmentation_mask[t], 'h w -> h w 3') * (255/6)
                s_mask_hidden = repeat(segementation_mask_hidden[t], 'h w -> h w 3') * (255/3)
                frame = np.concatenate((frame, fg_mask, s_mask, h_mask, s_mask_hidden), axis=1)

                # add instance masks to visualisation
                for i, mask in enumerate(instance_masks[t]):
                    mask = repeat(mask, 'h w -> h w 3') * 255
                    # add border to the right side
                    mask[:, -1, 0] = 255
                    frame = np.concatenate((frame, mask), axis=1)

                frame = frame.astype(np.uint8)
                cv2.imshow('frame', frame)
                cv2.waitKey(0)

        return (
            rgb_images,
            self.background,
            instance_positions,
            segmentation_mask,
            instance_pres,
            segementation_mask_hidden
        )
    

#a = BouncingBallDataset("./", 'BOUNCINGBALLS', "train", (64,64))