aboutsummaryrefslogtreecommitdiff
path: root/scripts/exec
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/exec')
-rw-r--r--scripts/exec/eval.py32
-rw-r--r--scripts/exec/eval_baselines.py (renamed from scripts/exec/eval_savi.py)5
-rw-r--r--scripts/exec/train.py15
3 files changed, 35 insertions, 17 deletions
diff --git a/scripts/exec/eval.py b/scripts/exec/eval.py
index 96ed32d..bc461ec 100644
--- a/scripts/exec/eval.py
+++ b/scripts/exec/eval.py
@@ -1,10 +1,29 @@
import argparse
import sys
from data.datasets.ADEPT.dataset import AdeptDataset
+from data.datasets.BOUNCINGBALLS.dataset import BouncingBallDataset
from data.datasets.CLEVRER.dataset import ClevrerDataset, ClevrerSample, RamImage
from scripts.utils.configuration import Configuration
-from scripts import evaluation_adept, evaluation_clevrer
+from scripts import evaluation_adept, evaluation_clevrer, evaluation_bb
+def main(load, n, cfg):
+
+ size = (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))
+
+ # Load dataset
+ if cfg.datatype == "clevrer":
+ testset = ClevrerDataset("./", cfg.dataset, 'val', size, use_slotformer=True, evaluation=True)
+ evaluation_clevrer.evaluate(cfg, testset, load, n, plot_frequency= 2, plot_first_samples = 3) # only plotting
+ evaluation_clevrer.evaluate(cfg, testset, load, n, plot_first_samples = 0) # evaluation
+ elif cfg.datatype == "adept":
+ testset = AdeptDataset("./", cfg.dataset, 'createdown', size)
+ evaluation_adept.evaluate(cfg, testset, load, n, plot_frequency= 1, plot_first_samples = 2)
+ elif cfg.datatype == "bouncingballs":
+ testset = BouncingBallDataset("./", cfg.dataset, "test", size, type_name = cfg.scenario)
+ evaluation_bb.evaluate(cfg, testset, load, n, plot_first_samples = 4)
+ else:
+ raise Exception("Dataset not supported")
+
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-cfg", default="", help='path to the configuration file')
@@ -16,13 +35,4 @@ if __name__ == "__main__":
cfg = Configuration(args.cfg)
print(f'Evaluating model {args.load}')
- # Load dataset
- if cfg.datatype == "clevrer":
- testset = ClevrerDataset("./", cfg.dataset, 'val', (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=True, evaluation=True)
- evaluation_clevrer.evaluate(cfg, testset, args.load, args.n, plot_frequency= 2, plot_first_samples = 3) # only plotting
- evaluation_clevrer.evaluate(cfg, testset, args.load, args.n, plot_first_samples = 0) # evaluation
- elif cfg.datatype == "adept":
- testset = AdeptDataset("./", cfg.dataset, 'createdown', (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
- evaluation_adept.evaluate(cfg, testset, args.load, args.n, plot_frequency= 1, plot_first_samples = 2)
- else:
- raise Exception("Dataset not supported") \ No newline at end of file
+ main(args.load, args.n, cfg) \ No newline at end of file
diff --git a/scripts/exec/eval_savi.py b/scripts/exec/eval_baselines.py
index 87d4e59..9e451ab 100644
--- a/scripts/exec/eval_savi.py
+++ b/scripts/exec/eval_baselines.py
@@ -1,12 +1,13 @@
import argparse
import sys
from data.datasets.ADEPT.dataset import AdeptDataset
-from scripts.evaluation_adept_savi import evaluate
+from scripts.evaluation_adept_baselines import evaluate
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-load", default="", type=str, help='path to savi slots')
parser.add_argument("-n", default="", type=str, help='results name')
+ parser.add_argument("-model", default="", type=str, help='model, either savi or gswm')
# Load configuration
args = parser.parse_args(sys.argv[1:])
@@ -14,4 +15,4 @@ if __name__ == "__main__":
# Load dataset
testset = AdeptDataset("./", 'adept', 'createdown', (30 * 2**(2*2), 20 * 2**(2*2)))
- evaluate(testset, args.load, args.n, plot_frequency= 1, plot_first_samples = 2) \ No newline at end of file
+ evaluate(testset, args.load, args.n, args.model, plot_frequency= 1, plot_first_samples = 2) \ No newline at end of file
diff --git a/scripts/exec/train.py b/scripts/exec/train.py
index f6e78fd..6fa6176 100644
--- a/scripts/exec/train.py
+++ b/scripts/exec/train.py
@@ -4,6 +4,8 @@ from scripts.utils.configuration import Configuration
from scripts import training
from data.datasets.CLEVRER.dataset import ClevrerDataset, ClevrerSample, RamImage
from data.datasets.ADEPT.dataset import AdeptDataset
+from data.datasets.BOUNCINGBALLS.dataset import BouncingBallDataset
+from scripts import evaluation_adept, evaluation_clevrer, evaluation_bb
if __name__ == "__main__":
parser = argparse.ArgumentParser()
@@ -20,12 +22,17 @@ if __name__ == "__main__":
print(f'Training model {cfg.model_path}')
# Load dataset
+ size = (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))
if cfg.datatype == "clevrer":
- trainset = ClevrerDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=False)
- valset = ClevrerDataset("./", cfg.dataset, "val", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=True)
+ trainset = ClevrerDataset("./", cfg.dataset, "train", size, use_slotformer=False)
+ valset = ClevrerDataset("./", cfg.dataset, "val", size, use_slotformer=True)
elif cfg.datatype == "adept":
- trainset = AdeptDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
- valset = AdeptDataset("./", cfg.dataset, "test", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
+ trainset = AdeptDataset("./", cfg.dataset, "train", size)
+ valset = AdeptDataset("./", cfg.dataset, "test", size)
+ valset.train = True
+ elif cfg.datatype == "bouncingballs":
+ trainset = BouncingBallDataset("./", cfg.dataset, "train", size, type_name = cfg.scenario)
+ valset = BouncingBallDataset("./", cfg.dataset, "val", size, type_name = cfg.scenario)
else:
raise Exception("Dataset not supported")