diff options
author | fredeee | 2023-11-02 10:47:21 +0100 |
---|---|---|
committer | fredeee | 2023-11-02 10:47:21 +0100 |
commit | f8302ee886ef9b631f11a52900dac964a61350e1 (patch) | |
tree | 87288be6f851ab69405e524b81940c501c52789a /scripts/exec | |
parent | f16fef1ab9371e1c81a2e0b2fbea59dee285a9f8 (diff) |
initiaƶ commit
Diffstat (limited to 'scripts/exec')
-rw-r--r-- | scripts/exec/eval.py | 28 | ||||
-rw-r--r-- | scripts/exec/eval_savi.py | 17 | ||||
-rw-r--r-- | scripts/exec/train.py | 33 |
3 files changed, 78 insertions, 0 deletions
diff --git a/scripts/exec/eval.py b/scripts/exec/eval.py new file mode 100644 index 0000000..96ed32d --- /dev/null +++ b/scripts/exec/eval.py @@ -0,0 +1,28 @@ +import argparse +import sys +from data.datasets.ADEPT.dataset import AdeptDataset +from data.datasets.CLEVRER.dataset import ClevrerDataset, ClevrerSample, RamImage +from scripts.utils.configuration import Configuration +from scripts import evaluation_adept, evaluation_clevrer + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-cfg", default="", help='path to the configuration file') + parser.add_argument("-load", default="", type=str, help='path to model') + parser.add_argument("-n", default="", type=str, help='results name') + + # Load configuration + args = parser.parse_args(sys.argv[1:]) + cfg = Configuration(args.cfg) + print(f'Evaluating model {args.load}') + + # Load dataset + if cfg.datatype == "clevrer": + testset = ClevrerDataset("./", cfg.dataset, 'val', (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=True, evaluation=True) + evaluation_clevrer.evaluate(cfg, testset, args.load, args.n, plot_frequency= 2, plot_first_samples = 3) # only plotting + evaluation_clevrer.evaluate(cfg, testset, args.load, args.n, plot_first_samples = 0) # evaluation + elif cfg.datatype == "adept": + testset = AdeptDataset("./", cfg.dataset, 'createdown', (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))) + evaluation_adept.evaluate(cfg, testset, args.load, args.n, plot_frequency= 1, plot_first_samples = 2) + else: + raise Exception("Dataset not supported")
\ No newline at end of file diff --git a/scripts/exec/eval_savi.py b/scripts/exec/eval_savi.py new file mode 100644 index 0000000..87d4e59 --- /dev/null +++ b/scripts/exec/eval_savi.py @@ -0,0 +1,17 @@ +import argparse +import sys +from data.datasets.ADEPT.dataset import AdeptDataset +from scripts.evaluation_adept_savi import evaluate + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-load", default="", type=str, help='path to savi slots') + parser.add_argument("-n", default="", type=str, help='results name') + + # Load configuration + args = parser.parse_args(sys.argv[1:]) + print(f'Evaluating savi slots {args.load}') + + # Load dataset + testset = AdeptDataset("./", 'adept', 'createdown', (30 * 2**(2*2), 20 * 2**(2*2))) + evaluate(testset, args.load, args.n, plot_frequency= 1, plot_first_samples = 2)
\ No newline at end of file diff --git a/scripts/exec/train.py b/scripts/exec/train.py new file mode 100644 index 0000000..f6e78fd --- /dev/null +++ b/scripts/exec/train.py @@ -0,0 +1,33 @@ +import argparse +import sys +from scripts.utils.configuration import Configuration +from scripts import training +from data.datasets.CLEVRER.dataset import ClevrerDataset, ClevrerSample, RamImage +from data.datasets.ADEPT.dataset import AdeptDataset + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("-cfg", default="", help='path to the configuration file') + parser.add_argument("-n", default=-1, type=int, help='optional run number') + parser.add_argument("-load", default="", type=str, help='path to pretrained model or checkpoint') + + # Load configuration + args = parser.parse_args(sys.argv[1:]) + cfg = Configuration(args.cfg) + cfg.model_path = f"{cfg.model_path}" + if args.n >= 0: + cfg.model_path = f"{cfg.model_path}.run{args.n}" + print(f'Training model {cfg.model_path}') + + # Load dataset + if cfg.datatype == "clevrer": + trainset = ClevrerDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=False) + valset = ClevrerDataset("./", cfg.dataset, "val", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2)), use_slotformer=True) + elif cfg.datatype == "adept": + trainset = AdeptDataset("./", cfg.dataset, "train", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))) + valset = AdeptDataset("./", cfg.dataset, "test", (cfg.model.latent_size[1] * 2**(cfg.model.level*2), cfg.model.latent_size[0] * 2**(cfg.model.level*2))) + else: + raise Exception("Dataset not supported") + + # Final call + training.train_loci(cfg, trainset, valset, args.load)
\ No newline at end of file |