aboutsummaryrefslogtreecommitdiff
path: root/scripts/exec/eval.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/exec/eval.py')
-rw-r--r--scripts/exec/eval.py32
1 files changed, 21 insertions, 11 deletions
diff --git a/scripts/exec/eval.py b/scripts/exec/eval.py
index 96ed32d..bc461ec 100644
--- a/scripts/exec/eval.py
+++ b/scripts/exec/eval.py
@@ -1,10 +1,29 @@
import argparse
import sys
from data.datasets.ADEPT.dataset import AdeptDataset
+from data.datasets.BOUNCINGBALLS.dataset import BouncingBallDataset
from data.datasets.CLEVRER.dataset import ClevrerDataset, ClevrerSample, RamImage
from scripts.utils.configuration import Configuration
-from scripts import evaluation_adept, evaluation_clevrer
+from scripts import evaluation_adept, evaluation_clevrer, evaluation_bb
+def main(load, n, cfg):
+
+ size = (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))
+
+ # Load dataset
+ if cfg.datatype == "clevrer":
+ testset = ClevrerDataset("./", cfg.dataset, 'val', size, use_slotformer=True, evaluation=True)
+ evaluation_clevrer.evaluate(cfg, testset, load, n, plot_frequency= 2, plot_first_samples = 3) # only plotting
+ evaluation_clevrer.evaluate(cfg, testset, load, n, plot_first_samples = 0) # evaluation
+ elif cfg.datatype == "adept":
+ testset = AdeptDataset("./", cfg.dataset, 'createdown', size)
+ evaluation_adept.evaluate(cfg, testset, load, n, plot_frequency= 1, plot_first_samples = 2)
+ elif cfg.datatype == "bouncingballs":
+ testset = BouncingBallDataset("./", cfg.dataset, "test", size, type_name = cfg.scenario)
+ evaluation_bb.evaluate(cfg, testset, load, n, plot_first_samples = 4)
+ else:
+ raise Exception("Dataset not supported")
+
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-cfg", default="", help='path to the configuration file')
@@ -16,13 +35,4 @@ if __name__ == "__main__":
cfg = Configuration(args.cfg)
print(f'Evaluating model {args.load}')
- # Load dataset
- if cfg.datatype == "clevrer":
- testset = ClevrerDataset("./", cfg.dataset, 'val', (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=True, evaluation=True)
- evaluation_clevrer.evaluate(cfg, testset, args.load, args.n, plot_frequency= 2, plot_first_samples = 3) # only plotting
- evaluation_clevrer.evaluate(cfg, testset, args.load, args.n, plot_first_samples = 0) # evaluation
- elif cfg.datatype == "adept":
- testset = AdeptDataset("./", cfg.dataset, 'createdown', (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
- evaluation_adept.evaluate(cfg, testset, args.load, args.n, plot_frequency= 1, plot_first_samples = 2)
- else:
- raise Exception("Dataset not supported") \ No newline at end of file
+ main(args.load, args.n, cfg) \ No newline at end of file