From 6bcf6b8306ce4903734fb31824799a50281cea69 Mon Sep 17 00:00:00 2001 From: fredeee Date: Sat, 23 Mar 2024 13:27:00 +0100 Subject: add bouncingball experiment and ablation studies --- scripts/exec/eval.py | 32 +++++++++++++++++++++----------- scripts/exec/eval_baselines.py | 18 ++++++++++++++++++ scripts/exec/eval_savi.py | 17 ----------------- scripts/exec/train.py | 15 +++++++++++---- 4 files changed, 50 insertions(+), 32 deletions(-) create mode 100644 scripts/exec/eval_baselines.py delete mode 100644 scripts/exec/eval_savi.py (limited to 'scripts/exec') diff --git a/scripts/exec/eval.py b/scripts/exec/eval.py index 96ed32d..bc461ec 100644 --- a/scripts/exec/eval.py +++ b/scripts/exec/eval.py @@ -1,10 +1,29 @@ import argparse import sys from data.datasets.ADEPT.dataset import AdeptDataset +from data.datasets.BOUNCINGBALLS.dataset import BouncingBallDataset from data.datasets.CLEVRER.dataset import ClevrerDataset, ClevrerSample, RamImage from scripts.utils.configuration import Configuration -from scripts import evaluation_adept, evaluation_clevrer +from scripts import evaluation_adept, evaluation_clevrer, evaluation_bb +def main(load, n, cfg): + + size = (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)) + + # Load dataset + if cfg.datatype == "clevrer": + testset = ClevrerDataset("./", cfg.dataset, 'val', size, use_slotformer=True, evaluation=True) + evaluation_clevrer.evaluate(cfg, testset, load, n, plot_frequency= 2, plot_first_samples = 3) # only plotting + evaluation_clevrer.evaluate(cfg, testset, load, n, plot_first_samples = 0) # evaluation + elif cfg.datatype == "adept": + testset = AdeptDataset("./", cfg.dataset, 'createdown', size) + evaluation_adept.evaluate(cfg, testset, load, n, plot_frequency= 1, plot_first_samples = 2) + elif cfg.datatype == "bouncingballs": + testset = BouncingBallDataset("./", cfg.dataset, "test", size, type_name = cfg.scenario) + evaluation_bb.evaluate(cfg, testset, load, n, plot_first_samples = 4) + else: + raise Exception("Dataset not supported") + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-cfg", default="", help='path to the configuration file') @@ -16,13 +35,4 @@ if __name__ == "__main__": cfg = Configuration(args.cfg) print(f'Evaluating model {args.load}') - # Load dataset - if cfg.datatype == "clevrer": - testset = ClevrerDataset("./", cfg.dataset, 'val', (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=True, evaluation=True) - evaluation_clevrer.evaluate(cfg, testset, args.load, args.n, plot_frequency= 2, plot_first_samples = 3) # only plotting - evaluation_clevrer.evaluate(cfg, testset, args.load, args.n, plot_first_samples = 0) # evaluation - elif cfg.datatype == "adept": - testset = AdeptDataset("./", cfg.dataset, 'createdown', (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))) - evaluation_adept.evaluate(cfg, testset, args.load, args.n, plot_frequency= 1, plot_first_samples = 2) - else: - raise Exception("Dataset not supported") \ No newline at end of file + main(args.load, args.n, cfg) \ No newline at end of file diff --git a/scripts/exec/eval_baselines.py b/scripts/exec/eval_baselines.py new file mode 100644 index 0000000..9e451ab --- /dev/null +++ b/scripts/exec/eval_baselines.py @@ -0,0 +1,18 @@ +import argparse +import sys +from data.datasets.ADEPT.dataset import AdeptDataset +from scripts.evaluation_adept_baselines import evaluate + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-load", default="", type=str, help='path to savi slots') + parser.add_argument("-n", default="", type=str, help='results name') + parser.add_argument("-model", default="", type=str, help='model, either savi or gswm') + + # Load configuration + args = parser.parse_args(sys.argv[1:]) + print(f'Evaluating savi slots {args.load}') + + # Load dataset + testset = AdeptDataset("./", 'adept', 'createdown', (30 * 2**(2*2), 20 * 2**(2*2))) + evaluate(testset, args.load, args.n, args.model, plot_frequency= 1, plot_first_samples = 2) \ No newline at end of file diff --git a/scripts/exec/eval_savi.py b/scripts/exec/eval_savi.py deleted file mode 100644 index 87d4e59..0000000 --- a/scripts/exec/eval_savi.py +++ /dev/null @@ -1,17 +0,0 @@ -import argparse -import sys -from data.datasets.ADEPT.dataset import AdeptDataset -from scripts.evaluation_adept_savi import evaluate - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("-load", default="", type=str, help='path to savi slots') - parser.add_argument("-n", default="", type=str, help='results name') - - # Load configuration - args = parser.parse_args(sys.argv[1:]) - print(f'Evaluating savi slots {args.load}') - - # Load dataset - testset = AdeptDataset("./", 'adept', 'createdown', (30 * 2**(2*2), 20 * 2**(2*2))) - evaluate(testset, args.load, args.n, plot_frequency= 1, plot_first_samples = 2) \ No newline at end of file diff --git a/scripts/exec/train.py b/scripts/exec/train.py index f6e78fd..6fa6176 100644 --- a/scripts/exec/train.py +++ b/scripts/exec/train.py @@ -4,6 +4,8 @@ from scripts.utils.configuration import Configuration from scripts import training from data.datasets.CLEVRER.dataset import ClevrerDataset, ClevrerSample, RamImage from data.datasets.ADEPT.dataset import AdeptDataset +from data.datasets.BOUNCINGBALLS.dataset import BouncingBallDataset +from scripts import evaluation_adept, evaluation_clevrer, evaluation_bb if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -20,12 +22,17 @@ if __name__ == "__main__": print(f'Training model {cfg.model_path}') # Load dataset + size = (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)) if cfg.datatype == "clevrer": - trainset = ClevrerDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=False) - valset = ClevrerDataset("./", cfg.dataset, "val", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=True) + trainset = ClevrerDataset("./", cfg.dataset, "train", size, use_slotformer=False) + valset = ClevrerDataset("./", cfg.dataset, "val", size, use_slotformer=True) elif cfg.datatype == "adept": - trainset = AdeptDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))) - valset = AdeptDataset("./", cfg.dataset, "test", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))) + trainset = AdeptDataset("./", cfg.dataset, "train", size) + valset = AdeptDataset("./", cfg.dataset, "test", size) + valset.train = True + elif cfg.datatype == "bouncingballs": + trainset = BouncingBallDataset("./", cfg.dataset, "train", size, type_name = cfg.scenario) + valset = BouncingBallDataset("./", cfg.dataset, "val", size, type_name = cfg.scenario) else: raise Exception("Dataset not supported") -- cgit v1.2.3