You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
connecting_the_dots/model/exp_synph_real.py

369 lines
15 KiB

import torch
import numpy as np
import time
from pathlib import Path
import logging
import sys
import itertools
import json
import matplotlib.pyplot as plt
import cv2
import torchvision.transforms as transforms
import co
import torchext
from model import networks
from data import dataset
class Worker(torchext.Worker):
def __init__(self, args, num_workers=18, train_batch_size=4, test_batch_size=4, save_frequency=1, **kwargs):
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers,
train_batch_size=train_batch_size, test_batch_size=test_batch_size,
save_frequency=save_frequency, **kwargs)
self.ms = args.ms
self.pattern_path = args.pattern_path
self.lcn_radius = args.lcn_radius
self.dp_weight = args.dp_weight
self.data_type = args.data_type
self.imsizes = [(488, 648)]
for iter in range(3):
self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2)))
with open('config.json') as fp:
config = json.load(fp)
data_root = Path(config['DATA_ROOT'])
self.settings_path = data_root / self.data_type / 'settings.pkl'
sample_paths = sorted((data_root / self.data_type).glob('0*/'))
self.train_paths = sample_paths[2 ** 10:]
self.test_paths = sample_paths[:2 ** 8]
# supervise the edge encoder with only 2**8 samples
self.train_edge = len(self.train_paths) - 2 ** 8
self.lcn_in = networks.LCN(self.lcn_radius, 0.05)
self.disparity_loss = networks.DisparityLoss()
# self.sup_disp_loss = torch.nn.CrossEntropyLoss()
self.sup_disp_loss = torch.nn.MSELoss()
self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device))
# evaluate in the region where opencv Block Matching has valid values
self.eval_mask = np.zeros(self.imsizes[0])
self.eval_mask[13:self.imsizes[0][0] - 13, 140:self.imsizes[0][1] - 13] = 1
self.eval_mask = self.eval_mask.astype(np.bool)
self.eval_h = self.imsizes[0][0] - 2 * 13
self.eval_w = self.imsizes[0][1] - 13 - 140
def get_train_set(self):
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True,
track_length=1)
return train_set
def get_test_sets(self):
test_sets = torchext.TestSets()
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True,
track_length=1)
test_sets.append('simple', test_set, test_frequency=1)
# initialize photometric loss modules according to image sizes
self.losses = []
for imsize, pat in zip(test_set.imsizes, test_set.patterns):
pat = pat.mean(axis=2)
pat = torch.from_numpy(pat[None][None].astype(np.float32))
pat = pat.to(self.train_device)
self.lcn_in = self.lcn_in.to(self.train_device)
pat, _ = self.lcn_in(pat)
pat = torch.cat([pat for idx in range(3)], dim=1)
self.losses.append(networks.RectifiedPatternSimilarityLoss(imsize[0], imsize[1], pattern=pat))
return test_sets
def copy_data(self, data, device, requires_grad, train):
self.lcn_in = self.lcn_in.to(device)
self.data = {}
for key, val in data.items():
grad = 'im' in key and requires_grad
self.data[key] = val.to(device).requires_grad_(requires_grad=grad)
# apply lcn to IR input
# concatenate the normalized IR input and the original IR image
if 'im' in key and 'blend' not in key:
im = self.data[key]
im_lcn, im_std = self.lcn_in(im)
im_cat = torch.cat((im_lcn, im), dim=1)
key_std = key.replace('im', 'std')
self.data[key] = im_cat
self.data[key_std] = im_std.to(device).detach()
def net_forward(self, net, train):
# FIXME hier schnibbeln?
out = net(self.data['im0'])
return out
@staticmethod
def find_corr_points_and_F(left, right):
sift = cv2.SIFT_create()
# find the keypoints and descriptors with SIFT
kp1, des1 = sift.detectAndCompute(cv2.normalize(left, None, 0, 255, cv2.NORM_MINMAX).astype('uint8'), None)
kp2, des2 = sift.detectAndCompute(cv2.normalize(right, None, 0, 255, cv2.NORM_MINMAX).astype('uint8'), None)
# FLANN parameters
FLANN_INDEX_KDTREE = 1
index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
search_params = dict(checks=50)
flann = cv2.FlannBasedMatcher(index_params, search_params)
matches = flann.knnMatch(des1, des2, k=2)
pts1 = []
pts2 = []
# ratio test as per Lowe's paper
for i, (m, n) in enumerate(matches):
if m.distance < 0.8 * n.distance:
pts2.append(kp2[m.trainIdx].pt)
pts1.append(kp1[m.queryIdx].pt)
pts1 = np.int32(pts1)
pts2 = np.int32(pts2)
F, mask = cv2.findFundamentalMat(pts1, pts2, cv2.FM_LMEDS)
# We select only inlier points
pts1 = pts1[mask.ravel() == 1]
pts2 = pts2[mask.ravel() == 1]
return pts1, pts2, F
def calc_sgbm_gt(self):
sgbm_matcher = cv2.StereoSGBM_create()
disp_gt = []
# cam_view = np.array(np.array_split(self.data['im0'].detach().to('cpu').numpy(), 4)[2:])
# for i in range(self.data['im0'].shape[0]):
for i in range(1):
cam_view = self.data['im0'].detach().to('cpu').numpy()[i, 0]
pattern = self.pattern_proj.to('cpu').numpy()[i, 0]
pts_l, pts_r, F = self.find_corr_points_and_F(cam_view, pattern)
H_l, _ = cv2.findHomography(pts_l, pts_r)
H_r, _ = cv2.findHomography(pts_r, pts_l)
left_rect = cv2.warpPerspective(cam_view, H_l, cam_view.shape)
right_rect = cv2.warpPerspective(pattern, H_r, pattern.shape)
transform = transforms.ToTensor()
disparity_gt = transform(cv2.normalize(
sgbm_matcher.compute(cv2.normalize(left_rect, None, 0, 255, cv2.NORM_MINMAX).astype('uint8'),
cv2.normalize(right_rect, None, 0, 255, cv2.NORM_MINMAX).astype('uint8')), None,
alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F).T)
disp_gt.append(disparity_gt)
return disp_gt
def loss_forward(self, out, train):
out, edge = out
if not (isinstance(out, tuple) or isinstance(out, list)):
out = [out]
if not (isinstance(edge, tuple) or isinstance(edge, list)):
edge = [edge]
vals = []
# apply photometric loss
for s, l, o in zip(itertools.count(), self.losses, out):
val, pattern_proj = l(o[0], self.data[f'im{s}'][:, 0:1, ...], self.data[f'std{s}'])
if s == 0:
self.pattern_proj = pattern_proj.detach()
vals.append(val)
# apply disparity loss
# 1-edge as ground truth edge if inversed
if isinstance(edge, tuple):
edge0 = 1 - torch.sigmoid(edge[0][0])
else:
edge0 = 1 - torch.sigmoid(edge[0])
val = 0
if isinstance(out[0], tuple):
# val = self.disparity_loss(out[0][1], edge0)
# FIXME disparity_loss ist unsupervised, wir wollen supervised(?)
# warum nicht einfach so die GT die wir eh schon haben?
# gt = self.data[f'disp0'].type('torch.LongTensor')
val += self.sup_disp_loss(out[0][1], self.data['disp0'])
# disp_gt = self.calc_sgbm_gt()
# if len(disp_gt) > 1:
# disparity_gt = torch.stack(disp_gt).to('cuda')
# # val += self.sup_disp_loss(out[0][1], disparity_gt)
# else:
# disparity_gt = disp_gt[0].to('cuda')
# val += self.sup_disp_loss(out[0][1][0], disparity_gt)
# print(disparity_gt)
# print(disparity_gt.shape)
# print(out[0][1])
# print(out[0][1].shape)
if isinstance(out[0], tuple):
val += self.disparity_loss(out[0][0], edge0)
else:
val += self.disparity_loss(out[0], edge0)
if self.dp_weight > 0:
vals.append(val * self.dp_weight)
# apply edge loss on a subset of training samples
for s, e in zip(itertools.count(), edge):
# inversed ground truth edge where 0 means edge
grad = self.data[f'grad{s}'] < 0.2
grad = grad.to(torch.float32)
ids = self.data['id']
mask = ids > self.train_edge
if mask.sum() > 0:
if isinstance(e, tuple):
val = self.edge_loss(e[0][mask], grad[mask])
else:
val = self.edge_loss(e[mask], grad[mask])
else:
val = torch.zeros_like(vals[0])
if s == 0:
if isinstance(e, tuple):
self.edge = e[0].detach()
else:
self.edge = e.detach()
self.edge = torch.sigmoid(self.edge)
self.edge_gt = grad.detach()
vals.append(val)
return vals
def numpy_in_out(self, output):
output, edge = output
if not (isinstance(output, tuple) or isinstance(output, list)):
output = [output]
es = output[0][0].detach().to('cpu').numpy()
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
im = self.data['im0'][:, 0:1, ...].detach().to('cpu').numpy()
ma = gt > 0
return es, gt, im, ma
def write_img(self, out_path, es, gt, im, ma):
logging.info(f'write img {out_path}')
u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0]))
diff = np.abs(es - gt)
vmin, vmax = np.nanmin(gt), np.nanmax(gt)
vmin = vmin - 0.2 * (vmax - vmin)
vmax = vmax + 0.2 * (vmax - vmin)
pattern_proj = self.pattern_proj.to('cpu').numpy()[0, 0]
im_orig = self.data['im0'].detach().to('cpu').numpy()[0, 0]
pattern_diff = np.abs(im_orig - pattern_proj)
fig = plt.figure(figsize=(16, 16))
es_ = co.cmap.color_depth_map(es, scale=vmax)
gt_ = co.cmap.color_depth_map(gt, scale=vmax)
diff_ = co.cmap.color_error_image(diff, BGR=True)
# plot disparities, ground truth disparity is shown only for reference
ax = plt.subplot(3, 3, 1)
plt.imshow(es_[..., [2, 1, 0]])
plt.xticks([])
plt.yticks([])
ax.set_title(f'Disparity Est. {es.min():.4f}/{es.max():.4f}')
ax = plt.subplot(3, 3, 2)
plt.imshow(gt_[..., [2, 1, 0]])
plt.xticks([])
plt.yticks([])
ax.set_title(f'Disparity GT {np.nanmin(gt):.4f}/{np.nanmax(gt):.4f}')
ax = plt.subplot(3, 3, 3)
plt.imshow(diff_[..., [2, 1, 0]])
plt.xticks([])
plt.yticks([])
ax.set_title(f'Disparity Err. {diff.mean():.5f}')
# plot edges
edge = self.edge.to('cpu').numpy()[0, 0]
edge_gt = self.edge_gt.to('cpu').numpy()[0, 0]
edge_err = np.abs(edge - edge_gt)
ax = plt.subplot(3, 3, 4);
plt.imshow(edge, cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Edge Est. {edge.min():.5f}/{edge.max():.5f}')
ax = plt.subplot(3, 3, 5);
plt.imshow(edge_gt, cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Edge GT {edge_gt.min():.5f}/{edge_gt.max():.5f}')
ax = plt.subplot(3, 3, 6);
plt.imshow(edge_err, cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Edge Err. {edge_err.mean():.5f}')
# plot normalized IR input and warped pattern
ax = plt.subplot(3, 3, 7);
plt.imshow(im, vmin=im.min(), vmax=im.max(), cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'IR input {im.mean():.5f}/{im.std():.5f}')
ax = plt.subplot(3, 3, 8);
plt.imshow(pattern_proj, vmin=im.min(), vmax=im.max(), cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Warped Pattern {pattern_proj.mean():.5f}/{pattern_proj.std():.5f}')
im_std = self.data['std0'].to('cpu').numpy()[0, 0]
ax = plt.subplot(3, 3, 9);
plt.imshow(im_std, cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'IR std {im_std.min():.5f}/{im_std.max():.5f}')
plt.tight_layout()
plt.savefig(str(out_path))
plt.close(fig)
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks=[]):
if batch_idx % 512 == 0:
out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
es, gt, im, ma = self.numpy_in_out(output)
self.write_img(out_path, es[0, 0], gt[0, 0], im[0, 0], ma[0, 0])
def callback_test_start(self, epoch, set_idx):
self.metric = co.metric.MultipleMetric(
co.metric.DistanceMetric(vec_length=1),
co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5])
)
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks=[]):
es, gt, im, ma = self.numpy_in_out(output)
if batch_idx % 8 == 0:
out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png'
self.write_img(out_path, es[0, 0], gt[0, 0], im[0, 0], ma[0, 0])
es, gt, im, ma = self.crop_output(es, gt, im, ma)
es = es.reshape(-1, 1)
gt = gt.reshape(-1, 1)
ma = ma.ravel()
self.metric.add(es, gt, ma)
def callback_test_stop(self, epoch, set_idx, loss):
logging.info(f'{self.metric}')
for k, v in self.metric.items():
self.metric_add_test(epoch, set_idx, k, v)
def crop_output(self, es, gt, im, ma):
bs = es.shape[0]
es = np.reshape(es[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
gt = np.reshape(gt[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
im = np.reshape(im[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
ma = np.reshape(ma[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
return es, gt, im, ma
if __name__ == '__main__':
# FIXME Nicolas fixe idee
# SGBM nutzen, um GT zu finden
# bei dispnet (oder w/e) letzte paar layer 'dublizieren' (zweiten head bauen) und so mehrere Loss funktionen gleichzeitig trainieren
# L1 + L2 und dann im selben Backwardspass optimieren
# für das ganze forward pass anpassen
pass