From 2731ef1ada44b88147fcd4b2c3ad08e3bd1cfed3 Mon Sep 17 00:00:00 2001 From: "Cpt.Captain" Date: Wed, 14 Sep 2022 17:24:23 +0200 Subject: [PATCH] changed some stuff, re-added default CREStereo scheduler --- cfgs/train.yaml | 25 ++++--- dataset.py | 71 ++++++++++++------- train_lightning.py | 172 +++++++++++++++++++++++++++++++++++---------- 3 files changed, 194 insertions(+), 74 deletions(-) diff --git a/cfgs/train.yaml b/cfgs/train.yaml index 6c152f0..0df7678 100644 --- a/cfgs/train.yaml +++ b/cfgs/train.yaml @@ -1,12 +1,12 @@ seed: 0 -mixed_precision: false -# base_lr: 4.0e-4 -base_lr: 0.001 -t_max: 161 +mixed_precision: true +base_lr: 4.0e-4 +# base_lr: 0.00001 +t_max: 16100 nr_gpus: 3 -batch_size: 2 -n_total_epoch: 300 +batch_size: 3 +n_total_epoch: 100 minibatch_per_epoch: 500 loadmodel: ~ @@ -17,11 +17,18 @@ model_save_freq_epoch: 1 max_disp: 256 image_width: 640 image_height: 480 +# dataset: "blender" # training_data_path: "./stereo_trainset/crestereo" -pattern_attention: false -dataset: "blender" # training_data_path: "/media/Data1/connecting_the_dots_data/ctd_data/" -training_data_path: "/media/Data1/connecting_the_dots_data/blender_renders/data" +# training_data_path: "/media/Data1/connecting_the_dots_data/blender_renders/data" +training_data_path: "/media/Data1/connecting_the_dots_data/blender_renders_ctd_randomize_light/data" + + +# FIXME any of this?? +pattern_attention: false +scene_attention: true +ignore_pattern_completely: false + log_level: "logging.INFO" diff --git a/dataset.py b/dataset.py index 2649992..e965259 100644 --- a/dataset.py +++ b/dataset.py @@ -7,6 +7,11 @@ from PIL import Image, ImageEnhance from megengine.data.dataset import Dataset +def downsample(img): + downsampled = cv2.pyrDown(img) + diff = (downsampled.shape[0] - 480) // 2 + return downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]] + class Augmentor: def __init__( @@ -234,11 +239,12 @@ class CREStereoDataset(Dataset): class CTDDataset(Dataset): - def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False, split=0.9, test_set=False): + def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False, split=0.9, test_set=False, use_lightning=True): super().__init__() self.rng = np.random.RandomState(0) self.augment = augment self.blur = blur + self.use_lightning = use_lightning imgs = glob.glob(os.path.join(root, f"{data_type}/*/im0_*.npy"), recursive=True) if test_set: self.imgs = imgs[:int(split * len(imgs))] @@ -248,10 +254,7 @@ class CTDDataset(Dataset): if resize_pattern and self.pattern.shape != (480, 640, 3): # self.pattern = cv2.resize(self.pattern, (640, 480)) - print(self.pattern.shape) - downsampled = cv2.pyrDown(self.pattern) - diff = (downsampled.shape[0] - 480) // 2 - self.pattern = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]] + self.pattern = downsample(self.pattern) self.augmentor = Augmentor( image_height=480, @@ -304,14 +307,18 @@ class CTDDataset(Dataset): left_img, right_img, left_disp ) - right_img = right_img.transpose((2, 0, 1)).astype("uint8") - return { - "left": left_img, - "right": right_img, - "disparity": left_disp, - "mask": disp_mask, - } + if not self.use_lightning: + right_img = right_img.transpose((2, 0, 1)).astype("uint8") + return { + "left": left_img, + "right": right_img, + "disparity": left_disp, + "mask": disp_mask, + } + right_img = right_img.transpose((2, 0, 1)).astype("uint8") + left_img = left_img.transpose((2, 0, 1)).astype("uint8") + return left_img, right_img, left_disp, disp_mask def __len__(self): return len(self.imgs) @@ -321,17 +328,30 @@ class BlenderDataset(CTDDataset): def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False, split=0.9, test_set=False, use_lightning=False): super().__init__(root, pattern_path) self.use_lightning = use_lightning - imgs = [f for f in glob.glob(f"{root}/im_*.png", recursive=True) if not 'depth0001' in f] - if test_set: + additional_img_types = { + 'depth', + 'disp', + 'grad', + } + + pngs = glob.glob(f"{root}/im_*.png", recursive=True) + imgs = [ + img for img in pngs + if all( + map( + lambda x: x not in img, additional_img_types + ) + ) + ] + if not test_set: self.imgs = imgs[:int(split * len(imgs))] else: self.imgs = imgs[int(split * len(imgs)):] + self.pattern = cv2.imread(pattern_path)#, cv2.IMREAD_GRAYSCALE) if resize_pattern and self.pattern.shape != (480, 640, 3): - downsampled = cv2.pyrDown(self.pattern) - diff = (downsampled.shape[0] - 480) // 2 - self.pattern = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]] + self.pattern = downsample(self.pattern) self.augmentor = Augmentor( image_height=480, @@ -345,7 +365,7 @@ class BlenderDataset(CTDDataset): def __getitem__(self, index): # find path left_path = self.imgs[index] - left_disp_path = left_path.split('.')[0] + '_depth0001.png' + left_disp_path = left_path.split('.')[0] + '_disp0001.png' # read img, disp left_img = cv2.imread(left_path) @@ -354,14 +374,14 @@ class BlenderDataset(CTDDataset): left_img = (left_img * 255).astype('uint8') if left_img.shape != (480, 640, 3): - downsampled = cv2.pyrDown(left_img) - diff = (downsampled.shape[0] - 480) // 2 - left_img = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]] + left_img = downsample(left_img) if left_img.shape[-1] != 3: left_img = cv2.merge([left_img, left_img, left_img]).reshape((480, 640, 3)) right_img = self.pattern - left_disp = self.get_disp(left_disp_path) + # left_disp = self.get_disp(left_disp_path) + disp = cv2.imread(left_disp_path, cv2.IMREAD_UNCHANGED) + left_disp = downsample(disp) if False: # self.rng.binomial(1, 0.5): left_img, right_img = np.fliplr(right_img), np.fliplr(left_img) @@ -384,7 +404,6 @@ class BlenderDataset(CTDDataset): ) if not self.use_lightning: - # right_img = right_img.transpose((2, 0, 1)).astype("uint8") return { "left": left_img, "right": right_img, @@ -400,15 +419,13 @@ class BlenderDataset(CTDDataset): baseline = 0.075 # meters fl = 560. # as per CTD depth = cv2.imread(path, cv2.IMREAD_UNCHANGED) - downsampled = cv2.pyrDown(depth) - diff = (downsampled.shape[0] - 480) // 2 - depth = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]] + depth = downsample(depth) # disp = np.load(path).transpose(1,2,0) # disp = baseline * fl / depth # return disp.astype(np.float32) / 32 # FIXME temporarily increase disparity until new data with better depth values is generated # higher values seem to speedup convergence, but introduce much stronger artifacting - mystery_factor = 150 + mystery_factor = 35 # mystery_factor = 1 disp = (baseline * fl * mystery_factor) / depth return disp.astype(np.float32) diff --git a/train_lightning.py b/train_lightning.py index de5bce7..91ef02c 100644 --- a/train_lightning.py +++ b/train_lightning.py @@ -94,6 +94,32 @@ def format_time(elapse): return "{:02d}:{:02d}:{:02d}".format(hour, minute, seconds) + +def outlier_fraction(estimate, target, mask=None, threshold=0): + def _process_inputs(estimate, target, mask): + if estimate.shape != target.shape: + raise Exception(f'estimate and target have to be same shape (expected {estimate.shape} == {target.shape})') + if mask is None: + mask = np.ones(estimate.shape, dtype=np.bool) + else: + mask = mask != 0 + if estimate.shape != mask.shape: + raise Exception(f'estimate and mask have to be same shape (expected {estimate.shape} == {mask.shape})') + return estimate, target, mask + estimate = torch.squeeze(estimate[:, 0, :, :]) + target = torch.squeeze(target[:, 0, :, :]) + + estimate, target, mask = _process_inputs(estimate, target, mask) + + mask = mask.cpu().detach().numpy() + estimate = estimate.cpu().detach().numpy() + target = target.cpu().detach().numpy() + + diff = np.abs(estimate[mask] - target[mask]) + m = (diff > threshold).sum() / mask.sum() + return m + + def ensure_dir(path): if not os.path.exists(path): os.makedirs(path, exist_ok=True) @@ -134,12 +160,12 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False): return flow_loss - class CREStereoLightning(LightningModule): - def __init__(self, args, logger, pattern_path, data_path): + def __init__(self, args, logger=None, pattern_path='', data_path=''): super().__init__() self.batch_size = args.batch_size self.wandb_logger = logger + self.data_type = 'blender' if 'blender' in data_path else 'ctd' self.lr = args.base_lr print(f'lr = {self.lr}') self.T_max = args.t_max if args.t_max else None @@ -149,13 +175,25 @@ class CREStereoLightning(LightningModule): self.model = Model( max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False ) + # so I can access it in adjust learn rate more easily + self.n_total_epoch = args.n_total_epoch + self.base_lr = args.base_lr + + self.automatic_optimization = False def train_dataloader(self): - dataset = BlenderDataset( - root=self.data_path, - pattern_path=self.pattern_path, - use_lightning=True, - ) + if self.data_type == 'blender': + dataset = BlenderDataset( + root=self.data_path, + pattern_path=self.pattern_path, + use_lightning=True, + ) + elif self.data_type == 'ctd': + dataset = CTDDataset( + root=self.data_path, + pattern_path=self.pattern_path, + use_lightning=True, + ) dataloader = DataLoader( dataset, self.batch_size, @@ -169,12 +207,20 @@ class CREStereoLightning(LightningModule): return dataloader def val_dataloader(self): - test_dataset = BlenderDataset( - root=self.data_path, - pattern_path=self.pattern_path, - test_set=True, - use_lightning=True, - ) + if self.data_type == 'blender': + test_dataset = BlenderDataset( + root=self.data_path, + pattern_path=self.pattern_path, + test_set=True, + use_lightning=True, + ) + elif self.data_type == 'ctd': + test_dataset = CTDDataset( + root=self.data_path, + pattern_path=self.pattern_path, + test_set=True, + use_lightning=True, + ) test_dataloader = DataLoader( test_dataset, @@ -190,12 +236,20 @@ class CREStereoLightning(LightningModule): def test_dataloader(self): # TODO change this to use IRL data? - test_dataset = BlenderDataset( - root=self.data_path, - pattern_path=self.pattern_path, - test_set=True, - use_lightning=True, - ) + if self.data_type == 'blender': + test_dataset = CTDDataset( + root=self.data_path, + pattern_path=self.pattern_path, + test_set=True, + use_lightning=True, + ) + elif self.data_type == 'ctd': + test_dataset = BlenderDataset( + root=self.data_path, + pattern_path=self.pattern_path, + test_set=True, + use_lightning=True, + ) test_dataloader = DataLoader( test_dataset, @@ -232,6 +286,10 @@ class CREStereoLightning(LightningModule): image_log['key'] = 'debug_train' self.wandb_logger.log_image(**image_log) self.log("train_loss", loss) + + # update learn rate every N epochs + if self.trainer.is_last_batch: + self.adjust_learning_rate() return loss def validation_step(self, batch, batch_idx): @@ -243,6 +301,11 @@ class CREStereoLightning(LightningModule): flow_predictions, gt_flow, valid_mask, gamma=0.8 ) self.log("val_loss", val_loss) + of = {} + for threshold in [0.1, 0.5, 1, 2, 5]: + of[threshold] = outlier_fraction(flow_predictions[0], gt_flow, valid_mask, threshold) + self.log("outlier_fraction", of) + # print(', '.join(f'of{thr}={val}' for thr, val in of.items())) if batch_idx % 8 == 0: self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp)) @@ -256,6 +319,9 @@ class CREStereoLightning(LightningModule): ) self.log("test_loss", test_loss) self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp)) + + def predict_step(self, batch, batch_idx, dataloader_idx=0): + return self(batch) def configure_optimizers(self): optimizer = optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.999)) @@ -268,35 +334,70 @@ class CREStereoLightning(LightningModule): 'name': 'CosineAnnealingLRScheduler', } return [optimizer], [lr_scheduler] + + def adjust_learning_rate(self): + optimizer = self.optimizers().optimizer + epoch = self.trainer.current_epoch+1 + + warm_up = 0.02 + const_range = 0.6 + min_lr_rate = 0.05 + + if epoch <= self.n_total_epoch * warm_up: + lr = (1 - min_lr_rate) * self.base_lr / ( + self.n_total_epoch * warm_up + ) * epoch + min_lr_rate * self.base_lr + elif self.n_total_epoch * warm_up < epoch <= self.n_total_epoch * const_range: + lr = self.base_lr + else: + lr = (min_lr_rate - 1) * self.base_lr / ( + (1 - const_range) * self.n_total_epoch + ) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * self.base_lr + + for param_group in optimizer.param_groups: + param_group['lr'] = lr if __name__ == "__main__": + wandb_logger = WandbLogger(project="crestereo-lightning", log_model=True) # train configuration args = parse_yaml("cfgs/train.yaml") - pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png' - - run = wandb.init(project="crestereo-lightning", config=args._asdict(), tags=['new_scheduler', 'default_lr', f'{"" if args.pattern_attention else "no-"}pattern-attention'], notes='') - run.config.update(args._asdict()) + wandb_logger.experiment.config.update(args._asdict()) config = wandb.config - wandb_logger = WandbLogger(project="crestereo-lightning", id=run.id, log_model=True) - # wandb_logger = WandbLogger(project="crestereo-lightning", log_model='all') - # wandb_logger.experiment.config.update(args._asdict()) + if 'blender' in config.training_data_path: + # this was used for our blender renders + pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png' + if 'ctd' in config.training_data_path: + # this one is used (i hope) for ctd + pattern_path = '/home/nils/kinect_from_settings.png' + + + devices = min(config.nr_gpus, torch.cuda.device_count()) + if devices != config.nr_gpus: + print(f'Using less devices than expected! ({devices} / {config.nr_gpus})') model = CREStereoLightning( # args, config, wandb_logger, pattern_path, - args.training_data_path, + config.training_data_path, # lr=0.00017378008287493763, # found with auto_lr_find=True ) # NOTE turn this down once it's working, this might use too much space # wandb_logger.watch(model, log_graph=False) #, log='all') + model_checkpoint = ModelCheckpoint( + monitor="val_loss", + mode="min", + save_top_k=2, + save_last=True, + ) + trainer = Trainer( accelerator='gpu', - devices=args.nr_gpus, - max_epochs=args.n_total_epoch, + devices=devices, + max_epochs=config.n_total_epoch, callbacks=[ EarlyStopping( monitor="val_loss", @@ -304,25 +405,20 @@ if __name__ == "__main__": patience=16, ), LearningRateMonitor(), - ModelCheckpoint( - monitor="val_loss", - mode="min", - save_top_k=2, - save_last=True, - ) + model_checkpoint, ], strategy=DDPSpawnStrategy(find_unused_parameters=False), # auto_scale_batch_size='binsearch', # auto_lr_find=True, - accumulate_grad_batches=4, + # accumulate_grad_batches=4, # needed to disable for manual optimization deterministic=True, check_val_every_n_epoch=1, limit_val_batches=64, limit_test_batches=256, logger=wandb_logger, - default_root_dir=args.log_dir_lightning, + default_root_dir=config.log_dir_lightning, ) # trainer.tune(model) trainer.fit(model) - trainer.validate() + # trainer.validate(chkpt_path=model_checkpoint.best_model_path)