diff options
Diffstat (limited to 'scripts/exec/train.py')
-rw-r--r-- | scripts/exec/train.py | 15 |
1 files changed, 11 insertions, 4 deletions
diff --git a/scripts/exec/train.py b/scripts/exec/train.py index f6e78fd..6fa6176 100644 --- a/scripts/exec/train.py +++ b/scripts/exec/train.py @@ -4,6 +4,8 @@ from scripts.utils.configuration import Configuration from scripts import training from data.datasets.CLEVRER.dataset import ClevrerDataset, ClevrerSample, RamImage from data.datasets.ADEPT.dataset import AdeptDataset +from data.datasets.BOUNCINGBALLS.dataset import BouncingBallDataset +from scripts import evaluation_adept, evaluation_clevrer, evaluation_bb if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -20,12 +22,17 @@ if __name__ == "__main__": print(f'Training model {cfg.model_path}') # Load dataset + size = (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)) if cfg.datatype == "clevrer": - trainset = ClevrerDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=False) - valset = ClevrerDataset("./", cfg.dataset, "val", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=True) + trainset = ClevrerDataset("./", cfg.dataset, "train", size, use_slotformer=False) + valset = ClevrerDataset("./", cfg.dataset, "val", size, use_slotformer=True) elif cfg.datatype == "adept": - trainset = AdeptDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))) - valset = AdeptDataset("./", cfg.dataset, "test", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))) + trainset = AdeptDataset("./", cfg.dataset, "train", size) + valset = AdeptDataset("./", cfg.dataset, "test", size) + valset.train = True + elif cfg.datatype == "bouncingballs": + trainset = BouncingBallDataset("./", cfg.dataset, "train", size, type_name = cfg.scenario) + valset = BouncingBallDataset("./", cfg.dataset, "val", size, type_name = cfg.scenario) else: raise Exception("Dataset not supported") |