From 63da24f4297c755c4da37f025cebc3db2b06b67b Mon Sep 17 00:00:00 2001 From: "Cpt.Captain" Date: Wed, 24 Aug 2022 16:25:12 +0200 Subject: [PATCH] change a bunch of stuff, add wip lightning implementation --- api_server.py | 181 ++++++++++++++++-- cfgs/train.yaml | 8 +- dataset.py | 111 ++++++++++- nets/attention/transformer.py | 13 +- nets/crestereo.py | 26 ++- nets/extractor.py | 41 ++-- nets/update.py | 2 +- test_model.py | 111 +++++++---- train.py | 193 +++++++++++++++---- train_lightning.py | 346 ++++++++++++++++++++++++++++++++++ 10 files changed, 918 insertions(+), 114 deletions(-) create mode 100644 train_lightning.py diff --git a/api_server.py b/api_server.py index 0448071..dd50fcf 100644 --- a/api_server.py +++ b/api_server.py @@ -1,3 +1,4 @@ +import os import json from datetime import datetime from typing import Union, Literal @@ -7,32 +8,149 @@ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F +import torch_tensorrt from cv2 import cv2 from fastapi import FastAPI, File, UploadFile from PIL import Image + from nets import Model +from train import inference as ctd_inference + + app = FastAPI() -reference_pattern_path = '/home/nils/kinect_reference_cropped.png' +# reference_pattern_path = '/home/nils/kinect_reference_cropped.png' +reference_pattern_path = '/home/nils/kinect_reference_far.png' +# reference_pattern_path = '/home/nils/kinect_diff_ref.png' +print(reference_pattern_path) reference_pattern = cv2.imread(reference_pattern_path) + +# shift reference pattern a few pixels to the left to simulate further backdrop +trans_mat = np.float32([[1, 0, 0], [0, 1, 0]]) +reference_pattern = cv2.warpAffine( + reference_pattern, trans_mat, reference_pattern.shape[1::-1], flags=cv2.INTER_LINEAR +) + iters = 20 -minimal_data = False +minimal_data = True +temporal_init = False +last_img = None device = torch.device('cuda:0') -def load_model(epoch): +def downsize(img): + diff = (512 - 480) // 2 + downsampled = cv2.pyrDown(img) + img = downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]] + return img + + +if 1024 in reference_pattern.shape: + reference_pattern = downsize(reference_pattern) + + +def ghetto_lcn(img): + # gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + gray = img + + float_gray = gray.astype(np.float32) / 255.0 + + blur = cv2.GaussianBlur(float_gray, (0, 0), sigmaX=2, sigmaY=2) + num = float_gray - blur + + blur = cv2.GaussianBlur(num * num, (0, 0), sigmaX=20, sigmaY=20) + den = cv2.pow(blur, 0.5) + + gray = num / den + + # cv2.normalize(gray, dst=gray, alpha=0.0, beta=1.0, norm_type=cv2.NORM_MINMAX) + cv2.normalize(gray, dst=gray, alpha=0.0, beta=255.0, norm_type=cv2.NORM_MINMAX) + return gray + + +# reference_pattern = ghetto_lcn(reference_pattern) + + +def load_model(epoch, use_tensorrt=False): global model epoch = f'epoch-{epoch}.pth' if epoch != 'latest' else 'latest.pth' model_path = f"train_log/models/{epoch}" model = Model(max_disp=256, mixed_precision=False, test_mode=True) + # FIXME WIP Workaround Dataparallel TensorRT incompatibility model = nn.DataParallel(model, device_ids=[device]) # model.load_state_dict(torch.load(model_path), strict=False) state_dict = torch.load(model_path)['state_dict'] model.load_state_dict(state_dict, strict=True) model.to(device) model.eval() + if use_tensorrt: + np_model = Model(max_disp=256, mixed_precision=False, test_mode=True) + np_model.load_state_dict(model.module.state_dict(), strict=True) + np_model.to(device) + np_model.eval() + + spec_dict = { + "inputs": [ + torch_tensorrt.Input( + min_shape=[1, 2, 240, 320], + max_shape=[1, 2, 480, 640], + opt_shape=[1, 2, 480, 640], + dtype=torch.int32, + ), + torch_tensorrt.Input( + min_shape=[1, 2, 240, 320], + max_shape=[1, 2, 480, 640], + opt_shape=[1, 2, 480, 640], + dtype=torch.int32, + ), + ], + "enabled_precisions": {torch.float, torch.half}, + "refit": False, + "debug": False, + "device": { + "device_type": torch_tensorrt.DeviceType.GPU, + "gpu_id": 0, + "dla_core": 0, + "allow_gpu_fallback": True + }, + "capability": torch_tensorrt.EngineCapability.default, + "num_min_timing_iters": 2, + "num_avg_timing_iters": 1, + } + spec = { + "forward": + torch_tensorrt.ts.TensorRTCompileSpec(**spec_dict) + } + + # trt_model = torch_tensorrt.compile(np_model , + # inputs=torch_tensorrt.Input( + # min_shape=[1, 2, 240, 320], + # max_shape=[1, 2, 480, 640], + # opt_shape=[1, 2, 480, 640], + # dtype=torch.int32, + # inputs = [torch_tensorrt.Input((1, 2, 480, 640)), torch_tensorrt.Input((1, 2, 480, 640))], # input shape + # enabled_precisions = {torch_tensorrt.dtype.half} # Run with FP16 + # ) + # trt_dw_model = torch_tensorrt.compile(np_model , + # inputs = [torch_tensorrt.Input((1, 2, 240, 320)), torch_tensorrt.Input((1, 2, 240, 320))], # input shape + # enabled_precisions = {torch_tensorrt.dtype.half} # Run with FP16 + # ) + + script_model = torch.jit.script(np_model.eval()) + # script_dw_model = torch.jit.script(trt_dw_model.eval()) + + # save the TensorRT embedded Torchscript + # torch.jit.save(trt_model, 'trt_torchscript_module.ts') + # torch.jit.save(trt_dw_model, 'trt_torchscript_dw_module.ts') + print(script_model) + print(script_model.forward) + print(script_model.forward()) + print(dir(script_model)) + + model = torch._C._jit_to_backend("tensorrt", script_model, spec) + print(f'loaded model {epoch}') return model @@ -74,14 +192,37 @@ def inference(left, right, model, n_iter=20): ) with torch.inference_mode(): - pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None) - pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2) + # pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None) + # pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2) + pred_flow_dw2 = model.forward(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None) + pred_flow = model.forward(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2) pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy() return pred_disp +def get_reference(): + refs = [ref.path for ref in os.scandir('/home/nils/references/')] + for ref in refs: + reference = cv2.imread(ref) + yield reference + +references = get_reference() + + +@app.post('/params/update_reference') +async def update_reference(): + global references, reference_pattern + try: + reference_pattern = downsize(next(references)) + print(reference_pattern.shape) + return {'status': 'success'} + except StopIteration: + references = get_reference() + return {'status': 'finished'} + + @app.post("/model/update/{epoch}") async def change_model(epoch: Union[int, Literal['latest']]): global model @@ -103,8 +244,15 @@ async def set_minimal_data(enable: bool): minimal_data = enable +@app.post("/params/temporal_init/{enable}") +async def set_temporal_init(enable: bool): + global temporal_init + temporal_init = enable + + @app.put("/ir") async def read_ir_input(file: UploadFile = File(...)): + global last_img, minimal_data try: img = np.array(Image.open(BytesIO(await file.read()))) except Exception as e: @@ -114,24 +262,35 @@ async def read_ir_input(file: UploadFile = File(...)): if len(img.shape) == 2: img = cv2.merge([img for _ in range(3)]) if img.shape == (1024, 1280, 3): - diff = (512 - 480) // 2 - downsampled = cv2.pyrDown(img) - img = downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]] + img = downsize(img) - img = img.transpose((1, 2, 0)) - ref_pat = reference_pattern.transpose((1, 2, 0)) + # img = img.transpose((1, 2, 0)) + # ref_pat = reference_pattern.transpose((1, 2, 0)) + ref_pat = reference_pattern start = datetime.now() - pred_disp = inference(img, ref_pat, model, iters) + if temporal_init: + pred_disp = ctd_inference(img, ref_pat, None, None, model, None, iters, False, last_img) + last_img = pred_disp + else: + pred_disp = ctd_inference(img, ref_pat, None, None, model, None, iters, False) + # pred_disp = inference(img, ref_pat, model, iters) duration = (datetime.now() - start).total_seconds() if minimal_data: return json.dumps({'disp': pred_disp, 'duration': duration}, cls=NumpyEncoder) else: + # return json.dumps({'disp': pred_disp, 'input': img, 'duration': duration}, return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img, 'duration': duration}, cls=NumpyEncoder) +@app.get('/temporal_init') +def get_temporal_init(): + return {'status': 'enabled' if temporal_init else 'disabled'} + + + @app.get('/') def main(): return {'test': 'abc'} diff --git a/cfgs/train.yaml b/cfgs/train.yaml index fc54fa2..c30a1ec 100644 --- a/cfgs/train.yaml +++ b/cfgs/train.yaml @@ -4,18 +4,22 @@ base_lr: 4.0e-4 nr_gpus: 3 batch_size: 4 -n_total_epoch: 600 +n_total_epoch: 300 minibatch_per_epoch: 500 loadmodel: ~ log_dir: "./train_log" +log_dir_lightning: "./train_log_lightning" model_save_freq_epoch: 1 max_disp: 256 image_width: 640 image_height: 480 # training_data_path: "./stereo_trainset/crestereo" -training_data_path: "/media/Data1/connecting_the_dots_data/ctd_data/" +pattern_attention: true +dataset: "blender" +# training_data_path: "/media/Data1/connecting_the_dots_data/ctd_data/" +training_data_path: "/media/Data1/connecting_the_dots_data/blender_renders/data" log_level: "logging.INFO" diff --git a/dataset.py b/dataset.py index 2cab5cc..3e4f36e 100644 --- a/dataset.py +++ b/dataset.py @@ -17,7 +17,7 @@ class Augmentor: scale_min=0.6, scale_max=1.0, seed=0, - ): + ): super().__init__() self.image_height = image_height self.image_width = image_width @@ -234,12 +234,16 @@ class CREStereoDataset(Dataset): class CTDDataset(Dataset): - def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False): + def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False, split=0.9, test_set=False): super().__init__() self.rng = np.random.RandomState(0) self.augment = augment self.blur = blur - self.imgs = glob.glob(os.path.join(root, f"{data_type}/*/im0_*.npy"), recursive=True) + imgs = glob.glob(os.path.join(root, f"{data_type}/*/im0_*.npy"), recursive=True) + if test_set: + self.imgs = imgs[:int(split * len(imgs))] + else: + self.imgs = imgs[int(split * len(imgs)):] self.pattern = cv2.imread(pattern_path)#, cv2.IMREAD_GRAYSCALE) if resize_pattern and self.pattern.shape != (480, 640, 3): @@ -271,6 +275,10 @@ class CTDDataset(Dataset): # read img, disp left_img = np.load(left_path) + + if left_img.dtype == 'float32': + left_img = (left_img * 255).astype('uint8') + left_img = cv2.merge([left_img, left_img, left_img]).reshape((480, 640, 3)) right_img = self.pattern @@ -307,3 +315,100 @@ class CTDDataset(Dataset): def __len__(self): return len(self.imgs) + + +class BlenderDataset(CTDDataset): + def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False, split=0.9, test_set=False, use_lightning=False): + super().__init__(root, pattern_path) + self.use_lightning = use_lightning + imgs = [f for f in glob.glob(f"{root}/im_*.png", recursive=True) if not 'depth0001' in f] + if test_set: + self.imgs = imgs[:int(split * len(imgs))] + else: + self.imgs = imgs[int(split * len(imgs)):] + self.pattern = cv2.imread(pattern_path)#, cv2.IMREAD_GRAYSCALE) + + if resize_pattern and self.pattern.shape != (480, 640, 3): + downsampled = cv2.pyrDown(self.pattern) + diff = (downsampled.shape[0] - 480) // 2 + self.pattern = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]] + + self.augmentor = Augmentor( + image_height=480, + image_width=640, + max_disp=256, + scale_min=0.6, + scale_max=1.0, + seed=0, + ) + + def __getitem__(self, index): + # find path + left_path = self.imgs[index] + left_disp_path = left_path.split('.')[0] + '_depth0001.png' + + # read img, disp + left_img = cv2.imread(left_path) + + if left_img.dtype == 'float32': + left_img = (left_img * 255).astype('uint8') + + if left_img.shape != (480, 640, 3): + downsampled = cv2.pyrDown(left_img) + diff = (downsampled.shape[0] - 480) // 2 + left_img = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]] + if left_img.shape[-1] != 3: + left_img = cv2.merge([left_img, left_img, left_img]).reshape((480, 640, 3)) + + right_img = self.pattern + left_disp = self.get_disp(left_disp_path) + + if False: # self.rng.binomial(1, 0.5): + left_img, right_img = np.fliplr(right_img), np.fliplr(left_img) + left_disp, right_disp = np.fliplr(right_disp), np.fliplr(left_disp) + left_disp[left_disp == np.inf] = 0 + + if self.blur: + kernel_size = random.sample([1,3,5,7,9], 1)[0] + kernel = (kernel_size, kernel_size) + left_img = cv2.GaussianBlur(left_img, kernel, 0) + + # augmentation + if not self.augment: + _left_img, _right_img, _left_disp, disp_mask = self.augmentor( + left_img, right_img, left_disp + ) + else: + left_img, right_img, left_disp, disp_mask = self.augmentor( + left_img, right_img, left_disp + ) + + if not self.use_lightning: + right_img = right_img.transpose((2, 0, 1)).astype("uint8") + return { + "left": left_img, + "right": right_img, + "disparity": left_disp, + "mask": disp_mask, + } + + right_img = right_img.transpose((2, 0, 1)).astype("uint8") + left_img = left_img.transpose((2, 0, 1)).astype("uint8") + return left_img, right_img, left_disp, disp_mask + + def get_disp(self, path): + baseline = 0.075 # meters + fl = 560. # as per CTD + depth = cv2.imread(path, cv2.IMREAD_UNCHANGED) + downsampled = cv2.pyrDown(depth) + diff = (downsampled.shape[0] - 480) // 2 + depth = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]] + # disp = np.load(path).transpose(1,2,0) + # disp = baseline * fl / depth + # return disp.astype(np.float32) / 32 + # FIXME temporarily increase disparity until new data with better depth values is generated + # higher values seem to speedup convergence, but introduce much stronger artifacting + # mystery_factor = 150 + mystery_factor = 1 + disp = (baseline * fl * mystery_factor) / depth + return disp.astype(np.float32) diff --git a/nets/attention/transformer.py b/nets/attention/transformer.py index de55ffc..040f681 100644 --- a/nets/attention/transformer.py +++ b/nets/attention/transformer.py @@ -86,8 +86,15 @@ class LocalFeatureTransformer(nn.Module): """ assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal" - for layer, name in zip(self.layers, self.layer_names): - + # NOTE Workaround for non statically determinable zip + # for layer, name in zip(self.layers, self.layer_names): + # layer_zip = ((layer, self.layer_names[i]) for i, layer in enumerate(self.layers)) + # layer_zip = [] + # for i, layer in enumerate(self.layers): + # layer_zip.append((layer, self.layer_names[i])) + + for i, layer in enumerate(self.layers): + name = self.layer_names[i] if name == 'self': feat0 = layer(feat0, feat0, mask0, mask0) feat1 = layer(feat1, feat1, mask1, mask1) @@ -97,4 +104,4 @@ class LocalFeatureTransformer(nn.Module): else: raise KeyError - return feat0, feat1 \ No newline at end of file + return feat0, feat1 diff --git a/nets/crestereo.py b/nets/crestereo.py index 70a6bac..377e36b 100644 --- a/nets/crestereo.py +++ b/nets/crestereo.py @@ -36,6 +36,12 @@ class CREStereo(nn.Module): self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=self.dropout) self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4) + # # NOTE Position_encoding as workaround for TensorRt + # image1_shape = [1, 2, 480, 640] + # self.pos_encoding_fn_small = PositionEncodingSine( + # d_model=256, max_shape=(image1_shape[2] // 16, image1_shape[3] // 16) + # ) + # loftr self.self_att_fn = LocalFeatureTransformer( d_model=256, nhead=8, layer_names=["self"] * 1, attention="linear" @@ -81,7 +87,7 @@ class CREStereo(nn.Module): zero_flow = torch.cat((_x, _y), dim=1).to(fmap.device) return zero_flow - def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False): + def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False, self_attend_right=True): """ Estimate optical flow between pair of frames """ image1 = 2 * (image1 / 255.0) - 1.0 @@ -130,17 +136,22 @@ class CREStereo(nn.Module): inp_dw16 = F.avg_pool2d(inp, 4, stride=4) # positional encoding and self-attention - pos_encoding_fn_small = PositionEncodingSine( - d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16) - ) + # pos_encoding_fn_small = PositionEncodingSine( + # d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16) + # ) # 'n c h w -> n (h w) c' - x_tmp = pos_encoding_fn_small(fmap1_dw16) + x_tmp = self.pos_encoding_fn_small(fmap1_dw16) fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1]) # 'n c h w -> n (h w) c' - x_tmp = pos_encoding_fn_small(fmap2_dw16) + x_tmp = self.pos_encoding_fn_small(fmap2_dw16) fmap2_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1]) - fmap1_dw16, fmap2_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16) + # FIXME experimental ! no self-attention for pattern + if not self_attend_right: + fmap1_dw16, _ = self.self_att_fn(fmap1_dw16, fmap2_dw16) + else: + fmap1_dw16, fmap2_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16) + fmap1_dw16, fmap2_dw16 = [ x.reshape(x.shape[0], image1.shape[2] // 16, -1, x.shape[2]).permute(0, 3, 1, 2) for x in [fmap1_dw16, fmap2_dw16] @@ -258,3 +269,4 @@ class CREStereo(nn.Module): return flow_up return predictions + diff --git a/nets/extractor.py b/nets/extractor.py index 993cd3a..367a416 100644 --- a/nets/extractor.py +++ b/nets/extractor.py @@ -1,6 +1,8 @@ +from typing import List import torch import torch.nn as nn import torch.nn.functional as F +from torch import Tensor # Ref: https://github.com/princeton-vl/RAFT/blob/master/core/extractor.py class ResidualBlock(nn.Module): @@ -96,28 +98,43 @@ class BasicEncoder(nn.Module): self.in_planes = dim return nn.Sequential(*layers) - def forward(self, x): + def forward(self, x: List[Tensor]): + # NOTE always assume list, otherwise TensorRT is sad + # batch_dim = x[0].shape[0] + # x_tensor = torch.cat(list(x), dim=0) # if input is list, combine batch dimension is_list = isinstance(x, tuple) or isinstance(x, list) if is_list: batch_dim = x[0].shape[0] - x = torch.cat(x, dim=0) + x_tensor = torch.cat(x, dim=0) + else: + x_tensor = x + + print() + print() + print(x_tensor.shape) + print() + print() - x = self.conv1(x) - x = self.norm1(x) - x = self.relu1(x) + x_tensor = self.conv1(x_tensor) + x_tensor = self.norm1(x_tensor) + x_tensor = self.relu1(x_tensor) - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) + x_tensor = self.layer1(x_tensor) + x_tensor = self.layer2(x_tensor) + x_tensor = self.layer3(x_tensor) - x = self.conv2(x) + x_tensor = self.conv2(x_tensor) if self.dropout is not None: - x = self.dropout(x) + x_tensor = self.dropout(x_tensor) if is_list: - x = torch.split(x, x.shape[0]//2, dim=0) + x_list = torch.split(x_tensor, x_tensor.shape[0]//2, dim=0) + return x_list + + x_list = torch.split(x_tensor, x_tensor.shape[0]//2, dim=0) + return x_list - return x \ No newline at end of file + # return list(x) diff --git a/nets/update.py b/nets/update.py index 401d504..3602b46 100644 --- a/nets/update.py +++ b/nets/update.py @@ -77,7 +77,7 @@ class BasicUpdateBlock(nn.Module): nn.ReLU(inplace=True), nn.Conv2d(256, mask_size**2 *9, 1, padding=0)) - def forward(self, net, inp, corr, flow, upsample=True): + def forward(self, net, inp, corr, flow, upsample: bool=True): # print(inp.shape, corr.shape, flow.shape) motion_features = self.encoder(flow, corr) # print(motion_features.shape, inp.shape) diff --git a/test_model.py b/test_model.py index 902e455..f7e6e51 100644 --- a/test_model.py +++ b/test_model.py @@ -3,6 +3,7 @@ import torch.nn as nn import torch.nn.functional as F import numpy as np import cv2 +import os from nets import Model @@ -16,17 +17,20 @@ device = 'cuda' wandb.init(project="crestereo", entity="cpt-captain") -def do_infer(left_img, right_img, gt_disp, model): - disp = ctd_inference(left_img, right_img, gt_disp, None, model, None, n_iter=20, wandb_log=False) +def do_infer(left_img, right_img, gt_disp, model, attend_pattern=True): + disp = ctd_inference(left_img, right_img, gt_disp, None, model, None, n_iter=20, wandb_log=False, attend_pattern=attend_pattern) disp_vis = normalize_and_colormap(disp) - gt_disp_vis = normalize_and_colormap(gt_disp) - if gt_disp.shape != disp.shape: - gt_disp = gt_disp.reshape(disp.shape) - disp_err = gt_disp - disp - disp_err = normalize_and_colormap(disp_err.abs()) - - wandb.log({ + # gt_disp_vis = normalize_and_colormap(gt_disp) + # if gt_disp.shape != disp.shape: + # gt_disp = gt_disp.reshape(disp.shape) + # disp_err = gt_disp - disp + # disp_err = normalize_and_colormap(disp_err.abs()) + if isinstance(left_img, torch.Tensor): + left_img = left_img.cpu().detach().numpy().astype('uint8') + right_img = right_img.cpu().detach().numpy().astype('uint8') + + results = { 'disp': wandb.Image( disp, caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}", @@ -35,41 +39,60 @@ def do_infer(left_img, right_img, gt_disp, model): disp_vis, caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}", ), - 'gt_disp_vis': wandb.Image( - gt_disp_vis, - caption=f"GT Disparity \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}", - ), - 'disp_err': wandb.Image( - disp_err, - caption=f"Disparity Error\n{disp_err.min():.{2}f}/{disp_err.max():.{2}f}", - ), + # 'disp_err': wandb.Image( + # disp_err, + # caption=f"Disparity Error\n{disp_err.min():.{2}f}/{disp_err.max():.{2}f}", + # ), 'input_left': wandb.Image( - left_img.cpu().detach().numpy().astype('uint8'), + left_img, caption=f"Input left", ), 'input_right': wandb.Image( - right_img.cpu().detach().numpy().astype('uint8'), + right_img, caption=f"Input right", ), - }) + } + + if gt_disp is not None: + print('logging gt') + print(f'gt: {gt_disp.max()}/{gt_disp.min()}/{gt_disp.mean()}') + gt_disp_vis = normalize_and_colormap(gt_disp) + results.update({ + 'gt_disp_vis': wandb.Image( + gt_disp_vis, + caption=f"GT Disparity \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}", + )}) + wandb.log(results) + + +def downsample(img, half_height_out=480): + downsampled = cv2.pyrDown(img) + diff = (downsampled.shape[0] - half_height_out) // 2 + return downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]] if __name__ == '__main__': # model_path = "models/crestereo_eth3d.pth" model_path = "train_log/models/latest.pth" + # model_path = "train_log/models/epoch-120.pth" + # model_path = "train_log/models/epoch-250.pth" + print(model_path) # reference_pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png' - reference_pattern_path = '/home/nils/kinect_reference_cropped.png' # reference_pattern_path = '/home/nils/new_reference.png' # reference_pattern_path = '/home/nils/kinect_reference_high_res.png' + reference_pattern_path = '/home/nils/miniprojekt/kinect_high_res_thresh_denoised.png' + # reference_pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png' + # reference_pattern_path = '/home/nils/kinect_reference_cropped.png' # reference_pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png' - data_type = 'kinect' + # data_type = 'kinect' + data_type = 'blender' augment = False args = parse_yaml("cfgs/train.yaml") - wandb.config.update({'model_path': model_path, 'reference_pattern': reference_pattern_path, 'augment': augment}) + wandb.config.update({'model_path': model_path, 'reference_pattern': reference_pattern_path, 'augment': augment, 'data_type': data_type, 'pattern_self_attention': args.pattern_attention}) model = Model(max_disp=256, mixed_precision=False, test_mode=True) model = nn.DataParallel(model, device_ids=[device]) @@ -78,16 +101,32 @@ if __name__ == '__main__': model.to(device) model.eval() - dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type, - pattern_path=reference_pattern_path, augment=augment) - dataloader = DataLoader(dataset, args.batch_size, shuffle=True, - num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True) - for batch in dataloader: - for left, right, disparity in zip(batch['left'], batch['right'], batch['disparity']): - right = right.transpose(0, 2).transpose(0, 1) - left_img = left - imgL = left.cpu().detach().numpy() - right_img = right - imgR = right.cpu().detach().numpy() - gt_disp = disparity - do_infer(left_img, right_img, gt_disp, model) + # dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type, + # pattern_path=reference_pattern_path, augment=augment) + # dataloader = DataLoader(dataset, args.batch_size, shuffle=True, + # num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True) + + # for batch in dataloader: + gt_disp = None + right = downsample(cv2.imread(reference_pattern_path)) + + if data_type == 'blender': + img_path = '/media/Data1/connecting_the_dots_data/blender_renders/data/' + elif data_type == 'kinect': + img_path = '/home/nils/kinect_pngs/ir/' + + for img in sorted(list(entry for entry in os.scandir(img_path) if 'depth' not in entry.name), key=lambda x:x.name)[:25]: + print(img.path) + if data_type == 'blender': + baseline = 0.075 # meters + fl = 560. # as per CTD + + gt_path = img.path.rsplit('.')[0] + '_depth0001.png' + gt_depth = downsample(cv2.imread(gt_path)) + + mystery_factor = 35 # we don't get reasonable disparities due to incorrect depth scaling (or something like that) + gt_disp = (baseline * fl * mystery_factor) / gt_depth + + left = downsample(cv2.imread(img.path)) + + do_infer(left, right, gt_disp, model, attend_pattern=args.pattern_attention) diff --git a/train.py b/train.py index 770d115..b036e03 100644 --- a/train.py +++ b/train.py @@ -9,7 +9,7 @@ import yaml from nets import Model # from dataset import CREStereoDataset -from dataset import CREStereoDataset, CTDDataset +from dataset import BlenderDataset, CREStereoDataset, CTDDataset import torch import torch.nn as nn @@ -32,14 +32,18 @@ def normalize_and_colormap(img): return ret -def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=True): +def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=True, last_img=None, test=False, attend_pattern=True): print("Model Forwarding...") - left = left.cpu().detach().numpy() + if isinstance(left, torch.Tensor): + left = left.cpu().detach().numpy() + imgR = right.cpu().detach().numpy() imgL = left - imgR = right.cpu().detach().numpy() + imgR = right imgL = np.ascontiguousarray(imgL[None, :, :, :]) imgR = np.ascontiguousarray(imgR[None, :, :, :]) + + flow_init = None # chosen for convenience device = torch.device('cuda:0') @@ -55,19 +59,54 @@ def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=Tru size=(imgL.shape[2] // 2, imgL.shape[3] // 2), mode="bilinear", align_corners=True, - ) + ).clamp(min=0, max=255) imgR_dw2 = F.interpolate( imgR, size=(imgL.shape[2] // 2, imgL.shape[3] // 2), mode="bilinear", align_corners=True, - ) + ).clamp(min=0, max=255) + if last_img is not None: + print('using flow_initialization') + print(last_img.shape) + # FIXME trying 8bit normalization because sometimes we get really high values that don't seem to help + print(last_img.max(), last_img.min()) + if last_img.min() < 0: + # print('Negative disparity detected. shifting...') + last_img = last_img - last_img.min() + if last_img.max() > 255: + # print('Excessive disparity detected. scaling...') + last_img = last_img / (last_img.max() / 255) + + + last_img = np.dstack([last_img, last_img]) + # last_img = np.dstack([last_img, last_img, last_img]) + last_img = np.dstack([last_img]) + last_img = last_img.reshape((1, 2, 480, 640)) + # print(last_img.shape) + # print(last_img.dtype) + # print(last_img.max(), last_img.min()) + flow_init = torch.tensor(last_img.astype("float32")).to(device) + # flow_init = F.interpolate( + # last_img, + # size=(last_img.shape[0] // 2, last_img.shape[1] // 2), + # mode="bilinear", + # align_corners=True, + # ) with torch.inference_mode(): - pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None) - pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2) + pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=flow_init, self_attend_right=attend_pattern) + pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2, self_attend_right=attend_pattern) + pf_base = pred_flow + if isinstance(pf_base, list): + pf_base = pred_flow[0] + pf = torch.squeeze(pf_base[:, 0, :, :]).cpu().detach().numpy() + print('pred_flow max min') + print(pf.max(), pf.min()) if not wandb_log: + if test: + return pred_flow return torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy() log = {} @@ -96,30 +135,36 @@ def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=Tru np.array([pred_disp.reshape(480, 640)]), caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", ) - log[f'pred_norm_{i}'] = wandb.Image( - np.array([pred_disp_norm.reshape(480, 640)]), - caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", - ) + # log[f'pred_norm_{i}'] = wandb.Image( + # np.array([pred_disp_norm.reshape(480, 640)]), + # caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", + # ) - log[f'pred_dw2_{i}'] = wandb.Image( - np.array([pred_disp_dw2.reshape(240, 320)]), - caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", - ) - log[f'pred_dw2_norm_{i}'] = wandb.Image( - np.array([pred_disp_dw2_norm.reshape(240, 320)]), - caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", - ) + # log[f'pred_dw2_{i}'] = wandb.Image( + # np.array([pred_disp_dw2.reshape(240, 320)]), + # caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", + # ) + # log[f'pred_dw2_norm_{i}'] = wandb.Image( + # np.array([pred_disp_dw2_norm.reshape(240, 320)]), + # caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", + # ) log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left") - log['input_right'] = wandb.Image(right.cpu().detach().numpy().transpose(1,2,0).astype('uint8'), caption="Input Right") + input_right = right.cpu().detach().numpy() if isinstance(right, torch.Tensor) else right + if input_right.shape != (480, 640, 3): + input_right.transpose(1,2,0) + log['input_right'] = wandb.Image(input_right.astype('uint8'), caption="Input Right") log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}") + gt_disp = gt_disp.cpu().detach().numpy() if isinstance(gt_disp, torch.Tensor) else gt_disp + disp = disp.cpu().detach().numpy() if isinstance(disp, torch.Tensor) else disp + disp_error = gt_disp - disp log['disp_error'] = wandb.Image( - normalize_and_colormap(disp_error.abs()), - caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{disp_error.abs().mean():.{2}f}", + normalize_and_colormap(abs(disp_error)), + caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{abs(disp_error).mean():.{2}f}", ) @@ -129,6 +174,7 @@ def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=Tru ) wandb.log(log) + return pred_flow def parse_yaml(file_path: str) -> namedtuple: @@ -172,12 +218,25 @@ def adjust_learning_rate(optimizer, epoch): for param_group in optimizer.param_groups: param_group['lr'] = lr -def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8): +def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False): ''' valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W) flow_preds[0]: (B, 2, H, W) flow_gt: (B, 2, H, W) ''' + if test: + # print('sequence loss') + if valid.shape != (2, 480, 640): + valid = torch.stack([valid, valid])#.transpose(0,1)#.transpose(1,2) + # print(valid.shape) + #valid = torch.stack([valid, valid]) + # print(valid.shape) + if valid.shape != (2, 480, 640): + valid = valid.transpose(0,1) + # print(valid.shape) + # print(valid.shape) + # print(flow_preds[0].shape) + # print(flow_gt.shape) n_predictions = len(flow_preds) flow_loss = 0.0 @@ -260,20 +319,41 @@ def main(args): start_iters = 0 # pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png' - pattern_path = '/home/nils/kinect_reference_cropped.png' + # pattern_path = '/home/nils/kinect_reference_cropped.png' + # pattern_path = '/home/nils/kinect_reference_far.png' + # pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png' + pattern_path = '/home/nils/miniprojekt/kinect_high_res_thresh_denoised.png' # pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png' # datasets # dataset = CREStereoDataset(args.training_data_path) - dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path) + # dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path) + # test_dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path, test_set=True) + if args.dataset == 'blender': + pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png' + dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path) + test_dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path, test_set=True) + elif args.dataset == 'ctd': + dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path) + test_dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path, test_set=True) + else: + print('unrecognized dataset') + quit() + + test_data_iter = iter(test_dataset) # if rank == 0: worklog.info(f"Dataset size: {len(dataset)}") + print(args.batch_size) dataloader = DataLoader(dataset, args.batch_size, shuffle=True, - num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True) + num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True) + # num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True) + test_dataloader = DataLoader(test_dataset, args.batch_size, shuffle=False, + num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True) # counter cur_iters = start_iters total_iters = args.minibatch_per_epoch * args.n_total_epoch t0 = time.perf_counter() + test_idx = 0 for epoch_idx in range(start_epoch_idx, args.n_total_epoch + 1): # adjust learning rate @@ -310,24 +390,59 @@ def main(args): # forward # left = left.transpose(1, 2).transpose(2, 3) left = left.transpose(1, 3).transpose(2, 3) - right = right.transpose(2, 3).transpose(2, 3)#.transpose(1, 2) - flow_predictions = model(left.cuda(), right.cuda()) + # right = right.transpose(2, 3).transpose(2, 3)#.transpose(1, 2) + flow_predictions = model(left.cuda(), right.cuda(), self_attend_right=args.pattern_attention) # loss & backword loss = sequence_loss( flow_predictions, gt_flow, valid_mask, gamma=0.8 ) - if batch_idx % 128 == 0: - inference( - mini_batch_data['left'][0], - mini_batch_data['right'][0], - mini_batch_data['disparity'][0], - mini_batch_data['mask'][0], - model, - batch_idx, - ) - + if batch_idx % 512 == 0: + test_idx = 0 + test_loss = 0 + for i, test_batch in enumerate(test_dataset): + # test_batch = next(test_data_iter) + if i >= 24: + break + + # TODO refactor, DRY + left, right, gt_disp, valid_mask = ( + test_batch['left'], + test_batch['right'], + torch.Tensor(test_batch['disparity']).cuda(), + torch.Tensor(test_batch['mask']).cuda(), + ) + gt_disp = torch.dstack([gt_disp, gt_disp]).transpose(2,0).transpose(1,2) + gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] + gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] + # print(f'left {left.shape}, right {right.shape}') + # left = left.transpose([2, 0, 1]) + right = right.transpose([1, 2, 0]) + # right = right.transpose(2, 3).transpose(2, 3)#.transpose(1, 2) + # print(f'left {left.shape}, right {right.shape}') + + model.eval() + flow_predictions = inference( + left, + right, + # gt_disp, + torch.Tensor(test_batch['disparity']).cuda(), + valid_mask, + model, + test_idx, + wandb_log=i % 4 == 0, + test=True, + ) + test_idx += 1 + test_loss += sequence_loss( + flow_predictions, gt_flow, valid_mask, gamma=0.8, test=True + ).data.item() + model.train() + + avg_test_loss = test_loss / test_idx + print(f'test_loss: {test_loss}\nlen test: {test_idx}\navg. loss: {avg_test_loss}') + metrics['test/loss'] = avg_test_loss # loss stats loss_item = loss.data.item() epoch_total_train_loss += loss_item diff --git a/train_lightning.py b/train_lightning.py new file mode 100644 index 0000000..b65b58e --- /dev/null +++ b/train_lightning.py @@ -0,0 +1,346 @@ +import os +import sys +import time +import logging +from collections import namedtuple + +import yaml +# from tensorboardX import SummaryWriter + +from nets import Model +# from dataset import CREStereoDataset +from dataset import BlenderDataset, CREStereoDataset, CTDDataset + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +from pytorch_lightning.lite import LightningLite +from pytorch_lightning import LightningDataModule, LightningModule, Trainer +from pytorch_lightning import Trainer, seed_everything +from pytorch_lightning.loggers import WandbLogger + +seed_everything(42, workers=True) + + +import wandb + +import numpy as np +import cv2 + + +def normalize_and_colormap(img): + ret = (img - img.min()) / (img.max() - img.min()) * 255.0 + if isinstance(ret, torch.Tensor): + ret = ret.cpu().detach().numpy() + ret = ret.astype("uint8") + ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO) + return ret + + +def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=True, last_img=None, test=False, attend_pattern=True): + + print("Model Forwarding...") + if isinstance(left, torch.Tensor): + left = left# .cpu().detach().numpy() + imgR = right# .cpu().detach().numpy() + imgL = left + imgR = right + imgL = np.ascontiguousarray(imgL[None, :, :, :]) + imgR = np.ascontiguousarray(imgR[None, :, :, :]) + + flow_init = None + + # chosen for convenience + + imgL = torch.tensor(imgL.astype("float32")) + imgR = torch.tensor(imgR.astype("float32")) + imgL = imgL.transpose(2,3).transpose(1,2) + if imgL.shape != imgR.shape: + imgR = imgR.transpose(2,3).transpose(1,2) + + imgL_dw2 = F.interpolate( + imgL, + size=(imgL.shape[2] // 2, imgL.shape[3] // 2), + mode="bilinear", + align_corners=True, + ).clamp(min=0, max=255) + imgR_dw2 = F.interpolate( + imgR, + size=(imgL.shape[2] // 2, imgL.shape[3] // 2), + mode="bilinear", + align_corners=True, + ).clamp(min=0, max=255) + if last_img is not None: + print('using flow_initialization') + print(last_img.shape) + # FIXME trying 8bit normalization because sometimes we get really high values that don't seem to help + print(last_img.max(), last_img.min()) + if last_img.min() < 0: + # print('Negative disparity detected. shifting...') + last_img = last_img - last_img.min() + if last_img.max() > 255: + # print('Excessive disparity detected. scaling...') + last_img = last_img / (last_img.max() / 255) + + + last_img = np.dstack([last_img, last_img]) + # last_img = np.dstack([last_img, last_img, last_img]) + last_img = np.dstack([last_img]) + last_img = last_img.reshape((1, 2, 480, 640)) + # print(last_img.shape) + # print(last_img.dtype) + # print(last_img.max(), last_img.min()) + flow_init = torch.tensor(last_img.astype("float32")) + # flow_init = F.interpolate( + # last_img, + # size=(last_img.shape[0] // 2, last_img.shape[1] // 2), + # mode="bilinear", + # align_corners=True, + # ) + with torch.inference_mode(): + pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=flow_init, self_attend_right=attend_pattern) + pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2, self_attend_right=attend_pattern) + pf_base = pred_flow + if isinstance(pf_base, list): + pf_base = pred_flow[0] + pf = torch.squeeze(pf_base[:, 0, :, :])# .cpu().detach().numpy() + print('pred_flow max min') + print(pf.max(), pf.min()) + + + if not wandb_log: + if test: + return pred_flow + return torch.squeeze(pred_flow[:, 0, :, :])# .cpu().detach().numpy() + + log = {} + in_h, in_w = left.shape[:2] + + # Resize image in case the GPU memory overflows + eval_h, eval_w = (in_h,in_w) + + for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)): + pred_disp = torch.squeeze(pf[:, 0, :, :])# .cpu().detach().numpy() + pred_disp_dw2 = torch.squeeze(pf_dw2[:, 0, :, :])# .cpu().detach().numpy() + + # pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) + # pred_disp_dw2_norm = cv2.normalize(pred_disp_dw2, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) + + if i == n_iter-1: + t = float(in_w) / float(eval_w) + disp = cv2.resize(pred_disp.cpu().detach().numpy(), (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t + + log[f'disp_vis'] = wandb.Image( + normalize_and_colormap(disp), + caption=f"Disparity \n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", + ) + + log[f'pred_{i}'] = wandb.Image( + np.array([pred_disp.cpu().detach().numpy().reshape(480, 640)]), + caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", + ) + # log[f'pred_norm_{i}'] = wandb.Image( + # np.array([pred_disp_norm.reshape(480, 640)]), + # caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", + # ) + + # log[f'pred_dw2_{i}'] = wandb.Image( + # np.array([pred_disp_dw2.reshape(240, 320)]), + # caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", + # ) + # log[f'pred_dw2_norm_{i}'] = wandb.Image( + # np.array([pred_disp_dw2_norm.reshape(240, 320)]), + # caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", + # ) + + + log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left") + input_right = right# .cpu().detach().numpy() if isinstance(right, torch.Tensor) else right + if input_right.shape != (480, 640, 3): + input_right.transpose(1,2,0) + log['input_right'] = wandb.Image(input_right.astype('uint8'), caption="Input Right") + + log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}") + + gt_disp = gt_disp# .cpu().detach().numpy() if isinstance(gt_disp, torch.Tensor) else gt_disp + disp = disp# .cpu().detach().numpy() if isinstance(disp, torch.Tensor) else disp + + disp_error = gt_disp - disp + log['disp_error'] = wandb.Image( + normalize_and_colormap(abs(disp_error)), + caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{abs(disp_error).mean():.{2}f}", + ) + + + log[f'gt_disp_vis'] = wandb.Image( + normalize_and_colormap(gt_disp), + caption=f"GT Disp Vis \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}", + ) + + wandb.log(log) + return pred_flow + + +def parse_yaml(file_path: str) -> namedtuple: + """Parse yaml configuration file and return the object in `namedtuple`.""" + with open(file_path, "rb") as f: + cfg: dict = yaml.safe_load(f) + args = namedtuple("train_args", cfg.keys())(*cfg.values()) + return args + + +def format_time(elapse): + elapse = int(elapse) + hour = elapse // 3600 + minute = elapse % 3600 // 60 + seconds = elapse % 60 + return "{:02d}:{:02d}:{:02d}".format(hour, minute, seconds) + + +def ensure_dir(path): + if not os.path.exists(path): + os.makedirs(path, exist_ok=True) + + +def adjust_learning_rate(optimizer, epoch): + + warm_up = 0.02 + const_range = 0.6 + min_lr_rate = 0.05 + + if epoch <= args.n_total_epoch * warm_up: + lr = (1 - min_lr_rate) * args.base_lr / ( + args.n_total_epoch * warm_up + ) * epoch + min_lr_rate * args.base_lr + elif args.n_total_epoch * warm_up < epoch <= args.n_total_epoch * const_range: + lr = args.base_lr + else: + lr = (min_lr_rate - 1) * args.base_lr / ( + (1 - const_range) * args.n_total_epoch + ) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * args.base_lr + + for param_group in optimizer.param_groups: + param_group['lr'] = lr + +def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False): + ''' + valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W) + flow_preds[0]: (B, 2, H, W) + flow_gt: (B, 2, H, W) + ''' + if test: + # print('sequence loss') + if valid.shape != (2, 480, 640): + valid = torch.stack([valid, valid])#.transpose(0,1)#.transpose(1,2) + # print(valid.shape) + #valid = torch.stack([valid, valid]) + # print(valid.shape) + if valid.shape != (2, 480, 640): + valid = valid.transpose(0,1) + # print(valid.shape) + # print(valid.shape) + # print(flow_preds[0].shape) + # print(flow_gt.shape) + n_predictions = len(flow_preds) + flow_loss = 0.0 + + # TEST + flow_gt = torch.squeeze(flow_gt, dim=-1) + + for i in range(n_predictions): + i_weight = gamma ** (n_predictions - i - 1) + i_loss = torch.abs(flow_preds[i] - flow_gt) + flow_loss += i_weight * (valid.unsqueeze(1) * i_loss).mean() + + return flow_loss + + + +class CREStereoLightning(LightningModule): + def __init__(self, args): + super().__init__() + self.batch_size = args.batch_size + self.model = Model( + max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False + ) + + def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False, self_attend_right=True): + return self.model(image1, image2, flow_init, iters, upsample, test_mode, self_attend_right) + + def training_step(self, batch, batch_idx): + # loss = self(batch) + left, right, gt_disp, valid_mask = batch + left = torch.Tensor(left).to(self.device) + right = torch.Tensor(right).to(self.device) + left = left + right = right + flow_predictions = self.forward(left, right) + loss = sequence_loss( + flow_predictions, gt_flow, valid_mask, gamma=0.8 + ) + self.log("train_loss", loss) + return loss + + def validation_step(self, batch, batch_idx): + left, right, gt_disp, valid_mask = batch + left = torch.Tensor(left).to(self.device) + right = torch.Tensor(right).to(self.device) + print(left.shape) + print(right.shape) + flow_predictions = self.forward(left, right) + val_loss = sequence_loss( + flow_predictions, gt_flow, valid_mask, gamma=0.8 + ) + self.log("val_loss", val_loss) + + def test_step(self, batch, batch_idx): + left, right, gt_disp, valid_mask = batch + # left, right, gt_disp, valid_mask = ( + # batch["left"], + # batch["right"], + # batch["disparity"], + # batch["mask"], + # ) + left = torch.Tensor(left).to(self.device) + right = torch.Tensor(right).to(self.device) + flow_predictions = self.forward(left, right) + test_loss = sequence_loss( + flow_predictions, gt_flow, valid_mask, gamma=0.8 + ) + self.log("test_loss", test_loss) + + def configure_optimizers(self): + return optim.Adam(self.model.parameters(), lr=0.01, betas=(0.9, 0.999)) + + +if __name__ == "__main__": + # train configuration + args = parse_yaml("cfgs/train.yaml") + # wandb.init(project="crestereo-lightning", entity="cpt-captain") + # Lite(strategy='dp', accelerator='gpu', devices=2).run(args) + pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png' + model = CREStereoLightning(args) + dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path, use_lightning=True) + test_dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path, test_set=True, use_lightning=True) + print(len(dataset)) + print(len(test_dataset)) + wandb_logger = WandbLogger(project="crestereo-lightning") + wandb.config.update(args._asdict()) + + trainer = Trainer( + max_epochs=args.n_total_epoch, + accelerator='gpu', + devices=2, + # auto_scale_batch_size='binsearch', + # strategy='ddp', + deterministic=True, + check_val_every_n_epoch=1, + limit_val_batches=24, + limit_test_batches=24, + logger=wandb_logger, + default_root_dir=args.log_dir_lightning, + ) + # trainer.tune(model) + trainer.fit(model, dataset, test_dataset)