diff options
author | fredeee | 2023-11-02 10:47:21 +0100 |
---|---|---|
committer | fredeee | 2023-11-02 10:47:21 +0100 |
commit | f8302ee886ef9b631f11a52900dac964a61350e1 (patch) | |
tree | 87288be6f851ab69405e524b81940c501c52789a /scripts/training.py | |
parent | f16fef1ab9371e1c81a2e0b2fbea59dee285a9f8 (diff) |
initiaƶ commit
Diffstat (limited to 'scripts/training.py')
-rw-r--r-- | scripts/training.py | 523 |
1 files changed, 523 insertions, 0 deletions
diff --git a/scripts/training.py b/scripts/training.py new file mode 100644 index 0000000..b807e54 --- /dev/null +++ b/scripts/training.py @@ -0,0 +1,523 @@ +import torch as th +from torch import nn +from torch.utils.data import Dataset, DataLoader, Subset +import cv2 +import numpy as np +import os +from einops import rearrange, repeat, reduce +from scripts.utils.configuration import Configuration +from scripts.utils.io import init_device, model_path, LossLogger +from scripts.utils.optimizers import RAdam +from model.loci import Loci +import random +from scripts.utils.io import Timer +from scripts.utils.plot_utils import color_mask +from scripts.validation import validation_clevrer, validation_adept + + +def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file): + + # Set up cpu or gpu training + device, verbose = init_device(cfg) + if verbose: + valset = Subset(valset, range(0, 8)) + + # Define model path + path = model_path(cfg, overwrite=False) + cfg.save(path) + os.makedirs(os.path.join(path, 'nets'), exist_ok=True) + + # Configure model + cfg_net = cfg.model + net = Loci( + cfg = cfg_net, + teacher_forcing = cfg.defaults.teacher_forcing + ) + net = net.to(device=device) + + # Log model size + log_modelsize(net) + + # Init Optimizers + optimizer_init, optimizer_encoder, optimizer_decoder, optimizer_predictor, optimizer_background, optimizer_update = init_optimizer(cfg, net) + + # Option to load model + if file != "": + load_model( + file, + cfg, + net, + optimizer_init, + optimizer_encoder, + optimizer_decoder, + optimizer_predictor, + optimizer_background, + cfg.defaults.load_optimizers, + only_encoder_decoder = (cfg.num_updates == 0) # only load encoder and decoder for initial training + ) + print(f'loaded {file}', flush=True) + + # Set up data loaders + trainloader = get_loader(cfg, trainset, cfg_net, shuffle=True) + valset.train = True #valset.dataset.train = True + valloader = get_loader(cfg, valset, cfg_net, shuffle=False) + + # initial save + save_model( + os.path.join(path, 'nets', 'net0.pt'), + net, + optimizer_init, + optimizer_encoder, + optimizer_decoder, + optimizer_predictor, + optimizer_background + ) + + # Start training at num_updates + num_updates = cfg.num_updates + if num_updates > 0: + print('!!! Start training at num_updates: ', num_updates) + print('!!! Net init status: ', net.get_init_status()) + + # Set up statistics + loss_tracker = LossLogger() + + # Set up training variables + num_time_steps = 0 + bptt_steps = cfg.bptt.bptt_steps + increase_bptt_steps = False + background_blendin_factor = 0.0 + th.backends.cudnn.benchmark = True + timer = Timer() + bceloss = nn.BCELoss() + + # Init net to current num_updates + if num_updates >= cfg.phases.background_pretraining_end and net.get_init_status() < 1: + net.inc_init_level() + + if num_updates >= cfg.phases.entity_pretraining_phase1_end and net.get_init_status() < 2: + net.inc_init_level() + + if num_updates >= cfg.phases.entity_pretraining_phase2_end and net.get_init_status() < 3: + net.inc_init_level() + for param in optimizer_init.param_groups: + param['lr'] = cfg.learning_rate + + if num_updates > cfg.phases.start_inner_loop: + net.cfg.inner_loop_enabled = True + + if num_updates >= cfg.phases.entity_pretraining_phase1_end: + background_blendin_factor = max(min((num_updates - cfg.phases.entity_pretraining_phase1_end)/30000, 1.0), 0.0) + + + # --- Start Training + print('Start training') + for epoch in range(cfg.max_epochs): + + # Validation every epoch + if epoch >= 0: + if cfg.datatype == 'adept': + validation_adept(valloader, net, cfg, device) + elif cfg.datatype == 'clevrer': + validation_clevrer(valloader, net, cfg, device) + + # Start epoch training + print('Start epoch:', epoch) + + # Backprop through time steps + if increase_bptt_steps: + bptt_steps = max(bptt_steps + 1, cfg.bptt.bptt_steps_max) + print('Increase closed loop steps to', bptt_steps) + increase_bptt_steps = False + + for batch_index, input in enumerate(trainloader): + + # Extract input and background + tensor = input[0] + background = input[1].to(device) + shuffleslots = (num_updates <= cfg.phases.shufleslots_end) + + # Placeholders + position = None + gestalt = None + priority = None + mask = None + object = None + rawmask = None + loss = th.tensor(0) + summed_loss = None + slots_occlusionfactor = None + + # Apply skip frames to sequence + selec = range(random.randrange(cfg.defaults.skip_frames), tensor.shape[1], cfg.defaults.skip_frames) + tensor = tensor[:,selec] + sequence_len = tensor.shape[1] + + # Initial frame + input = tensor[:,0].to(device) + input_next = input + target = th.clip(input, 0, 1).detach() + error_last = None + + # First apply teacher forcing for the first x frames + for t in range(-cfg.defaults.teacher_forcing, sequence_len-1): + + # Set update scheme for backprop through time + if t >= cfg.bptt.bptt_start_timestep: + t_run = (t - cfg.bptt.bptt_start_timestep) + run_optimizers = t_run % bptt_steps == bptt_steps - 1 + detach = (t_run % bptt_steps == 0) or t == -cfg.defaults.teacher_forcing + else: + run_optimizers = True + detach = True + + if verbose: + print(f't: {t}, run_optimizers: {run_optimizers}, detach: {detach}') + + if t >= 0: + # Skip to next frame + num_time_steps += 1 + input = input_next + input_next = tensor[:,t+1].to(device) + target = th.clip(input_next, 0, 1) + + # Apply error dropout + if net.get_init_status() > 2 and cfg.defaults.error_dropout > 0 and np.random.rand() < cfg.defaults.error_dropout: + error_last = th.zeros_like(error_last) + + # Apply sensation blackout when training clevrere + if net.cfg.inner_loop_enabled and cfg.datatype == 'clevrer': + if t >= 10: + blackout = th.tensor((np.random.rand(cfg_net.batch_size) < 0.2)[:,None,None,None]).float().to(device) + input = blackout * (input * 0) + (1-blackout) * input + error_last = blackout * (error_last * 0) + (1-blackout) * error_last + + # Forward Pass + ( + output_next, + output_cur, + position, + gestalt, + priority, + mask, + rawmask, + object, + background, + slots_occlusionfactor, + position_loss, + time_loss, + slots_closed + ) = net( + input, # current frame + error_last, # error of last frame --> missing object + mask, # masks of current frame + rawmask, # raw masks of current frame + position, # positions of objects of next frame + gestalt, # gestalt of objects of next frame + priority, # priority of objects of next frame + background, + slots_occlusionfactor, + reset = (t == -cfg.defaults.teacher_forcing), # new sequence + warmup = (t < 0), # teacher forcing + detach = detach, + shuffleslots = shuffleslots or ((cfg.datatype == 'clevrer') and (t<=0)), + reset_mask = (t <= 0), + clean_slots = (t <= 0 and not shuffleslots), + ) + + # Loss weighting + position_loss = position_loss * cfg_net.position_regularizer + time_loss = time_loss * cfg_net.time_regularizer + + # Compute background error + bg_error_cur = th.sqrt(reduce((input - background)**2, 'b c h w -> b 1 h w', 'mean')).detach() + bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach() + + # Compute next-frame prediction error + error_next = th.sqrt(reduce((target - output_next)**2, 'b c h w -> b 1 h w', 'mean')).detach() + error_next = th.sqrt(error_next) * bg_error_next + error_last = error_next + + # Initially focus on foreground learning + if background_blendin_factor < 1: + fg_mask_next = th.gt(bg_error_next, 0.1).float().detach() + fg_mask_next[fg_mask_next == 0] = background_blendin_factor + target = th.clip(target * fg_mask_next, 0, 1) + + fg_mask_cur = th.gt(bg_error_cur, 0.1).float().detach() + fg_mask_cur[fg_mask_cur == 0] = background_blendin_factor + input = th.clip(input * fg_mask_cur, 0, 1) + + # Gradually blend in background for more stable training + if num_updates % 30 == 0 and num_updates >= cfg.phases.entity_pretraining_phase1_end: + background_blendin_factor = min(1, background_blendin_factor + 0.001) + + # Final Loss computation + encoder_loss = th.mean((output_cur - input)**2) * cfg_net.encoder_regularizer + cliped_output_next = th.clip(output_next, 0, 1) + loss = bceloss(cliped_output_next, target) + encoder_loss + position_loss + time_loss + + # Accumulate loss over BPP steps + summed_loss = loss if summed_loss is None else summed_loss + loss + mask = mask.detach() + + if run_optimizers: + + # detach gradients for next step + position = position.detach() + gestalt = gestalt.detach() + rawmask = rawmask.detach() + object = object.detach() + priority = priority.detach() + + # zero grad + optimizer_init.zero_grad() + optimizer_encoder.zero_grad() + optimizer_decoder.zero_grad() + optimizer_predictor.zero_grad() + optimizer_background.zero_grad() + if net.cfg.inner_loop_enabled: + optimizer_update.zero_grad() + + # optimize + summed_loss.backward() + optimizer_init.step() + optimizer_encoder.step() + optimizer_decoder.step() + optimizer_predictor.step() + optimizer_background.step() + if net.cfg.inner_loop_enabled: + optimizer_update.step() + + # Reset loss + num_updates += 1 + summed_loss = None + + # Update net status + update_net_status(num_updates, net, cfg, optimizer_init) + + if num_updates == cfg.phases.start_inner_loop: + print('Start inner loop') + net.cfg.inner_loop_enabled = True + + if (cfg.bptt.increase_bptt_steps_every > 0) and ((num_updates-cfg.num_updates) % cfg.bptt.increase_bptt_steps_every == 0) and ((num_updates-cfg.num_updates) > 0): + increase_bptt_steps = True + + # Plots for online evaluation + if num_updates % 20000 == 0: + plot_online(cfg, path, num_updates, input, background, mask, sequence_len, t, output_next, bg_error_next) + + + # Track statisitcs + if t >= cfg.defaults.statistics_offset: + track_statistics(cfg_net, net, loss_tracker, input, gestalt, mask, target, output_next, output_cur, position_loss, time_loss, slots_closed, bg_error_cur, bg_error_next) + loss_tracker.update_average_loss(loss.item()) + + # Logging + if num_updates % 100 == 0 and run_optimizers: + print(f'Epoch[{num_updates}/{num_time_steps}/{sequence_len}]: {str(timer)}, {epoch + 1}, Blendin:{float(background_blendin_factor)}, i: {net.get_init_status() + net.initial_states.init.get():.2f},' + loss_tracker.get_log(), flush=True) + + # Training finished + if num_updates > cfg.max_updates: + save_model( + os.path.join(path, 'nets', 'net_final.pt'), + net, + optimizer_init, + optimizer_encoder, + optimizer_decoder, + optimizer_predictor, + optimizer_background + ) + print("Training finished") + return + + # Checkpointing + if num_updates % 50000 == 0 and run_optimizers: + save_model( + os.path.join(path, 'nets', f'net_{num_updates}.pt'), + net, + optimizer_init, + optimizer_encoder, + optimizer_decoder, + optimizer_predictor, + optimizer_background + ) + pass + +def track_statistics(cfg_net, net, loss_tracker, input, gestalt, mask, target, output_next, output_cur, position_loss, time_loss, slots_closed, bg_error_cur, bg_error_next): + + # Prediction Loss + mseloss = nn.MSELoss() + loss_next = mseloss(output_next * bg_error_next, target * bg_error_next) + + # Encoder Loss (only for stats) + loss_cur = mseloss(output_cur * bg_error_cur, input * bg_error_cur) + + # area of foreground mask + num_objects = th.mean(reduce((reduce(mask[:,:-1], 'b c h w -> b c', 'max') > 0.5).float(), 'b c -> b', 'sum')).item() + + # difference in shape + _gestalt = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt)), 'b (o c) -> (b o)', 'mean', o = cfg_net.num_objects) + _gestalt2 = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt))**2, 'b (o c) -> (b o)', 'mean', o = cfg_net.num_objects) + max_mask = (reduce(mask[:,:-1], 'b c h w -> (b c)', 'max') > 0.5).float() + avg_gestalt = (th.sum(_gestalt * max_mask) / (1e-16 + th.sum(max_mask))) + avg_gestalt2 = (th.sum(_gestalt2 * max_mask) / (1e-16 + th.sum(max_mask))) + avg_gestalt_mean = th.mean(th.clip(gestalt, 0, 1)) + + # udpdate gates + avg_update_gestalt = slots_closed[:,:,0].mean() + avg_update_position = slots_closed[:,:,1].mean() + + loss_tracker.update_complete(position_loss, time_loss, loss_cur, loss_next, loss_next, num_objects, net.get_openings(), avg_gestalt, avg_gestalt2, avg_gestalt_mean, avg_update_gestalt, avg_update_position) + pass + +def log_modelsize(net): + print(f'Loaded model with {sum([param.numel() for param in net.parameters()]):7d} parameters', flush=True) + print(f' States: {sum([param.numel() for param in net.initial_states.parameters()]):7d} parameters', flush=True) + print(f' Encoder: {sum([param.numel() for param in net.encoder.parameters()]):7d} parameters', flush=True) + print(f' Decoder: {sum([param.numel() for param in net.decoder.parameters()]):7d} parameters', flush=True) + print(f' predictor: {sum([param.numel() for param in net.predictor.parameters()]):7d} parameters', flush=True) + print(f' background: {sum([param.numel() for param in net.background.parameters()]):7d} parameters', flush=True) + print("\n") + pass + +def init_optimizer(cfg, net): + optimizer_init = RAdam(net.initial_states.parameters(), lr = cfg.learning_rate * 30) + optimizer_encoder = RAdam(net.encoder.parameters(), lr = cfg.learning_rate) + optimizer_decoder = RAdam(net.decoder.parameters(), lr = cfg.learning_rate) + optimizer_predictor = RAdam(net.predictor.parameters(), lr = cfg.learning_rate) + optimizer_background = RAdam([net.background.mask], lr = cfg.learning_rate) + optimizer_update = RAdam(net.percept_gate_controller.parameters(), lr = cfg.learning_rate) + return optimizer_init,optimizer_encoder,optimizer_decoder,optimizer_predictor,optimizer_background,optimizer_update + +def save_model( + file, + net, + optimizer_init, + optimizer_encoder, + optimizer_decoder, + optimizer_predictor, + optimizer_background +): + + state = { } + + state['optimizer_init'] = optimizer_init.state_dict() + state['optimizer_encoder'] = optimizer_encoder.state_dict() + state['optimizer_decoder'] = optimizer_decoder.state_dict() + state['optimizer_predictor'] = optimizer_predictor.state_dict() + state['optimizer_background'] = optimizer_background.state_dict() + + state["model"] = net.state_dict() + th.save(state, file) + pass + +def load_model( + file, + cfg, + net, + optimizer_init, + optimizer_encoder, + optimizer_decoder, + optimizer_predictor, + optimizer_background, + load_optimizers = True, + only_encoder_decoder = False +): + device = th.device(cfg.device) + state = th.load(file, map_location=device) + print(f"load {file} to device {device}, only encoder/decoder: {only_encoder_decoder}") + print(f"load optimizers: {load_optimizers}") + + if load_optimizers: + optimizer_init.load_state_dict(state[f'optimizer_init']) + for n in range(len(optimizer_init.param_groups)): + optimizer_init.param_groups[n]['lr'] = cfg.learning_rate + + optimizer_encoder.load_state_dict(state[f'optimizer_encoder']) + for n in range(len(optimizer_encoder.param_groups)): + optimizer_encoder.param_groups[n]['lr'] = cfg.learning_rate + + optimizer_decoder.load_state_dict(state[f'optimizer_decoder']) + for n in range(len(optimizer_decoder.param_groups)): + optimizer_decoder.param_groups[n]['lr'] = cfg.learning_rate + + optimizer_predictor.load_state_dict(state['optimizer_predictor']) + for n in range(len(optimizer_predictor.param_groups)): + optimizer_predictor.param_groups[n]['lr'] = cfg.learning_rate + + optimizer_background.load_state_dict(state['optimizer_background']) + for n in range(len(optimizer_background.param_groups)): + optimizer_background.param_groups[n]['lr'] = cfg.model.background.learning_rate + + # 1. Fill model with values of net + model = {} + allowed_keys = [] + rand_state = net.state_dict() + for key, value in rand_state.items(): + allowed_keys.append(key) + model[key.replace(".module.", ".")] = value + + # 2. Overwrite with values from file + for key, value in state["model"].items(): + # replace update_module with percept_gate_controller in key string: + key = key.replace("update_module", "percept_gate_controller") + + if key in allowed_keys: + if only_encoder_decoder: + if ('encoder' in key) or ('decoder' in key): + model[key.replace(".module.", ".")] = value + else: + model[key.replace(".module.", ".")] = value + + net.load_state_dict(model) + + pass + +def update_net_status(num_updates, net, cfg, optimizer_init): + if num_updates == cfg.phases.background_pretraining_end and net.get_init_status() < 1: + net.inc_init_level() + + if num_updates == cfg.phases.entity_pretraining_phase1_end and net.get_init_status() < 2: + net.inc_init_level() + + if num_updates == cfg.phases.entity_pretraining_phase2_end and net.get_init_status() < 3: + net.inc_init_level() + for param in optimizer_init.param_groups: + param['lr'] = cfg.learning_rate + + pass + +def plot_online(cfg, path, num_updates, input, background, mask, sequence_len, t, output_next, bg_error_next): + plot_path = os.path.join(path, 'plots', f'net_{num_updates}') + if not os.path.exists(plot_path): + os.makedirs(plot_path, exist_ok=True) + + # highlight error + grayscale = input[:,0:1] * 0.299 + input[:,1:2] * 0.587 + input[:,2:3] * 0.114 + object_mask_cur = th.sum(mask[:,:-1], dim=1).unsqueeze(dim=1) + highlited_input = grayscale * (1 - object_mask_cur) + highlited_input += grayscale * object_mask_cur * 0.3333333 + cmask = color_mask(mask[:,:-1]) + highlited_input = highlited_input + cmask * 0.6666666 + + cv2.imwrite(os.path.join(plot_path, f'input-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(input[0], 'c h w -> h w c').detach().cpu().numpy() * 255) + cv2.imwrite(os.path.join(plot_path, f'background-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(background[0], 'c h w -> h w c').detach().cpu().numpy() * 255) + cv2.imwrite(os.path.join(plot_path, f'error_mask-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(bg_error_next[0], 'c h w -> h w c').detach().cpu().numpy() * 255) + cv2.imwrite(os.path.join(plot_path, f'background_mask-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(mask[0,-1:], 'c h w -> h w c').detach().cpu().numpy() * 255) + cv2.imwrite(os.path.join(plot_path, f'output_next-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(output_next[0], 'c h w -> h w c').detach().cpu().numpy() * 255) + cv2.imwrite(os.path.join(plot_path, f'output_highlight-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(highlited_input[0], 'c h w -> h w c').detach().cpu().numpy() * 255) + pass + +def get_loader(cfg, set, cfg_net, shuffle = True): + loader = DataLoader( + set, + pin_memory = True, + num_workers = cfg.defaults.num_workers, + batch_size = cfg_net.batch_size, + shuffle = shuffle, + drop_last = True, + prefetch_factor = cfg.defaults.prefetch_factor, + persistent_workers = True + ) + return loader
\ No newline at end of file |