aboutsummaryrefslogtreecommitdiff
path: root/scripts/utils/io.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/utils/io.py')
-rw-r--r--scripts/utils/io.py72
1 files changed, 59 insertions, 13 deletions
diff --git a/scripts/utils/io.py b/scripts/utils/io.py
index 9bd8158..787c575 100644
--- a/scripts/utils/io.py
+++ b/scripts/utils/io.py
@@ -65,7 +65,7 @@ def model_path(cfg: Configuration, overwrite=False, move_old=True):
:param move_old: Moves old folder with the same name to an old folder, if not overwrite
:return: Model path
"""
- _path = os.path.join('out')
+ _path = os.path.join('out', cfg.dataset)
path = os.path.join(_path, cfg.model_path)
if not os.path.exists(_path):
@@ -95,14 +95,15 @@ def model_path(cfg: Configuration, overwrite=False, move_old=True):
class LossLogger:
- def __init__(self):
+ def __init__(self, writer):
self.avgloss = UEMA()
self.avg_position_loss = UEMA()
self.avg_time_loss = UEMA()
- self.avg_encoder_loss = UEMA()
- self.avg_mse_object_loss = UEMA()
- self.avg_long_mse_object_loss = UEMA(33333)
+ self.avg_latent_loss = UEMA()
+ self.avg_encoding_loss = UEMA()
+ self.avg_prediction_loss = UEMA()
+ self.avg_prediction_loss_long = UEMA(33333)
self.avg_num_objects = UEMA()
self.avg_openings = UEMA()
self.avg_gestalt = UEMA()
@@ -110,28 +111,73 @@ class LossLogger:
self.avg_gestalt_mean = UEMA()
self.avg_update_gestalt = UEMA()
self.avg_update_position = UEMA()
+ self.avg_num_bounded = UEMA()
+
+ self.writer = writer
- def update_complete(self, avg_position_loss, avg_time_loss, avg_encoder_loss, avg_mse_object_loss, avg_long_mse_object_loss, avg_num_objects, avg_openings, avg_gestalt, avg_gestalt2, avg_gestalt_mean, avg_update_gestalt, avg_update_position):
+ def update_complete(self, avg_position_loss, avg_time_loss, avg_latent_loss, avg_encoding_loss, avg_prediction_loss, avg_num_objects, avg_openings, avg_gestalt, avg_gestalt2, avg_gestalt_mean, avg_update_gestalt, avg_update_position, avg_num_bounded, lr, num_updates):
self.avg_position_loss.update(avg_position_loss.item())
self.avg_time_loss.update(avg_time_loss.item())
- self.avg_encoder_loss.update(avg_encoder_loss.item())
- self.avg_mse_object_loss.update(avg_mse_object_loss.item())
- self.avg_long_mse_object_loss.update(avg_long_mse_object_loss.item())
+ self.avg_latent_loss.update(avg_latent_loss.item())
+ self.avg_encoding_loss.update(avg_encoding_loss.item())
+ self.avg_prediction_loss.update(avg_prediction_loss.item())
+ self.avg_prediction_loss_long.update(avg_prediction_loss.item())
self.avg_num_objects.update(avg_num_objects)
self.avg_openings.update(avg_openings)
self.avg_gestalt.update(avg_gestalt.item())
self.avg_gestalt2.update(avg_gestalt2.item())
self.avg_gestalt_mean.update(avg_gestalt_mean.item())
- self.avg_update_gestalt.update(avg_update_gestalt.item())
- self.avg_update_position.update(avg_update_position.item())
+ self.avg_update_gestalt.update(avg_update_gestalt)
+ self.avg_update_position.update(avg_update_position)
+ self.avg_num_bounded.update(avg_num_bounded)
+
+ self.writer.add_scalar("Train/Position Loss", avg_position_loss.item(), num_updates)
+ self.writer.add_scalar("Train/Time Loss", avg_time_loss.item(), num_updates)
+ self.writer.add_scalar("Train/Latent Loss", avg_latent_loss.item(), num_updates)
+ self.writer.add_scalar("Train/Encoder Loss", avg_encoding_loss.item(), num_updates)
+ self.writer.add_scalar("Train/Prediction Loss", avg_prediction_loss.item(), num_updates)
+ self.writer.add_scalar("Train/Number of Objects", avg_num_objects, num_updates)
+ self.writer.add_scalar("Train/Openings", avg_openings, num_updates)
+ self.writer.add_scalar("Train/Gestalt", avg_gestalt.item(), num_updates)
+ self.writer.add_scalar("Train/Gestalt2", avg_gestalt2.item(), num_updates)
+ self.writer.add_scalar("Train/Gestalt Mean", avg_gestalt_mean.item(), num_updates)
+ self.writer.add_scalar("Train/Update Gestalt", avg_update_gestalt, num_updates)
+ self.writer.add_scalar("Train/Update Position", avg_update_position, num_updates)
+ self.writer.add_scalar("Train/Number Bounded", avg_num_bounded, num_updates)
+ self.writer.add_scalar("Train/Learning Rate", lr, num_updates)
+
pass
- def update_average_loss(self, avgloss):
+ def update_average_loss(self, avgloss, num_updates):
self.avgloss.update(avgloss)
+ self.writer.add_scalar("Train/Loss", avgloss, num_updates)
pass
def get_log(self):
- info = f'Loss: {np.abs(float(self.avgloss)):.2e}|{float(self.avg_mse_object_loss):.2e}|{float(self.avg_long_mse_object_loss):.2e}, reg: {float(self.avg_encoder_loss):.2e}|{float(self.avg_time_loss):.2e}|{float(self.avg_position_loss):.2e}, obj: {float(self.avg_num_objects):.1f}, open: {float(self.avg_openings):.2e}|{float(self.avg_gestalt):.2f}, bin: {float(self.avg_gestalt_mean):.2e}|{np.sqrt(float(self.avg_gestalt2) - float(self.avg_gestalt)**2):.2e} closed: {float(self.avg_update_gestalt):.2e}|{float(self.avg_update_position):.2e}'
+ info = f'Loss: {np.abs(float(self.avgloss)):.2e}|{float(self.avg_prediction_loss):.2e}|{float(self.avg_prediction_loss_long):.2e}, reg: {float(self.avg_encoding_loss):.2e}|{float(self.avg_time_loss):.2e}|{float(self.avg_latent_loss):.2e}|{float(self.avg_position_loss):.2e}, obj: {float(self.avg_num_objects):.1f}, open: {float(self.avg_openings):.2e}|{float(self.avg_gestalt):.2f}, bin: {float(self.avg_gestalt_mean):.2e}|{np.sqrt(float(self.avg_gestalt2) - float(self.avg_gestalt)**2):.2e} closed: {float(self.avg_update_gestalt):.2e}|{float(self.avg_update_position):.2e}'
return info
+
+class WriterWrapper():
+
+ def __init__(self, use_wandb: bool, cfg: Configuration):
+ if use_wandb:
+ from torch.utils.tensorboard import SummaryWriter
+ import wandb
+ wandb.init(project=f'Loci_Looped_{cfg.dataset}', name= cfg.model_path, sync_tensorboard=True, config=cfg)
+ self.writer = SummaryWriter()
+ else:
+ self.writer = None
+
+ def add_scalar(self, name, value, step):
+ if self.writer is not None:
+ self.writer.add_scalar(name, value, step)
+
+ def add_video(self, name, value, step):
+ if self.writer is not None:
+ self.writer.add_video(name, value, step)
+
+ def flush(self):
+ if self.writer is not None:
+ self.writer.flush()