diff options
Diffstat (limited to 'scripts/training.py')
-rw-r--r-- | scripts/training.py | 324 |
1 files changed, 228 insertions, 96 deletions
diff --git a/scripts/training.py b/scripts/training.py index b807e54..fd4a4bf 100644 --- a/scripts/training.py +++ b/scripts/training.py @@ -1,3 +1,4 @@ +from copy import deepcopy import torch as th from torch import nn from torch.utils.data import Dataset, DataLoader, Subset @@ -5,15 +6,18 @@ import cv2 import numpy as np import os from einops import rearrange, repeat, reduce -from scripts.utils.configuration import Configuration -from scripts.utils.io import init_device, model_path, LossLogger +from scripts.utils.configuration import Configuration, Dict +from scripts.utils.io import WriterWrapper, init_device, model_path, LossLogger from scripts.utils.optimizers import RAdam from model.loci import Loci import random from scripts.utils.io import Timer -from scripts.utils.plot_utils import color_mask -from scripts.validation import validation_clevrer, validation_adept +from scripts.utils.plot_utils import color_mask, plot_timestep +from scripts.validation import validation_bb, validation_clevrer, validation_adept +from scripts.exec.eval import main as eval_main +os.environ["WANDB__SERVICE_WAIT"] = "300" +os.environ["WANDB_ DISABLE_SERVICE"] = "true" def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): @@ -21,6 +25,12 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): device, verbose = init_device(cfg) if verbose: valset = Subset(valset, range(0, 8)) + use_wandb = (verbose == False) + + # generate random seed if not set + if not ('seed' in cfg.defaults): + cfg.defaults['seed'] = random.randint(0, 100) + th.manual_seed(cfg.defaults.seed) # Define model path path = model_path(cfg, overwrite=False) @@ -34,12 +44,16 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): teacher_forcing = cfg.defaults.teacher_forcing ) net = net.to(device=device) + #net = th.compile(net) + #net = th.jit.script(net) # Log model size log_modelsize(net) + writer = WriterWrapper(use_wandb, cfg) # Init Optimizers optimizer_init, optimizer_encoder, optimizer_decoder, optimizer_predictor, optimizer_background, optimizer_update = init_optimizer(cfg, net) + scheduler = init_lr_scheduler_StepLR(cfg, optimizer_init, optimizer_encoder, optimizer_decoder, optimizer_predictor, optimizer_background, optimizer_update) # Option to load model if file != "": @@ -58,9 +72,8 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): print(f'loaded {file}', flush=True) # Set up data loaders - trainloader = get_loader(cfg, trainset, cfg_net, shuffle=True) - valset.train = True #valset.dataset.train = True - valloader = get_loader(cfg, valset, cfg_net, shuffle=False) + trainloader = get_loader(cfg, trainset, cfg_net, shuffle=True, verbose=verbose) + valloader = get_loader(cfg, valset, cfg_net, shuffle=False, verbose=verbose) # initial save save_model( @@ -80,16 +93,24 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): print('!!! Net init status: ', net.get_init_status()) # Set up statistics - loss_tracker = LossLogger() + loss_tracker = LossLogger(writer) + + # Init loss function + imageloss = initialize_loss_function(cfg) # Set up training variables num_time_steps = 0 bptt_steps = cfg.bptt.bptt_steps + if not 'plot_interval' in cfg.defaults: + cfg.defaults.plot_interval = 20000 + blackout_rate = cfg.blackout.blackout_rate if ('blackout' in cfg) else 0.0 + rollout_length = cfg.vp.rollout_length if ('vp' in cfg) else 0 + burnin_length = cfg.vp.burnin_length if ('vp' in cfg) else 0 increase_bptt_steps = False background_blendin_factor = 0.0 th.backends.cudnn.benchmark = True + plot_next_sample = False timer = Timer() - bceloss = nn.BCELoss() # Init net to current num_updates if num_updates >= cfg.phases.background_pretraining_end and net.get_init_status() < 1: @@ -101,7 +122,7 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): if num_updates >= cfg.phases.entity_pretraining_phase2_end and net.get_init_status() < 3: net.inc_init_level() for param in optimizer_init.param_groups: - param['lr'] = cfg.learning_rate + param['lr'] = cfg.learning_rate.lr if num_updates > cfg.phases.start_inner_loop: net.cfg.inner_loop_enabled = True @@ -117,16 +138,18 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): # Validation every epoch if epoch >= 0: if cfg.datatype == 'adept': - validation_adept(valloader, net, cfg, device) + validation_adept(valloader, net, cfg, device, writer, epoch, path) elif cfg.datatype == 'clevrer': - validation_clevrer(valloader, net, cfg, device) + validation_clevrer(valloader, net, cfg, device, writer, epoch, path) + elif cfg.datatype == "bouncingballs" and (epoch % 2 == 0): + validation_bb(valloader, net, cfg, device, writer, epoch, path) # Start epoch training print('Start epoch:', epoch) # Backprop through time steps if increase_bptt_steps: - bptt_steps = max(bptt_steps + 1, cfg.bptt.bptt_steps_max) + bptt_steps = min(bptt_steps + 1, cfg.bptt.bptt_steps_max) print('Increase closed loop steps to', bptt_steps) increase_bptt_steps = False @@ -149,7 +172,8 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): slots_occlusionfactor = None # Apply skip frames to sequence - selec = range(random.randrange(cfg.defaults.skip_frames), tensor.shape[1], cfg.defaults.skip_frames) + start = random.randrange(cfg.defaults.skip_frames) if rollout_length == 0 else random.randrange(0, (tensor.shape[1]-(rollout_length+burnin_length))) + selec = range(start, tensor.shape[1], cfg.defaults.skip_frames) tensor = tensor[:,selec] sequence_len = tensor.shape[1] @@ -159,6 +183,12 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): target = th.clip(input, 0, 1).detach() error_last = None + # plotting mode + plot_this_sample = plot_next_sample + plot_next_sample = False + video_list = [] + num_rollout = 0 + # First apply teacher forcing for the first x frames for t in range(-cfg.defaults.teacher_forcing, sequence_len-1): @@ -185,13 +215,27 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): if net.get_init_status() > 2 and cfg.defaults.error_dropout > 0 and np.random.rand() < cfg.defaults.error_dropout: error_last = th.zeros_like(error_last) - # Apply sensation blackout when training clevrere - if net.cfg.inner_loop_enabled and cfg.datatype == 'clevrer': - if t >= 10: - blackout = th.tensor((np.random.rand(cfg_net.batch_size) < 0.2)[:,None,None,None]).float().to(device) + # Apply sensation blackout when training clevrer + if net.cfg.inner_loop_enabled and blackout_rate > 0: + if t >= cfg.blackout.blackout_start_timestep: + blackout = th.tensor((np.random.rand(cfg_net.batch_size) < blackout_rate)[:,None,None,None]).float().to(device) input = blackout * (input * 0) + (1-blackout) * input error_last = blackout * (error_last * 0) + (1-blackout) * error_last - + + # Apply rollout training + if net.cfg.inner_loop_enabled and (rollout_length > 0) and (burnin_length-1) < t: + input = input * 0 + error_last = error_last * 0 + num_rollout += 1 + + if (burnin_length + rollout_length -1) == t: + run_optimizers = True + detach = True + + if (burnin_length + rollout_length) == t: + break + + # Forward Pass ( output_next, @@ -206,7 +250,9 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): slots_occlusionfactor, position_loss, time_loss, - slots_closed + latent_loss, + slots_closed, + slots_bounded ) = net( input, # current frame error_last, # error of last frame --> missing object @@ -228,6 +274,8 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): # Loss weighting position_loss = position_loss * cfg_net.position_regularizer time_loss = time_loss * cfg_net.time_regularizer + if latent_loss.item() > 0: + latent_loss = latent_loss * cfg_net.latent_regularizer # Compute background error bg_error_cur = th.sqrt(reduce((input - background)**2, 'b c h w -> b 1 h w', 'mean')).detach() @@ -253,10 +301,14 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): background_blendin_factor = min(1, background_blendin_factor + 0.001) # Final Loss computation - encoder_loss = th.mean((output_cur - input)**2) * cfg_net.encoder_regularizer - cliped_output_next = th.clip(output_next, 0, 1) - loss = bceloss(cliped_output_next, target) + encoder_loss + position_loss + time_loss + encoding_loss = imageloss(output_cur, input) * cfg_net.encoder_regularizer + prediction_loss = imageloss(output_next, target) + loss = prediction_loss + encoding_loss + position_loss + time_loss + latent_loss + # apply loss decay according to num_rollout + if rollout_length > 0: + loss = loss * 0.75**(num_rollout-1) + # Accumulate loss over BPP steps summed_loss = loss if summed_loss is None else summed_loss + loss mask = mask.detach() @@ -295,6 +347,7 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): # Update net status update_net_status(num_updates, net, cfg, optimizer_init) + step_lr_scheduler(scheduler) if num_updates == cfg.phases.start_inner_loop: print('Start inner loop') @@ -303,38 +356,39 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): if (cfg.bptt.increase_bptt_steps_every > 0) and ((num_updates-cfg.num_updates) % cfg.bptt.increase_bptt_steps_every == 0) and ((num_updates-cfg.num_updates) > 0): increase_bptt_steps = True + if net.cfg.inner_loop_enabled and ('blackout' in cfg) and (cfg.blackout.blackout_increase_every > 0) and ((num_updates-cfg.num_updates) % cfg.blackout.blackout_increase_every == 0) and ((num_updates-cfg.num_updates) > 0): + blackout_rate = min(blackout_rate + cfg.blackout.blackout_increase_rate, cfg.blackout.blackout_rate_max) + # Plots for online evaluation - if num_updates % 20000 == 0: - plot_online(cfg, path, num_updates, input, background, mask, sequence_len, t, output_next, bg_error_next) + if num_updates % cfg.defaults.plot_interval == 0: + plot_next_sample = True + if plot_this_sample: + img_tensor = plot_online(cfg, path, f'{epoch}_{batch_index}', input, background, mask, sequence_len, t, output_next, bg_error_next) + video_list.append(img_tensor) # Track statisitcs if t >= cfg.defaults.statistics_offset: - track_statistics(cfg_net, net, loss_tracker, input, gestalt, mask, target, output_next, output_cur, position_loss, time_loss, slots_closed, bg_error_cur, bg_error_next) - loss_tracker.update_average_loss(loss.item()) + track_statistics(cfg, net, loss_tracker, input, gestalt, mask, target, output_next, output_cur, encoding_loss, prediction_loss, position_loss, time_loss, latent_loss, slots_closed, slots_bounded, bg_error_cur, bg_error_next, scheduler, num_updates) + loss_tracker.update_average_loss(loss.item(), num_updates) + writer.add_scalar('Train/BPTT_steps', bptt_steps, num_updates) + writer.add_scalar('Train/Background_Blendin', background_blendin_factor, num_updates) + writer.add_scalar('Train/Blackout_Rate', blackout_rate, num_updates) # Logging if num_updates % 100 == 0 and run_optimizers: print(f'Epoch[{num_updates}/{num_time_steps}/{sequence_len}]: {str(timer)}, {epoch + 1}, Blendin:{float(background_blendin_factor)}, i: {net.get_init_status() + net.initial_states.init.get():.2f},' + loss_tracker.get_log(), flush=True) # Training finished + net_path = None if num_updates > cfg.max_updates: - save_model( - os.path.join(path, 'nets', 'net_final.pt'), - net, - optimizer_init, - optimizer_encoder, - optimizer_decoder, - optimizer_predictor, - optimizer_background - ) - print("Training finished") - return - - # Checkpointing + net_path = os.path.join(path, 'nets', 'net_final.pt') if num_updates % 50000 == 0 and run_optimizers: + net_path = os.path.join(path, 'nets', f'net_{num_updates}.pt') + + if net_path is not None: save_model( - os.path.join(path, 'nets', f'net_{num_updates}.pt'), + net_path, net, optimizer_init, optimizer_encoder, @@ -342,33 +396,56 @@ def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): optimizer_predictor, optimizer_background ) - pass + if ('final' in net_path) or ('num_objects_test' in cfg.model): + eval_main(net_path, 1, deepcopy(cfg)) -def track_statistics(cfg_net, net, loss_tracker, input, gestalt, mask, target, output_next, output_cur, position_loss, time_loss, slots_closed, bg_error_cur, bg_error_next): - - # Prediction Loss - mseloss = nn.MSELoss() - loss_next = mseloss(output_next * bg_error_next, target * bg_error_next) + if 'final' in net_path: + print("Training finished") + writer.flush() + return - # Encoder Loss (only for stats) - loss_cur = mseloss(output_cur * bg_error_cur, input * bg_error_cur) + if plot_this_sample: + video_tensor = rearrange(th.stack(video_list, dim=0)[None], 'b t h w c -> b t c h w') + writer.add_video('Train/Video', video_tensor, num_updates) + + pass +def track_statistics(cfg, net, loss_tracker, input, gestalt, mask, target, output_next, output_cur, encoding_loss, prediction_loss, position_loss, time_loss, latent_loss, slots_closed, slots_bounded, bg_error_cur, bg_error_next, scheduler, num_updates): + # area of foreground mask num_objects = th.mean(reduce((reduce(mask[:,:-1], 'b c h w -> b c', 'max') > 0.5).float(), 'b c -> b', 'sum')).item() # difference in shape - _gestalt = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt)), 'b (o c) -> (b o)', 'mean', o = cfg_net.num_objects) - _gestalt2 = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt))**2, 'b (o c) -> (b o)', 'mean', o = cfg_net.num_objects) + _gestalt = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt)), 'b (o c) -> (b o)', 'mean', o = cfg.model.num_objects) + _gestalt2 = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt))**2, 'b (o c) -> (b o)', 'mean', o = cfg.model.num_objects) max_mask = (reduce(mask[:,:-1], 'b c h w -> (b c)', 'max') > 0.5).float() avg_gestalt = (th.sum(_gestalt * max_mask) / (1e-16 + th.sum(max_mask))) avg_gestalt2 = (th.sum(_gestalt2 * max_mask) / (1e-16 + th.sum(max_mask))) avg_gestalt_mean = th.mean(th.clip(gestalt, 0, 1)) # udpdate gates - avg_update_gestalt = slots_closed[:,:,0].mean() - avg_update_position = slots_closed[:,:,1].mean() + num_bounded = reduce(slots_bounded, 'b o -> b', 'sum').mean().item() + slots_closed = slots_closed * slots_bounded[:,:,None] + avg_update_gestalt = slots_closed[:,:,0].sum()/slots_bounded.sum() if slots_bounded.sum() > 0 else 0.0 + avg_update_position = slots_closed[:,:,1].sum()/slots_bounded.sum() if slots_bounded.sum() > 0 else 0.0 + avg_update_gestalt = float(avg_update_gestalt) + avg_update_position = float(avg_update_position) + + # Prediction Loss + Encoder Loss as MSE + only foreground pixels + mseloss = nn.MSELoss() + loss_next = mseloss(output_next * bg_error_next, target * bg_error_next) + loss_cur = mseloss(output_cur * bg_error_cur, input * bg_error_cur) + + # learning rate + if scheduler is not None: + lr = scheduler[0].get_last_lr()[0] + else: + lr = cfg.learning_rate.lr - loss_tracker.update_complete(position_loss, time_loss, loss_cur, loss_next, loss_next, num_objects, net.get_openings(), avg_gestalt, avg_gestalt2, avg_gestalt_mean, avg_update_gestalt, avg_update_position) + # gatelORD openings + openings = th.mean(net.get_openings()).item() + + loss_tracker.update_complete(position_loss, time_loss, latent_loss, loss_cur, loss_next, num_objects, openings, avg_gestalt, avg_gestalt2, avg_gestalt_mean, avg_update_gestalt, avg_update_position, num_bounded, lr, num_updates) pass def log_modelsize(net): @@ -381,15 +458,49 @@ def log_modelsize(net): print("\n") pass +def initialize_loss_function(cfg): + if not ('loss' in cfg.defaults): + cfg.defaults.loss = 'bce' + + if cfg.defaults.loss == 'mse': + imageloss = nn.MSELoss() + elif cfg.defaults.loss == 'bce': + imageloss = nn.BCELoss() + else: + raise NotImplementedError + + return imageloss + def init_optimizer(cfg, net): - optimizer_init = RAdam(net.initial_states.parameters(), lr = cfg.learning_rate * 30) - optimizer_encoder = RAdam(net.encoder.parameters(), lr = cfg.learning_rate) - optimizer_decoder = RAdam(net.decoder.parameters(), lr = cfg.learning_rate) - optimizer_predictor = RAdam(net.predictor.parameters(), lr = cfg.learning_rate) - optimizer_background = RAdam([net.background.mask], lr = cfg.learning_rate) - optimizer_update = RAdam(net.percept_gate_controller.parameters(), lr = cfg.learning_rate) + # backward compability: + if not isinstance(cfg.learning_rate, dict): + cfg.learning_rate = Dict({'lr': cfg.learning_rate}) + + lr = cfg.learning_rate.lr + optimizer_init = RAdam(net.initial_states.parameters(), lr = lr * 30) + optimizer_encoder = RAdam(net.encoder.parameters(), lr = lr) + optimizer_decoder = RAdam(net.decoder.parameters(), lr = lr) + optimizer_predictor = RAdam(net.predictor.parameters(), lr = lr) + optimizer_background = RAdam([net.background.mask], lr = lr) + optimizer_update = RAdam(net.percept_gate_controller.parameters(), lr = lr) return optimizer_init,optimizer_encoder,optimizer_decoder,optimizer_predictor,optimizer_background,optimizer_update +def init_lr_scheduler_StepLR(cfg, *optimizer_list): + scheduler_list = None + if 'deacrease_lr_every' in cfg.learning_rate: + print('Init lr scheduler') + scheduler_list = [] + for optimizer in optimizer_list: + scheduler = th.optim.lr_scheduler.StepLR(optimizer, step_size=cfg.learning_rate.deacrease_lr_every, gamma=cfg.learning_rate.deacrease_lr_factor) + scheduler_list.append(scheduler) + return scheduler_list + +def step_lr_scheduler(scheduler_list): + if scheduler_list is not None: + for scheduler in scheduler_list: + scheduler.step() + pass + def save_model( file, net, @@ -432,23 +543,23 @@ def load_model( if load_optimizers: optimizer_init.load_state_dict(state[f'optimizer_init']) for n in range(len(optimizer_init.param_groups)): - optimizer_init.param_groups[n]['lr'] = cfg.learning_rate + optimizer_init.param_groups[n]['lr'] = cfg.learning_rate.lr optimizer_encoder.load_state_dict(state[f'optimizer_encoder']) for n in range(len(optimizer_encoder.param_groups)): - optimizer_encoder.param_groups[n]['lr'] = cfg.learning_rate + optimizer_encoder.param_groups[n]['lr'] = cfg.learning_rate.lr optimizer_decoder.load_state_dict(state[f'optimizer_decoder']) for n in range(len(optimizer_decoder.param_groups)): - optimizer_decoder.param_groups[n]['lr'] = cfg.learning_rate + optimizer_decoder.param_groups[n]['lr'] = cfg.learning_rate.lr optimizer_predictor.load_state_dict(state['optimizer_predictor']) for n in range(len(optimizer_predictor.param_groups)): - optimizer_predictor.param_groups[n]['lr'] = cfg.learning_rate + optimizer_predictor.param_groups[n]['lr'] = cfg.learning_rate.lr optimizer_background.load_state_dict(state['optimizer_background']) for n in range(len(optimizer_background.param_groups)): - optimizer_background.param_groups[n]['lr'] = cfg.model.background.learning_rate + optimizer_background.param_groups[n]['lr'] = cfg.model.background.learning_rate.lr # 1. Fill model with values of net model = {} @@ -484,40 +595,61 @@ def update_net_status(num_updates, net, cfg, optimizer_init): if num_updates == cfg.phases.entity_pretraining_phase2_end and net.get_init_status() < 3: net.inc_init_level() for param in optimizer_init.param_groups: - param['lr'] = cfg.learning_rate + param['lr'] = cfg.learning_rate.lr pass def plot_online(cfg, path, num_updates, input, background, mask, sequence_len, t, output_next, bg_error_next): - plot_path = os.path.join(path, 'plots', f'net_{num_updates}') - if not os.path.exists(plot_path): + + # highlight error + grayscale = input[:,0:1] * 0.299 + input[:,1:2] * 0.587 + input[:,2:3] * 0.114 + object_mask_cur = th.sum(mask[:,:-1], dim=1).unsqueeze(dim=1) + highlited_input = grayscale * (1 - object_mask_cur) + highlited_input += grayscale * object_mask_cur * 0.3333333 + cmask = color_mask(mask[:,:-1]) + highlited_input = highlited_input + cmask * 0.6666666 + + input_ = rearrange(input[0], 'c h w -> h w c').detach().cpu() + background_ = rearrange(background[0], 'c h w -> h w c').detach().cpu() + mask_ = rearrange(mask[0,-1:], 'c h w -> h w c').detach().cpu() + output_next_ = rearrange(output_next[0], 'c h w -> h w c').detach().cpu() + bg_error_next_ = rearrange(bg_error_next[0], 'c h w -> h w c').detach().cpu() + highlited_input_ = rearrange(highlited_input[0], 'c h w -> h w c').detach().cpu() + + if False: + plot_path = os.path.join(path, 'plots', f'net_{num_updates}') os.makedirs(plot_path, exist_ok=True) - - # highlight error - grayscale = input[:,0:1] * 0.299 + input[:,1:2] * 0.587 + input[:,2:3] * 0.114 - object_mask_cur = th.sum(mask[:,:-1], dim=1).unsqueeze(dim=1) - highlited_input = grayscale * (1 - object_mask_cur) - highlited_input += grayscale * object_mask_cur * 0.3333333 - cmask = color_mask(mask[:,:-1]) - highlited_input = highlited_input + cmask * 0.6666666 - - cv2.imwrite(os.path.join(plot_path, f'input-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(input[0], 'c h w -> h w c').detach().cpu().numpy() * 255) - cv2.imwrite(os.path.join(plot_path, f'background-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(background[0], 'c h w -> h w c').detach().cpu().numpy() * 255) - cv2.imwrite(os.path.join(plot_path, f'error_mask-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(bg_error_next[0], 'c h w -> h w c').detach().cpu().numpy() * 255) - cv2.imwrite(os.path.join(plot_path, f'background_mask-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(mask[0,-1:], 'c h w -> h w c').detach().cpu().numpy() * 255) - cv2.imwrite(os.path.join(plot_path, f'output_next-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(output_next[0], 'c h w -> h w c').detach().cpu().numpy() * 255) - cv2.imwrite(os.path.join(plot_path, f'output_highlight-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(highlited_input[0], 'c h w -> h w c').detach().cpu().numpy() * 255) - pass - -def get_loader(cfg, set, cfg_net, shuffle = True): - loader = DataLoader( - set, - pin_memory = True, - num_workers = cfg.defaults.num_workers, - batch_size = cfg_net.batch_size, - shuffle = shuffle, - drop_last = True, - prefetch_factor = cfg.defaults.prefetch_factor, - persistent_workers = True - ) + cv2.imwrite(os.path.join(plot_path, f'input-{t+cfg.defaults.teacher_forcing:03d}.jpg'), input_.numpy() * 255) + cv2.imwrite(os.path.join(plot_path, f'background-{t+cfg.defaults.teacher_forcing:03d}.jpg'), background_.numpy() * 255) + cv2.imwrite(os.path.join(plot_path, f'error_mask-{t+cfg.defaults.teacher_forcing:03d}.jpg'), bg_error_next_.numpy() * 255) + cv2.imwrite(os.path.join(plot_path, f'background_mask-{t+cfg.defaults.teacher_forcing:03d}.jpg'), mask_.numpy() * 255) + cv2.imwrite(os.path.join(plot_path, f'output_next-{t+cfg.defaults.teacher_forcing:03d}.jpg'), output_next_.numpy() * 255) + cv2.imwrite(os.path.join(plot_path, f'output_highlight-{t+cfg.defaults.teacher_forcing:03d}.jpg'), highlited_input_.numpy() * 255) + + # stack input, output, mask and highlight into one image horizontally + img_tensor = th.cat([input_, highlited_input_, output_next_], dim=1) + return img_tensor + +def get_loader(cfg, set, cfg_net, shuffle = True, verbose=False): + if ((cfg.datatype == 'bouncingballs') and verbose) or ((cfg.datatype == 'adept') and not shuffle): + loader = DataLoader( + set, + pin_memory = True, + num_workers = 0, + batch_size = cfg_net.batch_size, + shuffle = shuffle, + drop_last = True, + persistent_workers = False + ) + else: + loader = DataLoader( + set, + pin_memory = True, + num_workers = cfg.defaults.num_workers, + batch_size = cfg_net.batch_size, + shuffle = shuffle, + drop_last = True, + prefetch_factor = cfg.defaults.prefetch_factor, + persistent_workers = True + ) return loader
\ No newline at end of file |