From 70e4bf6fe1c2766ed156b5d2f82d3064cd1896cf Mon Sep 17 00:00:00 2001 From: Nils Koch Date: Mon, 30 May 2022 16:13:06 +0200 Subject: [PATCH] add wandb, make compatible with ctd data --- dataset.py | 121 +++++++++++++++++++++++- nets/crestereo.py | 2 + test_model.py | 233 +++++++++++++++++++++++++++++++++++++++------- train.py | 175 +++++++++++++++++++++++++++++++--- 4 files changed, 479 insertions(+), 52 deletions(-) diff --git a/dataset.py b/dataset.py index e13e7b9..69a8bc0 100644 --- a/dataset.py +++ b/dataset.py @@ -1,4 +1,5 @@ import os +import random import cv2 import glob import numpy as np @@ -48,8 +49,8 @@ class Augmentor: def __call__(self, left_img, right_img, left_disp): # 1. chromatic augmentation - left_img = self.chromatic_augmentation(left_img) - right_img = self.chromatic_augmentation(right_img) + # left_img = self.chromatic_augmentation(left_img) + # right_img = self.chromatic_augmentation(right_img) # 2. spatial augmentation # 2.1) rotate & vertical shift for right image @@ -62,6 +63,7 @@ class Augmentor: self.rng.uniform(0, right_img.shape[1]), ) rot_mat = cv2.getRotationMatrix2D(image_center, ag, 1.0) + # right_img = right_img.transpose(2, 1, 0) right_img = cv2.warpAffine( right_img, rot_mat, right_img.shape[1::-1], flags=cv2.INTER_LINEAR ) @@ -69,6 +71,7 @@ class Augmentor: right_img = cv2.warpAffine( right_img, trans_mat, right_img.shape[1::-1], flags=cv2.INTER_LINEAR ) + # right_img = right_img.transpose(1, 2, 0) # 2.2) random resize resize_scale = self.rng.uniform(self.scale_min, self.scale_max) @@ -80,6 +83,10 @@ class Augmentor: fy=resize_scale, interpolation=cv2.INTER_LINEAR, ) + if len(left_img.shape) == 2: + left_img.shape += 1, + # left_img = cv2.merge([left_img, left_img, left_img]) + right_img = cv2.resize( right_img, None, @@ -87,6 +94,9 @@ class Augmentor: fy=resize_scale, interpolation=cv2.INTER_LINEAR, ) + if len(left_img.shape) == 2: + left_img.shape += 1, + # left_img = cv2.merge([left_img, left_img, left_img]) disp_mask = (left_disp < float(self.max_disp / resize_scale)) & (left_disp > 0) disp_mask = disp_mask.astype("float32") @@ -110,7 +120,11 @@ class Augmentor: ) # 2.3) random crop - h, w, c = left_img.shape + if len(left_img.shape) == 3: + h, w, c = left_img.shape + else: + h, w = left_img.shape + c = 1 dx = w - self.image_width dy = h - self.image_height dy = self.rng.randint(min(0, dy), max(0, dy) + 1) @@ -144,7 +158,7 @@ class Augmentor: (self.image_width, self.image_height), flags=cv2.INTER_LINEAR, borderValue=0, - ) + ) # 3. add random occlusion to right image if self.rng.binomial(1, 0.5): @@ -156,6 +170,9 @@ class Augmentor: np.mean(right_img, 0), 0 )[np.newaxis, np.newaxis] + if len(left_img.shape) == 2: + left_img = cv2.merge([left_img, left_img, left_img]) + return left_img, right_img, left_disp, disp_mask @@ -197,7 +214,8 @@ class CREStereoDataset(Dataset): left_disp[left_disp == np.inf] = 0 # augmentaion - left_img, right_img, left_disp, disp_mask = self.augmentor( + # left_img, right_img, left_disp, disp_mask = self.augmentor( + _, _, left_disp, disp_mask = self.augmentor( left_img, right_img, left_disp ) @@ -213,3 +231,96 @@ class CREStereoDataset(Dataset): def __len__(self): return len(self.imgs) + + +class CTDDataset(Dataset): + def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False): + super().__init__() + self.rng = np.random.RandomState(0) + self.augment = augment + self.blur = blur + self.imgs = glob.glob(os.path.join(root, f"{data_type}/*/im0_*.npy"), recursive=True) + self.pattern = cv2.imread(pattern_path)#, cv2.IMREAD_GRAYSCALE) + + if resize_pattern and self.pattern.shape != (480, 640, 3): + # self.pattern = cv2.resize(self.pattern, (640, 480)) + print(self.pattern.shape) + downsampled = cv2.pyrDown(self.pattern) + diff = (downsampled.shape[0] - 480) // 2 + self.pattern = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]] + + self.augmentor = Augmentor( + image_height=480, + image_width=640, + max_disp=256, + scale_min=0.6, + scale_max=1.0, + seed=0, + ) + + def get_disp(self, path): + # disp = cv2.imread(path, cv2.IMREAD_UNCHANGED) + disp = np.load(path).transpose(1,2,0) + # return disp.astype(np.float32) / 32 + return disp + + def __getitem__(self, index): + # find path + left_path = self.imgs[index] + prefix = left_path[: left_path.rfind("_")] + # right_path = prefix + "_right.jpg" + left_disp_path = left_path.replace('im', 'disp') + # right_disp_path = prefix + "_right.disp.png" + + # read img, disp + # left_img = cv2.imread(left_path, cv2.IMREAD_COLOR) + # right_img = cv2.imread(right_path, cv2.IMREAD_COLOR) + # left_img = np.load(left_path).transpose(1,2,0) + left_img = np.load(left_path) + # left_img = cv2.cvtColor(left_img, cv2.COLOR_GRAY2RGB) + # FIXME DO WE NEED THIS? OTHERWISE IT's PRETTY DARK + left_img = cv2.normalize(left_img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) + # left_img = cv2.merge([left_img, left_img, left_img]).reshape((3, 480, 640)) + left_img = cv2.merge([left_img, left_img, left_img]).reshape((480, 640, 3)) + + right_img = self.pattern + # right_img = cv2.merge([right_img, right_img, right_img]).reshape((3, 480, 640)) + left_disp = self.get_disp(left_disp_path) + # right_disp = self.get_disp(right_disp_path) + + if False: # self.rng.binomial(1, 0.5): + left_img, right_img = np.fliplr(right_img), np.fliplr(left_img) + left_disp, right_disp = np.fliplr(right_disp), np.fliplr(left_disp) + left_disp[left_disp == np.inf] = 0 + + if self.blur: + kernel_size = random.sample([1,3,5,7,9], 1)[0] + kernel = (kernel_size, kernel_size) + left_img = cv2.GaussianBlur(left_img, kernel, 0) + + # augmentation + if not self.augment: + _left_img, _right_img, _left_disp, disp_mask = self.augmentor( + left_img, right_img, left_disp + ) + else: + left_img, right_img, left_disp, disp_mask = self.augmentor( + left_img, right_img, left_disp + ) + + # left_img = left_img.transpose(2, 0, 1).astype("uint8") + # left_img = left_img.transpose(2, 0, 1).astype("float32") + # left_img = left_img.astype("float32") + right_img = right_img.transpose(2, 0, 1).astype("uint8") + # right_img = right_img.astype("float32") + # print('post_augment', left_img) + + return { + "left": left_img, + "right": right_img, + "disparity": left_disp, + "mask": disp_mask, + } + + def __len__(self): + return len(self.imgs) diff --git a/nets/crestereo.py b/nets/crestereo.py index 3f99917..70a6bac 100644 --- a/nets/crestereo.py +++ b/nets/crestereo.py @@ -155,6 +155,8 @@ class CREStereo(nn.Module): flow = None flow_up = None if flow_init is not None: + if isinstance(flow_init, list): + flow_init = flow_init[0] scale = fmap1.shape[2] / flow_init.shape[2] flow = -scale * F.interpolate( flow_init, diff --git a/test_model.py b/test_model.py index ddf98e0..9e04be7 100644 --- a/test_model.py +++ b/test_model.py @@ -1,12 +1,22 @@ import torch +import torch.nn as nn import torch.nn.functional as F import numpy as np import cv2 -from imread_from_url import imread_from_url from nets import Model +import wandb + +import random +from torch.utils.data import DataLoader +from dataset import CTDDataset +from train import normalize_and_colormap, parse_yaml, inference as ctd_inference + device = 'cuda' +wandb.init(project="crestereo", entity="cpt-captain") + + #Ref: https://github.com/megvii-research/CREStereo/blob/master/test.py def inference(left, right, model, n_iter=20): @@ -41,39 +51,198 @@ def inference(left, right, model, n_iter=20): return pred_disp -if __name__ == '__main__': - - left_img = imread_from_url("https://raw.githubusercontent.com/megvii-research/CREStereo/master/img/test/left.png") - right_img = imread_from_url("https://raw.githubusercontent.com/megvii-research/CREStereo/master/img/test/right.png") - - in_h, in_w = left_img.shape[:2] - # Resize image in case the GPU memory overflows - eval_h, eval_w = (in_h,in_w) - imgL = cv2.resize(left_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) - imgR = cv2.resize(right_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) +def inference_ctd(left, right, gt_disp, mask, model, epoch, n_iter=20): + + print("Model Forwarding...") + # print(left.shape) + left = left.cpu().detach().numpy() + imgL = left + imgR = right.cpu().detach().numpy() + imgL = np.ascontiguousarray(imgL[None, :, :, :]) + imgR = np.ascontiguousarray(imgR[None, :, :, :]) + + # chosen for convenience + device = torch.device('cuda:0') + + imgL = torch.tensor(imgL.astype("float32")).to(device) + imgR = torch.tensor(imgR.astype("float32")).to(device) + imgL = imgL.transpose(2,3).transpose(1,2) + + imgL_dw2 = F.interpolate( + imgL, + size=(imgL.shape[2] // 2, imgL.shape[3] // 2), + mode="bilinear", + align_corners=True, + ) + imgR_dw2 = F.interpolate( + imgR, + size=(imgL.shape[2] // 2, imgL.shape[3] // 2), + mode="bilinear", + align_corners=True, + ) + with torch.inference_mode(): + pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None) + pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2) + + log = {} + for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)): + pred_disp = torch.squeeze(pf[:, 0, :, :]).cpu().detach().numpy() + pred_disp_dw2 = torch.squeeze(pf_dw2[:, 0, :, :]).cpu().detach().numpy() + + pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) + pred_disp_dw2_norm = cv2.normalize(pred_disp_dw2, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) + + log[f'pred_{i}'] = wandb.Image( + np.array([pred_disp.reshape(480, 640)]), + caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", + ) + log[f'pred_norm_{i}'] = wandb.Image( + np.array([pred_disp_norm.reshape(480, 640)]), + caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", + ) + + log[f'pred_dw2_{i}'] = wandb.Image( + np.array([pred_disp_dw2.reshape(240, 320)]), + caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", + ) + log[f'pred_dw2_norm_{i}'] = wandb.Image( + np.array([pred_disp_dw2_norm.reshape(240, 320)]), + caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", + ) + + + log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left") + log['input_right'] = wandb.Image(right.cpu().detach().numpy().transpose(1,2,0).astype('uint8'), caption="Input Right") + + log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}") + + + disp_error = gt_disp - disp + log['disp_error'] = wandb.Image( + normalize_and_colormap(disp_error), + caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{disp_error.mean():.{2}f}", + ) + + wandb.log(log) + + +def do_infer(left_img, right_img, gt_disp, model): + in_h, in_w = left_img.shape[:2] + + # Resize image in case the GPU memory overflows + eval_h, eval_w = (in_h,in_w) + + # FIXME borked for some reason, hopefully not very important + + imgL = left_img.cpu().detach().numpy() if isinstance(left_img, torch.Tensor) else left_img + imgR = right_img.cpu().detach().numpy() if isinstance(right_img, torch.Tensor) else right_img + + imgL = cv2.resize(imgL, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) + imgR = cv2.resize(imgR, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) + + # pred = ctd_inference(imgL, imgR, gt_disp, None, model, None, n_iter=20) + pred = ctd_inference(left_img, right_img, gt_disp, None, model, None, n_iter=20, wandb_log=False) + + t = float(in_w) / float(eval_w) + disp = cv2.resize(pred, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t + + disp_vis = normalize_and_colormap(disp) + gt_disp_vis = normalize_and_colormap(gt_disp) + if gt_disp.shape != disp.shape: + gt_disp = gt_disp.reshape(disp.shape) + disp_err = gt_disp - disp + disp_err = normalize_and_colormap(disp_err.abs()) + + wandb.log({ + 'disp_vis': wandb.Image( + disp_vis, + caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}", + ), + 'gt_disp_vis': wandb.Image( + gt_disp_vis, + caption=f"GT Disparity \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}", + ), + 'disp_err': wandb.Image( + disp_err, + caption=f"Disparity Error\n{disp_err.min():.{2}f}/{disp_err.max():.{2}f}", + ), + 'input_left': wandb.Image( + left_img.cpu().detach().numpy().astype('uint8'), + caption=f"Input left", + ), + 'input_right': wandb.Image( + right_img.cpu().detach().numpy().astype('uint8'), + caption=f"Input right", + ), + }) - model_path = "models/crestereo_eth3d.pth" - - model = Model(max_disp=256, mixed_precision=False, test_mode=True) - model.load_state_dict(torch.load(model_path), strict=True) - model.to(device) - model.eval() - - pred = inference(imgL, imgR, model, n_iter=20) - - t = float(in_w) / float(eval_w) - disp = cv2.resize(pred, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t - - disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0 - disp_vis = disp_vis.astype("uint8") - disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) - - combined_img = np.hstack((left_img, disp_vis)) - cv2.namedWindow("output", cv2.WINDOW_NORMAL) - cv2.imshow("output", combined_img) - cv2.imwrite("output.jpg", disp_vis) - cv2.waitKey(0) +if __name__ == '__main__': + # model_path = "models/crestereo_eth3d.pth" + model_path = "train_log/models/latest.pth" + + # reference_pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png' + reference_pattern_path = '/home/nils/kinect_reference_cropped.png' + # reference_pattern_path = '/home/nils/new_reference.png' + # reference_pattern_path = '/home/nils/kinect_reference_high_res.png' + # reference_pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png' + + data_type = 'kinect' + augment = False + + args = parse_yaml("cfgs/train.yaml") + + wandb.config.update({'model_path': model_path, 'reference_pattern': reference_pattern_path, 'augment': augment}) + + model = Model(max_disp=256, mixed_precision=False, test_mode=True) + model = nn.DataParallel(model,device_ids=[device]) + # model.load_state_dict(torch.load(model_path), strict=False) + state_dict = torch.load(model_path)['state_dict'] + model.load_state_dict(state_dict, strict=True) + model.to(device) + model.eval() + + CTD = True + if not CTD: + left_img = cv2.imread("../test_imgs/left.png") + right_img = cv2.imread("../test_imgs/right.png") + in_h, in_w = left_img.shape[:2] + + # Resize image in case the GPU memory overflows + eval_h, eval_w = (in_h,in_w) + + # FIXME borked for some reason, hopefully not very important + imgL = cv2.resize(left_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) + imgR = cv2.resize(right_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) + + pred = inference(imgL, imgR, model, n_iter=20) + + t = float(in_w) / float(eval_w) + disp = cv2.resize(pred, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t + + disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0 + disp_vis = disp_vis.astype("uint8") + disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) + + combined_img = np.hstack((left_img, disp_vis)) + # cv2.namedWindow("output", cv2.WINDOW_NORMAL) + # cv2.imshow("output", combined_img) + cv2.imwrite("output.jpg", disp_vis) + # cv2.waitKey(0) + + else: + dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type, pattern_path=reference_pattern_path, augment=augment) + dataloader = DataLoader(dataset, args.batch_size, shuffle=True, + num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True) + for batch in dataloader: + for left, right, disparity in zip(batch['left'], batch['right'], batch['disparity']): + right = right.transpose(0, 2).transpose(0, 1) + left_img = left + imgL = left.cpu().detach().numpy() + right_img = right + imgR = right.cpu().detach().numpy() + gt_disp = disparity + do_infer(left_img, right_img, gt_disp, model) diff --git a/train.py b/train.py index 8f8d385..770d115 100644 --- a/train.py +++ b/train.py @@ -5,16 +5,131 @@ import logging from collections import namedtuple import yaml -from tensorboardX import SummaryWriter +# from tensorboardX import SummaryWriter from nets import Model -from dataset import CREStereoDataset +# from dataset import CREStereoDataset +from dataset import CREStereoDataset, CTDDataset import torch import torch.nn as nn +import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader +import wandb + +import numpy as np +import cv2 + + +def normalize_and_colormap(img): + ret = (img - img.min()) / (img.max() - img.min()) * 255.0 + if isinstance(ret, torch.Tensor): + ret = ret.cpu().detach().numpy() + ret = ret.astype("uint8") + ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO) + return ret + + +def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=True): + + print("Model Forwarding...") + left = left.cpu().detach().numpy() + imgL = left + imgR = right.cpu().detach().numpy() + imgL = np.ascontiguousarray(imgL[None, :, :, :]) + imgR = np.ascontiguousarray(imgR[None, :, :, :]) + + # chosen for convenience + device = torch.device('cuda:0') + + imgL = torch.tensor(imgL.astype("float32")).to(device) + imgR = torch.tensor(imgR.astype("float32")).to(device) + imgL = imgL.transpose(2,3).transpose(1,2) + if imgL.shape != imgR.shape: + imgR = imgR.transpose(2,3).transpose(1,2) + + imgL_dw2 = F.interpolate( + imgL, + size=(imgL.shape[2] // 2, imgL.shape[3] // 2), + mode="bilinear", + align_corners=True, + ) + imgR_dw2 = F.interpolate( + imgR, + size=(imgL.shape[2] // 2, imgL.shape[3] // 2), + mode="bilinear", + align_corners=True, + ) + with torch.inference_mode(): + pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None) + pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2) + + + if not wandb_log: + return torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy() + + log = {} + in_h, in_w = left.shape[:2] + + # Resize image in case the GPU memory overflows + eval_h, eval_w = (in_h,in_w) + + for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)): + pred_disp = torch.squeeze(pf[:, 0, :, :]).cpu().detach().numpy() + pred_disp_dw2 = torch.squeeze(pf_dw2[:, 0, :, :]).cpu().detach().numpy() + + pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) + pred_disp_dw2_norm = cv2.normalize(pred_disp_dw2, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) + + if i == n_iter-1: + t = float(in_w) / float(eval_w) + disp = cv2.resize(pred_disp, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t + + log[f'disp_vis'] = wandb.Image( + normalize_and_colormap(disp), + caption=f"Disparity \n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", + ) + + log[f'pred_{i}'] = wandb.Image( + np.array([pred_disp.reshape(480, 640)]), + caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", + ) + log[f'pred_norm_{i}'] = wandb.Image( + np.array([pred_disp_norm.reshape(480, 640)]), + caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", + ) + + log[f'pred_dw2_{i}'] = wandb.Image( + np.array([pred_disp_dw2.reshape(240, 320)]), + caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", + ) + log[f'pred_dw2_norm_{i}'] = wandb.Image( + np.array([pred_disp_dw2_norm.reshape(240, 320)]), + caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", + ) + + + log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left") + log['input_right'] = wandb.Image(right.cpu().detach().numpy().transpose(1,2,0).astype('uint8'), caption="Input Right") + + log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}") + + disp_error = gt_disp - disp + log['disp_error'] = wandb.Image( + normalize_and_colormap(disp_error.abs()), + caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{disp_error.abs().mean():.{2}f}", + ) + + + log[f'gt_disp_vis'] = wandb.Image( + normalize_and_colormap(gt_disp), + caption=f"GT Disp Vis \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}", + ) + + wandb.log(log) + def parse_yaml(file_path: str) -> namedtuple: """Parse yaml configuration file and return the object in `namedtuple`.""" @@ -65,6 +180,10 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8): ''' n_predictions = len(flow_preds) flow_loss = 0.0 + + # TEST + flow_gt = torch.squeeze(flow_gt, dim=-1) + for i in range(n_predictions): i_weight = gamma ** (n_predictions - i - 1) i_loss = torch.abs(flow_preds[i] - flow_gt) @@ -93,7 +212,9 @@ def main(args): optimizer = optim.Adam(model.parameters(), lr=0.1, betas=(0.9, 0.999)) # model = nn.DataParallel(model,device_ids=[0]) - tb_log = SummaryWriter(os.path.join(args.log_dir, "train.events")) + # tb_log = SummaryWriter(os.path.join(args.log_dir, "train.events")) + wandb.watch(model) + metrics = {} # worklog logging.basicConfig(level=eval(args.log_level)) @@ -138,8 +259,12 @@ def main(args): start_epoch_idx = 1 start_iters = 0 + # pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png' + pattern_path = '/home/nils/kinect_reference_cropped.png' + # pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png' # datasets - dataset = CREStereoDataset(args.training_data_path) + # dataset = CREStereoDataset(args.training_data_path) + dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path) # if rank == 0: worklog.info(f"Dataset size: {len(dataset)}") dataloader = DataLoader(dataset, args.batch_size, shuffle=True, @@ -183,6 +308,9 @@ def main(args): gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] # forward + # left = left.transpose(1, 2).transpose(2, 3) + left = left.transpose(1, 3).transpose(2, 3) + right = right.transpose(2, 3).transpose(2, 3)#.transpose(1, 2) flow_predictions = model(left.cuda(), right.cuda()) # loss & backword @@ -190,6 +318,16 @@ def main(args): flow_predictions, gt_flow, valid_mask, gamma=0.8 ) + if batch_idx % 128 == 0: + inference( + mini_batch_data['left'][0], + mini_batch_data['right'][0], + mini_batch_data['disparity'][0], + mini_batch_data['mask'][0], + model, + batch_idx, + ) + # loss stats loss_item = loss.data.item() epoch_total_train_loss += loss_item @@ -234,20 +372,25 @@ def main(args): worklog.info("".join(info)) # minibatch loss - tb_log.add_scalar("train/loss_batch", loss_item, cur_iters) - tb_log.add_scalar( - "train/lr", optimizer.param_groups[0]["lr"], cur_iters - ) - tb_log.flush() + # tb_log.add_scalar("train/loss_batch", loss_item, cur_iters) + metrics['train/loss_batch'] = loss_item + # tb_log.add_scalar( + # "train/lr", optimizer.param_groups[0]["lr"], cur_iters + # ) + metrics['train/lr'] = optimizer.param_groups[0]["lr"] + # tb_log.flush() + wandb.log(metrics) t1 = time.perf_counter() - tb_log.add_scalar( - "train/loss", - epoch_total_train_loss / args.minibatch_per_epoch, - epoch_idx, - ) - tb_log.flush() + # tb_log.add_scalar( + # "train/loss", + # epoch_total_train_loss / args.minibatch_per_epoch, + # epoch_idx, + # ) + metrics['train/loss'] = epoch_total_train_loss / args.minibatch_per_epoch + # tb_log.flush() + wandb.log(metrics) # save model params ckp_data = { @@ -271,4 +414,6 @@ def main(args): if __name__ == "__main__": # train configuration args = parse_yaml("cfgs/train.yaml") + wandb.init(project="crestereo", entity="cpt-captain") + wandb.config.update(args._asdict()) main(args)