From 6bcf6b8306ce4903734fb31824799a50281cea69 Mon Sep 17 00:00:00 2001 From: fredeee Date: Sat, 23 Mar 2024 13:27:00 +0100 Subject: add bouncingball experiment and ablation studies --- scripts/exec/train.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) (limited to 'scripts/exec/train.py') diff --git a/scripts/exec/train.py b/scripts/exec/train.py index f6e78fd..6fa6176 100644 --- a/scripts/exec/train.py +++ b/scripts/exec/train.py @@ -4,6 +4,8 @@ from scripts.utils.configuration import Configuration from scripts import training from data.datasets.CLEVRER.dataset import ClevrerDataset, ClevrerSample, RamImage from data.datasets.ADEPT.dataset import AdeptDataset +from data.datasets.BOUNCINGBALLS.dataset import BouncingBallDataset +from scripts import evaluation_adept, evaluation_clevrer, evaluation_bb if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -20,12 +22,17 @@ if __name__ == "__main__": print(f'Training model {cfg.model_path}') # Load dataset + size = (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)) if cfg.datatype == "clevrer": - trainset = ClevrerDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=False) - valset = ClevrerDataset("./", cfg.dataset, "val", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=True) + trainset = ClevrerDataset("./", cfg.dataset, "train", size, use_slotformer=False) + valset = ClevrerDataset("./", cfg.dataset, "val", size, use_slotformer=True) elif cfg.datatype == "adept": - trainset = AdeptDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))) - valset = AdeptDataset("./", cfg.dataset, "test", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))) + trainset = AdeptDataset("./", cfg.dataset, "train", size) + valset = AdeptDataset("./", cfg.dataset, "test", size) + valset.train = True + elif cfg.datatype == "bouncingballs": + trainset = BouncingBallDataset("./", cfg.dataset, "train", size, type_name = cfg.scenario) + valset = BouncingBallDataset("./", cfg.dataset, "val", size, type_name = cfg.scenario) else: raise Exception("Dataset not supported") -- cgit v1.2.3