diff options
Diffstat (limited to 'scripts/utils/eval_utils.py')
-rw-r--r-- | scripts/utils/eval_utils.py | 92 |
1 files changed, 87 insertions, 5 deletions
diff --git a/scripts/utils/eval_utils.py b/scripts/utils/eval_utils.py index faab7ec..a01ffd0 100644 --- a/scripts/utils/eval_utils.py +++ b/scripts/utils/eval_utils.py @@ -1,18 +1,92 @@ import os +import shutil +from einops import rearrange import torch as th from model.loci import Loci +def masks_to_boxes(masks: th.Tensor) -> th.Tensor: + """ + Compute the bounding boxes around the provided masks. + + Returns a [N, 4] tensor containing bounding boxes. The boxes are in ``(x1, y1, x2, y2)`` format with + ``0 <= x1 < x2`` and ``0 <= y1 < y2``. + + Args: + masks (Tensor[N, H, W]): masks to transform where N is the number of masks + and (H, W) are the spatial dimensions. + + Returns: + Tensor[N, 4]: bounding boxes + """ + if masks.numel() == 0: + return th.zeros((0, 4), device=masks.device, dtype=th.float) + + n = masks.shape[0] + + bounding_boxes = th.zeros((n, 4), device=masks.device, dtype=th.float) + + for index, mask in enumerate(masks): + if mask.sum() > 0: + y, x = th.where(mask != 0) + + bounding_boxes[index, 0] = th.min(x) + bounding_boxes[index, 1] = th.min(y) + bounding_boxes[index, 2] = th.max(x) + bounding_boxes[index, 3] = th.max(y) + + return bounding_boxes + +def boxes_to_centroids(boxes): + """Post-process masks instead of directly taking argmax. + + Args: + bboxes: [B, T, N, 4], 4: [x1, y1, x2, y2] + + Returns: + centroids: [B, T, N, 2], 2: [x, y] + """ + + centroids = (boxes[:, :, :, :2] + boxes[:, :, :, 2:]) / 2 + centroids = centroids.squeeze(0) + + # scale to [-1, 1] + centroids[:, :, 0] = centroids[:, :, 0] / 64 * 2 - 1 + centroids[:, :, 1] = centroids[:, :, 1] / 64 * 2 - 1 + + return centroids + +def compute_position_from_mask(mask): + """ + Compute the position of the object from the mask. + + Args: + mask (Tensor[B, N, H, W]): masks to transform where N is the number of masks + and (H, W) are the spatial dimensions. + + Returns: + Tensor[B, N, 2]: position of the object + + """ + masks_binary = (mask > 0.8).float()[:, :-1] + b, o, h, w = masks_binary.shape + masks2 = rearrange(masks_binary, 'b o h w -> (b o) h w') + boxes = masks_to_boxes(masks2.long()) + boxes = rearrange(boxes, '(b o) c -> b 1 o c', b=b, o=o) + centroids = boxes_to_centroids(boxes) + centroids = centroids[:, :, :, [1, 0]].squeeze(1) + return centroids + def setup_result_folders(file, name, set_test, evaluation_mode, object_view, individual_views): net_name = file.split('/')[-1].split('.')[0] #root_path = file.split('nets')[0] - root_path = os.path.join(*file.split('/')[0:-1]) + root_path = os.path.join(*file.split('/')[0:-2]) root_path = os.path.join(root_path, f'results{name}', net_name, set_test['type']) plot_path = os.path.join(root_path, evaluation_mode) # create directories - #if os.path.exists(plot_path): - # shutil.rmtree(plot_path) + if os.path.exists(plot_path): + shutil.rmtree(plot_path) os.makedirs(plot_path, exist_ok = True) if object_view: os.makedirs(os.path.join(plot_path, 'object'), exist_ok = True) @@ -59,13 +133,21 @@ def load_model(cfg, cfg_net, file, device): print(f"load {file} to device {device}") state = th.load(file, map_location=device) - # backward compatibility + # 1. Get keys of current model while ensuring backward compatibility model = {} + allowed_keys = [] + rand_state = net.state_dict() + for key, value in rand_state.items(): + allowed_keys.append(key) + + # 2. Overwrite with values from file for key, value in state["model"].items(): # replace update_module with percept_gate_controller in key string: key = key.replace("update_module", "percept_gate_controller") - model[key.replace(".module.", ".")] = value + if key in allowed_keys: + model[key.replace(".module.", ".")] = value + net.load_state_dict(model) # ??? |