import os import sys import time import logging from collections import namedtuple import yaml from nets import Model from dataset import BlenderDataset, CREStereoDataset, CTDDataset import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning import Trainer, seed_everything from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks import LearningRateMonitor from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.strategies import DDPSpawnStrategy seed_everything(42, workers=True) import wandb import numpy as np import cv2 def normalize_and_colormap(img): ret = (img - img.min()) / (img.max() - img.min()) * 255.0 if isinstance(ret, torch.Tensor): ret = ret.cpu().detach().numpy() ret = ret.astype("uint8") ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO) return ret def log_images(left, right, pred_disp, gt_disp): log = {} batch_idx = 0 if isinstance(pred_disp, list): pred_disp = pred_disp[-1] pred_disp = torch.squeeze(pred_disp[:, 0, :, :]) gt_disp = torch.squeeze(gt_disp[:, 0, :, :]) left = torch.squeeze(left[:, 0, :, :]) right = torch.squeeze(right[:, 0, :, :]) disp = pred_disp disp_error = gt_disp - disp input_left = left[batch_idx].cpu().detach().numpy()# .transpose(1,2,0) input_right = right[batch_idx].cpu().detach().numpy()# .transpose(1,2,0) wandb_log = dict( key='samples', images=[ normalize_and_colormap(pred_disp[batch_idx]), normalize_and_colormap(abs(disp_error[batch_idx])), normalize_and_colormap(gt_disp[batch_idx]), input_left, input_right, ], caption=[ f"Disparity \n{pred_disp[batch_idx].min():.{2}f}/{pred_disp[batch_idx].max():.{2}f}", f"Disp. Error\n{disp_error[batch_idx].min():.{2}f}/{disp_error[batch_idx].max():.{2}f}\n{abs(disp_error[batch_idx]).mean():.{2}f}", f"GT Disp Vis \n{gt_disp[batch_idx].min():.{2}f}/{gt_disp[batch_idx].max():.{2}f}", "Input Left", "Input Right" ], ) return wandb_log def parse_yaml(file_path: str) -> namedtuple: """Parse yaml configuration file and return the object in `namedtuple`.""" with open(file_path, "rb") as f: cfg: dict = yaml.safe_load(f) args = namedtuple("train_args", cfg.keys())(*cfg.values()) return args def format_time(elapse): elapse = int(elapse) hour = elapse // 3600 minute = elapse % 3600 // 60 seconds = elapse % 60 return "{:02d}:{:02d}:{:02d}".format(hour, minute, seconds) def ensure_dir(path): if not os.path.exists(path): os.makedirs(path, exist_ok=True) def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False): ''' valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W) flow_preds[0]: (B, 2, H, W) flow_gt: (B, 2, H, W) ''' """ if test: # print('sequence loss') if valid.shape != (2, 480, 640): valid = torch.stack([valid, valid])#.transpose(0,1)#.transpose(1,2) # print(valid.shape) #valid = torch.stack([valid, valid]) # print(valid.shape) if valid.shape != (2, 480, 640): valid = valid.transpose(0,1) # print(valid.shape) """ # print(valid.shape) # print(flow_preds[0].shape) # print(flow_gt.shape) n_predictions = len(flow_preds) flow_loss = 0.0 # TEST # flow_gt = torch.squeeze(flow_gt, dim=-1) for i in range(n_predictions): i_weight = gamma ** (n_predictions - i - 1) i_loss = torch.abs(flow_preds[i] - flow_gt) flow_loss += i_weight * (valid.unsqueeze(1) * i_loss).mean() return flow_loss class CREStereoLightning(LightningModule): def __init__(self, args, logger, pattern_path, data_path): super().__init__() self.batch_size = args.batch_size self.wandb_logger = logger self.lr = args.base_lr print(f'lr = {self.lr}') self.T_max = args.t_max if args.t_max else None self.pattern_attention = args.pattern_attention self.pattern_path = pattern_path self.data_path = data_path self.model = Model( max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False ) def train_dataloader(self): dataset = BlenderDataset( root=self.data_path, pattern_path=self.pattern_path, use_lightning=True, ) dataloader = DataLoader( dataset, self.batch_size, shuffle=True, num_workers=4, drop_last=True, persistent_workers=True, pin_memory=True, ) # num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True) return dataloader def val_dataloader(self): test_dataset = BlenderDataset( root=self.data_path, pattern_path=self.pattern_path, test_set=True, use_lightning=True, ) test_dataloader = DataLoader( test_dataset, self.batch_size, shuffle=False, num_workers=4, drop_last=False, persistent_workers=True, pin_memory=True ) # num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True) return test_dataloader def test_dataloader(self): # TODO change this to use IRL data? test_dataset = BlenderDataset( root=self.data_path, pattern_path=self.pattern_path, test_set=True, use_lightning=True, ) test_dataloader = DataLoader( test_dataset, self.batch_size, shuffle=False, num_workers=4, drop_last=False, persistent_workers=True, pin_memory=True ) return test_dataloader def forward( self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False, ): return self.model(image1, image2, flow_init, iters, upsample, test_mode, self.pattern_attention) def training_step(self, batch, batch_idx): left, right, gt_disp, valid_mask = batch flow_predictions = self.forward(left, right) gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] loss = sequence_loss( flow_predictions, gt_flow, valid_mask, gamma=0.8 ) if batch_idx % 128 == 0: image_log = log_images(left, right, flow_predictions, gt_disp) image_log['key'] = 'debug_train' self.wandb_logger.log_image(**image_log) self.log("train_loss", loss) return loss def validation_step(self, batch, batch_idx): left, right, gt_disp, valid_mask = batch flow_predictions = self.forward(left, right, test_mode=True) gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] val_loss = sequence_loss( flow_predictions, gt_flow, valid_mask, gamma=0.8 ) self.log("val_loss", val_loss) if batch_idx % 8 == 0: self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp)) def test_step(self, batch, batch_idx): left, right, gt_disp, valid_mask = batch gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] flow_predictions = self.forward(left, right, test_mode=True) test_loss = sequence_loss( flow_predictions, gt_flow, valid_mask, gamma=0.8 ) self.log("test_loss", test_loss) self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp)) def configure_optimizers(self): optimizer = optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.999)) print('len(self.train_dataloader)', len(self.train_dataloader())) lr_scheduler = { 'scheduler': torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=self.T_max if self.T_max else len(self.train_dataloader())/self.batch_size, ), 'name': 'CosineAnnealingLRScheduler', } return [optimizer], [lr_scheduler] if __name__ == "__main__": # train configuration args = parse_yaml("cfgs/train.yaml") pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png' run = wandb.init(project="crestereo-lightning", config=args._asdict(), tags=['new_scheduler', 'default_lr', f'{"" if args.pattern_attention else "no-"}pattern-attention'], notes='') run.config.update(args._asdict()) config = wandb.config wandb_logger = WandbLogger(project="crestereo-lightning", id=run.id, log_model=True) # wandb_logger = WandbLogger(project="crestereo-lightning", log_model='all') # wandb_logger.experiment.config.update(args._asdict()) model = CREStereoLightning( # args, config, wandb_logger, pattern_path, args.training_data_path, # lr=0.00017378008287493763, # found with auto_lr_find=True ) # NOTE turn this down once it's working, this might use too much space # wandb_logger.watch(model, log_graph=False) #, log='all') trainer = Trainer( accelerator='gpu', devices=args.nr_gpus, max_epochs=args.n_total_epoch, callbacks=[ EarlyStopping( monitor="val_loss", mode="min", patience=16, ), LearningRateMonitor(), ModelCheckpoint( monitor="val_loss", mode="min", save_top_k=2, save_last=True, ) ], strategy=DDPSpawnStrategy(find_unused_parameters=False), # auto_scale_batch_size='binsearch', # auto_lr_find=True, accumulate_grad_batches=4, deterministic=True, check_val_every_n_epoch=1, limit_val_batches=64, limit_test_batches=256, logger=wandb_logger, default_root_dir=args.log_dir_lightning, ) # trainer.tune(model) trainer.fit(model) trainer.validate()