diff options
author | fredeee | 2024-03-23 13:27:00 +0100 |
---|---|---|
committer | fredeee | 2024-03-23 13:27:00 +0100 |
commit | 6bcf6b8306ce4903734fb31824799a50281cea69 (patch) | |
tree | 0545ff1b8beb051993c2d75fd81306db1a22274d /scripts/evaluation_clevrer.py | |
parent | ad0b64a7f0140406151d18b19ab2ed5d19b6c511 (diff) |
add bouncingball experiment and ablation studies
Diffstat (limited to 'scripts/evaluation_clevrer.py')
-rw-r--r-- | scripts/evaluation_clevrer.py | 45 |
1 files changed, 33 insertions, 12 deletions
diff --git a/scripts/evaluation_clevrer.py b/scripts/evaluation_clevrer.py index a43d50c..168fba3 100644 --- a/scripts/evaluation_clevrer.py +++ b/scripts/evaluation_clevrer.py @@ -237,7 +237,7 @@ def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_frequency= 1, p raise ValueError('Number of rollout steps and burnin steps must be equal to the sequence length.') if not plotting_mode: - average_dic = compute_statistics_summary(metric_complete, evaluation_mode) + average_dic = compute_statistics_summary(metric_complete, evaluation_mode, root_path=root_path) # Store statistics with open(os.path.join(f'{root_path}/statistics', f'{evaluation_mode}_metric_complete.pkl'), 'wb') as f: @@ -250,24 +250,45 @@ def evaluate(cfg: Configuration, dataset: Dataset, file, n, plot_frequency= 1, p os.remove(f'{root_path}/tmp.jpg') pass -def compute_statistics_summary(metric_complete, evaluation_mode): +def compute_statistics_summary(metric_complete, evaluation_mode, root_path=None, consider_first_n_frames = None): + string = '' + def add_text(string, text, last=False): + string = string + ' \n ' + text + return string + average_dic = {} + if consider_first_n_frames is not None: + for key in metric_complete: + for sample in range(len(metric_complete[key])): + metric_complete[key][sample] = metric_complete[key][sample][:consider_first_n_frames] + for key in metric_complete: # take average over all frames - average_dic[key + 'complete_average'] = np.mean(metric_complete[key]) - average_dic[key + 'complete_std'] = np.std(metric_complete[key]) - print(f'{key} complete average: {average_dic[key + "complete_average"]:.4f} +/- {average_dic[key + "complete_std"]:.4f}') + average_dic[key + '_complete_average'] = np.mean(metric_complete[key]) + average_dic[key + '_complete_std'] = np.std(metric_complete[key]) + average_dic[key + '_complete_sum'] = np.sum(np.mean(metric_complete[key], axis=0)) # checked with GSWM code! + string = add_text(string, f'{key} complete average: {average_dic[key + "_complete_average"]:.4f} +/- {average_dic[key + "_complete_std"]:.4f} (sum: {average_dic[key + "_complete_sum"]:.4f})') + #print(f'{key} complete average: {average_dic[key + "complete_average"]:.4f} +/- {average_dic[key + "complete_std"]:.4f} (sum: {average_dic[key + "complete_sum"]:.4f})') if evaluation_mode == 'blackout': - # take average only for frames where blackout occurs + # take average only for frames where blackout occurs blackout_mask = np.array(metric_complete['blackout']) > 0 - average_dic[key + 'blackout_average'] = np.mean(np.array(metric_complete[key])[blackout_mask]) - average_dic[key + 'blackout_std'] = np.std(np.array(metric_complete[key])[blackout_mask]) - average_dic[key + 'visible_average'] = np.mean(np.array(metric_complete[key])[blackout_mask == False]) - average_dic[key + 'visible_std'] = np.std(np.array(metric_complete[key])[blackout_mask == False]) + average_dic[key + '_blackout_average'] = np.mean(np.array(metric_complete[key])[blackout_mask]) + average_dic[key + '_blackout_std'] = np.std(np.array(metric_complete[key])[blackout_mask]) + average_dic[key + '_visible_average'] = np.mean(np.array(metric_complete[key])[blackout_mask == False]) + average_dic[key + '_visible_std'] = np.std(np.array(metric_complete[key])[blackout_mask == False]) - print(f'{key} blackout average: {average_dic[key + "blackout_average"]:.4f} +/- {average_dic[key + "blackout_std"]:.4f}') - print(f'{key} visible average: {average_dic[key + "visible_average"]:.4f} +/- {average_dic[key + "visible_std"]:.4f}') + #print(f'{key} blackout average: {average_dic[key + "blackout_average"]:.4f} +/- {average_dic[key + "blackout_std"]:.4f}') + #print(f'{key} visible average: {average_dic[key + "visible_average"]:.4f} +/- {average_dic[key + "visible_std"]:.4f}') + string = add_text(string, f'{key} blackout average: {average_dic[key + "_blackout_average"]:.4f} +/- {average_dic[key + "_blackout_std"]:.4f}') + string = add_text(string, f'{key} visible average: {average_dic[key + "_visible_average"]:.4f} +/- {average_dic[key + "_visible_std"]:.4f}') + + print(string) + if root_path is not None: + f'{root_path}/statistics', f'{evaluation_mode}_metric_complete.pkl' + with open(os.path.join(f'{root_path}/statistics', f'{evaluation_mode}_metric_average.txt'), 'w') as f: + f.write(string) + return average_dic def compute_plot_statistics(cfg_net, statistics_complete_slots, mseloss, set_test, evaluation_mode, i, statistics_batch, t, target, output_next, mask_next, slots_bounded, slots_closed, rawmask_hidden): |