diff --git a/data/calibration_result.xml b/data/calibration_result.xml index 1398f03..1b57eb5 100644 --- a/data/calibration_result.xml +++ b/data/calibration_result.xml @@ -6,51 +6,51 @@
d
488. 648. -6.3901938604977060e-01 +7.3876627268710637e-01 3 3
d
- 1.6399564573473415e+03 0. -6.5062953701584874e+01 0. - 1.5741778806528637e+03 2.3634226604202689e+02 0. 0. 1.
+ 1.5726417056187443e+03 0. 1.4574187727420562e+02 0. + 1.5816754032205320e+03 2.4066087342652420e+02 0. 0. 1. 1 5
d
- 4.0533496459876667e-01 -7.9789330239048994e-01 - -3.4496681677903256e-02 -1.1244970513014216e-01 - 7.0913484303897389e-01
+ 9.4803929000342360e-02 -7.4931503928649663e+00 + 2.7510446825876069e-03 -1.5574797970388680e-02 + 5.9686429523557969e+01 3 3
d
- 1.6483495467542737e+03 0. 3.8306162278275889e+02 0. - 1.6326239866472497e+03 7.3044314024967093e+02 0. 0. 1.
+ 1.8196441201415089e+03 0. 3.1215139179762173e+02 0. + 1.7710473039077285e+03 6.4652482452978484e+02 0. 0. 1. 1 5
d
- -4.7323012136273940e-01 6.1654050808332572e+00 - -2.6533525558408575e-02 -4.8302040441684145e-02 - -2.1030103617531569e+01
+ 4.6501355527112370e-01 -5.2653146171000911e+00 + -3.1399879320030987e-03 -7.4973212336674019e-02 + 2.5370499794178890e+01 3 3
d
- 9.9544319031800177e-01 2.2550253921095241e-02 - -9.2651718265839858e-02 -3.7412411799396868e-02 - 9.8608873522691420e-01 -1.6195467792545257e-01 - 8.7710696570434246e-02 1.6468300551872025e-01 9.8243897591679974e-01
+ 9.9751030556985942e-01 7.9013584755337138e-03 7.0076806549434587e-02 + -7.0415421716406692e-04 9.9476980417395522e-01 + -1.0213981040979671e-01 -7.0517334384988042e-02 + 1.0183616861386664e-01 9.9229869510812307e-01 3 1
d
- -5.4987956675391622e+01 3.6267509838011689e+00 - -1.5791458092388201e+01
+ -3.6051053224527990e+01 -1.1530953901520501e+01 + 1.0668513452875833e+02 diff --git a/data/create_syn_data.py b/data/create_syn_data.py index e40ea15..d299f13 100644 --- a/data/create_syn_data.py +++ b/data/create_syn_data.py @@ -138,7 +138,6 @@ def create_data(out_root, idx, n_samples, imsize, patterns, K, baseline, blend_i # render the scene at multiple scales scales = [1, 0.5, 0.25, 0.125] - for scale in scales: fx = K[0, 0] * scale fy = K[1, 1] * scale @@ -254,6 +253,7 @@ if __name__ == '__main__': track_length = 4 # load pattern image + # FIXME which one???? pattern_path = './kinect_pattern.png' pattern_crop = True patterns = get_patterns(pattern_path, imsizes, pattern_crop) diff --git a/data/rectify.py b/data/rectify.py index 54f2dd0..720070a 100644 --- a/data/rectify.py +++ b/data/rectify.py @@ -85,34 +85,42 @@ def euler_angles_from_rotation_matrix(R): return psi, theta, phi #################################################### - -print('R1:\n', R1) -print(euler_angles_from_rotation_matrix(R1)) -print('R2:\n', R2) -print(euler_angles_from_rotation_matrix(R2)) -print('P1:\n', P1) -print('P2:\n', P2) -print('Q :\n', Q) +# print('R1:\n', R1) +# print(euler_angles_from_rotation_matrix(R1)) +# print('R2:\n', R2) +# print(euler_angles_from_rotation_matrix(R2)) +# print('P1:\n', P1) +# print('P2:\n', P2) +# print('Q :\n', Q) +# +# +# print(P1.shape) pattern = cv2.imread('kinect_pattern.png') sampled_pattern = cv2.imread('sampled_kinect_pattern.png') proj_rect_map1, proj_rect_map2 = cv2.initInverseRectificationMap( +# proj_rect_map1, proj_rect_map2=cv2.initUndistortRectifyMap( params['proj']['K'], params['proj']['dist'], R1, - # None, P1, - # (688, 488), - (1280, 1024), + (688, 488), + # (1280, 800), cv2.CV_16SC2, ) +# print(proj_rect_map1.shape, proj_rect_map2.shape) -rect_pat = cv2.remap(pattern, proj_rect_map1, proj_rect_map2, cv2.INTER_LINEAR) +samp_rect_pat = cv2.remap(sampled_pattern, proj_rect_map1, proj_rect_map2, cv2.INTER_CUBIC) +rect_pat = cv2.remap(pattern, proj_rect_map1, proj_rect_map2, cv2.INTER_CUBIC) -# FIXME rect_pat is always zero +# print(rect_pat.shape) +cv2.imshow('get rect', samp_rect_pat) +cv2.waitKey() cv2.imshow('get rect', rect_pat) cv2.waitKey() # cv2.imshow(rect_pat2) cv2.waitKey() +cv2.imwrite('rectified_sampled_pattern_new.png', samp_rect_pat) +cv2.imwrite('rectified_pattern_new.png', rect_pat) diff --git a/model/exp_synph.py b/model/exp_synph.py index 1574ef6..e74257c 100644 --- a/model/exp_synph.py +++ b/model/exp_synph.py @@ -7,6 +7,10 @@ import sys import itertools import json import matplotlib.pyplot as plt +import cv2 +import torchvision.transforms as transforms + + import co import torchext from model import networks @@ -14,7 +18,7 @@ from data import dataset class Worker(torchext.Worker): - def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs): + def __init__(self, args, num_workers=18, train_batch_size=4, test_batch_size=4, save_frequency=1, **kwargs): super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, train_batch_size=train_batch_size, test_batch_size=test_batch_size, save_frequency=save_frequency, **kwargs) @@ -43,6 +47,8 @@ class Worker(torchext.Worker): self.lcn_in = networks.LCN(self.lcn_radius, 0.05) self.disparity_loss = networks.DisparityLoss() + # self.sup_disp_loss = torch.nn.CrossEntropyLoss() + self.sup_disp_loss = torch.nn.MSELoss() self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device)) # evaluate in the region where opencv Block Matching has valid values @@ -96,9 +102,61 @@ class Worker(torchext.Worker): self.data[key_std] = im_std.to(device).detach() def net_forward(self, net, train): + # FIXME hier schnibbeln? out = net(self.data['im0']) return out + @staticmethod + def find_corr_points_and_F(left, right): + sift = cv2.SIFT_create() + # find the keypoints and descriptors with SIFT + kp1, des1 = sift.detectAndCompute(cv2.normalize(left, None, 0, 255, cv2.NORM_MINMAX).astype('uint8'), None) + kp2, des2 = sift.detectAndCompute(cv2.normalize(right, None, 0, 255, cv2.NORM_MINMAX).astype('uint8'), None) + # FLANN parameters + FLANN_INDEX_KDTREE = 1 + index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5) + search_params = dict(checks=50) + flann = cv2.FlannBasedMatcher(index_params, search_params) + matches = flann.knnMatch(des1, des2, k=2) + pts1 = [] + pts2 = [] + # ratio test as per Lowe's paper + for i, (m, n) in enumerate(matches): + if m.distance < 0.8 * n.distance: + pts2.append(kp2[m.trainIdx].pt) + pts1.append(kp1[m.queryIdx].pt) + + pts1 = np.int32(pts1) + pts2 = np.int32(pts2) + F, mask = cv2.findFundamentalMat(pts1, pts2, cv2.FM_LMEDS) + # We select only inlier points + pts1 = pts1[mask.ravel() == 1] + pts2 = pts2[mask.ravel() == 1] + return pts1, pts2, F + + def calc_sgbm_gt(self): + sgbm_matcher = cv2.StereoSGBM_create() + disp_gt = [] + # cam_view = np.array(np.array_split(self.data['im0'].detach().to('cpu').numpy(), 4)[2:]) + # for i in range(self.data['im0'].shape[0]): + for i in range(1): + cam_view = self.data['im0'].detach().to('cpu').numpy()[i, 0] + pattern = self.pattern_proj.to('cpu').numpy()[i, 0] + pts_l, pts_r, F = self.find_corr_points_and_F(cam_view, pattern) + H_l, _ = cv2.findHomography(pts_l, pts_r) + H_r, _ = cv2.findHomography(pts_r, pts_l) + + left_rect = cv2.warpPerspective(cam_view, H_l, cam_view.shape) + right_rect = cv2.warpPerspective(pattern, H_r, pattern.shape) + + transform = transforms.ToTensor() + disparity_gt = transform(cv2.normalize( + sgbm_matcher.compute(cv2.normalize(left_rect, None, 0, 255, cv2.NORM_MINMAX).astype('uint8'), + cv2.normalize(right_rect, None, 0, 255, cv2.NORM_MINMAX).astype('uint8')), None, + alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F).T) + disp_gt.append(disparity_gt) + return disp_gt + def loss_forward(self, out, train): out, edge = out if not (isinstance(out, tuple) or isinstance(out, list)): @@ -110,15 +168,24 @@ class Worker(torchext.Worker): # apply photometric loss for s, l, o in zip(itertools.count(), self.losses, out): - val, pattern_proj = l(o, self.data[f'im{s}'][:, 0:1, ...], self.data[f'std{s}']) + val, pattern_proj = l(o[0], self.data[f'im{s}'][:, 0:1, ...], self.data[f'std{s}']) if s == 0: self.pattern_proj = pattern_proj.detach() vals.append(val) # apply disparity loss # 1-edge as ground truth edge if inversed - edge0 = 1 - torch.sigmoid(edge[0]) - val = self.disparity_loss(out[0], edge0) + if isinstance(edge, tuple): + edge0 = 1 - torch.sigmoid(edge[0][0]) + else: + edge0 = 1 - torch.sigmoid(edge[0]) + val = 0 + if isinstance(out[0], tuple): + + val += self.sup_disp_loss(out[0][1], self.data['disp0']) + val += self.disparity_loss(out[0][0], edge0) + else: + val += self.disparity_loss(out[0], edge0) if self.dp_weight > 0: vals.append(val * self.dp_weight) @@ -130,11 +197,17 @@ class Worker(torchext.Worker): ids = self.data['id'] mask = ids > self.train_edge if mask.sum() > 0: - val = self.edge_loss(e[mask], grad[mask]) + if isinstance(e, tuple): + val = self.edge_loss(e[0][mask], grad[mask]) + else: + val = self.edge_loss(e[mask], grad[mask]) else: val = torch.zeros_like(vals[0]) if s == 0: - self.edge = e.detach() + if isinstance(e, tuple): + self.edge = e[0].detach() + else: + self.edge = e.detach() self.edge = torch.sigmoid(self.edge) self.edge_gt = grad.detach() vals.append(val) @@ -145,7 +218,7 @@ class Worker(torchext.Worker): output, edge = output if not (isinstance(output, tuple) or isinstance(output, list)): output = [output] - es = output[0].detach().to('cpu').numpy() + es = output[0][0].detach().to('cpu').numpy() gt = self.data['disp0'].to('cpu').numpy().astype(np.float32) im = self.data['im0'][:, 0:1, ...].detach().to('cpu').numpy() @@ -271,4 +344,9 @@ class Worker(torchext.Worker): if __name__ == '__main__': + # FIXME Nicolas fixe idee + # SGBM nutzen, um GT zu finden + # bei dispnet (oder w/e) letzte paar layer 'dublizieren' (zweiten head bauen) und so mehrere Loss funktionen gleichzeitig trainieren + # L1 + L2 und dann im selben Backwardspass optimieren + # für das ganze forward pass anpassen pass diff --git a/model/exp_synph_real.py b/model/exp_synph_real.py new file mode 100644 index 0000000..9c9d863 --- /dev/null +++ b/model/exp_synph_real.py @@ -0,0 +1,368 @@ +import torch +import numpy as np +import time +from pathlib import Path +import logging +import sys +import itertools +import json +import matplotlib.pyplot as plt +import cv2 +import torchvision.transforms as transforms + + +import co +import torchext +from model import networks +from data import dataset + + +class Worker(torchext.Worker): + def __init__(self, args, num_workers=18, train_batch_size=4, test_batch_size=4, save_frequency=1, **kwargs): + super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, + train_batch_size=train_batch_size, test_batch_size=test_batch_size, + save_frequency=save_frequency, **kwargs) + + self.ms = args.ms + self.pattern_path = args.pattern_path + self.lcn_radius = args.lcn_radius + self.dp_weight = args.dp_weight + self.data_type = args.data_type + + self.imsizes = [(488, 648)] + for iter in range(3): + self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2))) + + with open('config.json') as fp: + config = json.load(fp) + data_root = Path(config['DATA_ROOT']) + self.settings_path = data_root / self.data_type / 'settings.pkl' + sample_paths = sorted((data_root / self.data_type).glob('0*/')) + + self.train_paths = sample_paths[2 ** 10:] + self.test_paths = sample_paths[:2 ** 8] + + # supervise the edge encoder with only 2**8 samples + self.train_edge = len(self.train_paths) - 2 ** 8 + + self.lcn_in = networks.LCN(self.lcn_radius, 0.05) + self.disparity_loss = networks.DisparityLoss() + # self.sup_disp_loss = torch.nn.CrossEntropyLoss() + self.sup_disp_loss = torch.nn.MSELoss() + self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device)) + + # evaluate in the region where opencv Block Matching has valid values + self.eval_mask = np.zeros(self.imsizes[0]) + self.eval_mask[13:self.imsizes[0][0] - 13, 140:self.imsizes[0][1] - 13] = 1 + self.eval_mask = self.eval_mask.astype(np.bool) + self.eval_h = self.imsizes[0][0] - 2 * 13 + self.eval_w = self.imsizes[0][1] - 13 - 140 + + def get_train_set(self): + train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, + track_length=1) + + return train_set + + def get_test_sets(self): + test_sets = torchext.TestSets() + test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True, + track_length=1) + test_sets.append('simple', test_set, test_frequency=1) + + # initialize photometric loss modules according to image sizes + self.losses = [] + for imsize, pat in zip(test_set.imsizes, test_set.patterns): + pat = pat.mean(axis=2) + pat = torch.from_numpy(pat[None][None].astype(np.float32)) + pat = pat.to(self.train_device) + self.lcn_in = self.lcn_in.to(self.train_device) + pat, _ = self.lcn_in(pat) + pat = torch.cat([pat for idx in range(3)], dim=1) + self.losses.append(networks.RectifiedPatternSimilarityLoss(imsize[0], imsize[1], pattern=pat)) + + return test_sets + + def copy_data(self, data, device, requires_grad, train): + self.lcn_in = self.lcn_in.to(device) + + self.data = {} + for key, val in data.items(): + grad = 'im' in key and requires_grad + self.data[key] = val.to(device).requires_grad_(requires_grad=grad) + + # apply lcn to IR input + # concatenate the normalized IR input and the original IR image + if 'im' in key and 'blend' not in key: + im = self.data[key] + im_lcn, im_std = self.lcn_in(im) + im_cat = torch.cat((im_lcn, im), dim=1) + key_std = key.replace('im', 'std') + self.data[key] = im_cat + self.data[key_std] = im_std.to(device).detach() + + def net_forward(self, net, train): + # FIXME hier schnibbeln? + out = net(self.data['im0']) + return out + + @staticmethod + def find_corr_points_and_F(left, right): + sift = cv2.SIFT_create() + # find the keypoints and descriptors with SIFT + kp1, des1 = sift.detectAndCompute(cv2.normalize(left, None, 0, 255, cv2.NORM_MINMAX).astype('uint8'), None) + kp2, des2 = sift.detectAndCompute(cv2.normalize(right, None, 0, 255, cv2.NORM_MINMAX).astype('uint8'), None) + # FLANN parameters + FLANN_INDEX_KDTREE = 1 + index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5) + search_params = dict(checks=50) + flann = cv2.FlannBasedMatcher(index_params, search_params) + matches = flann.knnMatch(des1, des2, k=2) + pts1 = [] + pts2 = [] + # ratio test as per Lowe's paper + for i, (m, n) in enumerate(matches): + if m.distance < 0.8 * n.distance: + pts2.append(kp2[m.trainIdx].pt) + pts1.append(kp1[m.queryIdx].pt) + + pts1 = np.int32(pts1) + pts2 = np.int32(pts2) + F, mask = cv2.findFundamentalMat(pts1, pts2, cv2.FM_LMEDS) + # We select only inlier points + pts1 = pts1[mask.ravel() == 1] + pts2 = pts2[mask.ravel() == 1] + return pts1, pts2, F + + def calc_sgbm_gt(self): + sgbm_matcher = cv2.StereoSGBM_create() + disp_gt = [] + # cam_view = np.array(np.array_split(self.data['im0'].detach().to('cpu').numpy(), 4)[2:]) + # for i in range(self.data['im0'].shape[0]): + for i in range(1): + cam_view = self.data['im0'].detach().to('cpu').numpy()[i, 0] + pattern = self.pattern_proj.to('cpu').numpy()[i, 0] + pts_l, pts_r, F = self.find_corr_points_and_F(cam_view, pattern) + H_l, _ = cv2.findHomography(pts_l, pts_r) + H_r, _ = cv2.findHomography(pts_r, pts_l) + + left_rect = cv2.warpPerspective(cam_view, H_l, cam_view.shape) + right_rect = cv2.warpPerspective(pattern, H_r, pattern.shape) + + transform = transforms.ToTensor() + disparity_gt = transform(cv2.normalize( + sgbm_matcher.compute(cv2.normalize(left_rect, None, 0, 255, cv2.NORM_MINMAX).astype('uint8'), + cv2.normalize(right_rect, None, 0, 255, cv2.NORM_MINMAX).astype('uint8')), None, + alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F).T) + disp_gt.append(disparity_gt) + return disp_gt + + def loss_forward(self, out, train): + out, edge = out + if not (isinstance(out, tuple) or isinstance(out, list)): + out = [out] + if not (isinstance(edge, tuple) or isinstance(edge, list)): + edge = [edge] + + vals = [] + + # apply photometric loss + for s, l, o in zip(itertools.count(), self.losses, out): + val, pattern_proj = l(o[0], self.data[f'im{s}'][:, 0:1, ...], self.data[f'std{s}']) + if s == 0: + self.pattern_proj = pattern_proj.detach() + vals.append(val) + + # apply disparity loss + # 1-edge as ground truth edge if inversed + if isinstance(edge, tuple): + edge0 = 1 - torch.sigmoid(edge[0][0]) + else: + edge0 = 1 - torch.sigmoid(edge[0]) + val = 0 + if isinstance(out[0], tuple): + # val = self.disparity_loss(out[0][1], edge0) + # FIXME disparity_loss ist unsupervised, wir wollen supervised(?) + # warum nicht einfach so die GT die wir eh schon haben? + # gt = self.data[f'disp0'].type('torch.LongTensor') + + val += self.sup_disp_loss(out[0][1], self.data['disp0']) + # disp_gt = self.calc_sgbm_gt() + # if len(disp_gt) > 1: + # disparity_gt = torch.stack(disp_gt).to('cuda') + # # val += self.sup_disp_loss(out[0][1], disparity_gt) + # else: + # disparity_gt = disp_gt[0].to('cuda') + # val += self.sup_disp_loss(out[0][1][0], disparity_gt) + # print(disparity_gt) + # print(disparity_gt.shape) + # print(out[0][1]) + # print(out[0][1].shape) + if isinstance(out[0], tuple): + val += self.disparity_loss(out[0][0], edge0) + else: + val += self.disparity_loss(out[0], edge0) + if self.dp_weight > 0: + vals.append(val * self.dp_weight) + + # apply edge loss on a subset of training samples + for s, e in zip(itertools.count(), edge): + # inversed ground truth edge where 0 means edge + grad = self.data[f'grad{s}'] < 0.2 + grad = grad.to(torch.float32) + ids = self.data['id'] + mask = ids > self.train_edge + if mask.sum() > 0: + if isinstance(e, tuple): + val = self.edge_loss(e[0][mask], grad[mask]) + else: + val = self.edge_loss(e[mask], grad[mask]) + else: + val = torch.zeros_like(vals[0]) + if s == 0: + if isinstance(e, tuple): + self.edge = e[0].detach() + else: + self.edge = e.detach() + self.edge = torch.sigmoid(self.edge) + self.edge_gt = grad.detach() + vals.append(val) + + return vals + + def numpy_in_out(self, output): + output, edge = output + if not (isinstance(output, tuple) or isinstance(output, list)): + output = [output] + es = output[0][0].detach().to('cpu').numpy() + gt = self.data['disp0'].to('cpu').numpy().astype(np.float32) + im = self.data['im0'][:, 0:1, ...].detach().to('cpu').numpy() + + ma = gt > 0 + return es, gt, im, ma + + def write_img(self, out_path, es, gt, im, ma): + logging.info(f'write img {out_path}') + u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0])) + + diff = np.abs(es - gt) + + vmin, vmax = np.nanmin(gt), np.nanmax(gt) + vmin = vmin - 0.2 * (vmax - vmin) + vmax = vmax + 0.2 * (vmax - vmin) + + pattern_proj = self.pattern_proj.to('cpu').numpy()[0, 0] + im_orig = self.data['im0'].detach().to('cpu').numpy()[0, 0] + pattern_diff = np.abs(im_orig - pattern_proj) + + fig = plt.figure(figsize=(16, 16)) + es_ = co.cmap.color_depth_map(es, scale=vmax) + gt_ = co.cmap.color_depth_map(gt, scale=vmax) + diff_ = co.cmap.color_error_image(diff, BGR=True) + + # plot disparities, ground truth disparity is shown only for reference + ax = plt.subplot(3, 3, 1) + plt.imshow(es_[..., [2, 1, 0]]) + plt.xticks([]) + plt.yticks([]) + ax.set_title(f'Disparity Est. {es.min():.4f}/{es.max():.4f}') + ax = plt.subplot(3, 3, 2) + plt.imshow(gt_[..., [2, 1, 0]]) + plt.xticks([]) + plt.yticks([]) + ax.set_title(f'Disparity GT {np.nanmin(gt):.4f}/{np.nanmax(gt):.4f}') + ax = plt.subplot(3, 3, 3) + plt.imshow(diff_[..., [2, 1, 0]]) + plt.xticks([]) + plt.yticks([]) + ax.set_title(f'Disparity Err. {diff.mean():.5f}') + + # plot edges + edge = self.edge.to('cpu').numpy()[0, 0] + edge_gt = self.edge_gt.to('cpu').numpy()[0, 0] + edge_err = np.abs(edge - edge_gt) + ax = plt.subplot(3, 3, 4); + plt.imshow(edge, cmap='gray'); + plt.xticks([]); + plt.yticks([]); + ax.set_title(f'Edge Est. {edge.min():.5f}/{edge.max():.5f}') + ax = plt.subplot(3, 3, 5); + plt.imshow(edge_gt, cmap='gray'); + plt.xticks([]); + plt.yticks([]); + ax.set_title(f'Edge GT {edge_gt.min():.5f}/{edge_gt.max():.5f}') + ax = plt.subplot(3, 3, 6); + plt.imshow(edge_err, cmap='gray'); + plt.xticks([]); + plt.yticks([]); + ax.set_title(f'Edge Err. {edge_err.mean():.5f}') + + # plot normalized IR input and warped pattern + ax = plt.subplot(3, 3, 7); + plt.imshow(im, vmin=im.min(), vmax=im.max(), cmap='gray'); + plt.xticks([]); + plt.yticks([]); + ax.set_title(f'IR input {im.mean():.5f}/{im.std():.5f}') + ax = plt.subplot(3, 3, 8); + plt.imshow(pattern_proj, vmin=im.min(), vmax=im.max(), cmap='gray'); + plt.xticks([]); + plt.yticks([]); + ax.set_title(f'Warped Pattern {pattern_proj.mean():.5f}/{pattern_proj.std():.5f}') + im_std = self.data['std0'].to('cpu').numpy()[0, 0] + ax = plt.subplot(3, 3, 9); + plt.imshow(im_std, cmap='gray'); + plt.xticks([]); + plt.yticks([]); + ax.set_title(f'IR std {im_std.min():.5f}/{im_std.max():.5f}') + + plt.tight_layout() + plt.savefig(str(out_path)) + plt.close(fig) + + def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks=[]): + if batch_idx % 512 == 0: + out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png' + es, gt, im, ma = self.numpy_in_out(output) + self.write_img(out_path, es[0, 0], gt[0, 0], im[0, 0], ma[0, 0]) + + def callback_test_start(self, epoch, set_idx): + self.metric = co.metric.MultipleMetric( + co.metric.DistanceMetric(vec_length=1), + co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5]) + ) + + def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks=[]): + es, gt, im, ma = self.numpy_in_out(output) + + if batch_idx % 8 == 0: + out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png' + self.write_img(out_path, es[0, 0], gt[0, 0], im[0, 0], ma[0, 0]) + + es, gt, im, ma = self.crop_output(es, gt, im, ma) + + es = es.reshape(-1, 1) + gt = gt.reshape(-1, 1) + ma = ma.ravel() + self.metric.add(es, gt, ma) + + def callback_test_stop(self, epoch, set_idx, loss): + logging.info(f'{self.metric}') + for k, v in self.metric.items(): + self.metric_add_test(epoch, set_idx, k, v) + + def crop_output(self, es, gt, im, ma): + bs = es.shape[0] + es = np.reshape(es[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w]) + gt = np.reshape(gt[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w]) + im = np.reshape(im[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w]) + ma = np.reshape(ma[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w]) + return es, gt, im, ma + + +if __name__ == '__main__': + # FIXME Nicolas fixe idee + # SGBM nutzen, um GT zu finden + # bei dispnet (oder w/e) letzte paar layer 'dublizieren' (zweiten head bauen) und so mehrere Loss funktionen gleichzeitig trainieren + # L1 + L2 und dann im selben Backwardspass optimieren + # für das ganze forward pass anpassen + pass diff --git a/model/exp_synphge.py b/model/exp_synphge.py index eec320a..23da8e5 100644 --- a/model/exp_synphge.py +++ b/model/exp_synphge.py @@ -14,7 +14,7 @@ from data import dataset class Worker(torchext.Worker): - def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs): + def __init__(self, args, num_workers=18, train_batch_size=6, test_batch_size=6, save_frequency=1, **kwargs): super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, train_batch_size=train_batch_size, test_batch_size=test_batch_size, save_frequency=save_frequency, **kwargs) @@ -28,7 +28,7 @@ class Worker(torchext.Worker): self.data_type = args.data_type assert (self.track_length > 1) - self.imsizes = [(480, 640)] + self.imsizes = [(488, 648)] for iter in range(3): self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2))) diff --git a/model/networks.py b/model/networks.py index d6e06d9..feb4c69 100644 --- a/model/networks.py +++ b/model/networks.py @@ -125,11 +125,12 @@ class DispNetS(TimedModule): ''' def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False, - channel_multiplier=1): + channel_multiplier=1, double_head=True): super(DispNetS, self).__init__(mod_name='DispNetS') self.output_ms = output_ms self.coordconv = coordconv + self.double_head = double_head conv_planes = channel_multiplier * np.array([32, 64, 128, 256, 512, 512, 512]) self.conv1 = self.downsample_conv(channels_in, conv_planes[0], kernel_size=7) @@ -149,9 +150,10 @@ class DispNetS(TimedModule): self.upconv2 = self.upconv(upconv_planes[4], upconv_planes[5]) self.upconv1 = self.upconv(upconv_planes[5], upconv_planes[6]) - self.iconv7 = self.conv(upconv_planes[0] + conv_planes[5], upconv_planes[0]) - self.iconv6 = self.conv(upconv_planes[1] + conv_planes[4], upconv_planes[1]) - self.iconv5 = self.conv(upconv_planes[2] + conv_planes[3], upconv_planes[2]) + # TODO try this!!! + self.iconv7 = self.norm_conv(upconv_planes[0] + conv_planes[5], upconv_planes[0]) + self.iconv6 = self.norm_conv(upconv_planes[1] + conv_planes[4], upconv_planes[1]) + self.iconv5 = self.norm_conv(upconv_planes[2] + conv_planes[3], upconv_planes[2]) self.iconv4 = self.conv(upconv_planes[3] + conv_planes[2], upconv_planes[3]) self.iconv3 = self.conv(1 + upconv_planes[4] + conv_planes[1], upconv_planes[4]) self.iconv2 = self.conv(1 + upconv_planes[5] + conv_planes[0], upconv_planes[5]) @@ -160,13 +162,25 @@ class DispNetS(TimedModule): if isinstance(output_facs, list): self.predict_disp4 = output_facs[3](upconv_planes[3], imsizes[3]) self.predict_disp3 = output_facs[2](upconv_planes[4], imsizes[2]) - self.predict_disp2 = output_facs[1](upconv_planes[5], imsizes[1]) - self.predict_disp1 = output_facs[0](upconv_planes[6], imsizes[0]) + if double_head: + self.predict_disp2 = output_facs[1](upconv_planes[5], imsizes[1]) + self.predict_disp1 = output_facs[0](upconv_planes[6], imsizes[0]) + self.predict_disp2_double = output_facs[1](upconv_planes[5], imsizes[1]) + self.predict_disp1_double = output_facs[0](upconv_planes[6], imsizes[0]) + else: + self.predict_disp2 = output_facs[1](upconv_planes[5], imsizes[1]) + self.predict_disp1 = output_facs[0](upconv_planes[6], imsizes[0]) else: self.predict_disp4 = output_facs(upconv_planes[3], imsizes[3]) self.predict_disp3 = output_facs(upconv_planes[4], imsizes[2]) - self.predict_disp2 = output_facs(upconv_planes[5], imsizes[1]) - self.predict_disp1 = output_facs(upconv_planes[6], imsizes[0]) + if double_head: + self.predict_disp2 = output_facs(upconv_planes[5], imsizes[1]) + self.predict_disp1 = output_facs(upconv_planes[6], imsizes[0]) + self.predict_disp2_double = output_facs(upconv_planes[5], imsizes[1]) + self.predict_disp1_double = output_facs(upconv_planes[6], imsizes[0]) + else: + self.predict_disp2 = output_facs(upconv_planes[5], imsizes[1]) + self.predict_disp1 = output_facs(upconv_planes[6], imsizes[0]) def init_weights(self): for m in self.modules(): @@ -190,6 +204,12 @@ class DispNetS(TimedModule): ) def conv(self, in_planes, out_planes): + return torch.nn.Sequential( + torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1), + torch.nn.ReLU(inplace=True) + ) + + def norm_conv(self, in_planes, out_planes): return torch.nn.Sequential( torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1), # TODO try this @@ -254,9 +274,28 @@ class DispNetS(TimedModule): out_iconv1 = self.iconv1(concat1) disp1 = self.predict_disp1(out_iconv1) + if self.double_head: + out_upconv2_d = self.crop_like(self.upconv2(out_iconv3), out_conv1) + disp3_up_d = self.crop_like( + torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1) + concat2_d = torch.cat((out_upconv2_d, out_conv1, disp3_up_d), 1) + out_iconv2_d = self.iconv2(concat2_d) + disp2_d = self.predict_disp2_double(out_iconv2_d) + + out_upconv1_d = self.crop_like(self.upconv1(out_iconv2), x) + disp2_up_d = self.crop_like( + torch.nn.functional.interpolate(disp2_d, scale_factor=2, mode='bilinear', align_corners=False), x) + concat1_d = torch.cat((out_upconv1_d, disp2_up_d), 1) + out_iconv1_d = self.iconv1(concat1_d) + disp1_d = self.predict_disp1_double(out_iconv1_d) + if self.output_ms: + if self.double_head: + return (disp1, disp1_d), (disp2, disp2_d), disp3, disp4 return disp1, disp2, disp3, disp4 else: + if self.double_head: + return disp1, disp1_d return disp1 diff --git a/requirements.txt b/requirements.txt index d295769..0d00441 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,5 +3,5 @@ numpy matplotlib pandas scipy -opencv +opencv-python xmltodict diff --git a/torchext/functions.py b/torchext/functions.py index 9ae1d4c..c54a983 100644 --- a/torchext/functions.py +++ b/torchext/functions.py @@ -1,5 +1,6 @@ import torch + def photometric_loss_pytorch(es, ta, block_size, type='mse', eps=0.1): type = type.lower() p = block_size // 2