From 6bcf6b8306ce4903734fb31824799a50281cea69 Mon Sep 17 00:00:00 2001 From: fredeee Date: Sat, 23 Mar 2024 13:27:00 +0100 Subject: add bouncingball experiment and ablation studies --- scripts/exec/eval_baselines.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) create mode 100644 scripts/exec/eval_baselines.py (limited to 'scripts/exec/eval_baselines.py') diff --git a/scripts/exec/eval_baselines.py b/scripts/exec/eval_baselines.py new file mode 100644 index 0000000..9e451ab --- /dev/null +++ b/scripts/exec/eval_baselines.py @@ -0,0 +1,18 @@ +import argparse +import sys +from data.datasets.ADEPT.dataset import AdeptDataset +from scripts.evaluation_adept_baselines import evaluate + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-load", default="", type=str, help='path to savi slots') + parser.add_argument("-n", default="", type=str, help='results name') + parser.add_argument("-model", default="", type=str, help='model, either savi or gswm') + + # Load configuration + args = parser.parse_args(sys.argv[1:]) + print(f'Evaluating savi slots {args.load}') + + # Load dataset + testset = AdeptDataset("./", 'adept', 'createdown', (30 * 2**(2*2), 20 * 2**(2*2))) + evaluate(testset, args.load, args.n, args.model, plot_frequency= 1, plot_first_samples = 2) \ No newline at end of file -- cgit v1.2.3