aboutsummaryrefslogtreecommitdiff
path: root/scripts/evaluation_adept.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/evaluation_adept.py')
-rw-r--r--scripts/evaluation_adept.py11
1 files changed, 10 insertions, 1 deletions
diff --git a/scripts/evaluation_adept.py b/scripts/evaluation_adept.py
index eab3ad9..b2a18c0 100644
--- a/scripts/evaluation_adept.py
+++ b/scripts/evaluation_adept.py
@@ -2,6 +2,7 @@ import torch as th
from torch.utils.data import Dataset, DataLoader, Subset
from torch import nn
import os
+from scripts.utils.eval_adept import eval_adept
from scripts.utils.plot_utils import plot_timestep
from scripts.utils.configuration import Configuration
from scripts.utils.io import init_device
@@ -21,10 +22,12 @@ def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_frequency= 1, p
# Config
cfg_net = cfg.model
cfg_net.batch_size = 1
+ cfg_net.inner_loop_enabled = (cfg.max_updates > cfg.phases.start_inner_loop)
# Load model
net = load_model(cfg, cfg_net, file, device)
net.eval()
+ net.predictor.enable_att_weights()
# Plot config
object_view = True
@@ -216,7 +219,9 @@ def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_frequency= 1, p
# 4. Plot
if (t % plot_frequency == 0) and (i < plot_first_samples) and (t >= 0):
- plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i)
+ att = net.predictor.get_att_weights()
+ openings = net.get_openings()
+ plot_timestep(cfg, cfg_net, input, target, mask_cur, mask_next, output_next, position_encoder_cur, position_next, rawmask_hidden, rawmask_cur, rawmask_next, largest_object, object_cur, object_next, object_hidden, slots_bounded, slots_partially_occluded_cur, slots_occluded_cur, slots_partially_occluded_next, slots_occluded_next, slots_closed, gt_positions_target_next, association_table, error_next, output_hidden, object_view, individual_views, statistics_complete_slots, statistics_batch, sequence_len, root_path, plot_path, t_index, t, i, att= att, openings=openings)
# fill jumping statistics
statistics_complete_slots['vanishing'].extend(np.tile(slots_vanishing_memory.astype(int), t+1))
@@ -235,8 +240,11 @@ def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_frequency= 1, p
pd.DataFrame(statistics_complete).to_csv(f'{root_path}/statistics/trialframe.csv')
pd.DataFrame(statistics_complete_slots).to_csv(f'{root_path}/statistics/slotframe.csv')
pd.DataFrame(acc_memory_complete).to_csv(f'{root_path}/statistics/accframe.csv')
+
if object_view and os.path.exists(f'{root_path}/tmp.jpg'):
os.remove(f'{root_path}/tmp.jpg')
+
+ eval_adept(f'{root_path}/statistics')
pass
@@ -309,6 +317,7 @@ def update_mota_acc(acc, gt_positions, estimated_positions, slots_bounded, cfg_n
def calculate_tracking_error(gt_positions_target, gt_visibility_target, position_cur, cfg_num_slots, slots_bounded, slots_occluded_cur, association_table, gt_occluder_mask):
# tracking utils
+ gt_positions_target = gt_positions_target.clone()
pdist = nn.PairwiseDistance(p=2).to(position_cur.device)
# 1. association of newly bounded slots to ground truth objects