aboutsummaryrefslogtreecommitdiff
path: root/scripts/exec/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/exec/train.py')
-rw-r--r--scripts/exec/train.py15
1 files changed, 11 insertions, 4 deletions
diff --git a/scripts/exec/train.py b/scripts/exec/train.py
index f6e78fd..6fa6176 100644
--- a/scripts/exec/train.py
+++ b/scripts/exec/train.py
@@ -4,6 +4,8 @@ from scripts.utils.configuration import Configuration
from scripts import training
from data.datasets.CLEVRER.dataset import ClevrerDataset, ClevrerSample, RamImage
from data.datasets.ADEPT.dataset import AdeptDataset
+from data.datasets.BOUNCINGBALLS.dataset import BouncingBallDataset
+from scripts import evaluation_adept, evaluation_clevrer, evaluation_bb
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@@ -20,12 +22,17 @@ if __name__ == "__main__":
print(f'Training model {cfg.model_path}')
# Load dataset
+ size = (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))
if cfg.datatype == "clevrer":
- trainset = ClevrerDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=False)
- valset = ClevrerDataset("./", cfg.dataset, "val", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=True)
+ trainset = ClevrerDataset("./", cfg.dataset, "train", size, use_slotformer=False)
+ valset = ClevrerDataset("./", cfg.dataset, "val", size, use_slotformer=True)
elif cfg.datatype == "adept":
- trainset = AdeptDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
- valset = AdeptDataset("./", cfg.dataset, "test", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
+ trainset = AdeptDataset("./", cfg.dataset, "train", size)
+ valset = AdeptDataset("./", cfg.dataset, "test", size)
+ valset.train = True
+ elif cfg.datatype == "bouncingballs":
+ trainset = BouncingBallDataset("./", cfg.dataset, "train", size, type_name = cfg.scenario)
+ valset = BouncingBallDataset("./", cfg.dataset, "val", size, type_name = cfg.scenario)
else:
raise Exception("Dataset not supported")