import os import sys import time import logging from collections import namedtuple import yaml from nets import Model from dataset import BlenderDataset, CREStereoDataset, CTDDataset import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks import LearningRateMonitor from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.strategies import DDPSpawnStrategy seed_everything(42, workers=True) import wandb import numpy as np import cv2 def normalize_and_colormap(img, reduce_dynamic_range=False): # print(img.min()) # print(img.max()) # print(img.mean()) ret = (img - img.min()) / (img.max() - img.min()) * 255.0 # print(ret.min()) # print(ret.max()) # print(ret.mean()) # FIXME do I need to compress dynamic range somehow or something? if reduce_dynamic_range and img.max() > 5*img.mean(): ret = (img - img.min()) / (5*img.mean() - img.min()) * 255.0 # print(ret.min()) # print(ret.max()) # print(ret.mean()) if isinstance(ret, torch.Tensor): ret = ret.cpu().detach().numpy() ret = ret.astype("uint8") ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO) return ret def log_images(left, right, pred_disp, gt_disp): log = {} batch_idx = 0 if isinstance(pred_disp, list): pred_disp = pred_disp[-1] left = torch.squeeze(left[:, 0, :, :]) right = torch.squeeze(right[:, 0, :, :]) pred_disp = torch.squeeze(pred_disp[:, 0, :, :]) gt_disp = torch.squeeze(gt_disp[:, 0, :, :]) # print('gt_disp debug') # print(gt_disp.shape) singular_batch = False if len(left.shape) == 2: singular_batch = True print('batch_size seems to be 1') input_left = left.cpu().detach().numpy() input_right = right.cpu().detach().numpy() else: input_left = left[batch_idx].cpu().detach().numpy()# .transpose(1,2,0) input_right = right[batch_idx].cpu().detach().numpy()# .transpose(1,2,0) disp = pred_disp disp_error = gt_disp - disp # print('gt_disp debug normalize') # print(gt_disp.max(), gt_disp.min()) # print(gt_disp.dtype) if singular_batch: wandb_log = dict( key='samples', images=[ pred_disp, normalize_and_colormap(pred_disp), normalize_and_colormap(abs(disp_error), reduce_dynamic_range=True), normalize_and_colormap(gt_disp, reduce_dynamic_range=True), input_left, input_right, ], caption=[ f"Disparity \n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", f"Disparity (vis) \n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{abs(disp_error).mean():.{2}f}", f"GT Disp Vis \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}", "Input Left", "Input Right" ], ) else: wandb_log = dict( key='samples', images=[ # pred_disp.cpu().detach().numpy().transpose(1,2,0), normalize_and_colormap(pred_disp[batch_idx]), normalize_and_colormap(abs(disp_error[batch_idx])), normalize_and_colormap(gt_disp[batch_idx]), input_left, input_right, ], caption=[ # f"Disparity \n{pred_disp[batch_idx].min():.{2}f}/{pred_disp[batch_idx].max():.{2}f}", f"Disparity (vis)\n{pred_disp[batch_idx].min():.{2}f}/{pred_disp[batch_idx].max():.{2}f}", f"Disp. Error\n{disp_error[batch_idx].min():.{2}f}/{disp_error[batch_idx].max():.{2}f}\n{abs(disp_error[batch_idx]).mean():.{2}f}", f"GT Disp Vis \n{gt_disp[batch_idx].min():.{2}f}/{gt_disp[batch_idx].max():.{2}f}", "Input Left", "Input Right" ], ) return wandb_log def parse_yaml(file_path: str) -> namedtuple: """Parse yaml configuration file and return the object in `namedtuple`.""" with open(file_path, "rb") as f: cfg: dict = yaml.safe_load(f) args = namedtuple("train_args", cfg.keys())(*cfg.values()) return args def format_time(elapse): elapse = int(elapse) hour = elapse // 3600 minute = elapse % 3600 // 60 seconds = elapse % 60 return "{:02d}:{:02d}:{:02d}".format(hour, minute, seconds) def outlier_fraction(estimate, target, mask=None, threshold=0): def _process_inputs(estimate, target, mask): if estimate.shape != target.shape: raise Exception(f'estimate and target have to be same shape (expected {estimate.shape} == {target.shape})') if mask is None: mask = np.ones(estimate.shape, dtype=np.bool) else: mask = mask != 0 if estimate.shape != mask.shape: if len(mask.shape) == 3: mask = mask[0] if estimate.shape != mask.shape: raise Exception(f'estimate and mask have to be same shape (expected {estimate.shape} == {mask.shape})') return estimate, target, mask estimate = torch.squeeze(estimate[:, 0, :, :]) target = torch.squeeze(target[:, 0, :, :]) estimate, target, mask = _process_inputs(estimate, target, mask) mask = mask.cpu().detach().numpy() estimate = estimate.cpu().detach().numpy() target = target.cpu().detach().numpy() diff = np.abs(estimate[mask] - target[mask]) m = (diff > threshold).sum() / mask.sum() return m def ensure_dir(path): if not os.path.exists(path): os.makedirs(path, exist_ok=True) def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False): ''' valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W) flow_preds[0]: (B, 2, H, W) flow_gt: (B, 2, H, W) ''' n_predictions = len(flow_preds) flow_loss = 0.0 for i in range(n_predictions): i_weight = gamma ** (n_predictions - i - 1) i_loss = torch.abs(flow_preds[i] - flow_gt) flow_loss += i_weight * (valid.unsqueeze(1) * i_loss).mean() return flow_loss class CREStereoLightning(LightningModule): def __init__(self, args, logger=None, pattern_path=''): super().__init__() self.batch_size = args.batch_size self.wandb_logger = logger self.imwidth = args.image_width self.imheight = args.image_height self.data_type = 'blender' if 'blender' in args.training_data_path else 'ctd' self.eval_type = 'kinect' if 'kinect' in args.test_data_path else args.training_data_path self.lr = args.base_lr self.T_max = args.t_max if args.t_max else None self.pattern_attention = args.pattern_attention self.pattern_path = pattern_path self.data_path = args.training_data_path self.test_data_path = args.test_data_path self.data_limit = args.data_limit # between 0 and 1. self.model = Model( max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False ) if args.scheduler == 'default': self.automatic_optimization = False # so I can access it in adjust learn rate more easily self.n_total_epoch = args.n_total_epoch self.base_lr = args.base_lr def train_dataloader(self): # we never train on kinect is_kinect = False if self.data_type == 'blender': dataset = BlenderDataset( root=self.data_path, pattern_path=self.pattern_path, use_lightning=True, data_type='kinect' if is_kinect else 'blender', disp_avail=not is_kinect, data_limit = self.data_limit, ) elif self.data_type == 'ctd': dataset = CTDDataset( root=self.data_path, pattern_path=self.pattern_path, use_lightning=True, data_limit = self.data_limit, ) dataloader = DataLoader( dataset, self.batch_size, shuffle=True, num_workers=4, drop_last=True, persistent_workers=True, pin_memory=True, ) return dataloader def val_dataloader(self): # we also don't want to validate on kinect data is_kinect = False if self.data_type == 'blender': test_dataset = BlenderDataset( root=self.data_path, pattern_path=self.pattern_path, test_set=True, use_lightning=True, data_type='kinect' if is_kinect else 'blender', disp_avail=not is_kinect, data_limit = self.data_limit, ) elif self.data_type == 'ctd': test_dataset = CTDDataset( root=self.data_path, pattern_path=self.pattern_path, test_set=True, use_lightning=True, data_limit = self.data_limit, ) test_dataloader = DataLoader( test_dataset, self.batch_size, shuffle=False, num_workers=4, drop_last=False, persistent_workers=True, pin_memory=True ) return test_dataloader def test_dataloader(self): is_kinect = self.eval_type == 'kinect' if self.data_type == 'blender': test_dataset = BlenderDataset( root=self.test_data_path, pattern_path=self.pattern_path, test_set=True, split=0. if is_kinect else 0.9, # if we test on kinect data, use all available samples for test set use_lightning=True, augment=False, disp_avail=not is_kinect, data_type='kinect' if is_kinect else 'blender', data_limit = self.data_limit, ) elif self.data_type == 'ctd': test_dataset = CTDDataset( root=self.test_data_path, pattern_path=self.pattern_path, test_set=True, use_lightning=True, augment=False, data_limit = self.data_limit, ) test_dataloader = DataLoader( test_dataset, 1 if is_kinect else self.batch_size, shuffle=False, num_workers=4, drop_last=False, persistent_workers=True, pin_memory=True ) return test_dataloader def forward( self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False, ): return self.model(image1, image2, flow_init, iters, upsample, test_mode, self.pattern_attention) def training_step(self, batch, batch_idx): left, right, gt_disp, valid_mask = batch flow_predictions = self.forward(left, right) gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] loss = sequence_loss( flow_predictions, gt_flow, valid_mask, gamma=0.8 ) if batch_idx % 128 == 0: image_log = log_images(left, right, flow_predictions, gt_disp) image_log['key'] = 'debug_train' self.wandb_logger.log_image(**image_log) self.log("train_loss", loss) # update learn rate every N epochs if self.trainer.is_last_batch: self.adjust_learning_rate() return loss def validation_step(self, batch, batch_idx): left, right, gt_disp, valid_mask = batch flow_predictions = self.forward(left, right, test_mode=True) gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] val_loss = sequence_loss( flow_predictions, gt_flow, valid_mask, gamma=0.8 ) self.log("val_loss", val_loss) of = {} for threshold in [0.1, 0.5, 1, 2, 5]: of[threshold] = outlier_fraction(flow_predictions[0], gt_flow, valid_mask, threshold) self.log("outlier_fraction", of) # print(', '.join(f'of{thr}={val}' for thr, val in of.items())) if batch_idx % 8 == 0: images = log_images(left, right, flow_predictions, gt_disp) self.wandb_logger.log_image(**images) def test_step(self, batch, batch_idx): left, right, gt_disp, valid_mask = batch gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] flow_predictions = self.forward(left, right, test_mode=True) test_loss = sequence_loss( flow_predictions, gt_flow, valid_mask, gamma=0.8 ) self.log("test_loss", test_loss) of = {} for threshold in [0.1, 0.5, 1, 2, 5]: of[str(threshold)] = outlier_fraction(flow_predictions[0], gt_flow, valid_mask, threshold) self.log("outlier_fraction", of) images = log_images(left, right, flow_predictions, gt_disp) images['images'].append(gt_disp) images['caption'].append('GT Disp') self.wandb_logger.log_image(**images) def predict_step(self, batch, batch_idx, dataloader_idx=0): return self(batch) def configure_optimizers(self): optimizer = optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.999)) if not self.automatic_optimization: return optimizer lr_scheduler = { 'scheduler': torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=self.T_max if self.T_max else len(self.train_dataloader())/self.batch_size, ), 'name': 'LR Scheduler', } return [optimizer], [lr_scheduler] def adjust_learning_rate(self): optimizer = self.optimizers().optimizer epoch = self.trainer.current_epoch+1 warm_up = 0.02 const_range = 0.6 min_lr_rate = 0.05 if epoch <= self.n_total_epoch * warm_up: lr = (1 - min_lr_rate) * self.base_lr / ( self.n_total_epoch * warm_up ) * epoch + min_lr_rate * self.base_lr elif self.n_total_epoch * warm_up < epoch <= self.n_total_epoch * const_range: lr = self.base_lr else: lr = (min_lr_rate - 1) * self.base_lr / ( (1 - const_range) * self.n_total_epoch ) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * self.base_lr for param_group in optimizer.param_groups: param_group['lr'] = lr self.log('train/lr', lr) if __name__ == "__main__": # wandb.init(project='crestereo-lightning') wandb_logger = WandbLogger(project="crestereo-lightning", log_model=True) # train configuration args = parse_yaml("cfgs/train.yaml") wandb_logger.experiment.config.update(args._asdict()) config = wandb.config data_limit = config.data_limit if 'blender' in config.training_data_path: # this was used for our blender renders pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png' elif 'ctd' in config.training_data_path: # this one is used (i hope) for ctd pattern_path = '/home/nils/kinect_from_settings.png' devices = min(config.nr_gpus, torch.cuda.device_count()) if devices != config.nr_gpus: print(f'Using less devices than expected! ({devices} / {config.nr_gpus})') model = CREStereoLightning( # args, config, wandb_logger, pattern_path, # lr=0.00017378008287493763, # found with auto_lr_find=True ) # NOTE turn this down once it's working, this might use too much space # wandb_logger.watch(model, log_graph=False) #, log='all') model_checkpoint = ModelCheckpoint( monitor="val_loss", mode="min", save_top_k=2, save_last=True, ) if config.scheduler == 'default': trainer = Trainer( accelerator='gpu', devices=devices, max_epochs=config.n_total_epoch, callbacks=[ EarlyStopping( monitor="val_loss", mode="min", patience=8, ), LearningRateMonitor(), model_checkpoint, ], strategy=DDPSpawnStrategy(find_unused_parameters=False), # auto_scale_batch_size='binsearch', # auto_lr_find=True, # accumulate_grad_batches=4, # needed to disable for manual optimization deterministic=True, check_val_every_n_epoch=1, limit_val_batches=64, limit_test_batches=256, logger=wandb_logger, default_root_dir=config.log_dir_lightning, ) else: trainer = Trainer( accelerator='gpu', devices=devices, max_epochs=config.n_total_epoch, callbacks=[ EarlyStopping( monitor="val_loss", mode="min", patience=8, ), LearningRateMonitor(), model_checkpoint, ], strategy=DDPSpawnStrategy(find_unused_parameters=False), # auto_scale_batch_size='binsearch', # auto_lr_find=True, accumulate_grad_batches=4, # needed to disable for manual optimization deterministic=True, check_val_every_n_epoch=1, limit_val_batches=64, limit_test_batches=256, logger=wandb_logger, default_root_dir=config.log_dir_lightning, ) # trainer.tune(model) trainer.fit(model) # trainer.validate(chkpt_path=model_checkpoint.best_model_path) trainer.test(model)