aboutsummaryrefslogtreecommitdiff
path: root/scripts/evaluation_clevrer.py
diff options
context:
space:
mode:
authorfredeee2024-03-23 13:27:00 +0100
committerfredeee2024-03-23 13:27:00 +0100
commit6bcf6b8306ce4903734fb31824799a50281cea69 (patch)
tree0545ff1b8beb051993c2d75fd81306db1a22274d /scripts/evaluation_clevrer.py
parentad0b64a7f0140406151d18b19ab2ed5d19b6c511 (diff)
add bouncingball experiment and ablation studies
Diffstat (limited to 'scripts/evaluation_clevrer.py')
-rw-r--r--scripts/evaluation_clevrer.py45
1 files changed, 33 insertions, 12 deletions
diff --git a/scripts/evaluation_clevrer.py b/scripts/evaluation_clevrer.py
index a43d50c..168fba3 100644
--- a/scripts/evaluation_clevrer.py
+++ b/scripts/evaluation_clevrer.py
@@ -237,7 +237,7 @@ def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_frequency= 1, p
raise ValueError('Number of rollout steps and burnin steps must be equal to the sequence length.')
if not plotting_mode:
- average_dic = compute_statistics_summary(metric_complete, evaluation_mode)
+ average_dic = compute_statistics_summary(metric_complete, evaluation_mode, root_path=root_path)
# Store statistics
with open(os.path.join(f'{root_path}/statistics', f'{evaluation_mode}_metric_complete.pkl'), 'wb') as f:
@@ -250,24 +250,45 @@ def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_frequency= 1, p
os.remove(f'{root_path}/tmp.jpg')
pass
-def compute_statistics_summary(metric_complete, evaluation_mode):
+def compute_statistics_summary(metric_complete, evaluation_mode, root_path=None, consider_first_n_frames = None):
+ string = ''
+ def add_text(string, text, last=False):
+ string = string + ' \n ' + text
+ return string
+
average_dic = {}
+ if consider_first_n_frames is not None:
+ for key in metric_complete:
+ for sample in range(len(metric_complete[key])):
+ metric_complete[key][sample] = metric_complete[key][sample][:consider_first_n_frames]
+
for key in metric_complete:
# take average over all frames
- average_dic[key + 'complete_average'] = np.mean(metric_complete[key])
- average_dic[key + 'complete_std'] = np.std(metric_complete[key])
- print(f'{key} complete average: {average_dic[key + "complete_average"]:.4f} +/- {average_dic[key + "complete_std"]:.4f}')
+ average_dic[key + '_complete_average'] = np.mean(metric_complete[key])
+ average_dic[key + '_complete_std'] = np.std(metric_complete[key])
+ average_dic[key + '_complete_sum'] = np.sum(np.mean(metric_complete[key], axis=0)) # checked with GSWM code!
+ string = add_text(string, f'{key} complete average: {average_dic[key + "_complete_average"]:.4f} +/- {average_dic[key + "_complete_std"]:.4f} (sum: {average_dic[key + "_complete_sum"]:.4f})')
+ #print(f'{key} complete average: {average_dic[key + "complete_average"]:.4f} +/- {average_dic[key + "complete_std"]:.4f} (sum: {average_dic[key + "complete_sum"]:.4f})')
if evaluation_mode == 'blackout':
- # take average only for frames where blackout occurs
+ # take average only for frames where blackout occurs
blackout_mask = np.array(metric_complete['blackout']) > 0
- average_dic[key + 'blackout_average'] = np.mean(np.array(metric_complete[key])[blackout_mask])
- average_dic[key + 'blackout_std'] = np.std(np.array(metric_complete[key])[blackout_mask])
- average_dic[key + 'visible_average'] = np.mean(np.array(metric_complete[key])[blackout_mask == False])
- average_dic[key + 'visible_std'] = np.std(np.array(metric_complete[key])[blackout_mask == False])
+ average_dic[key + '_blackout_average'] = np.mean(np.array(metric_complete[key])[blackout_mask])
+ average_dic[key + '_blackout_std'] = np.std(np.array(metric_complete[key])[blackout_mask])
+ average_dic[key + '_visible_average'] = np.mean(np.array(metric_complete[key])[blackout_mask == False])
+ average_dic[key + '_visible_std'] = np.std(np.array(metric_complete[key])[blackout_mask == False])
- print(f'{key} blackout average: {average_dic[key + "blackout_average"]:.4f} +/- {average_dic[key + "blackout_std"]:.4f}')
- print(f'{key} visible average: {average_dic[key + "visible_average"]:.4f} +/- {average_dic[key + "visible_std"]:.4f}')
+ #print(f'{key} blackout average: {average_dic[key + "blackout_average"]:.4f} +/- {average_dic[key + "blackout_std"]:.4f}')
+ #print(f'{key} visible average: {average_dic[key + "visible_average"]:.4f} +/- {average_dic[key + "visible_std"]:.4f}')
+ string = add_text(string, f'{key} blackout average: {average_dic[key + "_blackout_average"]:.4f} +/- {average_dic[key + "_blackout_std"]:.4f}')
+ string = add_text(string, f'{key} visible average: {average_dic[key + "_visible_average"]:.4f} +/- {average_dic[key + "_visible_std"]:.4f}')
+
+ print(string)
+ if root_path is not None:
+ f'{root_path}/statistics', f'{evaluation_mode}_metric_complete.pkl'
+ with open(os.path.join(f'{root_path}/statistics', f'{evaluation_mode}_metric_average.txt'), 'w') as f:
+ f.write(string)
+
return average_dic
def compute_plot_statistics(cfg_net, statistics_complete_slots, mseloss, set_test, evaluation_mode, i, statistics_batch, t, target, output_next, mask_next, slots_bounded, slots_closed, rawmask_hidden):