diff options
Diffstat (limited to 'scripts/exec')
-rw-r--r-- | scripts/exec/eval.py | 32 | ||||
-rw-r--r-- | scripts/exec/eval_baselines.py (renamed from scripts/exec/eval_savi.py) | 5 | ||||
-rw-r--r-- | scripts/exec/train.py | 15 |
3 files changed, 35 insertions, 17 deletions
diff --git a/scripts/exec/eval.py b/scripts/exec/eval.py index 96ed32d..bc461ec 100644 --- a/scripts/exec/eval.py +++ b/scripts/exec/eval.py @@ -1,10 +1,29 @@ import argparse import sys from data.datasets.ADEPT.dataset import AdeptDataset +from data.datasets.BOUNCINGBALLS.dataset import BouncingBallDataset from data.datasets.CLEVRER.dataset import ClevrerDataset, ClevrerSample, RamImage from scripts.utils.configuration import Configuration -from scripts import evaluation_adept, evaluation_clevrer +from scripts import evaluation_adept, evaluation_clevrer, evaluation_bb +def main(load, n, cfg): + + size = (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)) + + # Load dataset + if cfg.datatype == "clevrer": + testset = ClevrerDataset("./", cfg.dataset, 'val', size, use_slotformer=True, evaluation=True) + evaluation_clevrer.evaluate(cfg, testset, load, n, plot_frequency= 2, plot_first_samples = 3) # only plotting + evaluation_clevrer.evaluate(cfg, testset, load, n, plot_first_samples = 0) # evaluation + elif cfg.datatype == "adept": + testset = AdeptDataset("./", cfg.dataset, 'createdown', size) + evaluation_adept.evaluate(cfg, testset, load, n, plot_frequency= 1, plot_first_samples = 2) + elif cfg.datatype == "bouncingballs": + testset = BouncingBallDataset("./", cfg.dataset, "test", size, type_name = cfg.scenario) + evaluation_bb.evaluate(cfg, testset, load, n, plot_first_samples = 4) + else: + raise Exception("Dataset not supported") + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-cfg", default="", help='path to the configuration file') @@ -16,13 +35,4 @@ if __name__ == "__main__": cfg = Configuration(args.cfg) print(f'Evaluating model {args.load}') - # Load dataset - if cfg.datatype == "clevrer": - testset = ClevrerDataset("./", cfg.dataset, 'val', (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=True, evaluation=True) - evaluation_clevrer.evaluate(cfg, testset, args.load, args.n, plot_frequency= 2, plot_first_samples = 3) # only plotting - evaluation_clevrer.evaluate(cfg, testset, args.load, args.n, plot_first_samples = 0) # evaluation - elif cfg.datatype == "adept": - testset = AdeptDataset("./", cfg.dataset, 'createdown', (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))) - evaluation_adept.evaluate(cfg, testset, args.load, args.n, plot_frequency= 1, plot_first_samples = 2) - else: - raise Exception("Dataset not supported")
\ No newline at end of file + main(args.load, args.n, cfg)
\ No newline at end of file diff --git a/scripts/exec/eval_savi.py b/scripts/exec/eval_baselines.py index 87d4e59..9e451ab 100644 --- a/scripts/exec/eval_savi.py +++ b/scripts/exec/eval_baselines.py @@ -1,12 +1,13 @@ import argparse import sys from data.datasets.ADEPT.dataset import AdeptDataset -from scripts.evaluation_adept_savi import evaluate +from scripts.evaluation_adept_baselines import evaluate if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("-load", default="", type=str, help='path to savi slots') parser.add_argument("-n", default="", type=str, help='results name') + parser.add_argument("-model", default="", type=str, help='model, either savi or gswm') # Load configuration args = parser.parse_args(sys.argv[1:]) @@ -14,4 +15,4 @@ if __name__ == "__main__": # Load dataset testset = AdeptDataset("./", 'adept', 'createdown', (30 * 2**(2*2), 20 * 2**(2*2))) - evaluate(testset, args.load, args.n, plot_frequency= 1, plot_first_samples = 2)
\ No newline at end of file + evaluate(testset, args.load, args.n, args.model, plot_frequency= 1, plot_first_samples = 2)
\ No newline at end of file diff --git a/scripts/exec/train.py b/scripts/exec/train.py index f6e78fd..6fa6176 100644 --- a/scripts/exec/train.py +++ b/scripts/exec/train.py @@ -4,6 +4,8 @@ from scripts.utils.configuration import Configuration from scripts import training from data.datasets.CLEVRER.dataset import ClevrerDataset, ClevrerSample, RamImage from data.datasets.ADEPT.dataset import AdeptDataset +from data.datasets.BOUNCINGBALLS.dataset import BouncingBallDataset +from scripts import evaluation_adept, evaluation_clevrer, evaluation_bb if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -20,12 +22,17 @@ if __name__ == "__main__": print(f'Training model {cfg.model_path}') # Load dataset + size = (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)) if cfg.datatype == "clevrer": - trainset = ClevrerDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=False) - valset = ClevrerDataset("./", cfg.dataset, "val", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=True) + trainset = ClevrerDataset("./", cfg.dataset, "train", size, use_slotformer=False) + valset = ClevrerDataset("./", cfg.dataset, "val", size, use_slotformer=True) elif cfg.datatype == "adept": - trainset = AdeptDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))) - valset = AdeptDataset("./", cfg.dataset, "test", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))) + trainset = AdeptDataset("./", cfg.dataset, "train", size) + valset = AdeptDataset("./", cfg.dataset, "test", size) + valset.train = True + elif cfg.datatype == "bouncingballs": + trainset = BouncingBallDataset("./", cfg.dataset, "train", size, type_name = cfg.scenario) + valset = BouncingBallDataset("./", cfg.dataset, "val", size, type_name = cfg.scenario) else: raise Exception("Dataset not supported") |