diff --git a/api_server.py b/api_server.py index dd50fcf..2d439f0 100644 --- a/api_server.py +++ b/api_server.py @@ -22,7 +22,8 @@ from train import inference as ctd_inference app = FastAPI() # reference_pattern_path = '/home/nils/kinect_reference_cropped.png' -reference_pattern_path = '/home/nils/kinect_reference_far.png' +# reference_pattern_path = '/home/nils/kinect_reference_far.png' +reference_pattern_path = '/home/nils/mpc/kinect_downshift_rotate_left-1.png' # reference_pattern_path = '/home/nils/kinect_diff_ref.png' print(reference_pattern_path) reference_pattern = cv2.imread(reference_pattern_path) diff --git a/cfgs/train.yaml b/cfgs/train.yaml index 0df7678..c0a2015 100644 --- a/cfgs/train.yaml +++ b/cfgs/train.yaml @@ -1,12 +1,12 @@ seed: 0 mixed_precision: true -base_lr: 4.0e-4 -# base_lr: 0.00001 +base_lr: 0.00025 t_max: 16100 +scheduler: "cosineannealing" nr_gpus: 3 batch_size: 3 -n_total_epoch: 100 +n_total_epoch: 64 minibatch_per_epoch: 500 loadmodel: ~ @@ -17,17 +17,9 @@ model_save_freq_epoch: 1 max_disp: 256 image_width: 640 image_height: 480 -# dataset: "blender" -# training_data_path: "./stereo_trainset/crestereo" -# training_data_path: "/media/Data1/connecting_the_dots_data/ctd_data/" -# training_data_path: "/media/Data1/connecting_the_dots_data/blender_renders/data" training_data_path: "/media/Data1/connecting_the_dots_data/blender_renders_ctd_randomize_light/data" - - -# FIXME any of this?? -pattern_attention: false -scene_attention: true -ignore_pattern_completely: false +test_data_path: "./eval_kinect" +data_limit: 1. log_level: "logging.INFO" diff --git a/dataset.py b/dataset.py index e965259..3905733 100644 --- a/dataset.py +++ b/dataset.py @@ -239,17 +239,22 @@ class CREStereoDataset(Dataset): class CTDDataset(Dataset): - def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False, split=0.9, test_set=False, use_lightning=True): + def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False, split=0.9, test_set=False, use_lightning=True, data_limit=1.): super().__init__() self.rng = np.random.RandomState(0) self.augment = augment self.blur = blur self.use_lightning = use_lightning - imgs = glob.glob(os.path.join(root, f"{data_type}/*/im0_*.npy"), recursive=True) - if test_set: + self.data_type = data_type + + imgs = glob.glob(os.path.join(root, f"{data_type if not 'syn' in root else ''}/*/im0_0*.npy"), recursive=True) + if not test_set: self.imgs = imgs[:int(split * len(imgs))] else: self.imgs = imgs[int(split * len(imgs)):] + + self.imgs = self.imgs[:int(data_limit * len(self.imgs))] + self.pattern = cv2.imread(pattern_path)#, cv2.IMREAD_GRAYSCALE) if resize_pattern and self.pattern.shape != (480, 640, 3): @@ -325,9 +330,30 @@ class CTDDataset(Dataset): class BlenderDataset(CTDDataset): - def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False, split=0.9, test_set=False, use_lightning=False): - super().__init__(root, pattern_path) + def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False, split=0.9, test_set=False, use_lightning=False, disp_avail=False, data_limit=1.): + super().__init__(root, pattern_path, augment=augment) self.use_lightning = use_lightning + self.disp_avail = disp_avail + self.data_type = data_type + + self.get_imgs(root, test_set, split) + self.imgs = self.imgs[:int(data_limit * len(self.imgs))] + + self.pattern = cv2.imread(pattern_path)#, cv2.IMREAD_GRAYSCALE) + + if resize_pattern and self.pattern.shape != (480, 640, 3): + self.pattern = downsample(self.pattern) + + self.augmentor = Augmentor( + image_height=480, + image_width=640, + max_disp=256, + scale_min=0.6, + scale_max=1.0, + seed=0, + ) + + def get_imgs(self, root, test_set, split): additional_img_types = { 'depth', 'disp', @@ -347,25 +373,13 @@ class BlenderDataset(CTDDataset): self.imgs = imgs[:int(split * len(imgs))] else: self.imgs = imgs[int(split * len(imgs)):] - - self.pattern = cv2.imread(pattern_path)#, cv2.IMREAD_GRAYSCALE) - - if resize_pattern and self.pattern.shape != (480, 640, 3): - self.pattern = downsample(self.pattern) - - self.augmentor = Augmentor( - image_height=480, - image_width=640, - max_disp=256, - scale_min=0.6, - scale_max=1.0, - seed=0, - ) def __getitem__(self, index): # find path left_path = self.imgs[index] - left_disp_path = left_path.split('.')[0] + '_disp0001.png' + left_disp_path = left_path.rsplit('.', maxsplit=1)[0] + '_disp0001.png' + if not self.disp_avail: + left_depth_path = left_path.rsplit('.', maxsplit=1)[0] + '_depth0001.png' # read img, disp left_img = cv2.imread(left_path) @@ -379,9 +393,23 @@ class BlenderDataset(CTDDataset): left_img = cv2.merge([left_img, left_img, left_img]).reshape((480, 640, 3)) right_img = self.pattern - # left_disp = self.get_disp(left_disp_path) - disp = cv2.imread(left_disp_path, cv2.IMREAD_UNCHANGED) - left_disp = downsample(disp) + + # In some cases, we have disparity as floats in the range [0..1]. Thus we need to upscale the values. + # 64 has been arbitrarily chosen as max_disp for this case, as this is roughly the max disparity of the CTD dataset + max_disp = 64 + if not self.disp_avail: + left_disp = self.get_disp(left_depth_path) + if left_disp.max() <= 1.: + left_disp = (left_disp * max_disp).astype('uint8') + else: + try: + left_disp = cv2.imread(left_disp_path, cv2.IMREAD_UNCHANGED) + if left_disp.max() <= 1.: + left_disp = (left_disp * max_disp).astype('uint8') + if left_disp.shape != (480, 640, 3): + left_disp = downsample(left_disp) + except: + print(f'something happened, probably couldn\'t find {left_disp_path}') if False: # self.rng.binomial(1, 0.5): left_img, right_img = np.fliplr(right_img), np.fliplr(left_img) @@ -398,6 +426,9 @@ class BlenderDataset(CTDDataset): _left_img, _right_img, _left_disp, disp_mask = self.augmentor( left_img, right_img, left_disp ) + left_img = left_img.astype('float32') + right_img = right_img.astype('float32') + left_disp = left_disp.astype('float32') else: left_img, right_img, left_disp, disp_mask = self.augmentor( left_img, right_img, left_disp @@ -418,14 +449,13 @@ class BlenderDataset(CTDDataset): def get_disp(self, path): baseline = 0.075 # meters fl = 560. # as per CTD - depth = cv2.imread(path, cv2.IMREAD_UNCHANGED) - depth = downsample(depth) - # disp = np.load(path).transpose(1,2,0) - # disp = baseline * fl / depth - # return disp.astype(np.float32) / 32 - # FIXME temporarily increase disparity until new data with better depth values is generated - # higher values seem to speedup convergence, but introduce much stronger artifacting - mystery_factor = 35 - # mystery_factor = 1 - disp = (baseline * fl * mystery_factor) / depth + # depth = cv2.imread(path, cv2.IMREAD_UNCHANGED) + depth = cv2.imread(path, cv2.IMREAD_GRAYSCALE) + if depth.shape != (480, 640): + depth = downsample(depth) + + disp = (baseline * fl) / depth + + disp[disp == np.inf] = 0 + return disp.astype(np.float32) diff --git a/nets/attention/transformer.py b/nets/attention/transformer.py index 040f681..cfa5f13 100644 --- a/nets/attention/transformer.py +++ b/nets/attention/transformer.py @@ -76,7 +76,7 @@ class LocalFeatureTransformer(nn.Module): if p.dim() > 1: nn.init.xavier_uniform_(p) - def forward(self, feat0, feat1, mask0=None, mask1=None): + def forward(self, feat0, feat1, mask0=None, mask1=None, ignore_second_feat=False): """ Args: feat0 (torch.Tensor): [N, L, C] @@ -97,6 +97,9 @@ class LocalFeatureTransformer(nn.Module): name = self.layer_names[i] if name == 'self': feat0 = layer(feat0, feat0, mask0, mask0) + if ignore_second_feat: + # save some compute + continue feat1 = layer(feat1, feat1, mask1, mask1) elif name == 'cross': feat0 = layer(feat0, feat1, mask0, mask1) @@ -104,4 +107,6 @@ class LocalFeatureTransformer(nn.Module): else: raise KeyError + if ignore_second_feat: + return feat0 return feat0, feat1 diff --git a/nets/crestereo.py b/nets/crestereo.py index 6cc0bfd..6991f3b 100644 --- a/nets/crestereo.py +++ b/nets/crestereo.py @@ -151,7 +151,8 @@ class CREStereo(nn.Module): # FIXME experimental ! no self-attention for pattern if not self_attend_right: - fmap1_dw16, _ = self.self_att_fn(fmap1_dw16, fmap2_dw16) + print('skipping right attention') + fmap1_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16, ignore_second_feat=True) else: fmap1_dw16, fmap2_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16) diff --git a/train_lightning.py b/train_lightning.py index 91ef02c..6fb6f8c 100644 --- a/train_lightning.py +++ b/train_lightning.py @@ -31,8 +31,22 @@ import numpy as np import cv2 -def normalize_and_colormap(img): +def normalize_and_colormap(img, reduce_dynamic_range=False): + # print(img.min()) + # print(img.max()) + # print(img.mean()) ret = (img - img.min()) / (img.max() - img.min()) * 255.0 + # print(ret.min()) + # print(ret.max()) + # print(ret.mean()) + + # FIXME do I need to compress dynamic range somehow or something? + if reduce_dynamic_range and img.max() > 5*img.mean(): + ret = (img - img.min()) / (5*img.mean() - img.min()) * 255.0 + # print(ret.min()) + # print(ret.max()) + # print(ret.mean()) + if isinstance(ret, torch.Tensor): ret = ret.cpu().detach().numpy() ret = ret.astype("uint8") @@ -47,34 +61,71 @@ def log_images(left, right, pred_disp, gt_disp): if isinstance(pred_disp, list): pred_disp = pred_disp[-1] - pred_disp = torch.squeeze(pred_disp[:, 0, :, :]) - gt_disp = torch.squeeze(gt_disp[:, 0, :, :]) left = torch.squeeze(left[:, 0, :, :]) right = torch.squeeze(right[:, 0, :, :]) - + pred_disp = torch.squeeze(pred_disp[:, 0, :, :]) + gt_disp = torch.squeeze(gt_disp[:, 0, :, :]) + + # print('gt_disp debug') + # print(gt_disp.shape) + + singular_batch = False + if len(left.shape) == 2: + singular_batch = True + print('batch_size seems to be 1') + input_left = left.cpu().detach().numpy() + input_right = right.cpu().detach().numpy() + else: + input_left = left[batch_idx].cpu().detach().numpy()# .transpose(1,2,0) + input_right = right[batch_idx].cpu().detach().numpy()# .transpose(1,2,0) + disp = pred_disp disp_error = gt_disp - disp - - input_left = left[batch_idx].cpu().detach().numpy()# .transpose(1,2,0) - input_right = right[batch_idx].cpu().detach().numpy()# .transpose(1,2,0) - - wandb_log = dict( - key='samples', - images=[ - normalize_and_colormap(pred_disp[batch_idx]), - normalize_and_colormap(abs(disp_error[batch_idx])), - normalize_and_colormap(gt_disp[batch_idx]), - input_left, - input_right, - ], - caption=[ - f"Disparity \n{pred_disp[batch_idx].min():.{2}f}/{pred_disp[batch_idx].max():.{2}f}", - f"Disp. Error\n{disp_error[batch_idx].min():.{2}f}/{disp_error[batch_idx].max():.{2}f}\n{abs(disp_error[batch_idx]).mean():.{2}f}", - f"GT Disp Vis \n{gt_disp[batch_idx].min():.{2}f}/{gt_disp[batch_idx].max():.{2}f}", - "Input Left", - "Input Right" - ], - ) + + # print('gt_disp debug normalize') + # print(gt_disp.max(), gt_disp.min()) + # print(gt_disp.dtype) + + if singular_batch: + wandb_log = dict( + key='samples', + images=[ + pred_disp, + normalize_and_colormap(pred_disp), + normalize_and_colormap(abs(disp_error), reduce_dynamic_range=True), + normalize_and_colormap(gt_disp, reduce_dynamic_range=True), + input_left, + input_right, + ], + caption=[ + f"Disparity \n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", + f"Disparity (vis) \n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", + f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{abs(disp_error).mean():.{2}f}", + f"GT Disp Vis \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}", + "Input Left", + "Input Right" + ], + ) + else: + wandb_log = dict( + key='samples', + images=[ + # pred_disp.cpu().detach().numpy().transpose(1,2,0), + normalize_and_colormap(pred_disp[batch_idx]), + normalize_and_colormap(abs(disp_error[batch_idx])), + normalize_and_colormap(gt_disp[batch_idx]), + input_left, + input_right, + ], + caption=[ + # f"Disparity \n{pred_disp[batch_idx].min():.{2}f}/{pred_disp[batch_idx].max():.{2}f}", + f"Disparity (vis)\n{pred_disp[batch_idx].min():.{2}f}/{pred_disp[batch_idx].max():.{2}f}", + f"Disp. Error\n{disp_error[batch_idx].min():.{2}f}/{disp_error[batch_idx].max():.{2}f}\n{abs(disp_error[batch_idx]).mean():.{2}f}", + f"GT Disp Vis \n{gt_disp[batch_idx].min():.{2}f}/{gt_disp[batch_idx].max():.{2}f}", + "Input Left", + "Input Right" + ], + ) return wandb_log @@ -104,7 +155,10 @@ def outlier_fraction(estimate, target, mask=None, threshold=0): else: mask = mask != 0 if estimate.shape != mask.shape: - raise Exception(f'estimate and mask have to be same shape (expected {estimate.shape} == {mask.shape})') + if len(mask.shape) == 3: + mask = mask[0] + if estimate.shape != mask.shape: + raise Exception(f'estimate and mask have to be same shape (expected {estimate.shape} == {mask.shape})') return estimate, target, mask estimate = torch.squeeze(estimate[:, 0, :, :]) target = torch.squeeze(target[:, 0, :, :]) @@ -131,27 +185,9 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False): flow_preds[0]: (B, 2, H, W) flow_gt: (B, 2, H, W) ''' - """ - if test: - # print('sequence loss') - if valid.shape != (2, 480, 640): - valid = torch.stack([valid, valid])#.transpose(0,1)#.transpose(1,2) - # print(valid.shape) - #valid = torch.stack([valid, valid]) - # print(valid.shape) - if valid.shape != (2, 480, 640): - valid = valid.transpose(0,1) - # print(valid.shape) - """ - # print(valid.shape) - # print(flow_preds[0].shape) - # print(flow_gt.shape) n_predictions = len(flow_preds) flow_loss = 0.0 - # TEST - # flow_gt = torch.squeeze(flow_gt, dim=-1) - for i in range(n_predictions): i_weight = gamma ** (n_predictions - i - 1) i_loss = torch.abs(flow_preds[i] - flow_gt) @@ -161,38 +197,50 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False): class CREStereoLightning(LightningModule): - def __init__(self, args, logger=None, pattern_path='', data_path=''): + def __init__(self, args, logger=None, pattern_path=''): super().__init__() self.batch_size = args.batch_size self.wandb_logger = logger - self.data_type = 'blender' if 'blender' in data_path else 'ctd' + self.imwidth = args.image_width + self.imheight = args.image_height + self.data_type = 'blender' if 'blender' in args.training_data_path else 'ctd' + self.eval_type = 'kinect' if 'kinect' in args.test_data_path else args.training_data_path self.lr = args.base_lr - print(f'lr = {self.lr}') self.T_max = args.t_max if args.t_max else None self.pattern_attention = args.pattern_attention self.pattern_path = pattern_path - self.data_path = data_path + self.data_path = args.training_data_path + self.test_data_path = args.test_data_path + self.data_limit = args.data_limit # between 0 and 1. self.model = Model( max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False ) - # so I can access it in adjust learn rate more easily + + if args.scheduler == 'default': + self.automatic_optimization = False + # so I can access it in adjust learn rate more easily + self.n_total_epoch = args.n_total_epoch self.base_lr = args.base_lr - - self.automatic_optimization = False def train_dataloader(self): + # we never train on kinect + is_kinect = False if self.data_type == 'blender': dataset = BlenderDataset( root=self.data_path, pattern_path=self.pattern_path, use_lightning=True, + data_type='kinect' if is_kinect else 'blender', + disp_avail=not is_kinect, + data_limit = self.data_limit, ) elif self.data_type == 'ctd': dataset = CTDDataset( root=self.data_path, pattern_path=self.pattern_path, use_lightning=True, + data_limit = self.data_limit, ) dataloader = DataLoader( dataset, @@ -203,16 +251,20 @@ class CREStereoLightning(LightningModule): persistent_workers=True, pin_memory=True, ) - # num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True) return dataloader def val_dataloader(self): + # we also don't want to validate on kinect data + is_kinect = False if self.data_type == 'blender': test_dataset = BlenderDataset( root=self.data_path, pattern_path=self.pattern_path, test_set=True, use_lightning=True, + data_type='kinect' if is_kinect else 'blender', + disp_avail=not is_kinect, + data_limit = self.data_limit, ) elif self.data_type == 'ctd': test_dataset = CTDDataset( @@ -220,6 +272,7 @@ class CREStereoLightning(LightningModule): pattern_path=self.pattern_path, test_set=True, use_lightning=True, + data_limit = self.data_limit, ) test_dataloader = DataLoader( @@ -231,29 +284,35 @@ class CREStereoLightning(LightningModule): persistent_workers=True, pin_memory=True ) - # num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True) return test_dataloader def test_dataloader(self): - # TODO change this to use IRL data? + is_kinect = self.eval_type == 'kinect' if self.data_type == 'blender': - test_dataset = CTDDataset( - root=self.data_path, + test_dataset = BlenderDataset( + root=self.test_data_path, pattern_path=self.pattern_path, test_set=True, + split=0. if is_kinect else 0.9, # if we test on kinect data, use all available samples for test set use_lightning=True, + augment=False, + disp_avail=not is_kinect, + data_type='kinect' if is_kinect else 'blender', + data_limit = self.data_limit, ) elif self.data_type == 'ctd': - test_dataset = BlenderDataset( - root=self.data_path, + test_dataset = CTDDataset( + root=self.test_data_path, pattern_path=self.pattern_path, test_set=True, use_lightning=True, + augment=False, + data_limit = self.data_limit, ) test_dataloader = DataLoader( test_dataset, - self.batch_size, + 1 if is_kinect else self.batch_size, shuffle=False, num_workers=4, drop_last=False, @@ -307,7 +366,8 @@ class CREStereoLightning(LightningModule): self.log("outlier_fraction", of) # print(', '.join(f'of{thr}={val}' for thr, val in of.items())) if batch_idx % 8 == 0: - self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp)) + images = log_images(left, right, flow_predictions, gt_disp) + self.wandb_logger.log_image(**images) def test_step(self, batch, batch_idx): left, right, gt_disp, valid_mask = batch @@ -318,20 +378,28 @@ class CREStereoLightning(LightningModule): flow_predictions, gt_flow, valid_mask, gamma=0.8 ) self.log("test_loss", test_loss) - self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp)) + of = {} + for threshold in [0.1, 0.5, 1, 2, 5]: + of[str(threshold)] = outlier_fraction(flow_predictions[0], gt_flow, valid_mask, threshold) + self.log("outlier_fraction", of) + images = log_images(left, right, flow_predictions, gt_disp) + images['images'].append(gt_disp) + images['caption'].append('GT Disp') + self.wandb_logger.log_image(**images) def predict_step(self, batch, batch_idx, dataloader_idx=0): return self(batch) def configure_optimizers(self): optimizer = optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.999)) - print('len(self.train_dataloader)', len(self.train_dataloader())) + if not self.automatic_optimization: + return optimizer lr_scheduler = { 'scheduler': torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=self.T_max if self.T_max else len(self.train_dataloader())/self.batch_size, ), - 'name': 'CosineAnnealingLRScheduler', + 'name': 'LR Scheduler', } return [optimizer], [lr_scheduler] @@ -356,18 +424,21 @@ class CREStereoLightning(LightningModule): for param_group in optimizer.param_groups: param_group['lr'] = lr + self.log('train/lr', lr) if __name__ == "__main__": + # wandb.init(project='crestereo-lightning') wandb_logger = WandbLogger(project="crestereo-lightning", log_model=True) # train configuration args = parse_yaml("cfgs/train.yaml") wandb_logger.experiment.config.update(args._asdict()) config = wandb.config + data_limit = config.data_limit if 'blender' in config.training_data_path: # this was used for our blender renders pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png' - if 'ctd' in config.training_data_path: + elif 'ctd' in config.training_data_path: # this one is used (i hope) for ctd pattern_path = '/home/nils/kinect_from_settings.png' @@ -381,7 +452,6 @@ if __name__ == "__main__": config, wandb_logger, pattern_path, - config.training_data_path, # lr=0.00017378008287493763, # found with auto_lr_find=True ) # NOTE turn this down once it's working, this might use too much space @@ -394,31 +464,59 @@ if __name__ == "__main__": save_last=True, ) - trainer = Trainer( - accelerator='gpu', - devices=devices, - max_epochs=config.n_total_epoch, - callbacks=[ - EarlyStopping( - monitor="val_loss", - mode="min", - patience=16, - ), - LearningRateMonitor(), - model_checkpoint, - ], - strategy=DDPSpawnStrategy(find_unused_parameters=False), - # auto_scale_batch_size='binsearch', - # auto_lr_find=True, - # accumulate_grad_batches=4, # needed to disable for manual optimization - deterministic=True, - check_val_every_n_epoch=1, - limit_val_batches=64, - limit_test_batches=256, - logger=wandb_logger, - default_root_dir=config.log_dir_lightning, - ) + if config.scheduler == 'default': + trainer = Trainer( + accelerator='gpu', + devices=devices, + max_epochs=config.n_total_epoch, + callbacks=[ + EarlyStopping( + monitor="val_loss", + mode="min", + patience=8, + ), + LearningRateMonitor(), + model_checkpoint, + ], + strategy=DDPSpawnStrategy(find_unused_parameters=False), + # auto_scale_batch_size='binsearch', + # auto_lr_find=True, + # accumulate_grad_batches=4, # needed to disable for manual optimization + deterministic=True, + check_val_every_n_epoch=1, + limit_val_batches=64, + limit_test_batches=256, + logger=wandb_logger, + default_root_dir=config.log_dir_lightning, + ) + else: + trainer = Trainer( + accelerator='gpu', + devices=devices, + max_epochs=config.n_total_epoch, + callbacks=[ + EarlyStopping( + monitor="val_loss", + mode="min", + patience=8, + ), + LearningRateMonitor(), + model_checkpoint, + ], + strategy=DDPSpawnStrategy(find_unused_parameters=False), + # auto_scale_batch_size='binsearch', + # auto_lr_find=True, + accumulate_grad_batches=4, # needed to disable for manual optimization + deterministic=True, + check_val_every_n_epoch=1, + limit_val_batches=64, + limit_test_batches=256, + logger=wandb_logger, + default_root_dir=config.log_dir_lightning, + ) + # trainer.tune(model) trainer.fit(model) # trainer.validate(chkpt_path=model_checkpoint.best_model_path) + trainer.test(model)