import os import sys import time import logging from collections import namedtuple import yaml # from tensorboardX import SummaryWriter from nets import Model # from dataset import CREStereoDataset from dataset import BlenderDataset, CREStereoDataset, CTDDataset import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader from pytorch_lightning.lite import LightningLite from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.loggers import WandbLogger seed_everything(42, workers=True) import wandb import numpy as np import cv2 def normalize_and_colormap(img): ret = (img - img.min()) / (img.max() - img.min()) * 255.0 if isinstance(ret, torch.Tensor): ret = ret.cpu().detach().numpy() ret = ret.astype("uint8") ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO) return ret def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=True, last_img=None, test=False, attend_pattern=True): print("Model Forwarding...") if isinstance(left, torch.Tensor): left = left# .cpu().detach().numpy() imgR = right# .cpu().detach().numpy() imgL = left imgR = right imgL = np.ascontiguousarray(imgL[None, :, :, :]) imgR = np.ascontiguousarray(imgR[None, :, :, :]) flow_init = None # chosen for convenience imgL = torch.tensor(imgL.astype("float32")) imgR = torch.tensor(imgR.astype("float32")) imgL = imgL.transpose(2,3).transpose(1,2) if imgL.shape != imgR.shape: imgR = imgR.transpose(2,3).transpose(1,2) imgL_dw2 = F.interpolate( imgL, size=(imgL.shape[2] // 2, imgL.shape[3] // 2), mode="bilinear", align_corners=True, ).clamp(min=0, max=255) imgR_dw2 = F.interpolate( imgR, size=(imgL.shape[2] // 2, imgL.shape[3] // 2), mode="bilinear", align_corners=True, ).clamp(min=0, max=255) if last_img is not None: print('using flow_initialization') print(last_img.shape) # FIXME trying 8bit normalization because sometimes we get really high values that don't seem to help print(last_img.max(), last_img.min()) if last_img.min() < 0: # print('Negative disparity detected. shifting...') last_img = last_img - last_img.min() if last_img.max() > 255: # print('Excessive disparity detected. scaling...') last_img = last_img / (last_img.max() / 255) last_img = np.dstack([last_img, last_img]) # last_img = np.dstack([last_img, last_img, last_img]) last_img = np.dstack([last_img]) last_img = last_img.reshape((1, 2, 480, 640)) # print(last_img.shape) # print(last_img.dtype) # print(last_img.max(), last_img.min()) flow_init = torch.tensor(last_img.astype("float32")) # flow_init = F.interpolate( # last_img, # size=(last_img.shape[0] // 2, last_img.shape[1] // 2), # mode="bilinear", # align_corners=True, # ) with torch.inference_mode(): pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=flow_init, self_attend_right=attend_pattern) pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2, self_attend_right=attend_pattern) pf_base = pred_flow if isinstance(pf_base, list): pf_base = pred_flow[0] pf = torch.squeeze(pf_base[:, 0, :, :])# .cpu().detach().numpy() print('pred_flow max min') print(pf.max(), pf.min()) if not wandb_log: if test: return pred_flow return torch.squeeze(pred_flow[:, 0, :, :])# .cpu().detach().numpy() log = {} in_h, in_w = left.shape[:2] # Resize image in case the GPU memory overflows eval_h, eval_w = (in_h,in_w) for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)): pred_disp = torch.squeeze(pf[:, 0, :, :])# .cpu().detach().numpy() pred_disp_dw2 = torch.squeeze(pf_dw2[:, 0, :, :])# .cpu().detach().numpy() # pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) # pred_disp_dw2_norm = cv2.normalize(pred_disp_dw2, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) if i == n_iter-1: t = float(in_w) / float(eval_w) disp = cv2.resize(pred_disp.cpu().detach().numpy(), (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t log[f'disp_vis'] = wandb.Image( normalize_and_colormap(disp), caption=f"Disparity \n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", ) log[f'pred_{i}'] = wandb.Image( np.array([pred_disp.cpu().detach().numpy().reshape(480, 640)]), caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", ) # log[f'pred_norm_{i}'] = wandb.Image( # np.array([pred_disp_norm.reshape(480, 640)]), # caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", # ) # log[f'pred_dw2_{i}'] = wandb.Image( # np.array([pred_disp_dw2.reshape(240, 320)]), # caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", # ) # log[f'pred_dw2_norm_{i}'] = wandb.Image( # np.array([pred_disp_dw2_norm.reshape(240, 320)]), # caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", # ) log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left") input_right = right# .cpu().detach().numpy() if isinstance(right, torch.Tensor) else right if input_right.shape != (480, 640, 3): input_right.transpose(1,2,0) log['input_right'] = wandb.Image(input_right.astype('uint8'), caption="Input Right") log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}") gt_disp = gt_disp# .cpu().detach().numpy() if isinstance(gt_disp, torch.Tensor) else gt_disp disp = disp# .cpu().detach().numpy() if isinstance(disp, torch.Tensor) else disp disp_error = gt_disp - disp log['disp_error'] = wandb.Image( normalize_and_colormap(abs(disp_error)), caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{abs(disp_error).mean():.{2}f}", ) log[f'gt_disp_vis'] = wandb.Image( normalize_and_colormap(gt_disp), caption=f"GT Disp Vis \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}", ) wandb.log(log) return pred_flow def parse_yaml(file_path: str) -> namedtuple: """Parse yaml configuration file and return the object in `namedtuple`.""" with open(file_path, "rb") as f: cfg: dict = yaml.safe_load(f) args = namedtuple("train_args", cfg.keys())(*cfg.values()) return args def format_time(elapse): elapse = int(elapse) hour = elapse // 3600 minute = elapse % 3600 // 60 seconds = elapse % 60 return "{:02d}:{:02d}:{:02d}".format(hour, minute, seconds) def ensure_dir(path): if not os.path.exists(path): os.makedirs(path, exist_ok=True) def adjust_learning_rate(optimizer, epoch): warm_up = 0.02 const_range = 0.6 min_lr_rate = 0.05 if epoch <= args.n_total_epoch * warm_up: lr = (1 - min_lr_rate) * args.base_lr / ( args.n_total_epoch * warm_up ) * epoch + min_lr_rate * args.base_lr elif args.n_total_epoch * warm_up < epoch <= args.n_total_epoch * const_range: lr = args.base_lr else: lr = (min_lr_rate - 1) * args.base_lr / ( (1 - const_range) * args.n_total_epoch ) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * args.base_lr for param_group in optimizer.param_groups: param_group['lr'] = lr def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False): ''' valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W) flow_preds[0]: (B, 2, H, W) flow_gt: (B, 2, H, W) ''' if test: # print('sequence loss') if valid.shape != (2, 480, 640): valid = torch.stack([valid, valid])#.transpose(0,1)#.transpose(1,2) # print(valid.shape) #valid = torch.stack([valid, valid]) # print(valid.shape) if valid.shape != (2, 480, 640): valid = valid.transpose(0,1) # print(valid.shape) # print(valid.shape) # print(flow_preds[0].shape) # print(flow_gt.shape) n_predictions = len(flow_preds) flow_loss = 0.0 # TEST flow_gt = torch.squeeze(flow_gt, dim=-1) for i in range(n_predictions): i_weight = gamma ** (n_predictions - i - 1) i_loss = torch.abs(flow_preds[i] - flow_gt) flow_loss += i_weight * (valid.unsqueeze(1) * i_loss).mean() return flow_loss class CREStereoLightning(LightningModule): def __init__(self, args): super().__init__() self.batch_size = args.batch_size self.model = Model( max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False ) def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False, self_attend_right=True): return self.model(image1, image2, flow_init, iters, upsample, test_mode, self_attend_right) def training_step(self, batch, batch_idx): # loss = self(batch) left, right, gt_disp, valid_mask = batch left = torch.Tensor(left).to(self.device) right = torch.Tensor(right).to(self.device) left = left right = right flow_predictions = self.forward(left, right) loss = sequence_loss( flow_predictions, gt_flow, valid_mask, gamma=0.8 ) self.log("train_loss", loss) return loss def validation_step(self, batch, batch_idx): left, right, gt_disp, valid_mask = batch left = torch.Tensor(left).to(self.device) right = torch.Tensor(right).to(self.device) print(left.shape) print(right.shape) flow_predictions = self.forward(left, right) val_loss = sequence_loss( flow_predictions, gt_flow, valid_mask, gamma=0.8 ) self.log("val_loss", val_loss) def test_step(self, batch, batch_idx): left, right, gt_disp, valid_mask = batch # left, right, gt_disp, valid_mask = ( # batch["left"], # batch["right"], # batch["disparity"], # batch["mask"], # ) left = torch.Tensor(left).to(self.device) right = torch.Tensor(right).to(self.device) flow_predictions = self.forward(left, right) test_loss = sequence_loss( flow_predictions, gt_flow, valid_mask, gamma=0.8 ) self.log("test_loss", test_loss) def configure_optimizers(self): return optim.Adam(self.model.parameters(), lr=0.01, betas=(0.9, 0.999)) if __name__ == "__main__": # train configuration args = parse_yaml("cfgs/train.yaml") # wandb.init(project="crestereo-lightning", entity="cpt-captain") # Lite(strategy='dp', accelerator='gpu', devices=2).run(args) pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png' model = CREStereoLightning(args) dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path, use_lightning=True) test_dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path, test_set=True, use_lightning=True) print(len(dataset)) print(len(test_dataset)) wandb_logger = WandbLogger(project="crestereo-lightning") wandb.config.update(args._asdict()) trainer = Trainer( max_epochs=args.n_total_epoch, accelerator='gpu', devices=2, # auto_scale_batch_size='binsearch', # strategy='ddp', deterministic=True, check_val_every_n_epoch=1, limit_val_batches=24, limit_test_batches=24, logger=wandb_logger, default_root_dir=args.log_dir_lightning, ) # trainer.tune(model) trainer.fit(model, dataset, test_dataset)