diff options
author | fredeee | 2024-03-23 13:27:00 +0100 |
---|---|---|
committer | fredeee | 2024-03-23 13:27:00 +0100 |
commit | 6bcf6b8306ce4903734fb31824799a50281cea69 (patch) | |
tree | 0545ff1b8beb051993c2d75fd81306db1a22274d /scripts/evaluation_adept.py | |
parent | ad0b64a7f0140406151d18b19ab2ed5d19b6c511 (diff) |
add bouncingball experiment and ablation studies
Diffstat (limited to 'scripts/evaluation_adept.py')
-rw-r--r-- | scripts/evaluation_adept.py | 11 |
1 files changed, 10 insertions, 1 deletions
diff --git a/scripts/evaluation_adept.py b/scripts/evaluation_adept.py index eab3ad9..b2a18c0 100644 --- a/scripts/evaluation_adept.py +++ b/scripts/evaluation_adept.py @@ -2,6 +2,7 @@ import torch as th from torch.utils.data import Dataset, DataLoader, Subset from torch import nn import os +from scripts.utils.eval_adept import eval_adept from scripts.utils.plot_utils import plot_timestep from scripts.utils.configuration import Configuration from scripts.utils.io import init_device @@ -21,10 +22,12 @@ def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_frequency= 1, p # Config cfg_net = cfg.model cfg_net.batch_size = 1 + cfg_net.inner_loop_enabled = (cfg.max_updates > cfg.phases.start_inner_loop) # Load model net = load_model(cfg, cfg_net, file, device) net.eval() + net.predictor.enable_att_weights() # Plot config object_view = True @@ -216,7 +219,9 @@ def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_frequency= 1, p # 4. Plot if (t % plot_frequency == 0) and (i < plot_first_samples) and (t >= 0): - plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i) + att = net.predictor.get_att_weights() + openings = net.get_openings() + plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i, att= att, openings=openings) # fill jumping statistics statistics_complete_slots['vanishing'].extend(np.tile(slots_vanishing_memory.astype(int), t+1)) @@ -235,8 +240,11 @@ def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_frequency= 1, p pd.DataFrame(statistics_complete).to_csv(f'{root_path}/statistics/trialframe.csv') pd.DataFrame(statistics_complete_slots).to_csv(f'{root_path}/statistics/slotframe.csv') pd.DataFrame(acc_memory_complete).to_csv(f'{root_path}/statistics/accframe.csv') + if object_view and os.path.exists(f'{root_path}/tmp.jpg'): os.remove(f'{root_path}/tmp.jpg') + + eval_adept(f'{root_path}/statistics') pass @@ -309,6 +317,7 @@ def update_mota_acc(acc, gt_positions, estimated_positions, slots_bounded, cfg_n def calculate_tracking_error(gt_positions_target, gt_visibility_target, position_cur, cfg_num_slots, slots_bounded, slots_occluded_cur, association_table, gt_occluder_mask): # tracking utils + gt_positions_target = gt_positions_target.clone() pdist = nn.PairwiseDistance(p=2).to(position_cur.device) # 1. association of newly bounded slots to ground truth objects |