aboutsummaryrefslogtreecommitdiff
path: root/scripts/training.py
diff options
context:
space:
mode:
authorfredeee2023-11-02 10:47:21 +0100
committerfredeee2023-11-02 10:47:21 +0100
commitf8302ee886ef9b631f11a52900dac964a61350e1 (patch)
tree87288be6f851ab69405e524b81940c501c52789a /scripts/training.py
parentf16fef1ab9371e1c81a2e0b2fbea59dee285a9f8 (diff)
initiaƶ commit
Diffstat (limited to 'scripts/training.py')
-rw-r--r--scripts/training.py523
1 files changed, 523 insertions, 0 deletions
diff --git a/scripts/training.py b/scripts/training.py
new file mode 100644
index 0000000..b807e54
--- /dev/null
+++ b/scripts/training.py
@@ -0,0 +1,523 @@
+import torch as th
+from torch import nn
+from torch.utils.data import Dataset, DataLoader, Subset
+import cv2
+import numpy as np
+import os
+from einops import rearrange, repeat, reduce
+from scripts.utils.configuration import Configuration
+from scripts.utils.io import init_device, model_path, LossLogger
+from scripts.utils.optimizers import RAdam
+from model.loci import Loci
+import random
+from scripts.utils.io import Timer
+from scripts.utils.plot_utils import color_mask
+from scripts.validation import validation_clevrer, validation_adept
+
+
+def train_loci(cfg: Configuration, trainset: Dataset, valset: Dataset, file):
+
+ # Set up cpu or gpu training
+ device, verbose = init_device(cfg)
+ if verbose:
+ valset = Subset(valset, range(0, 8))
+
+ # Define model path
+ path = model_path(cfg, overwrite=False)
+ cfg.save(path)
+ os.makedirs(os.path.join(path, 'nets'), exist_ok=True)
+
+ # Configure model
+ cfg_net = cfg.model
+ net = Loci(
+ cfg = cfg_net,
+ teacher_forcing = cfg.defaults.teacher_forcing
+ )
+ net = net.to(device=device)
+
+ # Log model size
+ log_modelsize(net)
+
+ # Init Optimizers
+ optimizer_init, optimizer_encoder, optimizer_decoder, optimizer_predictor, optimizer_background, optimizer_update = init_optimizer(cfg, net)
+
+ # Option to load model
+ if file != "":
+ load_model(
+ file,
+ cfg,
+ net,
+ optimizer_init,
+ optimizer_encoder,
+ optimizer_decoder,
+ optimizer_predictor,
+ optimizer_background,
+ cfg.defaults.load_optimizers,
+ only_encoder_decoder = (cfg.num_updates == 0) # only load encoder and decoder for initial training
+ )
+ print(f'loaded {file}', flush=True)
+
+ # Set up data loaders
+ trainloader = get_loader(cfg, trainset, cfg_net, shuffle=True)
+ valset.train = True #valset.dataset.train = True
+ valloader = get_loader(cfg, valset, cfg_net, shuffle=False)
+
+ # initial save
+ save_model(
+ os.path.join(path, 'nets', 'net0.pt'),
+ net,
+ optimizer_init,
+ optimizer_encoder,
+ optimizer_decoder,
+ optimizer_predictor,
+ optimizer_background
+ )
+
+ # Start training at num_updates
+ num_updates = cfg.num_updates
+ if num_updates > 0:
+ print('!!! Start training at num_updates: ', num_updates)
+ print('!!! Net init status: ', net.get_init_status())
+
+ # Set up statistics
+ loss_tracker = LossLogger()
+
+ # Set up training variables
+ num_time_steps = 0
+ bptt_steps = cfg.bptt.bptt_steps
+ increase_bptt_steps = False
+ background_blendin_factor = 0.0
+ th.backends.cudnn.benchmark = True
+ timer = Timer()
+ bceloss = nn.BCELoss()
+
+ # Init net to current num_updates
+ if num_updates >= cfg.phases.background_pretraining_end and net.get_init_status() < 1:
+ net.inc_init_level()
+
+ if num_updates >= cfg.phases.entity_pretraining_phase1_end and net.get_init_status() < 2:
+ net.inc_init_level()
+
+ if num_updates >= cfg.phases.entity_pretraining_phase2_end and net.get_init_status() < 3:
+ net.inc_init_level()
+ for param in optimizer_init.param_groups:
+ param['lr'] = cfg.learning_rate
+
+ if num_updates > cfg.phases.start_inner_loop:
+ net.cfg.inner_loop_enabled = True
+
+ if num_updates >= cfg.phases.entity_pretraining_phase1_end:
+ background_blendin_factor = max(min((num_updates - cfg.phases.entity_pretraining_phase1_end)/30000, 1.0), 0.0)
+
+
+ # --- Start Training
+ print('Start training')
+ for epoch in range(cfg.max_epochs):
+
+ # Validation every epoch
+ if epoch >= 0:
+ if cfg.datatype == 'adept':
+ validation_adept(valloader, net, cfg, device)
+ elif cfg.datatype == 'clevrer':
+ validation_clevrer(valloader, net, cfg, device)
+
+ # Start epoch training
+ print('Start epoch:', epoch)
+
+ # Backprop through time steps
+ if increase_bptt_steps:
+ bptt_steps = max(bptt_steps + 1, cfg.bptt.bptt_steps_max)
+ print('Increase closed loop steps to', bptt_steps)
+ increase_bptt_steps = False
+
+ for batch_index, input in enumerate(trainloader):
+
+ # Extract input and background
+ tensor = input[0]
+ background = input[1].to(device)
+ shuffleslots = (num_updates <= cfg.phases.shufleslots_end)
+
+ # Placeholders
+ position = None
+ gestalt = None
+ priority = None
+ mask = None
+ object = None
+ rawmask = None
+ loss = th.tensor(0)
+ summed_loss = None
+ slots_occlusionfactor = None
+
+ # Apply skip frames to sequence
+ selec = range(random.randrange(cfg.defaults.skip_frames), tensor.shape[1], cfg.defaults.skip_frames)
+ tensor = tensor[:,selec]
+ sequence_len = tensor.shape[1]
+
+ # Initial frame
+ input = tensor[:,0].to(device)
+ input_next = input
+ target = th.clip(input, 0, 1).detach()
+ error_last = None
+
+ # First apply teacher forcing for the first x frames
+ for t in range(-cfg.defaults.teacher_forcing, sequence_len-1):
+
+ # Set update scheme for backprop through time
+ if t >= cfg.bptt.bptt_start_timestep:
+ t_run = (t - cfg.bptt.bptt_start_timestep)
+ run_optimizers = t_run % bptt_steps == bptt_steps - 1
+ detach = (t_run % bptt_steps == 0) or t == -cfg.defaults.teacher_forcing
+ else:
+ run_optimizers = True
+ detach = True
+
+ if verbose:
+ print(f't: {t}, run_optimizers: {run_optimizers}, detach: {detach}')
+
+ if t >= 0:
+ # Skip to next frame
+ num_time_steps += 1
+ input = input_next
+ input_next = tensor[:,t+1].to(device)
+ target = th.clip(input_next, 0, 1)
+
+ # Apply error dropout
+ if net.get_init_status() > 2 and cfg.defaults.error_dropout > 0 and np.random.rand() < cfg.defaults.error_dropout:
+ error_last = th.zeros_like(error_last)
+
+ # Apply sensation blackout when training clevrere
+ if net.cfg.inner_loop_enabled and cfg.datatype == 'clevrer':
+ if t >= 10:
+ blackout = th.tensor((np.random.rand(cfg_net.batch_size) < 0.2)[:,None,None,None]).float().to(device)
+ input = blackout * (input * 0) + (1-blackout) * input
+ error_last = blackout * (error_last * 0) + (1-blackout) * error_last
+
+ # Forward Pass
+ (
+ output_next,
+ output_cur,
+ position,
+ gestalt,
+ priority,
+ mask,
+ rawmask,
+ object,
+ background,
+ slots_occlusionfactor,
+ position_loss,
+ time_loss,
+ slots_closed
+ ) = net(
+ input, # current frame
+ error_last, # error of last frame --> missing object
+ mask, # masks of current frame
+ rawmask, # raw masks of current frame
+ position, # positions of objects of next frame
+ gestalt, # gestalt of objects of next frame
+ priority, # priority of objects of next frame
+ background,
+ slots_occlusionfactor,
+ reset = (t == -cfg.defaults.teacher_forcing), # new sequence
+ warmup = (t < 0), # teacher forcing
+ detach = detach,
+ shuffleslots = shuffleslots or ((cfg.datatype == 'clevrer') and (t<=0)),
+ reset_mask = (t <= 0),
+ clean_slots = (t <= 0 and not shuffleslots),
+ )
+
+ # Loss weighting
+ position_loss = position_loss * cfg_net.position_regularizer
+ time_loss = time_loss * cfg_net.time_regularizer
+
+ # Compute background error
+ bg_error_cur = th.sqrt(reduce((input - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+ bg_error_next = th.sqrt(reduce((target - background)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+
+ # Compute next-frame prediction error
+ error_next = th.sqrt(reduce((target - output_next)**2, 'b c h w -> b 1 h w', 'mean')).detach()
+ error_next = th.sqrt(error_next) * bg_error_next
+ error_last = error_next
+
+ # Initially focus on foreground learning
+ if background_blendin_factor < 1:
+ fg_mask_next = th.gt(bg_error_next, 0.1).float().detach()
+ fg_mask_next[fg_mask_next == 0] = background_blendin_factor
+ target = th.clip(target * fg_mask_next, 0, 1)
+
+ fg_mask_cur = th.gt(bg_error_cur, 0.1).float().detach()
+ fg_mask_cur[fg_mask_cur == 0] = background_blendin_factor
+ input = th.clip(input * fg_mask_cur, 0, 1)
+
+ # Gradually blend in background for more stable training
+ if num_updates % 30 == 0 and num_updates >= cfg.phases.entity_pretraining_phase1_end:
+ background_blendin_factor = min(1, background_blendin_factor + 0.001)
+
+ # Final Loss computation
+ encoder_loss = th.mean((output_cur - input)**2) * cfg_net.encoder_regularizer
+ cliped_output_next = th.clip(output_next, 0, 1)
+ loss = bceloss(cliped_output_next, target) + encoder_loss + position_loss + time_loss
+
+ # Accumulate loss over BPP steps
+ summed_loss = loss if summed_loss is None else summed_loss + loss
+ mask = mask.detach()
+
+ if run_optimizers:
+
+ # detach gradients for next step
+ position = position.detach()
+ gestalt = gestalt.detach()
+ rawmask = rawmask.detach()
+ object = object.detach()
+ priority = priority.detach()
+
+ # zero grad
+ optimizer_init.zero_grad()
+ optimizer_encoder.zero_grad()
+ optimizer_decoder.zero_grad()
+ optimizer_predictor.zero_grad()
+ optimizer_background.zero_grad()
+ if net.cfg.inner_loop_enabled:
+ optimizer_update.zero_grad()
+
+ # optimize
+ summed_loss.backward()
+ optimizer_init.step()
+ optimizer_encoder.step()
+ optimizer_decoder.step()
+ optimizer_predictor.step()
+ optimizer_background.step()
+ if net.cfg.inner_loop_enabled:
+ optimizer_update.step()
+
+ # Reset loss
+ num_updates += 1
+ summed_loss = None
+
+ # Update net status
+ update_net_status(num_updates, net, cfg, optimizer_init)
+
+ if num_updates == cfg.phases.start_inner_loop:
+ print('Start inner loop')
+ net.cfg.inner_loop_enabled = True
+
+ if (cfg.bptt.increase_bptt_steps_every > 0) and ((num_updates-cfg.num_updates) % cfg.bptt.increase_bptt_steps_every == 0) and ((num_updates-cfg.num_updates) > 0):
+ increase_bptt_steps = True
+
+ # Plots for online evaluation
+ if num_updates % 20000 == 0:
+ plot_online(cfg, path, num_updates, input, background, mask, sequence_len, t, output_next, bg_error_next)
+
+
+ # Track statisitcs
+ if t >= cfg.defaults.statistics_offset:
+ track_statistics(cfg_net, net, loss_tracker, input, gestalt, mask, target, output_next, output_cur, position_loss, time_loss, slots_closed, bg_error_cur, bg_error_next)
+ loss_tracker.update_average_loss(loss.item())
+
+ # Logging
+ if num_updates % 100 == 0 and run_optimizers:
+ print(f'Epoch[{num_updates}/{num_time_steps}/{sequence_len}]: {str(timer)}, {epoch + 1}, Blendin:{float(background_blendin_factor)}, i: {net.get_init_status() + net.initial_states.init.get():.2f},' + loss_tracker.get_log(), flush=True)
+
+ # Training finished
+ if num_updates > cfg.max_updates:
+ save_model(
+ os.path.join(path, 'nets', 'net_final.pt'),
+ net,
+ optimizer_init,
+ optimizer_encoder,
+ optimizer_decoder,
+ optimizer_predictor,
+ optimizer_background
+ )
+ print("Training finished")
+ return
+
+ # Checkpointing
+ if num_updates % 50000 == 0 and run_optimizers:
+ save_model(
+ os.path.join(path, 'nets', f'net_{num_updates}.pt'),
+ net,
+ optimizer_init,
+ optimizer_encoder,
+ optimizer_decoder,
+ optimizer_predictor,
+ optimizer_background
+ )
+ pass
+
+def track_statistics(cfg_net, net, loss_tracker, input, gestalt, mask, target, output_next, output_cur, position_loss, time_loss, slots_closed, bg_error_cur, bg_error_next):
+
+ # Prediction Loss
+ mseloss = nn.MSELoss()
+ loss_next = mseloss(output_next * bg_error_next, target * bg_error_next)
+
+ # Encoder Loss (only for stats)
+ loss_cur = mseloss(output_cur * bg_error_cur, input * bg_error_cur)
+
+ # area of foreground mask
+ num_objects = th.mean(reduce((reduce(mask[:,:-1], 'b c h w -> b c', 'max') > 0.5).float(), 'b c -> b', 'sum')).item()
+
+ # difference in shape
+ _gestalt = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt)), 'b (o c) -> (b o)', 'mean', o = cfg_net.num_objects)
+ _gestalt2 = reduce(th.min(th.abs(gestalt), th.abs(1 - gestalt))**2, 'b (o c) -> (b o)', 'mean', o = cfg_net.num_objects)
+ max_mask = (reduce(mask[:,:-1], 'b c h w -> (b c)', 'max') > 0.5).float()
+ avg_gestalt = (th.sum(_gestalt * max_mask) / (1e-16 + th.sum(max_mask)))
+ avg_gestalt2 = (th.sum(_gestalt2 * max_mask) / (1e-16 + th.sum(max_mask)))
+ avg_gestalt_mean = th.mean(th.clip(gestalt, 0, 1))
+
+ # udpdate gates
+ avg_update_gestalt = slots_closed[:,:,0].mean()
+ avg_update_position = slots_closed[:,:,1].mean()
+
+ loss_tracker.update_complete(position_loss, time_loss, loss_cur, loss_next, loss_next, num_objects, net.get_openings(), avg_gestalt, avg_gestalt2, avg_gestalt_mean, avg_update_gestalt, avg_update_position)
+ pass
+
+def log_modelsize(net):
+ print(f'Loaded model with {sum([param.numel() for param in net.parameters()]):7d} parameters', flush=True)
+ print(f' States: {sum([param.numel() for param in net.initial_states.parameters()]):7d} parameters', flush=True)
+ print(f' Encoder: {sum([param.numel() for param in net.encoder.parameters()]):7d} parameters', flush=True)
+ print(f' Decoder: {sum([param.numel() for param in net.decoder.parameters()]):7d} parameters', flush=True)
+ print(f' predictor: {sum([param.numel() for param in net.predictor.parameters()]):7d} parameters', flush=True)
+ print(f' background: {sum([param.numel() for param in net.background.parameters()]):7d} parameters', flush=True)
+ print("\n")
+ pass
+
+def init_optimizer(cfg, net):
+ optimizer_init = RAdam(net.initial_states.parameters(), lr = cfg.learning_rate * 30)
+ optimizer_encoder = RAdam(net.encoder.parameters(), lr = cfg.learning_rate)
+ optimizer_decoder = RAdam(net.decoder.parameters(), lr = cfg.learning_rate)
+ optimizer_predictor = RAdam(net.predictor.parameters(), lr = cfg.learning_rate)
+ optimizer_background = RAdam([net.background.mask], lr = cfg.learning_rate)
+ optimizer_update = RAdam(net.percept_gate_controller.parameters(), lr = cfg.learning_rate)
+ return optimizer_init,optimizer_encoder,optimizer_decoder,optimizer_predictor,optimizer_background,optimizer_update
+
+def save_model(
+ file,
+ net,
+ optimizer_init,
+ optimizer_encoder,
+ optimizer_decoder,
+ optimizer_predictor,
+ optimizer_background
+):
+
+ state = { }
+
+ state['optimizer_init'] = optimizer_init.state_dict()
+ state['optimizer_encoder'] = optimizer_encoder.state_dict()
+ state['optimizer_decoder'] = optimizer_decoder.state_dict()
+ state['optimizer_predictor'] = optimizer_predictor.state_dict()
+ state['optimizer_background'] = optimizer_background.state_dict()
+
+ state["model"] = net.state_dict()
+ th.save(state, file)
+ pass
+
+def load_model(
+ file,
+ cfg,
+ net,
+ optimizer_init,
+ optimizer_encoder,
+ optimizer_decoder,
+ optimizer_predictor,
+ optimizer_background,
+ load_optimizers = True,
+ only_encoder_decoder = False
+):
+ device = th.device(cfg.device)
+ state = th.load(file, map_location=device)
+ print(f"load {file} to device {device}, only encoder/decoder: {only_encoder_decoder}")
+ print(f"load optimizers: {load_optimizers}")
+
+ if load_optimizers:
+ optimizer_init.load_state_dict(state[f'optimizer_init'])
+ for n in range(len(optimizer_init.param_groups)):
+ optimizer_init.param_groups[n]['lr'] = cfg.learning_rate
+
+ optimizer_encoder.load_state_dict(state[f'optimizer_encoder'])
+ for n in range(len(optimizer_encoder.param_groups)):
+ optimizer_encoder.param_groups[n]['lr'] = cfg.learning_rate
+
+ optimizer_decoder.load_state_dict(state[f'optimizer_decoder'])
+ for n in range(len(optimizer_decoder.param_groups)):
+ optimizer_decoder.param_groups[n]['lr'] = cfg.learning_rate
+
+ optimizer_predictor.load_state_dict(state['optimizer_predictor'])
+ for n in range(len(optimizer_predictor.param_groups)):
+ optimizer_predictor.param_groups[n]['lr'] = cfg.learning_rate
+
+ optimizer_background.load_state_dict(state['optimizer_background'])
+ for n in range(len(optimizer_background.param_groups)):
+ optimizer_background.param_groups[n]['lr'] = cfg.model.background.learning_rate
+
+ # 1. Fill model with values of net
+ model = {}
+ allowed_keys = []
+ rand_state = net.state_dict()
+ for key, value in rand_state.items():
+ allowed_keys.append(key)
+ model[key.replace(".module.", ".")] = value
+
+ # 2. Overwrite with values from file
+ for key, value in state["model"].items():
+ # replace update_module with percept_gate_controller in key string:
+ key = key.replace("update_module", "percept_gate_controller")
+
+ if key in allowed_keys:
+ if only_encoder_decoder:
+ if ('encoder' in key) or ('decoder' in key):
+ model[key.replace(".module.", ".")] = value
+ else:
+ model[key.replace(".module.", ".")] = value
+
+ net.load_state_dict(model)
+
+ pass
+
+def update_net_status(num_updates, net, cfg, optimizer_init):
+ if num_updates == cfg.phases.background_pretraining_end and net.get_init_status() < 1:
+ net.inc_init_level()
+
+ if num_updates == cfg.phases.entity_pretraining_phase1_end and net.get_init_status() < 2:
+ net.inc_init_level()
+
+ if num_updates == cfg.phases.entity_pretraining_phase2_end and net.get_init_status() < 3:
+ net.inc_init_level()
+ for param in optimizer_init.param_groups:
+ param['lr'] = cfg.learning_rate
+
+ pass
+
+def plot_online(cfg, path, num_updates, input, background, mask, sequence_len, t, output_next, bg_error_next):
+ plot_path = os.path.join(path, 'plots', f'net_{num_updates}')
+ if not os.path.exists(plot_path):
+ os.makedirs(plot_path, exist_ok=True)
+
+ # highlight error
+ grayscale = input[:,0:1] * 0.299 + input[:,1:2] * 0.587 + input[:,2:3] * 0.114
+ object_mask_cur = th.sum(mask[:,:-1], dim=1).unsqueeze(dim=1)
+ highlited_input = grayscale * (1 - object_mask_cur)
+ highlited_input += grayscale * object_mask_cur * 0.3333333
+ cmask = color_mask(mask[:,:-1])
+ highlited_input = highlited_input + cmask * 0.6666666
+
+ cv2.imwrite(os.path.join(plot_path, f'input-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(input[0], 'c h w -> h w c').detach().cpu().numpy() * 255)
+ cv2.imwrite(os.path.join(plot_path, f'background-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(background[0], 'c h w -> h w c').detach().cpu().numpy() * 255)
+ cv2.imwrite(os.path.join(plot_path, f'error_mask-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(bg_error_next[0], 'c h w -> h w c').detach().cpu().numpy() * 255)
+ cv2.imwrite(os.path.join(plot_path, f'background_mask-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(mask[0,-1:], 'c h w -> h w c').detach().cpu().numpy() * 255)
+ cv2.imwrite(os.path.join(plot_path, f'output_next-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(output_next[0], 'c h w -> h w c').detach().cpu().numpy() * 255)
+ cv2.imwrite(os.path.join(plot_path, f'output_highlight-{num_updates // sequence_len:05d}-{t+cfg.defaults.teacher_forcing:03d}.jpg'), rearrange(highlited_input[0], 'c h w -> h w c').detach().cpu().numpy() * 255)
+ pass
+
+def get_loader(cfg, set, cfg_net, shuffle = True):
+ loader = DataLoader(
+ set,
+ pin_memory = True,
+ num_workers = cfg.defaults.num_workers,
+ batch_size = cfg_net.batch_size,
+ shuffle = shuffle,
+ drop_last = True,
+ prefetch_factor = cfg.defaults.prefetch_factor,
+ persistent_workers = True
+ )
+ return loader \ No newline at end of file