aboutsummaryrefslogtreecommitdiff
path: root/scripts/training.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/training.py')
-rw-r--r--scripts/training.py324
1 files changed, 228 insertions, 96 deletions
diff --git a/scripts/training.py b/scripts/training.py
index b807e54..fd4a4bf 100644
--- a/scripts/training.py
+++ b/scripts/training.py
@@ -1,3 +1,4 @@
+from copy import deepcopy
import torch as th
from torch import nn
from torch.utils.data import Dataset, DataLoader, Subset
@@ -5,15 +6,18 @@ import cv2
import numpy as np
import os
from einops import rearrange, repeat, reduce
-from scripts.utils.configuration import Configuration
-from scripts.utils.io import init_device, model_path, LossLogger
+from scripts.utils.configuration import Configuration, Dict
+from scripts.utils.io import WriterWrapper, init_device, model_path, LossLogger
from scripts.utils.optimizers import RAdam
from model.loci import Loci
import random
from scripts.utils.io import Timer
-from scripts.utils.plot_utils import color_mask
-from scripts.validation import validation_clevrer, validation_adept
+from scripts.utils.plot_utils import color_mask, plot_timestep
+from scripts.validation import validation_bb, validation_clevrer, validation_adept
+from scripts.exec.eval import main as eval_main
+os.environ["WANDB__SERVICE_WAIT"] = "300"
+os.environ["WANDB_ DISABLE_SERVICE"] = "true"
def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
@@ -21,6 +25,12 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
device, verbose = init_device(cfg)
if verbose:
valset = Subset(valset, range(0, 8))
+ use_wandb = (verbose == False)
+
+ # generate random seed if not set
+ if not ('seed' in cfg.defaults):
+ cfg.defaults['seed'] = random.randint(0, 100)
+ th.manual_seed(cfg.defaults.seed)
# Define model path
path = model_path(cfg, overwrite=False)
@@ -34,12 +44,16 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
teacher_forcing = cfg.defaults.teacher_forcing
)
net = net.to(device=device)
+ #net = th.compile(net)
+ #net = th.jit.script(net)
# Log model size
log_modelsize(net)
+ writer = WriterWrapper(use_wandb, cfg)
# Init Optimizers
optimizer_init, optimizer_encoder, optimizer_decoder, optimizer_predictor, optimizer_background, optimizer_update = init_optimizer(cfg, net)
+ scheduler = init_lr_scheduler_StepLR(cfg, optimizer_init, optimizer_encoder, optimizer_decoder, optimizer_predictor, optimizer_background, optimizer_update)
# Option to load model
if file != "":
@@ -58,9 +72,8 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
print(f'loaded {file}', flush=True)
# Set up data loaders
- trainloader = get_loader(cfg, trainset, cfg_net, shuffle=True)
- valset.train = True #valset.dataset.train = True
- valloader = get_loader(cfg, valset, cfg_net, shuffle=False)
+ trainloader = get_loader(cfg, trainset, cfg_net, shuffle=True, verbose=verbose)
+ valloader = get_loader(cfg, valset, cfg_net, shuffle=False, verbose=verbose)
# initial save
save_model(
@@ -80,16 +93,24 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
print('!!! Net init status: ', net.get_init_status())
# Set up statistics
- loss_tracker = LossLogger()
+ loss_tracker = LossLogger(writer)
+
+ # Init loss function
+ imageloss = initialize_loss_function(cfg)
# Set up training variables
num_time_steps = 0
bptt_steps = cfg.bptt.bptt_steps
+ if not 'plot_interval' in cfg.defaults:
+ cfg.defaults.plot_interval = 20000
+ blackout_rate = cfg.blackout.blackout_rate if ('blackout' in cfg) else 0.0
+ rollout_length = cfg.vp.rollout_length if ('vp' in cfg) else 0
+ burnin_length = cfg.vp.burnin_length if ('vp' in cfg) else 0
increase_bptt_steps = False
background_blendin_factor = 0.0
th.backends.cudnn.benchmark = True
+ plot_next_sample = False
timer = Timer()
- bceloss = nn.BCELoss()
# Init net to current num_updates
if num_updates >= cfg.phases.background_pretraining_end and net.get_init_status() < 1:
@@ -101,7 +122,7 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
if num_updates >= cfg.phases.entity_pretraining_phase2_end and net.get_init_status() < 3:
net.inc_init_level()
for param in optimizer_init.param_groups:
- param['lr'] = cfg.learning_rate
+ param['lr'] = cfg.learning_rate.lr
if num_updates > cfg.phases.start_inner_loop:
net.cfg.inner_loop_enabled = True
@@ -117,16 +138,18 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
# Validation every epoch
if epoch >= 0:
if cfg.datatype == 'adept':
- validation_adept(valloader, net, cfg, device)
+ validation_adept(valloader, net, cfg, device, writer, epoch, path)
elif cfg.datatype == 'clevrer':
- validation_clevrer(valloader, net, cfg, device)
+ validation_clevrer(valloader, net, cfg, device, writer, epoch, path)
+ elif cfg.datatype == "bouncingballs" and (epoch % 2 == 0):
+ validation_bb(valloader, net, cfg, device, writer, epoch, path)
# Start epoch training
print('Start epoch:', epoch)
# Backprop through time steps
if increase_bptt_steps:
- bptt_steps = max(bptt_steps + 1, cfg.bptt.bptt_steps_max)
+ bptt_steps = min(bptt_steps + 1, cfg.bptt.bptt_steps_max)
print('Increase closed loop steps to', bptt_steps)
increase_bptt_steps = False
@@ -149,7 +172,8 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
slots_occlusionfactor = None
# Apply skip frames to sequence
- selec = range(random.randrange(cfg.defaults.skip_frames), tensor.shape[1], cfg.defaults.skip_frames)
+ start = random.randrange(cfg.defaults.skip_frames) if rollout_length == 0 else random.randrange(0, (tensor.shape[1]-(rollout_length+burnin_length)))
+ selec = range(start, tensor.shape[1], cfg.defaults.skip_frames)
tensor = tensor[:,selec]
sequence_len = tensor.shape[1]
@@ -159,6 +183,12 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
target = th.clip(input, 0, 1).detach()
error_last = None
+ # plotting mode
+ plot_this_sample = plot_next_sample
+ plot_next_sample = False
+ video_list = []
+ num_rollout = 0
+
# First apply teacher forcing for the first x frames
for t in range(-cfg.defaults.teacher_forcing, sequence_len-1):
@@ -185,13 +215,27 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
if net.get_init_status() > 2 and cfg.defaults.error_dropout > 0 and np.random.rand() < cfg.defaults.error_dropout:
error_last = th.zeros_like(error_last)
- # Apply sensation blackout when training clevrere
- if net.cfg.inner_loop_enabled and cfg.datatype == 'clevrer':
- if t >= 10:
- blackout = th.tensor((np.random.rand(cfg_net.batch_size) < 0.2)[:,None,None,None]).float().to(device)
+ # Apply sensation blackout when training clevrer
+ if net.cfg.inner_loop_enabled and blackout_rate > 0:
+ if t >= cfg.blackout.blackout_start_timestep:
+ blackout = th.tensor((np.random.rand(cfg_net.batch_size) < blackout_rate)[:,None,None,None]).float().to(device)
input = blackout * (input * 0) + (1-blackout) * input
error_last = blackout * (error_last * 0) + (1-blackout) * error_last
-
+
+ # Apply rollout training
+ if net.cfg.inner_loop_enabled and (rollout_length > 0) and (burnin_length-1) < t:
+ input = input * 0
+ error_last = error_last * 0
+ num_rollout += 1
+
+ if (burnin_length + rollout_length -1) == t:
+ run_optimizers = True
+ detach = True
+
+ if (burnin_length + rollout_length) == t:
+ break
+
+
# Forward Pass
(
output_next,
@@ -206,7 +250,9 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
slots_occlusionfactor,
position_loss,
time_loss,
- slots_closed
+ latent_loss,
+ slots_closed,
+ slots_bounded
) = net(
input, # current frame
error_last, # error of last frame --> missing object
@@ -228,6 +274,8 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
# Loss weighting
position_loss = position_loss * cfg_net.position_regularizer
time_loss = time_loss * cfg_net.time_regularizer
+ if latent_loss.item() > 0:
+ latent_loss = latent_loss * cfg_net.latent_regularizer
# Compute background error
bg_error_cur = th.sqrt(reduce((input - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
@@ -253,10 +301,14 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
background_blendin_factor = min(1, background_blendin_factor + 0.001)
# Final Loss computation
- encoder_loss = th.mean((output_cur - input)**2) * cfg_net.encoder_regularizer
- cliped_output_next = th.clip(output_next, 0, 1)
- loss = bceloss(cliped_output_next, target) + encoder_loss + position_loss + time_loss
+ encoding_loss = imageloss(output_cur, input) * cfg_net.encoder_regularizer
+ prediction_loss = imageloss(output_next, target)
+ loss = prediction_loss + encoding_loss + position_loss + time_loss + latent_loss
+ # apply loss decay according to num_rollout
+ if rollout_length > 0:
+ loss = loss * 0.75**(num_rollout-1)
+
# Accumulate loss over BPP steps
summed_loss = loss if summed_loss is None else summed_loss + loss
mask = mask.detach()
@@ -295,6 +347,7 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
# Update net status
update_net_status(num_updates, net, cfg, optimizer_init)
+ step_lr_scheduler(scheduler)
if num_updates == cfg.phases.start_inner_loop:
print('Start inner loop')
@@ -303,38 +356,39 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
if (cfg.bptt.increase_bptt_steps_every > 0) and ((num_updates-cfg.num_updates) % cfg.bptt.increase_bptt_steps_every == 0) and ((num_updates-cfg.num_updates) > 0):
increase_bptt_steps = True
+ if net.cfg.inner_loop_enabled and ('blackout' in cfg) and (cfg.blackout.blackout_increase_every > 0) and ((num_updates-cfg.num_updates) % cfg.blackout.blackout_increase_every == 0) and ((num_updates-cfg.num_updates) > 0):
+ blackout_rate = min(blackout_rate + cfg.blackout.blackout_increase_rate, cfg.blackout.blackout_rate_max)
+
# Plots for online evaluation
- if num_updates % 20000 == 0:
- plot_online(cfg, path, num_updates, input, background, mask, sequence_len, t, output_next, bg_error_next)
+ if num_updates % cfg.defaults.plot_interval == 0:
+ plot_next_sample = True
+ if plot_this_sample:
+ img_tensor = plot_online(cfg, path, f'{epoch}_{batch_index}', input, background, mask, sequence_len, t, output_next, bg_error_next)
+ video_list.append(img_tensor)
# Track statisitcs
if t >= cfg.defaults.statistics_offset:
- track_statistics(cfg_net, net, loss_tracker, input, gestalt, mask, target, output_next, output_cur, position_loss, time_loss, slots_closed, bg_error_cur, bg_error_next)
- loss_tracker.update_average_loss(loss.item())
+ track_statistics(cfg, net, loss_tracker, input, gestalt, mask, target, output_next, output_cur, encoding_loss, prediction_loss, position_loss, time_loss, latent_loss, slots_closed, slots_bounded, bg_error_cur, bg_error_next, scheduler, num_updates)
+ loss_tracker.update_average_loss(loss.item(), num_updates)
+ writer.add_scalar('Train/BPTT_steps', bptt_steps, num_updates)
+ writer.add_scalar('Train/Background_Blendin', background_blendin_factor, num_updates)
+ writer.add_scalar('Train/Blackout_Rate', blackout_rate, num_updates)
# Logging
if num_updates % 100 == 0 and run_optimizers:
print(f'Epoch[{num_updates}/{num_time_steps}/{sequence_len}]: {str(timer)}, {epoch + 1}, Blendin:{float(background_blendin_factor)}, i: {net.get_init_status() + net.initial_states.init.get():.2f},' + loss_tracker.get_log(), flush=True)
# Training finished
+ net_path = None
if num_updates > cfg.max_updates:
- save_model(
- os.path.join(path, 'nets', 'net_final.pt'),
- net,
- optimizer_init,
- optimizer_encoder,
- optimizer_decoder,
- optimizer_predictor,
- optimizer_background
- )
- print("Training finished")
- return
-
- # Checkpointing
+ net_path = os.path.join(path, 'nets', 'net_final.pt')
if num_updates % 50000 == 0 and run_optimizers:
+ net_path = os.path.join(path, 'nets', f'net_{num_updates}.pt')
+
+ if net_path is not None:
save_model(
- os.path.join(path, 'nets', f'net_{num_updates}.pt'),
+ net_path,
net,
optimizer_init,
optimizer_encoder,
@@ -342,33 +396,56 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
optimizer_predictor,
optimizer_background
)
- pass
+ if ('final' in net_path) or ('num_objects_test' in cfg.model):
+ eval_main(net_path, 1, deepcopy(cfg))
-def track_statistics(cfg_net, net, loss_tracker, input, gestalt, mask, target, output_next, output_cur, position_loss, time_loss, slots_closed, bg_error_cur, bg_error_next):
-
- # Prediction Loss
- mseloss = nn.MSELoss()
- loss_next = mseloss(output_next * bg_error_next, target * bg_error_next)
+ if 'final' in net_path:
+ print("Training finished")
+ writer.flush()
+ return
- # Encoder Loss (only for stats)
- loss_cur = mseloss(output_cur * bg_error_cur, input * bg_error_cur)
+ if plot_this_sample:
+ video_tensor = rearrange(th.stack(video_list, dim=0)[None], 'b t h w c -> b t c h w')
+ writer.add_video('Train/Video', video_tensor, num_updates)
+
+ pass
+def track_statistics(cfg, net, loss_tracker, input, gestalt, mask, target, output_next, output_cur, encoding_loss, prediction_loss, position_loss, time_loss, latent_loss, slots_closed, slots_bounded, bg_error_cur, bg_error_next, scheduler, num_updates):
+
# area of foreground mask
num_objects = th.mean(reduce((reduce(mask[:,:-1], 'b c h w -> b c', 'max') > 0.5).float(), 'b c -> b', 'sum')).item()
# difference in shape
- _gestalt = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt)), 'b (o c) -> (b o)', 'mean', o = cfg_net.num_objects)
- _gestalt2 = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt))**2, 'b (o c) -> (b o)', 'mean', o = cfg_net.num_objects)
+ _gestalt = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt)), 'b (o c) -> (b o)', 'mean', o = cfg.model.num_objects)
+ _gestalt2 = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt))**2, 'b (o c) -> (b o)', 'mean', o = cfg.model.num_objects)
max_mask = (reduce(mask[:,:-1], 'b c h w -> (b c)', 'max') > 0.5).float()
avg_gestalt = (th.sum(_gestalt * max_mask) / (1e-16 + th.sum(max_mask)))
avg_gestalt2 = (th.sum(_gestalt2 * max_mask) / (1e-16 + th.sum(max_mask)))
avg_gestalt_mean = th.mean(th.clip(gestalt, 0, 1))
# udpdate gates
- avg_update_gestalt = slots_closed[:,:,0].mean()
- avg_update_position = slots_closed[:,:,1].mean()
+ num_bounded = reduce(slots_bounded, 'b o -> b', 'sum').mean().item()
+ slots_closed = slots_closed * slots_bounded[:,:,None]
+ avg_update_gestalt = slots_closed[:,:,0].sum()/slots_bounded.sum() if slots_bounded.sum() > 0 else 0.0
+ avg_update_position = slots_closed[:,:,1].sum()/slots_bounded.sum() if slots_bounded.sum() > 0 else 0.0
+ avg_update_gestalt = float(avg_update_gestalt)
+ avg_update_position = float(avg_update_position)
+
+ # Prediction Loss + Encoder Loss as MSE + only foreground pixels
+ mseloss = nn.MSELoss()
+ loss_next = mseloss(output_next * bg_error_next, target * bg_error_next)
+ loss_cur = mseloss(output_cur * bg_error_cur, input * bg_error_cur)
+
+ # learning rate
+ if scheduler is not None:
+ lr = scheduler[0].get_last_lr()[0]
+ else:
+ lr = cfg.learning_rate.lr
- loss_tracker.update_complete(position_loss, time_loss, loss_cur, loss_next, loss_next, num_objects, net.get_openings(), avg_gestalt, avg_gestalt2, avg_gestalt_mean, avg_update_gestalt, avg_update_position)
+ # gatelORD openings
+ openings = th.mean(net.get_openings()).item()
+
+ loss_tracker.update_complete(position_loss, time_loss, latent_loss, loss_cur, loss_next, num_objects, openings, avg_gestalt, avg_gestalt2, avg_gestalt_mean, avg_update_gestalt, avg_update_position, num_bounded, lr, num_updates)
pass
def log_modelsize(net):
@@ -381,15 +458,49 @@ def log_modelsize(net):
print("\n")
pass
+def initialize_loss_function(cfg):
+ if not ('loss' in cfg.defaults):
+ cfg.defaults.loss = 'bce'
+
+ if cfg.defaults.loss == 'mse':
+ imageloss = nn.MSELoss()
+ elif cfg.defaults.loss == 'bce':
+ imageloss = nn.BCELoss()
+ else:
+ raise NotImplementedError
+
+ return imageloss
+
def init_optimizer(cfg, net):
- optimizer_init = RAdam(net.initial_states.parameters(), lr = cfg.learning_rate * 30)
- optimizer_encoder = RAdam(net.encoder.parameters(), lr = cfg.learning_rate)
- optimizer_decoder = RAdam(net.decoder.parameters(), lr = cfg.learning_rate)
- optimizer_predictor = RAdam(net.predictor.parameters(), lr = cfg.learning_rate)
- optimizer_background = RAdam([net.background.mask], lr = cfg.learning_rate)
- optimizer_update = RAdam(net.percept_gate_controller.parameters(), lr = cfg.learning_rate)
+ # backward compability:
+ if not isinstance(cfg.learning_rate, dict):
+ cfg.learning_rate = Dict({'lr': cfg.learning_rate})
+
+ lr = cfg.learning_rate.lr
+ optimizer_init = RAdam(net.initial_states.parameters(), lr = lr * 30)
+ optimizer_encoder = RAdam(net.encoder.parameters(), lr = lr)
+ optimizer_decoder = RAdam(net.decoder.parameters(), lr = lr)
+ optimizer_predictor = RAdam(net.predictor.parameters(), lr = lr)
+ optimizer_background = RAdam([net.background.mask], lr = lr)
+ optimizer_update = RAdam(net.percept_gate_controller.parameters(), lr = lr)
return optimizer_init,optimizer_encoder,optimizer_decoder,optimizer_predictor,optimizer_background,optimizer_update
+def init_lr_scheduler_StepLR(cfg, *optimizer_list):
+ scheduler_list = None
+ if 'deacrease_lr_every' in cfg.learning_rate:
+ print('Init lr scheduler')
+ scheduler_list = []
+ for optimizer in optimizer_list:
+ scheduler = th.optim.lr_scheduler.StepLR(optimizer, step_size=cfg.learning_rate.deacrease_lr_every, gamma=cfg.learning_rate.deacrease_lr_factor)
+ scheduler_list.append(scheduler)
+ return scheduler_list
+
+def step_lr_scheduler(scheduler_list):
+ if scheduler_list is not None:
+ for scheduler in scheduler_list:
+ scheduler.step()
+ pass
+
def save_model(
file,
net,
@@ -432,23 +543,23 @@ def load_model(
if load_optimizers:
optimizer_init.load_state_dict(state[f'optimizer_init'])
for n in range(len(optimizer_init.param_groups)):
- optimizer_init.param_groups[n]['lr'] = cfg.learning_rate
+ optimizer_init.param_groups[n]['lr'] = cfg.learning_rate.lr
optimizer_encoder.load_state_dict(state[f'optimizer_encoder'])
for n in range(len(optimizer_encoder.param_groups)):
- optimizer_encoder.param_groups[n]['lr'] = cfg.learning_rate
+ optimizer_encoder.param_groups[n]['lr'] = cfg.learning_rate.lr
optimizer_decoder.load_state_dict(state[f'optimizer_decoder'])
for n in range(len(optimizer_decoder.param_groups)):
- optimizer_decoder.param_groups[n]['lr'] = cfg.learning_rate
+ optimizer_decoder.param_groups[n]['lr'] = cfg.learning_rate.lr
optimizer_predictor.load_state_dict(state['optimizer_predictor'])
for n in range(len(optimizer_predictor.param_groups)):
- optimizer_predictor.param_groups[n]['lr'] = cfg.learning_rate
+ optimizer_predictor.param_groups[n]['lr'] = cfg.learning_rate.lr
optimizer_background.load_state_dict(state['optimizer_background'])
for n in range(len(optimizer_background.param_groups)):
- optimizer_background.param_groups[n]['lr'] = cfg.model.background.learning_rate
+ optimizer_background.param_groups[n]['lr'] = cfg.model.background.learning_rate.lr
# 1. Fill model with values of net
model = {}
@@ -484,40 +595,61 @@ def update_net_status(num_updates, net, cfg, optimizer_init):
if num_updates == cfg.phases.entity_pretraining_phase2_end and net.get_init_status() < 3:
net.inc_init_level()
for param in optimizer_init.param_groups:
- param['lr'] = cfg.learning_rate
+ param['lr'] = cfg.learning_rate.lr
pass
def plot_online(cfg, path, num_updates, input, background, mask, sequence_len, t, output_next, bg_error_next):
- plot_path = os.path.join(path, 'plots', f'net_{num_updates}')
- if not os.path.exists(plot_path):
+
+ # highlight error
+ grayscale = input[:,0:1] * 0.299 + input[:,1:2] * 0.587 + input[:,2:3] * 0.114
+ object_mask_cur = th.sum(mask[:,:-1], dim=1).unsqueeze(dim=1)
+ highlited_input = grayscale * (1 - object_mask_cur)
+ highlited_input += grayscale * object_mask_cur * 0.3333333
+ cmask = color_mask(mask[:,:-1])
+ highlited_input = highlited_input + cmask * 0.6666666
+
+ input_ = rearrange(input[0], 'c h w -> h w c').detach().cpu()
+ background_ = rearrange(background[0], 'c h w -> h w c').detach().cpu()
+ mask_ = rearrange(mask[0,-1:], 'c h w -> h w c').detach().cpu()
+ output_next_ = rearrange(output_next[0], 'c h w -> h w c').detach().cpu()
+ bg_error_next_ = rearrange(bg_error_next[0], 'c h w -> h w c').detach().cpu()
+ highlited_input_ = rearrange(highlited_input[0], 'c h w -> h w c').detach().cpu()
+
+ if False:
+ plot_path = os.path.join(path, 'plots', f'net_{num_updates}')
os.makedirs(plot_path, exist_ok=True)
-
- # highlight error
- grayscale = input[:,0:1] * 0.299 + input[:,1:2] * 0.587 + input[:,2:3] * 0.114
- object_mask_cur = th.sum(mask[:,:-1], dim=1).unsqueeze(dim=1)
- highlited_input = grayscale * (1 - object_mask_cur)
- highlited_input += grayscale * object_mask_cur * 0.3333333
- cmask = color_mask(mask[:,:-1])
- highlited_input = highlited_input + cmask * 0.6666666
-
- cv2.imwrite(os.path.join(plot_path, f'input-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(input[0], 'c h w -> h w c').detach().cpu().numpy() * 255)
- cv2.imwrite(os.path.join(plot_path, f'background-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(background[0], 'c h w -> h w c').detach().cpu().numpy() * 255)
- cv2.imwrite(os.path.join(plot_path, f'error_mask-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(bg_error_next[0], 'c h w -> h w c').detach().cpu().numpy() * 255)
- cv2.imwrite(os.path.join(plot_path, f'background_mask-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(mask[0,-1:], 'c h w -> h w c').detach().cpu().numpy() * 255)
- cv2.imwrite(os.path.join(plot_path, f'output_next-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(output_next[0], 'c h w -> h w c').detach().cpu().numpy() * 255)
- cv2.imwrite(os.path.join(plot_path, f'output_highlight-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(highlited_input[0], 'c h w -> h w c').detach().cpu().numpy() * 255)
- pass
-
-def get_loader(cfg, set, cfg_net, shuffle = True):
- loader = DataLoader(
- set,
- pin_memory = True,
- num_workers = cfg.defaults.num_workers,
- batch_size = cfg_net.batch_size,
- shuffle = shuffle,
- drop_last = True,
- prefetch_factor = cfg.defaults.prefetch_factor,
- persistent_workers = True
- )
+ cv2.imwrite(os.path.join(plot_path, f'input-{t+cfg.defaults.teacher_forcing:03d}.jpg'), input_.numpy() * 255)
+ cv2.imwrite(os.path.join(plot_path, f'background-{t+cfg.defaults.teacher_forcing:03d}.jpg'), background_.numpy() * 255)
+ cv2.imwrite(os.path.join(plot_path, f'error_mask-{t+cfg.defaults.teacher_forcing:03d}.jpg'), bg_error_next_.numpy() * 255)
+ cv2.imwrite(os.path.join(plot_path, f'background_mask-{t+cfg.defaults.teacher_forcing:03d}.jpg'), mask_.numpy() * 255)
+ cv2.imwrite(os.path.join(plot_path, f'output_next-{t+cfg.defaults.teacher_forcing:03d}.jpg'), output_next_.numpy() * 255)
+ cv2.imwrite(os.path.join(plot_path, f'output_highlight-{t+cfg.defaults.teacher_forcing:03d}.jpg'), highlited_input_.numpy() * 255)
+
+ # stack input, output, mask and highlight into one image horizontally
+ img_tensor = th.cat([input_, highlited_input_, output_next_], dim=1)
+ return img_tensor
+
+def get_loader(cfg, set, cfg_net, shuffle = True, verbose=False):
+ if ((cfg.datatype == 'bouncingballs') and verbose) or ((cfg.datatype == 'adept') and not shuffle):
+ loader = DataLoader(
+ set,
+ pin_memory = True,
+ num_workers = 0,
+ batch_size = cfg_net.batch_size,
+ shuffle = shuffle,
+ drop_last = True,
+ persistent_workers = False
+ )
+ else:
+ loader = DataLoader(
+ set,
+ pin_memory = True,
+ num_workers = cfg.defaults.num_workers,
+ batch_size = cfg_net.batch_size,
+ shuffle = shuffle,
+ drop_last = True,
+ prefetch_factor = cfg.defaults.prefetch_factor,
+ persistent_workers = True
+ )
return loader \ No newline at end of file