aboutsummaryrefslogtreecommitdiff
path: root/scripts/exec
diff options
context:
space:
mode:
authorfredeee2023-11-02 10:47:21 +0100
committerfredeee2023-11-02 10:47:21 +0100
commitf8302ee886ef9b631f11a52900dac964a61350e1 (patch)
tree87288be6f851ab69405e524b81940c501c52789a /scripts/exec
parentf16fef1ab9371e1c81a2e0b2fbea59dee285a9f8 (diff)
initiaƶ commit
Diffstat (limited to 'scripts/exec')
-rw-r--r--scripts/exec/eval.py28
-rw-r--r--scripts/exec/eval_savi.py17
-rw-r--r--scripts/exec/train.py33
3 files changed, 78 insertions, 0 deletions
diff --git a/scripts/exec/eval.py b/scripts/exec/eval.py
new file mode 100644
index 0000000..96ed32d
--- /dev/null
+++ b/scripts/exec/eval.py
@@ -0,0 +1,28 @@
+import argparse
+import sys
+from data.datasets.ADEPT.dataset import AdeptDataset
+from data.datasets.CLEVRER.dataset import ClevrerDataset, ClevrerSample, RamImage
+from scripts.utils.configuration import Configuration
+from scripts import evaluation_adept, evaluation_clevrer
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-cfg", default="", help='path to the configuration file')
+ parser.add_argument("-load", default="", type=str, help='path to model')
+ parser.add_argument("-n", default="", type=str, help='results name')
+
+ # Load configuration
+ args = parser.parse_args(sys.argv[1:])
+ cfg = Configuration(args.cfg)
+ print(f'Evaluating model {args.load}')
+
+ # Load dataset
+ if cfg.datatype == "clevrer":
+ testset = ClevrerDataset("./", cfg.dataset, 'val', (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=True, evaluation=True)
+ evaluation_clevrer.evaluate(cfg, testset, args.load, args.n, plot_frequency= 2, plot_first_samples = 3) # only plotting
+ evaluation_clevrer.evaluate(cfg, testset, args.load, args.n, plot_first_samples = 0) # evaluation
+ elif cfg.datatype == "adept":
+ testset = AdeptDataset("./", cfg.dataset, 'createdown', (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
+ evaluation_adept.evaluate(cfg, testset, args.load, args.n, plot_frequency= 1, plot_first_samples = 2)
+ else:
+ raise Exception("Dataset not supported") \ No newline at end of file
diff --git a/scripts/exec/eval_savi.py b/scripts/exec/eval_savi.py
new file mode 100644
index 0000000..87d4e59
--- /dev/null
+++ b/scripts/exec/eval_savi.py
@@ -0,0 +1,17 @@
+import argparse
+import sys
+from data.datasets.ADEPT.dataset import AdeptDataset
+from scripts.evaluation_adept_savi import evaluate
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-load", default="", type=str, help='path to savi slots')
+ parser.add_argument("-n", default="", type=str, help='results name')
+
+ # Load configuration
+ args = parser.parse_args(sys.argv[1:])
+ print(f'Evaluating savi slots {args.load}')
+
+ # Load dataset
+ testset = AdeptDataset("./", 'adept', 'createdown', (30 * 2**(2*2), 20 * 2**(2*2)))
+ evaluate(testset, args.load, args.n, plot_frequency= 1, plot_first_samples = 2) \ No newline at end of file
diff --git a/scripts/exec/train.py b/scripts/exec/train.py
new file mode 100644
index 0000000..f6e78fd
--- /dev/null
+++ b/scripts/exec/train.py
@@ -0,0 +1,33 @@
+import argparse
+import sys
+from scripts.utils.configuration import Configuration
+from scripts import training
+from data.datasets.CLEVRER.dataset import ClevrerDataset, ClevrerSample, RamImage
+from data.datasets.ADEPT.dataset import AdeptDataset
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-cfg", default="", help='path to the configuration file')
+ parser.add_argument("-n", default=-1, type=int, help='optional run number')
+ parser.add_argument("-load", default="", type=str, help='path to pretrained model or checkpoint')
+
+ # Load configuration
+ args = parser.parse_args(sys.argv[1:])
+ cfg = Configuration(args.cfg)
+ cfg.model_path = f"{cfg.model_path}"
+ if args.n >= 0:
+ cfg.model_path = f"{cfg.model_path}.run{args.n}"
+ print(f'Training model {cfg.model_path}')
+
+ # Load dataset
+ if cfg.datatype == "clevrer":
+ trainset = ClevrerDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=False)
+ valset = ClevrerDataset("./", cfg.dataset, "val", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=True)
+ elif cfg.datatype == "adept":
+ trainset = AdeptDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
+ valset = AdeptDataset("./", cfg.dataset, "test", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)))
+ else:
+ raise Exception("Dataset not supported")
+
+ # Final call
+ training.train_loci(cfg, trainset, valset, args.load) \ No newline at end of file