init
This commit is contained in:
parent
26157cbb80
commit
f5e5c4bd3f
@ -1,2 +0,0 @@
|
||||
# connecting_the_dots
|
||||
This repository contains the code for the paper "Connecting the Dots: Learning Representations for Active Monocular Depth Estimation" https://avg.is.tuebingen.mpg.de/publications/riegler2019cvpr
|
23
co/__init__.py
Normal file
23
co/__init__.py
Normal file
@ -0,0 +1,23 @@
|
||||
# import os
|
||||
# this_dir = os.path.dirname(__file__)
|
||||
# print(this_dir)
|
||||
# import sys
|
||||
# sys.path.append(this_dir)
|
||||
|
||||
# set matplotlib backend depending on env
|
||||
import os
|
||||
import matplotlib
|
||||
if os.name == 'posix' and "DISPLAY" not in os.environ:
|
||||
matplotlib.use('Agg')
|
||||
|
||||
from . import geometry
|
||||
from . import plt
|
||||
from . import plt2d
|
||||
from . import plt3d
|
||||
from . import metric
|
||||
from . import table
|
||||
from . import utils
|
||||
from . import io3d
|
||||
from . import gtimer
|
||||
from . import cmap
|
||||
from . import args
|
71
co/args.py
Normal file
71
co/args.py
Normal file
@ -0,0 +1,71 @@
|
||||
import argparse
|
||||
import os
|
||||
from .utils import str2bool
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
#
|
||||
parser.add_argument('--output_dir',
|
||||
help='Output directory',
|
||||
default='./output', type=str)
|
||||
parser.add_argument('--loss',
|
||||
help='Train with \'ph\' for the first stage without geometric loss, \
|
||||
train with \'phge\' for the second stage with geometric loss',
|
||||
default='ph', choices=['ph','phge'], type=str)
|
||||
parser.add_argument('--data_type',
|
||||
default='syn', choices=['syn'], type=str)
|
||||
#
|
||||
parser.add_argument('--cmd',
|
||||
help='Start training or test',
|
||||
default='resume', choices=['retrain', 'resume', 'retest', 'test_init'], type=str)
|
||||
parser.add_argument('--epoch',
|
||||
help='If larger than -1, retest on the specified epoch',
|
||||
default=-1, type=int)
|
||||
parser.add_argument('--epochs',
|
||||
help='Training epochs',
|
||||
default=100, type=int)
|
||||
|
||||
#
|
||||
parser.add_argument('--ms',
|
||||
help='If true, use multiscale loss',
|
||||
default=True, type=str2bool)
|
||||
parser.add_argument('--pattern_path',
|
||||
help='Path of the pattern image',
|
||||
default='./data/kinect_patttern.png', type=str)
|
||||
#
|
||||
parser.add_argument('--dp_weight',
|
||||
help='Weight of the disparity loss',
|
||||
default=0.02, type=float)
|
||||
parser.add_argument('--ge_weight',
|
||||
help='Weight of the geometric loss',
|
||||
default=0.1, type=float)
|
||||
#
|
||||
parser.add_argument('--lcn_radius',
|
||||
help='Radius of the window for LCN pre-processing',
|
||||
default=5, type=int)
|
||||
parser.add_argument('--max_disp',
|
||||
help='Maximum disparity',
|
||||
default=128, type=int)
|
||||
#
|
||||
parser.add_argument('--track_length',
|
||||
help='Track length for geometric loss',
|
||||
default=2, type=int)
|
||||
#
|
||||
parser.add_argument('--blend_im',
|
||||
help='Parameter for adding texture',
|
||||
default=0.6, type=float)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
args.exp_name = get_exp_name(args)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def get_exp_name(args):
|
||||
name = f"exp_{args.data_type}"
|
||||
return name
|
||||
|
||||
|
||||
|
93
co/cmap.py
Normal file
93
co/cmap.py
Normal file
@ -0,0 +1,93 @@
|
||||
import numpy as np
|
||||
|
||||
_color_map_errors = np.array([
|
||||
[149, 54, 49], #0: log2(x) = -infinity
|
||||
[180, 117, 69], #0.0625: log2(x) = -4
|
||||
[209, 173, 116], #0.125: log2(x) = -3
|
||||
[233, 217, 171], #0.25: log2(x) = -2
|
||||
[248, 243, 224], #0.5: log2(x) = -1
|
||||
[144, 224, 254], #1.0: log2(x) = 0
|
||||
[97, 174, 253], #2.0: log2(x) = 1
|
||||
[67, 109, 244], #4.0: log2(x) = 2
|
||||
[39, 48, 215], #8.0: log2(x) = 3
|
||||
[38, 0, 165], #16.0: log2(x) = 4
|
||||
[38, 0, 165] #inf: log2(x) = inf
|
||||
]).astype(float)
|
||||
|
||||
def color_error_image(errors, scale=1, mask=None, BGR=True):
|
||||
"""
|
||||
Color an input error map.
|
||||
|
||||
Arguments:
|
||||
errors -- HxW numpy array of errors
|
||||
[scale=1] -- scaling the error map (color change at unit error)
|
||||
[mask=None] -- zero-pixels are masked white in the result
|
||||
[BGR=True] -- toggle between BGR and RGB
|
||||
|
||||
Returns:
|
||||
colored_errors -- HxWx3 numpy array visualizing the errors
|
||||
"""
|
||||
|
||||
errors_flat = errors.flatten()
|
||||
errors_color_indices = np.clip(np.log2(errors_flat / scale + 1e-5) + 5, 0, 9)
|
||||
i0 = np.floor(errors_color_indices).astype(int)
|
||||
f1 = errors_color_indices - i0.astype(float)
|
||||
colored_errors_flat = _color_map_errors[i0, :] * (1-f1).reshape(-1,1) + _color_map_errors[i0+1, :] * f1.reshape(-1,1)
|
||||
|
||||
if mask is not None:
|
||||
colored_errors_flat[mask.flatten() == 0] = 255
|
||||
|
||||
if not BGR:
|
||||
colored_errors_flat = colored_errors_flat[:,[2,1,0]]
|
||||
|
||||
return colored_errors_flat.reshape(errors.shape[0], errors.shape[1], 3).astype(np.int)
|
||||
|
||||
_color_map_depths = np.array([
|
||||
[0, 0, 0], # 0.000
|
||||
[0, 0, 255], # 0.114
|
||||
[255, 0, 0], # 0.299
|
||||
[255, 0, 255], # 0.413
|
||||
[0, 255, 0], # 0.587
|
||||
[0, 255, 255], # 0.701
|
||||
[255, 255, 0], # 0.886
|
||||
[255, 255, 255], # 1.000
|
||||
[255, 255, 255], # 1.000
|
||||
]).astype(float)
|
||||
_color_map_bincenters = np.array([
|
||||
0.0,
|
||||
0.114,
|
||||
0.299,
|
||||
0.413,
|
||||
0.587,
|
||||
0.701,
|
||||
0.886,
|
||||
1.000,
|
||||
2.000, # doesn't make a difference, just strictly higher than 1
|
||||
])
|
||||
|
||||
def color_depth_map(depths, scale=None):
|
||||
"""
|
||||
Color an input depth map.
|
||||
|
||||
Arguments:
|
||||
depths -- HxW numpy array of depths
|
||||
[scale=None] -- scaling the values (defaults to the maximum depth)
|
||||
|
||||
Returns:
|
||||
colored_depths -- HxWx3 numpy array visualizing the depths
|
||||
"""
|
||||
|
||||
if scale is None:
|
||||
scale = depths.max()
|
||||
|
||||
values = np.clip(depths.flatten() / scale, 0, 1)
|
||||
# for each value, figure out where they fit in in the bincenters: what is the last bincenter smaller than this value?
|
||||
lower_bin = ((values.reshape(-1, 1) >= _color_map_bincenters.reshape(1,-1)) * np.arange(0,9)).max(axis=1)
|
||||
lower_bin_value = _color_map_bincenters[lower_bin]
|
||||
higher_bin_value = _color_map_bincenters[lower_bin + 1]
|
||||
alphas = (values - lower_bin_value) / (higher_bin_value - lower_bin_value)
|
||||
colors = _color_map_depths[lower_bin] * (1-alphas).reshape(-1,1) + _color_map_depths[lower_bin + 1] * alphas.reshape(-1,1)
|
||||
return colors.reshape(depths.shape[0], depths.shape[1], 3).astype(np.uint8)
|
||||
|
||||
#from utils.debug import save_color_numpy
|
||||
#save_color_numpy(color_depth_map(np.matmul(np.ones((100,1)), np.arange(0,1200).reshape(1,1200)), scale=1000))
|
800
co/geometry.py
Normal file
800
co/geometry.py
Normal file
@ -0,0 +1,800 @@
|
||||
import numpy as np
|
||||
import scipy.spatial
|
||||
import scipy.linalg
|
||||
|
||||
def nullspace(A, atol=1e-13, rtol=0):
|
||||
u, s, vh = np.linalg.svd(A)
|
||||
tol = max(atol, rtol * s[0])
|
||||
nnz = (s >= tol).sum()
|
||||
ns = vh[nnz:].conj().T
|
||||
return ns
|
||||
|
||||
def nearest_orthogonal_matrix(R):
|
||||
U,S,Vt = np.linalg.svd(R)
|
||||
return U @ np.eye(3,dtype=R.dtype) @ Vt
|
||||
|
||||
def power_iters(A, n_iters=10):
|
||||
b = np.random.uniform(-1,1, size=(A.shape[0], A.shape[1], 1))
|
||||
for iter in range(n_iters):
|
||||
b = A @ b
|
||||
b = b / np.linalg.norm(b, axis=1, keepdims=True)
|
||||
return b
|
||||
|
||||
def rayleigh_quotient(A, b):
|
||||
return (b.transpose(0,2,1) @ A @ b) / (b.transpose(0,2,1) @ b)
|
||||
|
||||
|
||||
def cross_prod_mat(x):
|
||||
x = x.reshape(-1,3)
|
||||
X = np.empty((x.shape[0],3,3), dtype=x.dtype)
|
||||
X[:,0,0] = 0
|
||||
X[:,0,1] = -x[:,2]
|
||||
X[:,0,2] = x[:,1]
|
||||
X[:,1,0] = x[:,2]
|
||||
X[:,1,1] = 0
|
||||
X[:,1,2] = -x[:,0]
|
||||
X[:,2,0] = -x[:,1]
|
||||
X[:,2,1] = x[:,0]
|
||||
X[:,2,2] = 0
|
||||
return X.squeeze()
|
||||
|
||||
def hat_operator(x):
|
||||
return cross_prod_mat(x)
|
||||
|
||||
def vee_operator(X):
|
||||
X = X.reshape(-1,3,3)
|
||||
x = np.empty((X.shape[0], 3), dtype=X.dtype)
|
||||
x[:,0] = X[:,2,1]
|
||||
x[:,1] = X[:,0,2]
|
||||
x[:,2] = X[:,1,0]
|
||||
return x.squeeze()
|
||||
|
||||
|
||||
def rot_x(x, dtype=np.float32):
|
||||
x = np.array(x, copy=False)
|
||||
x = x.reshape(-1,1)
|
||||
R = np.zeros((x.shape[0],3,3), dtype=dtype)
|
||||
R[:,0,0] = 1
|
||||
R[:,1,1] = np.cos(x).ravel()
|
||||
R[:,1,2] = -np.sin(x).ravel()
|
||||
R[:,2,1] = np.sin(x).ravel()
|
||||
R[:,2,2] = np.cos(x).ravel()
|
||||
return R.squeeze()
|
||||
|
||||
def rot_y(y, dtype=np.float32):
|
||||
y = np.array(y, copy=False)
|
||||
y = y.reshape(-1,1)
|
||||
R = np.zeros((y.shape[0],3,3), dtype=dtype)
|
||||
R[:,0,0] = np.cos(y).ravel()
|
||||
R[:,0,2] = np.sin(y).ravel()
|
||||
R[:,1,1] = 1
|
||||
R[:,2,0] = -np.sin(y).ravel()
|
||||
R[:,2,2] = np.cos(y).ravel()
|
||||
return R.squeeze()
|
||||
|
||||
def rot_z(z, dtype=np.float32):
|
||||
z = np.array(z, copy=False)
|
||||
z = z.reshape(-1,1)
|
||||
R = np.zeros((z.shape[0],3,3), dtype=dtype)
|
||||
R[:,0,0] = np.cos(z).ravel()
|
||||
R[:,0,1] = -np.sin(z).ravel()
|
||||
R[:,1,0] = np.sin(z).ravel()
|
||||
R[:,1,1] = np.cos(z).ravel()
|
||||
R[:,2,2] = 1
|
||||
return R.squeeze()
|
||||
|
||||
def xyz_from_rotm(R):
|
||||
R = R.reshape(-1,3,3)
|
||||
xyz = np.empty((R.shape[0],3), dtype=R.dtype)
|
||||
for bidx in range(R.shape[0]):
|
||||
if R[bidx,0,2] < 1:
|
||||
if R[bidx,0,2] > -1:
|
||||
xyz[bidx,1] = np.arcsin(R[bidx,0,2])
|
||||
xyz[bidx,0] = np.arctan2(-R[bidx,1,2], R[bidx,2,2])
|
||||
xyz[bidx,2] = np.arctan2(-R[bidx,0,1], R[bidx,0,0])
|
||||
else:
|
||||
xyz[bidx,1] = -np.pi/2
|
||||
xyz[bidx,0] = -np.arctan2(R[bidx,1,0],R[bidx,1,1])
|
||||
xyz[bidx,2] = 0
|
||||
else:
|
||||
xyz[bidx,1] = np.pi/2
|
||||
xyz[bidx,0] = np.arctan2(R[bidx,1,0], R[bidx,1,1])
|
||||
xyz[bidx,2] = 0
|
||||
return xyz.squeeze()
|
||||
|
||||
def zyx_from_rotm(R):
|
||||
R = R.reshape(-1,3,3)
|
||||
zyx = np.empty((R.shape[0],3), dtype=R.dtype)
|
||||
for bidx in range(R.shape[0]):
|
||||
if R[bidx,2,0] < 1:
|
||||
if R[bidx,2,0] > -1:
|
||||
zyx[bidx,1] = np.arcsin(-R[bidx,2,0])
|
||||
zyx[bidx,0] = np.arctan2(R[bidx,1,0], R[bidx,0,0])
|
||||
zyx[bidx,2] = np.arctan2(R[bidx,2,1], R[bidx,2,2])
|
||||
else:
|
||||
zyx[bidx,1] = np.pi / 2
|
||||
zyx[bidx,0] = -np.arctan2(-R[bidx,1,2], R[bidx,1,1])
|
||||
zyx[bidx,2] = 0
|
||||
else:
|
||||
zyx[bidx,1] = -np.pi / 2
|
||||
zyx[bidx,0] = np.arctan2(-R[bidx,1,2], R[bidx,1,1])
|
||||
zyx[bidx,2] = 0
|
||||
return zyx.squeeze()
|
||||
|
||||
def rotm_from_xyz(xyz):
|
||||
xyz = np.array(xyz, copy=False).reshape(-1,3)
|
||||
return (rot_x(xyz[:,0]) @ rot_y(xyz[:,1]) @ rot_z(xyz[:,2])).squeeze()
|
||||
|
||||
def rotm_from_zyx(zyx):
|
||||
zyx = np.array(zyx, copy=False).reshape(-1,3)
|
||||
return (rot_z(zyx[:,0]) @ rot_y(zyx[:,1]) @ rot_x(zyx[:,2])).squeeze()
|
||||
|
||||
def rotm_from_quat(q):
|
||||
q = q.reshape(-1,4)
|
||||
w, x, y, z = q[:,0], q[:,1], q[:,2], q[:,3]
|
||||
R = np.array([
|
||||
[1 - 2*y*y - 2*z*z, 2*x*y - 2*z*w, 2*x*z + 2*y*w],
|
||||
[2*x*y + 2*z*w, 1 - 2*x*x - 2*z*z, 2*y*z - 2*x*w],
|
||||
[2*x*z - 2*y*w, 2*y*z + 2*x*w, 1 - 2*x*x - 2*y*y]
|
||||
], dtype=q.dtype)
|
||||
R = R.transpose((2,0,1))
|
||||
return R.squeeze()
|
||||
|
||||
def rotm_from_axisangle(a):
|
||||
# exponential
|
||||
a = a.reshape(-1,3)
|
||||
phi = np.linalg.norm(a, axis=1).reshape(-1,1,1)
|
||||
iphi = np.zeros_like(phi)
|
||||
np.divide(1, phi, out=iphi, where=phi != 0)
|
||||
A = cross_prod_mat(a) * iphi
|
||||
R = np.eye(3, dtype=a.dtype) + np.sin(phi) * A + (1 - np.cos(phi)) * A @ A
|
||||
return R.squeeze()
|
||||
|
||||
def rotm_from_lookat(dir, up=None):
|
||||
dir = dir.reshape(-1,3)
|
||||
if up is None:
|
||||
up = np.zeros_like(dir)
|
||||
up[:,1] = 1
|
||||
dir /= np.linalg.norm(dir, axis=1, keepdims=True)
|
||||
up /= np.linalg.norm(up, axis=1, keepdims=True)
|
||||
x = dir[:,None,:] @ cross_prod_mat(up).transpose(0,2,1)
|
||||
y = x @ cross_prod_mat(dir).transpose(0,2,1)
|
||||
x = x.squeeze()
|
||||
y = y.squeeze()
|
||||
x /= np.linalg.norm(x, axis=1, keepdims=True)
|
||||
y /= np.linalg.norm(y, axis=1, keepdims=True)
|
||||
R = np.empty((dir.shape[0],3,3), dtype=dir.dtype)
|
||||
R[:,0,0] = x[:,0]
|
||||
R[:,0,1] = y[:,0]
|
||||
R[:,0,2] = dir[:,0]
|
||||
R[:,1,0] = x[:,1]
|
||||
R[:,1,1] = y[:,1]
|
||||
R[:,1,2] = dir[:,1]
|
||||
R[:,2,0] = x[:,2]
|
||||
R[:,2,1] = y[:,2]
|
||||
R[:,2,2] = dir[:,2]
|
||||
return R.transpose(0,2,1).squeeze()
|
||||
|
||||
def rotm_distance_identity(R0, R1):
|
||||
# https://link.springer.com/article/10.1007%2Fs10851-009-0161-2
|
||||
# in [0, 2*sqrt(2)]
|
||||
R0 = R0.reshape(-1,3,3)
|
||||
R1 = R1.reshape(-1,3,3)
|
||||
dists = np.linalg.norm(np.eye(3,dtype=R0.dtype) - R0 @ R1.transpose(0,2,1), axis=(1,2))
|
||||
return dists.squeeze()
|
||||
|
||||
def rotm_distance_geodesic(R0, R1):
|
||||
# https://link.springer.com/article/10.1007%2Fs10851-009-0161-2
|
||||
# in [0, pi)
|
||||
R0 = R0.reshape(-1,3,3)
|
||||
R1 = R1.reshape(-1,3,3)
|
||||
RtR = R0 @ R1.transpose(0,2,1)
|
||||
aa = axisangle_from_rotm(RtR)
|
||||
S = cross_prod_mat(aa).reshape(-1,3,3)
|
||||
dists = np.linalg.norm(S, axis=(1,2))
|
||||
return dists.squeeze()
|
||||
|
||||
|
||||
|
||||
def axisangle_from_rotm(R):
|
||||
# logarithm of rotation matrix
|
||||
# R = R.reshape(-1,3,3)
|
||||
# tr = np.trace(R, axis1=1, axis2=2)
|
||||
# phi = np.arccos(np.clip((tr - 1) / 2, -1, 1))
|
||||
# scale = np.zeros_like(phi)
|
||||
# div = 2 * np.sin(phi)
|
||||
# np.divide(phi, div, out=scale, where=np.abs(div) > 1e-6)
|
||||
# A = (R - R.transpose(0,2,1)) * scale.reshape(-1,1,1)
|
||||
# aa = np.stack((A[:,2,1], A[:,0,2], A[:,1,0]), axis=1)
|
||||
# return aa.squeeze()
|
||||
R = R.reshape(-1,3,3)
|
||||
omega = np.empty((R.shape[0], 3), dtype=R.dtype)
|
||||
omega[:,0] = R[:,2,1] - R[:,1,2]
|
||||
omega[:,1] = R[:,0,2] - R[:,2,0]
|
||||
omega[:,2] = R[:,1,0] - R[:,0,1]
|
||||
r = np.linalg.norm(omega, axis=1).reshape(-1,1)
|
||||
t = np.trace(R, axis1=1, axis2=2).reshape(-1,1)
|
||||
omega = np.arctan2(r, t-1) * omega
|
||||
aa = np.zeros_like(omega)
|
||||
np.divide(omega, r, out=aa, where=r != 0)
|
||||
return aa.squeeze()
|
||||
|
||||
def axisangle_from_quat(q):
|
||||
q = q.reshape(-1,4)
|
||||
phi = 2 * np.arccos(q[:,0])
|
||||
denom = np.zeros_like(q[:,0])
|
||||
np.divide(1, np.sqrt(1 - q[:,0]**2), out=denom, where=q[:,0] != 1)
|
||||
axis = q[:,1:] * denom.reshape(-1,1)
|
||||
denom = np.linalg.norm(axis, axis=1).reshape(-1,1)
|
||||
a = np.zeros_like(axis)
|
||||
np.divide(phi.reshape(-1,1) * axis, denom, out=a, where=denom != 0)
|
||||
aa = a.astype(q.dtype)
|
||||
return aa.squeeze()
|
||||
|
||||
def axisangle_apply(aa, x):
|
||||
# working only with single aa and single x at the moment
|
||||
xshape = x.shape
|
||||
aa = aa.reshape(3,)
|
||||
x = x.reshape(3,)
|
||||
phi = np.linalg.norm(aa)
|
||||
e = np.zeros_like(aa)
|
||||
np.divide(aa, phi, out=e, where=phi != 0)
|
||||
xr = np.cos(phi) * x + np.sin(phi) * np.cross(e, x) + (1 - np.cos(phi)) * (e.T @ x) * e
|
||||
return xr.reshape(xshape)
|
||||
|
||||
|
||||
def exp_so3(R):
|
||||
w = axisangle_from_rotm(R)
|
||||
return w
|
||||
|
||||
def log_so3(w):
|
||||
R = rotm_from_axisangle(w)
|
||||
return R
|
||||
|
||||
def exp_se3(R, t):
|
||||
R = R.reshape(-1,3,3)
|
||||
t = t.reshape(-1,3)
|
||||
|
||||
w = exp_so3(R).reshape(-1,3)
|
||||
|
||||
phi = np.linalg.norm(w, axis=1).reshape(-1,1,1)
|
||||
A = cross_prod_mat(w)
|
||||
Vi = np.eye(3, dtype=R.dtype) - A/2 + (1 - (phi * np.sin(phi) / (2 * (1 - np.cos(phi))))) / phi**2 * A @ A
|
||||
u = t.reshape(-1,1,3) @ Vi.transpose(0,2,1)
|
||||
|
||||
# v = (u, w)
|
||||
v = np.empty((R.shape[0],6), dtype=R.dtype)
|
||||
v[:,:3] = u.squeeze()
|
||||
v[:,3:] = w
|
||||
|
||||
return v.squeeze()
|
||||
|
||||
def log_se3(v):
|
||||
# v = (u, w)
|
||||
v = v.reshape(-1,6)
|
||||
u = v[:,:3]
|
||||
w = v[:,3:]
|
||||
|
||||
R = log_so3(w)
|
||||
|
||||
phi = np.linalg.norm(w, axis=1).reshape(-1,1,1)
|
||||
A = cross_prod_mat(w)
|
||||
V = np.eye(3, dtype=v.dtype) + (1 - np.cos(phi)) / phi**2 * A + (phi - np.sin(phi)) / phi**3 * A @ A
|
||||
t = u.reshape(-1,1,3) @ V.transpose(0,2,1)
|
||||
|
||||
return R.squeeze(), t.squeeze()
|
||||
|
||||
|
||||
def quat_from_rotm(R):
|
||||
R = R.reshape(-1,3,3)
|
||||
q = np.empty((R.shape[0], 4,), dtype=R.dtype)
|
||||
q[:,0] = np.sqrt( np.maximum(0, 1 + R[:,0,0] + R[:,1,1] + R[:,2,2]) )
|
||||
q[:,1] = np.sqrt( np.maximum(0, 1 + R[:,0,0] - R[:,1,1] - R[:,2,2]) )
|
||||
q[:,2] = np.sqrt( np.maximum(0, 1 - R[:,0,0] + R[:,1,1] - R[:,2,2]) )
|
||||
q[:,3] = np.sqrt( np.maximum(0, 1 - R[:,0,0] - R[:,1,1] + R[:,2,2]) )
|
||||
q[:,1] *= np.sign(q[:,1] * (R[:,2,1] - R[:,1,2]))
|
||||
q[:,2] *= np.sign(q[:,2] * (R[:,0,2] - R[:,2,0]))
|
||||
q[:,3] *= np.sign(q[:,3] * (R[:,1,0] - R[:,0,1]))
|
||||
q /= np.linalg.norm(q,axis=1,keepdims=True)
|
||||
return q.squeeze()
|
||||
|
||||
def quat_from_axisangle(a):
|
||||
a = a.reshape(-1, 3)
|
||||
phi = np.linalg.norm(a, axis=1)
|
||||
iphi = np.zeros_like(phi)
|
||||
np.divide(1, phi, out=iphi, where=phi != 0)
|
||||
a = a * iphi.reshape(-1,1)
|
||||
theta = phi / 2.0
|
||||
r = np.cos(theta)
|
||||
stheta = np.sin(theta)
|
||||
q = np.stack((r, stheta*a[:,0], stheta*a[:,1], stheta*a[:,2]), axis=1)
|
||||
q /= np.linalg.norm(q, axis=1).reshape(-1,1)
|
||||
return q.squeeze()
|
||||
|
||||
def quat_identity(n=1, dtype=np.float32):
|
||||
q = np.zeros((n,4), dtype=dtype)
|
||||
q[:,0] = 1
|
||||
return q.squeeze()
|
||||
|
||||
def quat_conjugate(q):
|
||||
shape = q.shape
|
||||
q = q.reshape(-1,4).copy()
|
||||
q[:,1:] *= -1
|
||||
return q.reshape(shape)
|
||||
|
||||
def quat_product(q1, q2):
|
||||
# q1 . q2 is equivalent to R(q1) @ R(q2)
|
||||
shape = q1.shape
|
||||
q1, q2 = q1.reshape(-1,4), q2.reshape(-1, 4)
|
||||
q = np.empty((max(q1.shape[0], q2.shape[0]), 4), dtype=q1.dtype)
|
||||
a1,b1,c1,d1 = q1[:,0], q1[:,1], q1[:,2], q1[:,3]
|
||||
a2,b2,c2,d2 = q2[:,0], q2[:,1], q2[:,2], q2[:,3]
|
||||
q[:,0] = a1 * a2 - b1 * b2 - c1 * c2 - d1 * d2
|
||||
q[:,1] = a1 * b2 + b1 * a2 + c1 * d2 - d1 * c2
|
||||
q[:,2] = a1 * c2 - b1 * d2 + c1 * a2 + d1 * b2
|
||||
q[:,3] = a1 * d2 + b1 * c2 - c1 * b2 + d1 * a2
|
||||
return q.squeeze()
|
||||
|
||||
def quat_apply(q, x):
|
||||
xshape = x.shape
|
||||
x = x.reshape(-1, 3)
|
||||
qshape = q.shape
|
||||
q = q.reshape(-1, 4)
|
||||
|
||||
p = np.empty((x.shape[0], 4), dtype=x.dtype)
|
||||
p[:,0] = 0
|
||||
p[:,1:] = x
|
||||
|
||||
r = quat_product(quat_product(q, p), quat_conjugate(q))
|
||||
if r.ndim == 1:
|
||||
return r[1:].reshape(xshape)
|
||||
else:
|
||||
return r[:,1:].reshape(xshape)
|
||||
|
||||
|
||||
def quat_random(rng=None, n=1):
|
||||
# http://planning.cs.uiuc.edu/node198.html
|
||||
if rng is not None:
|
||||
u = rng.uniform(0, 1, size=(3,n))
|
||||
else:
|
||||
u = np.random.uniform(0, 1, size=(3,n))
|
||||
q = np.array((
|
||||
np.sqrt(1 - u[0]) * np.sin(2 * np.pi * u[1]),
|
||||
np.sqrt(1 - u[0]) * np.cos(2 * np.pi * u[1]),
|
||||
np.sqrt(u[0]) * np.sin(2 * np.pi * u[2]),
|
||||
np.sqrt(u[0]) * np.cos(2 * np.pi * u[2])
|
||||
)).T
|
||||
q /= np.linalg.norm(q,axis=1,keepdims=True)
|
||||
return q.squeeze()
|
||||
|
||||
def quat_distance_angle(q0, q1):
|
||||
# https://math.stackexchange.com/questions/90081/quaternion-distance
|
||||
# https://link.springer.com/article/10.1007%2Fs10851-009-0161-2
|
||||
q0 = q0.reshape(-1,4)
|
||||
q1 = q1.reshape(-1,4)
|
||||
dists = np.arccos(np.clip(2 * np.sum(q0 * q1, axis=1)**2 - 1, -1, 1))
|
||||
return dists
|
||||
|
||||
def quat_distance_normdiff(q0, q1):
|
||||
# https://link.springer.com/article/10.1007%2Fs10851-009-0161-2
|
||||
# \phi_4
|
||||
# [0, 1]
|
||||
q0 = q0.reshape(-1,4)
|
||||
q1 = q1.reshape(-1,4)
|
||||
return 1 - np.sum(q0 * q1, axis=1)**2
|
||||
|
||||
def quat_distance_mineucl(q0, q1):
|
||||
# https://link.springer.com/article/10.1007%2Fs10851-009-0161-2
|
||||
# http://users.cecs.anu.edu.au/~trumpf/pubs/Hartley_Trumpf_Dai_Li.pdf
|
||||
q0 = q0.reshape(-1,4)
|
||||
q1 = q1.reshape(-1,4)
|
||||
diff0 = ((q0 - q1)**2).sum(axis=1)
|
||||
diff1 = ((q0 + q1)**2).sum(axis=1)
|
||||
return np.minimum(diff0, diff1)
|
||||
|
||||
def quat_slerp_space(q0, q1, num=100, endpoint=True):
|
||||
q0 = q0.ravel()
|
||||
q1 = q1.ravel()
|
||||
dot = q0.dot(q1)
|
||||
if dot < 0:
|
||||
q1 *= -1
|
||||
dot *= -1
|
||||
t = np.linspace(0, 1, num=num, endpoint=endpoint, dtype=q0.dtype)
|
||||
t = t.reshape((-1,1))
|
||||
if dot > 0.9995:
|
||||
ret = q0 + t * (q1 - q0)
|
||||
return ret
|
||||
dot = np.clip(dot, -1, 1)
|
||||
theta0 = np.arccos(dot)
|
||||
theta = theta0 * t
|
||||
s0 = np.cos(theta) - dot * np.sin(theta) / np.sin(theta0)
|
||||
s1 = np.sin(theta) / np.sin(theta0)
|
||||
return (s0 * q0) + (s1 * q1)
|
||||
|
||||
def cart_to_spherical(x):
|
||||
shape = x.shape
|
||||
x = x.reshape(-1,3)
|
||||
y = np.empty_like(x)
|
||||
y[:,0] = np.linalg.norm(x, axis=1) # r
|
||||
y[:,1] = np.arccos(x[:,2] / y[:,0]) # theta
|
||||
y[:,2] = np.arctan2(x[:,1], x[:,0]) # phi
|
||||
return y.reshape(shape)
|
||||
|
||||
def spherical_to_cart(x):
|
||||
shape = x.shape
|
||||
x = x.reshape(-1,3)
|
||||
y = np.empty_like(x)
|
||||
y[:,0] = x[:,0] * np.sin(x[:,1]) * np.cos(x[:,2])
|
||||
y[:,1] = x[:,0] * np.sin(x[:,1]) * np.sin(x[:,2])
|
||||
y[:,2] = x[:,0] * np.cos(x[:,1])
|
||||
return y.reshape(shape)
|
||||
|
||||
def spherical_random(r=1, n=1):
|
||||
# http://mathworld.wolfram.com/SpherePointPicking.html
|
||||
# https://math.stackexchange.com/questions/1585975/how-to-generate-random-points-on-a-sphere
|
||||
x = np.empty((n,3))
|
||||
x[:,0] = r
|
||||
x[:,1] = 2 * np.pi * np.random.uniform(0,1, size=(n,))
|
||||
x[:,2] = np.arccos(2 * np.random.uniform(0,1, size=(n,)) - 1)
|
||||
return x.squeeze()
|
||||
|
||||
def color_pcl(pcl, K, im, color_axis=0, as_int=True, invalid_color=[0,0,0]):
|
||||
uvd = K @ pcl.T
|
||||
uvd /= uvd[2]
|
||||
uvd = np.round(uvd).astype(np.int32)
|
||||
mask = np.logical_and(uvd[0] >= 0, uvd[1] >= 0)
|
||||
color = np.empty((pcl.shape[0], 3), dtype=im.dtype)
|
||||
if color_axis == 0:
|
||||
mask = np.logical_and(mask, uvd[0] < im.shape[2])
|
||||
mask = np.logical_and(mask, uvd[1] < im.shape[1])
|
||||
uvd = uvd[:,mask]
|
||||
color[mask,:] = im[:,uvd[1],uvd[0]].T
|
||||
elif color_axis == 2:
|
||||
mask = np.logical_and(mask, uvd[0] < im.shape[1])
|
||||
mask = np.logical_and(mask, uvd[1] < im.shape[0])
|
||||
uvd = uvd[:,mask]
|
||||
color[mask,:] = im[uvd[1],uvd[0], :]
|
||||
else:
|
||||
raise Exception('invalid color_axis')
|
||||
color[np.logical_not(mask),:3] = invalid_color
|
||||
if as_int:
|
||||
color = (255.0 * color).astype(np.int32)
|
||||
return color
|
||||
|
||||
def center_pcl(pcl, robust=False, copy=False, axis=1):
|
||||
if copy:
|
||||
pcl = pcl.copy()
|
||||
if robust:
|
||||
mu = np.median(pcl, axis=axis, keepdims=True)
|
||||
else:
|
||||
mu = np.mean(pcl, axis=axis, keepdims=True)
|
||||
return pcl - mu
|
||||
|
||||
def to_homogeneous(x):
|
||||
# return np.hstack((x, np.ones((x.shape[0],1),dtype=x.dtype)))
|
||||
return np.concatenate((x, np.ones((*x.shape[:-1],1),dtype=x.dtype)), axis=-1)
|
||||
|
||||
def from_homogeneous(x):
|
||||
return x[:,:-1] / x[:,-1]
|
||||
|
||||
def project_uvn(uv, Ki=None):
|
||||
if uv.shape[1] == 2:
|
||||
uvn = to_homogeneous(uv)
|
||||
else:
|
||||
uvn = uv
|
||||
if uvn.shape[1] != 3:
|
||||
raise Exception('uv should have shape Nx2 or Nx3')
|
||||
if Ki is None:
|
||||
return uvn
|
||||
else:
|
||||
return uvn @ Ki.T
|
||||
|
||||
def project_uvd(uv, depth, K=np.eye(3), R=np.eye(3), t=np.zeros((3,1)), ignore_negative_depth=True, return_uvn=False):
|
||||
Ki = np.linalg.inv(K)
|
||||
|
||||
if ignore_negative_depth:
|
||||
mask = depth >= 0
|
||||
uv = uv[mask,:]
|
||||
d = depth[mask]
|
||||
else:
|
||||
d = depth.ravel()
|
||||
|
||||
uv1 = to_homogeneous(uv)
|
||||
|
||||
uvn1 = uv1 @ Ki.T
|
||||
xyz = d.reshape(-1,1) * uvn1
|
||||
xyz = (xyz - t.reshape((1,3))) @ R
|
||||
|
||||
if return_uvn:
|
||||
return xyz, uvn1
|
||||
else:
|
||||
return xyz
|
||||
|
||||
def project_depth(depth, K, R=np.eye(3,3), t=np.zeros((3,1)), ignore_negative_depth=True, return_uvn=False):
|
||||
u, v = np.meshgrid(range(depth.shape[1]), range(depth.shape[0]))
|
||||
uv = np.hstack((u.reshape(-1,1), v.reshape(-1,1)))
|
||||
return project_uvd(uv, depth.ravel(), K, R, t, ignore_negative_depth, return_uvn)
|
||||
|
||||
|
||||
def project_xyz(xyz, K=np.eye(3), R=np.eye(3,3), t=np.zeros((3,1))):
|
||||
uvd = K @ (R @ xyz.T + t.reshape((3,1)))
|
||||
uvd[:2] /= uvd[2]
|
||||
return uvd[:2].T, uvd[2]
|
||||
|
||||
|
||||
def relative_motion(R0, t0, R1, t1, Rt_from_global=True):
|
||||
t0 = t0.reshape((3,1))
|
||||
t1 = t1.reshape((3,1))
|
||||
if Rt_from_global:
|
||||
Rr = R1 @ R0.T
|
||||
tr = t1 - Rr @ t0
|
||||
else:
|
||||
Rr = R1.T @ R0
|
||||
tr = R1.T @ (t0 - t1)
|
||||
return Rr, tr.ravel()
|
||||
|
||||
|
||||
def translation_to_cameracenter(R, t):
|
||||
t = t.reshape(-1,3,1)
|
||||
R = R.reshape(-1,3,3)
|
||||
C = -R.transpose(0,2,1) @ t
|
||||
return C.squeeze()
|
||||
|
||||
def cameracenter_to_translation(R, C):
|
||||
C = C.reshape(-1,3,1)
|
||||
R = R.reshape(-1,3,3)
|
||||
t = -R @ C
|
||||
return t.squeeze()
|
||||
|
||||
def decompose_projection_matrix(P, return_t=True):
|
||||
if P.shape[0] != 3 or P.shape[1] != 4:
|
||||
raise Exception('P has to be 3x4')
|
||||
M = P[:, :3]
|
||||
C = -np.linalg.inv(M) @ P[:, 3:]
|
||||
|
||||
R,K = np.linalg.qr(np.flipud(M).T)
|
||||
K = np.flipud(K.T)
|
||||
K = np.fliplr(K)
|
||||
R = np.flipud(R.T)
|
||||
|
||||
T = np.diag(np.sign(np.diag(K)))
|
||||
K = K @ T
|
||||
R = T @ R
|
||||
|
||||
if np.linalg.det(R) < 0:
|
||||
R *= -1
|
||||
|
||||
K /= K[2,2]
|
||||
if return_t:
|
||||
return K, R, cameracenter_to_translation(R, C)
|
||||
else:
|
||||
return K, R, C
|
||||
|
||||
|
||||
def compose_projection_matrix(K=np.eye(3), R=np.eye(3,3), t=np.zeros((3,1))):
|
||||
return K @ np.hstack((R, t.reshape((3,1))))
|
||||
|
||||
|
||||
|
||||
def point_plane_distance(pts, plane):
|
||||
pts = pts.reshape(-1,3)
|
||||
return np.abs(np.sum(plane[:3] * pts, axis=1) + plane[3]) / np.linalg.norm(plane[:3])
|
||||
|
||||
def fit_plane(pts):
|
||||
pts = pts.reshape(-1,3)
|
||||
center = np.mean(pts, axis=0)
|
||||
A = pts - center
|
||||
u, s, vh = np.linalg.svd(A, full_matrices=False)
|
||||
# if pts.shape[0] > 100:
|
||||
# import ipdb; ipdb.set_trace()
|
||||
plane = np.array([*vh[2], -vh[2].dot(center)])
|
||||
return plane
|
||||
|
||||
def tetrahedron(dtype=np.float32):
|
||||
verts = np.array([
|
||||
(np.sqrt(8/9), 0, -1/3), (-np.sqrt(2/9), np.sqrt(2/3), -1/3),
|
||||
(-np.sqrt(2/9), -np.sqrt(2/3), -1/3), (0, 0, 1)], dtype=dtype)
|
||||
faces = np.array([(0,1,2), (0,2,3), (0,1,3), (1,2,3)], dtype=np.int32)
|
||||
normals = -np.mean(verts, axis=0) + verts
|
||||
normals /= np.linalg.norm(normals, axis=1).reshape(-1,1)
|
||||
return verts, faces, normals
|
||||
|
||||
def cube(dtype=np.float32):
|
||||
verts = np.array([
|
||||
[-0.5,-0.5,-0.5], [-0.5,0.5,-0.5], [0.5,0.5,-0.5], [0.5,-0.5,-0.5],
|
||||
[-0.5,-0.5,0.5], [-0.5,0.5,0.5], [0.5,0.5,0.5], [0.5,-0.5,0.5]], dtype=dtype)
|
||||
faces = np.array([
|
||||
(0,1,2), (0,2,3), (4,5,6), (4,6,7),
|
||||
(0,4,7), (0,7,3), (1,5,6), (1,6,2),
|
||||
(3,2,6), (3,6,7), (0,1,5), (0,5,4)], dtype=np.int32)
|
||||
normals = -np.mean(verts, axis=0) + verts
|
||||
normals /= np.linalg.norm(normals, axis=1).reshape(-1,1)
|
||||
return verts, faces, normals
|
||||
|
||||
def octahedron(dtype=np.float32):
|
||||
verts = np.array([
|
||||
(+1,0,0), (0,+1,0), (0,0,+1),
|
||||
(-1,0,0), (0,-1,0), (0,0,-1)], dtype=dtype)
|
||||
faces = np.array([
|
||||
(0,1,2), (1,2,3), (3,2,4), (4,2,0),
|
||||
(0,1,5), (1,5,3), (3,5,4), (4,5,0)], dtype=np.int32)
|
||||
normals = -np.mean(verts, axis=0) + verts
|
||||
normals /= np.linalg.norm(normals, axis=1).reshape(-1,1)
|
||||
return verts, faces, normals
|
||||
|
||||
def icosahedron(dtype=np.float32):
|
||||
p = (1 + np.sqrt(5)) / 2
|
||||
verts = np.array([
|
||||
(-1,0,p), (1,0,p), (1,0,-p), (-1,0,-p),
|
||||
(0,-p,1), (0,p,1), (0,p,-1), (0,-p,-1),
|
||||
(-p,-1,0), (p,-1,0), (p,1,0), (-p,1,0)
|
||||
], dtype=dtype)
|
||||
faces = np.array([
|
||||
(0,1,4), (0,1,5), (1,4,9), (1,9,10), (1,10,5), (0,4,8), (0,8,11), (0,11,5),
|
||||
(5,6,11), (5,6,10), (4,7,8), (4,7,9),
|
||||
(3,2,6), (3,2,7), (2,6,10), (2,10,9), (2,9,7), (3,6,11), (3,11,8), (3,8,7),
|
||||
], dtype=np.int32)
|
||||
normals = -np.mean(verts, axis=0) + verts
|
||||
normals /= np.linalg.norm(normals, axis=1).reshape(-1,1)
|
||||
return verts, faces, normals
|
||||
|
||||
def xyplane(dtype=np.float32, z=0, interleaved=False):
|
||||
if interleaved:
|
||||
eps = 1e-6
|
||||
verts = np.array([
|
||||
(-1,-1,z), (-1,1,z), (1,1,z),
|
||||
(1-eps,1,z), (1-eps,-1,z), (-1-eps,-1,z)], dtype=dtype)
|
||||
faces = np.array([(0,1,2), (3,4,5)], dtype=np.int32)
|
||||
else:
|
||||
verts = np.array([(-1,-1,z), (-1,1,z), (1,1,z), (1,-1,z)], dtype=dtype)
|
||||
faces = np.array([(0,1,2), (0,2,3)], dtype=np.int32)
|
||||
normals = np.zeros_like(verts)
|
||||
normals[:,2] = -1
|
||||
return verts, faces, normals
|
||||
|
||||
def mesh_independent_verts(verts, faces, normals=None):
|
||||
new_verts = []
|
||||
new_normals = []
|
||||
for f in faces:
|
||||
new_verts.append(verts[f[0]])
|
||||
new_verts.append(verts[f[1]])
|
||||
new_verts.append(verts[f[2]])
|
||||
if normals is not None:
|
||||
new_normals.append(normals[f[0]])
|
||||
new_normals.append(normals[f[1]])
|
||||
new_normals.append(normals[f[2]])
|
||||
new_verts = np.array(new_verts)
|
||||
new_faces = np.arange(0, faces.size, dtype=faces.dtype).reshape(-1,3)
|
||||
if normals is None:
|
||||
return new_verts, new_faces
|
||||
else:
|
||||
new_normals = np.array(new_normals)
|
||||
return new_verts, new_faces, new_normals
|
||||
|
||||
|
||||
def stack_mesh(verts, faces):
|
||||
n_verts = 0
|
||||
mfaces = []
|
||||
for idx, f in enumerate(faces):
|
||||
mfaces.append(f + n_verts)
|
||||
n_verts += verts[idx].shape[0]
|
||||
verts = np.vstack(verts)
|
||||
faces = np.vstack(mfaces)
|
||||
return verts, faces
|
||||
|
||||
def normalize_mesh(verts):
|
||||
# all the verts have unit distance to the center (0,0,0)
|
||||
return verts / np.linalg.norm(verts, axis=1, keepdims=True)
|
||||
|
||||
|
||||
def mesh_triangle_areas(verts, faces):
|
||||
a = verts[faces[:,0]]
|
||||
b = verts[faces[:,1]]
|
||||
c = verts[faces[:,2]]
|
||||
x = np.empty_like(a)
|
||||
x = a - b
|
||||
y = a - c
|
||||
t = np.empty_like(a)
|
||||
t[:,0] = (x[:,1] * y[:,2] - x[:,2] * y[:,1]);
|
||||
t[:,1] = (x[:,2] * y[:,0] - x[:,0] * y[:,2]);
|
||||
t[:,2] = (x[:,0] * y[:,1] - x[:,1] * y[:,0]);
|
||||
return np.linalg.norm(t, axis=1) / 2
|
||||
|
||||
def subdivde_mesh(verts_in, faces_in, n=1):
|
||||
for iter in range(n):
|
||||
verts = []
|
||||
for v in verts_in:
|
||||
verts.append(v)
|
||||
faces = []
|
||||
verts_dict = {}
|
||||
for f in faces_in:
|
||||
f = np.sort(f)
|
||||
i0,i1,i2 = f
|
||||
v0,v1,v2 = verts_in[f]
|
||||
|
||||
k = i0*len(verts_in)+i1
|
||||
if k in verts_dict:
|
||||
i01 = verts_dict[k]
|
||||
else:
|
||||
i01 = len(verts)
|
||||
verts_dict[k] = i01
|
||||
v01 = (v0 + v1) / 2
|
||||
verts.append(v01)
|
||||
|
||||
k = i0*len(verts_in)+i2
|
||||
if k in verts_dict:
|
||||
i02 = verts_dict[k]
|
||||
else:
|
||||
i02 = len(verts)
|
||||
verts_dict[k] = i02
|
||||
v02 = (v0 + v2) / 2
|
||||
verts.append(v02)
|
||||
|
||||
k = i1*len(verts_in)+i2
|
||||
if k in verts_dict:
|
||||
i12 = verts_dict[k]
|
||||
else:
|
||||
i12 = len(verts)
|
||||
verts_dict[k] = i12
|
||||
v12 = (v1 + v2) / 2
|
||||
verts.append(v12)
|
||||
|
||||
faces.append((i0,i01,i02))
|
||||
faces.append((i01,i1,i12))
|
||||
faces.append((i12,i2,i02))
|
||||
faces.append((i01,i12,i02))
|
||||
|
||||
verts_in = np.array(verts, dtype=verts_in.dtype)
|
||||
faces_in = np.array(faces, dtype=np.int32)
|
||||
return verts_in, faces_in
|
||||
|
||||
|
||||
def mesh_adjust_winding_order(verts, faces, normals):
|
||||
n0 = normals[faces[:,0]]
|
||||
n1 = normals[faces[:,1]]
|
||||
n2 = normals[faces[:,2]]
|
||||
fnormals = (n0 + n1 + n2) / 3
|
||||
|
||||
v0 = verts[faces[:,0]]
|
||||
v1 = verts[faces[:,1]]
|
||||
v2 = verts[faces[:,2]]
|
||||
|
||||
e0 = v1 - v0
|
||||
e1 = v2 - v0
|
||||
fn = np.cross(e0, e1)
|
||||
|
||||
dot = np.sum(fnormals * fn, axis=1)
|
||||
ma = dot < 0
|
||||
|
||||
nfaces = faces.copy()
|
||||
nfaces[ma,1], nfaces[ma,2] = nfaces[ma,2], nfaces[ma,1]
|
||||
|
||||
return nfaces
|
||||
|
||||
|
||||
def pcl_to_shapecl(verts, colors=None, shape='cube', width=1.0):
|
||||
if shape == 'tetrahedron':
|
||||
cverts, cfaces, _ = tetrahedron()
|
||||
elif shape == 'cube':
|
||||
cverts, cfaces, _ = cube()
|
||||
elif shape == 'octahedron':
|
||||
cverts, cfaces, _ = octahedron()
|
||||
elif shape == 'icosahedron':
|
||||
cverts, cfaces, _ = icosahedron()
|
||||
else:
|
||||
raise Exception('invalid shape')
|
||||
|
||||
sverts = np.tile(cverts, (verts.shape[0], 1))
|
||||
sverts *= width
|
||||
sverts += np.repeat(verts, cverts.shape[0], axis=0)
|
||||
|
||||
sfaces = np.tile(cfaces, (verts.shape[0], 1))
|
||||
sfoffset = cverts.shape[0] * np.arange(0, verts.shape[0])
|
||||
sfaces += np.repeat(sfoffset, cfaces.shape[0]).reshape(-1,1)
|
||||
|
||||
if colors is not None:
|
||||
scolors = np.repeat(colors, cverts.shape[0], axis=0)
|
||||
else:
|
||||
scolors = None
|
||||
|
||||
return sverts, sfaces, scolors
|
32
co/gtimer.py
Normal file
32
co/gtimer.py
Normal file
@ -0,0 +1,32 @@
|
||||
import numpy as np
|
||||
|
||||
from . import utils
|
||||
|
||||
class StopWatch(utils.StopWatch):
|
||||
def __del__(self):
|
||||
print('='*80)
|
||||
print('gtimer:')
|
||||
total = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.sum).items()])
|
||||
print(f' [total] {total}')
|
||||
mean = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.mean).items()])
|
||||
print(f' [mean] {mean}')
|
||||
median = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.median).items()])
|
||||
print(f' [median] {median}')
|
||||
print('='*80)
|
||||
|
||||
GTIMER = StopWatch()
|
||||
|
||||
def start(name):
|
||||
GTIMER.start(name)
|
||||
def stop(name):
|
||||
GTIMER.stop(name)
|
||||
|
||||
class Ctx(object):
|
||||
def __init__(self, name):
|
||||
self.name = name
|
||||
|
||||
def __enter__(self):
|
||||
start(self.name)
|
||||
|
||||
def __exit__(self, *args):
|
||||
stop(self.name)
|
267
co/io3d.py
Normal file
267
co/io3d.py
Normal file
@ -0,0 +1,267 @@
|
||||
import struct
|
||||
import numpy as np
|
||||
import collections
|
||||
|
||||
def _write_ply_point(fp, x,y,z, color=None, normal=None, binary=False):
|
||||
args = [x,y,z]
|
||||
if color is not None:
|
||||
args += [int(color[0]), int(color[1]), int(color[2])]
|
||||
if normal is not None:
|
||||
args += [normal[0],normal[1],normal[2]]
|
||||
if binary:
|
||||
fmt = '<fff'
|
||||
if color is not None:
|
||||
fmt = fmt + 'BBB'
|
||||
if normal is not None:
|
||||
fmt = fmt + 'fff'
|
||||
fp.write(struct.pack(fmt, *args))
|
||||
else:
|
||||
fmt = '%f %f %f'
|
||||
if color is not None:
|
||||
fmt = fmt + ' %d %d %d'
|
||||
if normal is not None:
|
||||
fmt = fmt + ' %f %f %f'
|
||||
fmt += '\n'
|
||||
fp.write(fmt % tuple(args))
|
||||
|
||||
def _write_ply_triangle(fp, i0,i1,i2, binary):
|
||||
if binary:
|
||||
fp.write(struct.pack('<Biii', 3,i0,i1,i2))
|
||||
else:
|
||||
fp.write('3 %d %d %d\n' % (i0,i1,i2))
|
||||
|
||||
def _write_ply_header_line(fp, str, binary):
|
||||
if binary:
|
||||
fp.write(str.encode())
|
||||
else:
|
||||
fp.write(str)
|
||||
|
||||
def write_ply(path, verts, trias=None, color=None, normals=None, binary=False):
|
||||
if verts.shape[1] != 3:
|
||||
raise Exception('verts has to be of shape Nx3')
|
||||
if trias is not None and trias.shape[1] != 3:
|
||||
raise Exception('trias has to be of shape Nx3')
|
||||
if color is not None and not callable(color) and not isinstance(color, np.ndarray) and color.shape[1] != 3:
|
||||
raise Exception('color has to be of shape Nx3 or a callable')
|
||||
|
||||
mode = 'wb' if binary else 'w'
|
||||
with open(path, mode) as fp:
|
||||
_write_ply_header_line(fp, "ply\n", binary)
|
||||
if binary:
|
||||
_write_ply_header_line(fp, "format binary_little_endian 1.0\n", binary)
|
||||
else:
|
||||
_write_ply_header_line(fp, "format ascii 1.0\n", binary)
|
||||
_write_ply_header_line(fp, "element vertex %d\n" % (verts.shape[0]), binary)
|
||||
_write_ply_header_line(fp, "property float32 x\n", binary)
|
||||
_write_ply_header_line(fp, "property float32 y\n", binary)
|
||||
_write_ply_header_line(fp, "property float32 z\n", binary)
|
||||
if color is not None:
|
||||
_write_ply_header_line(fp, "property uchar red\n", binary)
|
||||
_write_ply_header_line(fp, "property uchar green\n", binary)
|
||||
_write_ply_header_line(fp, "property uchar blue\n", binary)
|
||||
if normals is not None:
|
||||
_write_ply_header_line(fp, "property float32 nx\n", binary)
|
||||
_write_ply_header_line(fp, "property float32 ny\n", binary)
|
||||
_write_ply_header_line(fp, "property float32 nz\n", binary)
|
||||
if trias is not None:
|
||||
_write_ply_header_line(fp, "element face %d\n" % (trias.shape[0]), binary)
|
||||
_write_ply_header_line(fp, "property list uchar int32 vertex_indices\n", binary)
|
||||
_write_ply_header_line(fp, "end_header\n", binary)
|
||||
|
||||
for vidx, v in enumerate(verts):
|
||||
if color is not None:
|
||||
if callable(color):
|
||||
c = color(vidx)
|
||||
elif color.shape[0] > 1:
|
||||
c = color[vidx]
|
||||
else:
|
||||
c = color[0]
|
||||
else:
|
||||
c = None
|
||||
if normals is None:
|
||||
n = None
|
||||
else:
|
||||
n = normals[vidx]
|
||||
_write_ply_point(fp, v[0],v[1],v[2], c, n, binary)
|
||||
|
||||
if trias is not None:
|
||||
for t in trias:
|
||||
_write_ply_triangle(fp, t[0],t[1],t[2], binary)
|
||||
|
||||
def faces_to_triangles(faces):
|
||||
new_faces = []
|
||||
for f in faces:
|
||||
if f[0] == 3:
|
||||
new_faces.append([f[1], f[2], f[3]])
|
||||
elif f[0] == 4:
|
||||
new_faces.append([f[1], f[2], f[3]])
|
||||
new_faces.append([f[3], f[4], f[1]])
|
||||
else:
|
||||
raise Exception('unknown face count %d', f[0])
|
||||
return new_faces
|
||||
|
||||
def read_ply(path):
|
||||
with open(path, 'rb') as f:
|
||||
# parse header
|
||||
line = f.readline().decode().strip()
|
||||
if line != 'ply':
|
||||
raise Exception('Header error')
|
||||
n_verts = 0
|
||||
n_faces = 0
|
||||
vert_types = {}
|
||||
vert_bin_format = []
|
||||
vert_bin_len = 0
|
||||
vert_bin_cols = 0
|
||||
line = f.readline().decode()
|
||||
parse_vertex_prop = False
|
||||
while line.strip() != 'end_header':
|
||||
if 'format' in line:
|
||||
if 'ascii' in line:
|
||||
binary = False
|
||||
elif 'binary_little_endian' in line:
|
||||
binary = True
|
||||
else:
|
||||
raise Exception('invalid ply format')
|
||||
if 'element face' in line:
|
||||
splits = line.strip().split(' ')
|
||||
n_faces = int(splits[-1])
|
||||
parse_vertex_prop = False
|
||||
if 'element camera' in line:
|
||||
parse_vertex_prop = False
|
||||
if 'element vertex' in line:
|
||||
splits = line.strip().split(' ')
|
||||
n_verts = int(splits[-1])
|
||||
parse_vertex_prop = True
|
||||
if parse_vertex_prop and 'property' in line:
|
||||
prop = line.strip().split()
|
||||
if prop[1] == 'float':
|
||||
vert_bin_format.append('f4')
|
||||
vert_bin_len += 4
|
||||
vert_bin_cols += 1
|
||||
elif prop[1] == 'uchar':
|
||||
vert_bin_format.append('B')
|
||||
vert_bin_len += 1
|
||||
vert_bin_cols += 1
|
||||
else:
|
||||
raise Exception('invalid property')
|
||||
vert_types[prop[2]] = len(vert_types)
|
||||
line = f.readline().decode()
|
||||
|
||||
# parse content
|
||||
if binary:
|
||||
sz = n_verts * vert_bin_len
|
||||
fmt = ','.join(vert_bin_format)
|
||||
verts = np.ndarray(shape=(1, n_verts), dtype=np.dtype(fmt), buffer=f.read(sz))
|
||||
verts = verts[0].astype(vert_bin_cols*'f4,').view(dtype='f4').reshape((n_verts,-1))
|
||||
faces = []
|
||||
for idx in range(n_faces):
|
||||
fmt = '<Biii'
|
||||
length = struct.calcsize(fmt)
|
||||
dat = f.read(length)
|
||||
vals = struct.unpack(fmt, dat)
|
||||
faces.append(vals)
|
||||
faces = faces_to_triangles(faces)
|
||||
faces = np.array(faces, dtype=np.int32)
|
||||
else:
|
||||
verts = []
|
||||
for idx in range(n_verts):
|
||||
vals = [float(v) for v in f.readline().decode().strip().split(' ')]
|
||||
verts.append(vals)
|
||||
verts = np.array(verts, dtype=np.float32)
|
||||
faces = []
|
||||
for idx in range(n_faces):
|
||||
splits = f.readline().decode().strip().split(' ')
|
||||
n_face_verts = int(splits[0])
|
||||
vals = [int(v) for v in splits[0:n_face_verts+1]]
|
||||
faces.append(vals)
|
||||
faces = faces_to_triangles(faces)
|
||||
faces = np.array(faces, dtype=np.int32)
|
||||
|
||||
xyz = None
|
||||
if 'x' in vert_types and 'y' in vert_types and 'z' in vert_types:
|
||||
xyz = verts[:,[vert_types['x'], vert_types['y'], vert_types['z']]]
|
||||
colors = None
|
||||
if 'red' in vert_types and 'green' in vert_types and 'blue' in vert_types:
|
||||
colors = verts[:,[vert_types['red'], vert_types['green'], vert_types['blue']]]
|
||||
colors /= 255
|
||||
normals = None
|
||||
if 'nx' in vert_types and 'ny' in vert_types and 'nz' in vert_types:
|
||||
normals = verts[:,[vert_types['nx'], vert_types['ny'], vert_types['nz']]]
|
||||
|
||||
return xyz, faces, colors, normals
|
||||
|
||||
|
||||
def _read_obj_split_f(s):
|
||||
parts = s.split('/')
|
||||
vidx = int(parts[0]) - 1
|
||||
if len(parts) >= 2 and len(parts[1]) > 0:
|
||||
tidx = int(parts[1]) - 1
|
||||
else:
|
||||
tidx = -1
|
||||
if len(parts) >= 3 and len(parts[2]) > 0:
|
||||
nidx = int(parts[2]) - 1
|
||||
else:
|
||||
nidx = -1
|
||||
return vidx, tidx, nidx
|
||||
|
||||
def read_obj(path):
|
||||
with open(path, 'r') as fp:
|
||||
lines = fp.readlines()
|
||||
|
||||
verts = []
|
||||
colors = []
|
||||
fnorms = []
|
||||
fnorm_map = collections.defaultdict(list)
|
||||
faces = []
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line.startswith('#') or len(line) == 0:
|
||||
continue
|
||||
|
||||
parts = line.split()
|
||||
if line.startswith('v '):
|
||||
parts = parts[1:]
|
||||
x,y,z = float(parts[0]), float(parts[1]), float(parts[2])
|
||||
if len(parts) == 4 or len(parts) == 7:
|
||||
w = float(parts[3])
|
||||
x,y,z = x/w, y/w, z/w
|
||||
verts.append((x,y,z))
|
||||
if len(parts) >= 6:
|
||||
r,g,b = float(parts[-3]), float(parts[-2]), float(parts[-1])
|
||||
rgb.append((r,g,b))
|
||||
|
||||
elif line.startswith('vn '):
|
||||
parts = parts[1:]
|
||||
x,y,z = float(parts[0]), float(parts[1]), float(parts[2])
|
||||
fnorms.append((x,y,z))
|
||||
|
||||
elif line.startswith('f '):
|
||||
parts = parts[1:]
|
||||
if len(parts) != 3:
|
||||
raise Exception('only triangle meshes supported atm')
|
||||
vidx0, tidx0, nidx0 = _read_obj_split_f(parts[0])
|
||||
vidx1, tidx1, nidx1 = _read_obj_split_f(parts[1])
|
||||
vidx2, tidx2, nidx2 = _read_obj_split_f(parts[2])
|
||||
|
||||
faces.append((vidx0, vidx1, vidx2))
|
||||
if nidx0 >= 0:
|
||||
fnorm_map[vidx0].append( nidx0 )
|
||||
if nidx1 >= 0:
|
||||
fnorm_map[vidx1].append( nidx1 )
|
||||
if nidx2 >= 0:
|
||||
fnorm_map[vidx2].append( nidx2 )
|
||||
|
||||
verts = np.array(verts)
|
||||
colors = np.array(colors)
|
||||
fnorms = np.array(fnorms)
|
||||
faces = np.array(faces)
|
||||
|
||||
# face normals to vertex normals
|
||||
norms = np.zeros_like(verts)
|
||||
for vidx in fnorm_map.keys():
|
||||
ind = fnorm_map[vidx]
|
||||
norms[vidx] = fnorms[ind].sum(axis=0)
|
||||
N = np.linalg.norm(norms, axis=1, keepdims=True)
|
||||
np.divide(norms, N, out=norms, where=N != 0)
|
||||
|
||||
return verts, faces, colors, norms
|
248
co/metric.py
Normal file
248
co/metric.py
Normal file
@ -0,0 +1,248 @@
|
||||
import numpy as np
|
||||
from . import geometry
|
||||
|
||||
def _process_inputs(estimate, target, mask):
|
||||
if estimate.shape != target.shape:
|
||||
raise Exception('estimate and target have to be same shape')
|
||||
if mask is None:
|
||||
mask = np.ones(estimate.shape, dtype=np.bool)
|
||||
else:
|
||||
mask = mask != 0
|
||||
if estimate.shape != mask.shape:
|
||||
raise Exception('estimate and mask have to be same shape')
|
||||
return estimate, target, mask
|
||||
|
||||
def mse(estimate, target, mask=None):
|
||||
estimate, target, mask = _process_inputs(estimate, target, mask)
|
||||
m = np.sum((estimate[mask] - target[mask])**2) / mask.sum()
|
||||
return m
|
||||
|
||||
def rmse(estimate, target, mask=None):
|
||||
return np.sqrt(mse(estimate, target, mask))
|
||||
|
||||
def mae(estimate, target, mask=None):
|
||||
estimate, target, mask = _process_inputs(estimate, target, mask)
|
||||
m = np.abs(estimate[mask] - target[mask]).sum() / mask.sum()
|
||||
return m
|
||||
|
||||
def outlier_fraction(estimate, target, mask=None, threshold=0):
|
||||
estimate, target, mask = _process_inputs(estimate, target, mask)
|
||||
diff = np.abs(estimate[mask] - target[mask])
|
||||
m = (diff > threshold).sum() / mask.sum()
|
||||
return m
|
||||
|
||||
|
||||
class Metric(object):
|
||||
def __init__(self, str_prefix=''):
|
||||
self.str_prefix = str_prefix
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def add(self, es, ta, ma=None):
|
||||
pass
|
||||
|
||||
def get(self):
|
||||
return {}
|
||||
|
||||
def items(self):
|
||||
return self.get().items()
|
||||
|
||||
def __str__(self):
|
||||
return ', '.join([f'{self.str_prefix}{key}={value:.5f}' for key, value in self.get().items()])
|
||||
|
||||
class MultipleMetric(Metric):
|
||||
def __init__(self, *metrics, **kwargs):
|
||||
self.metrics = [*metrics]
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def reset(self):
|
||||
for m in self.metrics:
|
||||
m.reset()
|
||||
|
||||
def add(self, es, ta, ma=None):
|
||||
for m in self.metrics:
|
||||
m.add(es, ta, ma)
|
||||
|
||||
def get(self):
|
||||
ret = {}
|
||||
for m in self.metrics:
|
||||
vals = m.get()
|
||||
for k in vals:
|
||||
ret[k] = vals[k]
|
||||
return ret
|
||||
|
||||
def __str__(self):
|
||||
return '\n'.join([str(m) for m in self.metrics])
|
||||
|
||||
class BaseDistanceMetric(Metric):
|
||||
def __init__(self, name='', **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.name = name
|
||||
|
||||
def reset(self):
|
||||
self.dists = []
|
||||
|
||||
def add(self, es, ta, ma=None):
|
||||
pass
|
||||
|
||||
def get(self):
|
||||
dists = np.hstack(self.dists)
|
||||
return {
|
||||
f'dist{self.name}_mean': float(np.mean(dists)),
|
||||
f'dist{self.name}_std': float(np.std(dists)),
|
||||
f'dist{self.name}_median': float(np.median(dists)),
|
||||
f'dist{self.name}_q10': float(np.percentile(dists, 10)),
|
||||
f'dist{self.name}_q90': float(np.percentile(dists, 90)),
|
||||
f'dist{self.name}_min': float(np.min(dists)),
|
||||
f'dist{self.name}_max': float(np.max(dists)),
|
||||
}
|
||||
|
||||
class DistanceMetric(BaseDistanceMetric):
|
||||
def __init__(self, vec_length, p=2, **kwargs):
|
||||
super().__init__(name=f'{p}', **kwargs)
|
||||
self.vec_length = vec_length
|
||||
self.p = p
|
||||
|
||||
def add(self, es, ta, ma=None):
|
||||
if es.shape != ta.shape or es.shape[1] != self.vec_length or es.ndim != 2:
|
||||
print(es.shape, ta.shape)
|
||||
raise Exception('es and ta have to be of shape Nxdim')
|
||||
if ma is not None:
|
||||
es = es[ma != 0]
|
||||
ta = ta[ma != 0]
|
||||
dist = np.linalg.norm(es - ta, ord=self.p, axis=1)
|
||||
self.dists.append( dist )
|
||||
|
||||
class OutlierFractionMetric(DistanceMetric):
|
||||
def __init__(self, thresholds, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.thresholds = thresholds
|
||||
|
||||
def get(self):
|
||||
dists = np.hstack(self.dists)
|
||||
ret = {}
|
||||
for t in self.thresholds:
|
||||
ma = dists > t
|
||||
ret[f'of{t}'] = float(ma.sum() / ma.size)
|
||||
return ret
|
||||
|
||||
class RelativeDistanceMetric(BaseDistanceMetric):
|
||||
def __init__(self, vec_length, p=2, **kwargs):
|
||||
super().__init__(name=f'rel{p}', **kwargs)
|
||||
self.vec_length = vec_length
|
||||
self.p = p
|
||||
|
||||
def add(self, es, ta, ma=None):
|
||||
if es.shape != ta.shape or es.shape[1] != self.vec_length or es.ndim != 2:
|
||||
raise Exception('es and ta have to be of shape Nxdim')
|
||||
dist = np.linalg.norm(es - ta, ord=self.p, axis=1)
|
||||
denom = np.linalg.norm(ta, ord=self.p, axis=1)
|
||||
dist /= denom
|
||||
if ma is not None:
|
||||
dist = dist[ma != 0]
|
||||
self.dists.append( dist )
|
||||
|
||||
class RotmDistanceMetric(BaseDistanceMetric):
|
||||
def __init__(self, type='identity', **kwargs):
|
||||
super().__init__(name=type, **kwargs)
|
||||
self.type = type
|
||||
|
||||
def add(self, es, ta, ma=None):
|
||||
if es.shape != ta.shape or es.shape[1] != 3 or es.shape[2] != 3 or es.ndim != 3:
|
||||
print(es.shape, ta.shape)
|
||||
raise Exception('es and ta have to be of shape Nx3x3')
|
||||
if ma is not None:
|
||||
raise Exception('mask is not implemented')
|
||||
if self.type == 'identity':
|
||||
self.dists.append( geometry.rotm_distance_identity(es, ta) )
|
||||
elif self.type == 'geodesic':
|
||||
self.dists.append( geometry.rotm_distance_geodesic_unit_sphere(es, ta) )
|
||||
else:
|
||||
raise Exception('invalid distance type')
|
||||
|
||||
class QuaternionDistanceMetric(BaseDistanceMetric):
|
||||
def __init__(self, type='angle', **kwargs):
|
||||
super().__init__(name=type, **kwargs)
|
||||
self.type = type
|
||||
|
||||
def add(self, es, ta, ma=None):
|
||||
if es.shape != ta.shape or es.shape[1] != 4 or es.ndim != 2:
|
||||
print(es.shape, ta.shape)
|
||||
raise Exception('es and ta have to be of shape Nx4')
|
||||
if ma is not None:
|
||||
raise Exception('mask is not implemented')
|
||||
if self.type == 'angle':
|
||||
self.dists.append( geometry.quat_distance_angle(es, ta) )
|
||||
elif self.type == 'mineucl':
|
||||
self.dists.append( geometry.quat_distance_mineucl(es, ta) )
|
||||
elif self.type == 'normdiff':
|
||||
self.dists.append( geometry.quat_distance_normdiff(es, ta) )
|
||||
else:
|
||||
raise Exception('invalid distance type')
|
||||
|
||||
|
||||
class BinaryAccuracyMetric(Metric):
|
||||
def __init__(self, thresholds=np.linspace(0.0, 1.0, num=101, dtype=np.float64)[:-1], **kwargs):
|
||||
self.thresholds = thresholds
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def reset(self):
|
||||
self.tps = [0 for wp in self.thresholds]
|
||||
self.fps = [0 for wp in self.thresholds]
|
||||
self.fns = [0 for wp in self.thresholds]
|
||||
self.tns = [0 for wp in self.thresholds]
|
||||
self.n_pos = 0
|
||||
self.n_neg = 0
|
||||
|
||||
def add(self, es, ta, ma=None):
|
||||
if ma is not None:
|
||||
raise Exception('mask is not implemented')
|
||||
es = es.ravel()
|
||||
ta = ta.ravel()
|
||||
if es.shape[0] != ta.shape[0]:
|
||||
raise Exception('invalid shape of es, or ta')
|
||||
if es.min() < 0 or es.max() > 1:
|
||||
raise Exception('estimate has wrong value range')
|
||||
ta_p = (ta == 1)
|
||||
ta_n = (ta == 0)
|
||||
es_p = es[ta_p]
|
||||
es_n = es[ta_n]
|
||||
for idx, wp in enumerate(self.thresholds):
|
||||
wp = np.asscalar(wp)
|
||||
self.tps[idx] += (es_p > wp).sum()
|
||||
self.fps[idx] += (es_n > wp).sum()
|
||||
self.fns[idx] += (es_p <= wp).sum()
|
||||
self.tns[idx] += (es_n <= wp).sum()
|
||||
self.n_pos += ta_p.sum()
|
||||
self.n_neg += ta_n.sum()
|
||||
|
||||
def get(self):
|
||||
tps = np.array(self.tps).astype(np.float32)
|
||||
fps = np.array(self.fps).astype(np.float32)
|
||||
fns = np.array(self.fns).astype(np.float32)
|
||||
tns = np.array(self.tns).astype(np.float32)
|
||||
wp = self.thresholds
|
||||
|
||||
ret = {}
|
||||
|
||||
precisions = np.divide(tps, tps + fps, out=np.zeros_like(tps), where=tps + fps != 0)
|
||||
recalls = np.divide(tps, tps + fns, out=np.zeros_like(tps), where=tps + fns != 0) # tprs
|
||||
fprs = np.divide(fps, fps + tns, out=np.zeros_like(tps), where=fps + tns != 0)
|
||||
|
||||
precisions = np.r_[0, precisions, 1]
|
||||
recalls = np.r_[1, recalls, 0]
|
||||
fprs = np.r_[1, fprs, 0]
|
||||
|
||||
ret['auc'] = float(-np.trapz(recalls, fprs))
|
||||
ret['prauc'] = float(-np.trapz(precisions, recalls))
|
||||
ret['ap'] = float(-(np.diff(recalls) * precisions[:-1]).sum())
|
||||
|
||||
accuracies = np.divide(tps + tns, tps + tns + fps + fns)
|
||||
aacc = np.mean(accuracies)
|
||||
for t in np.linspace(0,1,num=11)[1:-1]:
|
||||
idx = np.argmin(np.abs(t - wp))
|
||||
ret[f'acc{wp[idx]:.2f}'] = float(accuracies[idx])
|
||||
|
||||
return ret
|
99
co/plt.py
Normal file
99
co/plt.py
Normal file
@ -0,0 +1,99 @@
|
||||
import numpy as np
|
||||
import matplotlib as mpl
|
||||
from matplotlib import _pylab_helpers
|
||||
from matplotlib.rcsetup import interactive_bk as _interactive_bk
|
||||
import matplotlib.pyplot as plt
|
||||
import os
|
||||
import time
|
||||
|
||||
def save(path, remove_axis=False, dpi=300, fig=None):
|
||||
if fig is None:
|
||||
fig = plt.gcf()
|
||||
dirname = os.path.dirname(path)
|
||||
if dirname != '' and not os.path.exists(dirname):
|
||||
os.makedirs(dirname)
|
||||
if remove_axis:
|
||||
for ax in fig.axes:
|
||||
ax.axis('off')
|
||||
ax.margins(0,0)
|
||||
fig.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
||||
for ax in fig.axes:
|
||||
ax.xaxis.set_major_locator(plt.NullLocator())
|
||||
ax.yaxis.set_major_locator(plt.NullLocator())
|
||||
fig.savefig(path, dpi=dpi, bbox_inches='tight', pad_inches=0)
|
||||
|
||||
def color_map(im_, cmap='viridis', vmin=None, vmax=None):
|
||||
cm = plt.get_cmap(cmap)
|
||||
im = im_.copy()
|
||||
if vmin is None:
|
||||
vmin = np.nanmin(im)
|
||||
if vmax is None:
|
||||
vmax = np.nanmax(im)
|
||||
mask = np.logical_not(np.isfinite(im))
|
||||
im[mask] = vmin
|
||||
im = (im.clip(vmin, vmax) - vmin) / (vmax - vmin)
|
||||
im = cm(im)
|
||||
im = im[...,:3]
|
||||
for c in range(3):
|
||||
im[mask, c] = 1
|
||||
return im
|
||||
|
||||
def interactive_legend(leg=None, fig=None, all_axes=True):
|
||||
if leg is None:
|
||||
leg = plt.legend()
|
||||
if fig is None:
|
||||
fig = plt.gcf()
|
||||
if all_axes:
|
||||
axs = fig.get_axes()
|
||||
else:
|
||||
axs = [fig.gca()]
|
||||
|
||||
# lined = dict()
|
||||
# lines = ax.lines
|
||||
# for legline, origline in zip(leg.get_lines(), ax.lines):
|
||||
# legline.set_picker(5)
|
||||
# lined[legline] = origline
|
||||
lined = dict()
|
||||
for lidx, legline in enumerate(leg.get_lines()):
|
||||
legline.set_picker(5)
|
||||
lined[legline] = [ax.lines[lidx] for ax in axs]
|
||||
|
||||
def onpick(event):
|
||||
if event.mouseevent.dblclick:
|
||||
tmp = [(k,v) for k,v in lined.items()]
|
||||
else:
|
||||
tmp = [(event.artist, lined[event.artist])]
|
||||
|
||||
for legline, origline in tmp:
|
||||
for ol in origline:
|
||||
vis = not ol.get_visible()
|
||||
ol.set_visible(vis)
|
||||
if vis:
|
||||
legline.set_alpha(1.0)
|
||||
else:
|
||||
legline.set_alpha(0.2)
|
||||
fig.canvas.draw()
|
||||
|
||||
fig.canvas.mpl_connect('pick_event', onpick)
|
||||
|
||||
def non_annoying_pause(interval, focus_figure=False):
|
||||
# https://github.com/matplotlib/matplotlib/issues/11131
|
||||
backend = mpl.rcParams['backend']
|
||||
if backend in _interactive_bk:
|
||||
figManager = _pylab_helpers.Gcf.get_active()
|
||||
if figManager is not None:
|
||||
canvas = figManager.canvas
|
||||
if canvas.figure.stale:
|
||||
canvas.draw()
|
||||
if focus_figure:
|
||||
plt.show(block=False)
|
||||
canvas.start_event_loop(interval)
|
||||
return
|
||||
time.sleep(interval)
|
||||
|
||||
def remove_all_ticks(fig=None):
|
||||
if fig is None:
|
||||
fig = plt.gcf()
|
||||
for ax in fig.axes:
|
||||
ax.axes.get_xaxis().set_visible(False)
|
||||
ax.axes.get_yaxis().set_visible(False)
|
57
co/plt2d.py
Normal file
57
co/plt2d.py
Normal file
@ -0,0 +1,57 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from . import geometry
|
||||
|
||||
def image_matrix(ims, bgval=0):
|
||||
n = ims.shape[0]
|
||||
m = int( np.ceil(np.sqrt(n)) )
|
||||
h = ims.shape[1]
|
||||
w = ims.shape[2]
|
||||
mat = np.empty((m*h, m*w), dtype=ims.dtype)
|
||||
mat.fill(bgval)
|
||||
idx = 0
|
||||
for r in range(m):
|
||||
for c in range(m):
|
||||
if idx < n:
|
||||
mat[r*h:(r+1)*h, c*w:(c+1)*w] = ims[idx]
|
||||
idx += 1
|
||||
return mat
|
||||
|
||||
def image_cat(ims, vertical=False):
|
||||
offx = [0]
|
||||
offy = [0]
|
||||
if vertical:
|
||||
width = max([im.shape[1] for im in ims])
|
||||
offx += [0 for im in ims[:-1]]
|
||||
offy += [im.shape[0] for im in ims[:-1]]
|
||||
height = sum([im.shape[0] for im in ims])
|
||||
else:
|
||||
height = max([im.shape[0] for im in ims])
|
||||
offx += [im.shape[1] for im in ims[:-1]]
|
||||
offy += [0 for im in ims[:-1]]
|
||||
width = sum([im.shape[1] for im in ims])
|
||||
offx = np.cumsum(offx)
|
||||
offy = np.cumsum(offy)
|
||||
|
||||
im = np.zeros((height,width,*ims[0].shape[2:]), dtype=ims[0].dtype)
|
||||
for im0, ox, oy in zip(ims, offx, offy):
|
||||
im[oy:oy + im0.shape[0], ox:ox + im0.shape[1]] = im0
|
||||
|
||||
return im, offx, offy
|
||||
|
||||
def line(li, h, w, ax=None, *args, **kwargs):
|
||||
if ax is None:
|
||||
ax = plt.gca()
|
||||
xs = (-li[2] - li[1] * np.array((0, h-1))) / li[0]
|
||||
ys = (-li[2] - li[0] * np.array((0, w-1))) / li[1]
|
||||
pts = np.array([(0,ys[0]), (w-1, ys[1]), (xs[0], 0), (xs[1], h-1)])
|
||||
pts = pts[np.logical_and(np.logical_and(pts[:,0] >= 0, pts[:,0] < w), np.logical_and(pts[:,1] >= 0, pts[:,1] < h))]
|
||||
ax.plot(pts[:,0], pts[:,1], *args, **kwargs)
|
||||
|
||||
def depthshow(depth, *args, ax=None, **kwargs):
|
||||
if ax is None:
|
||||
ax = plt.gca()
|
||||
d = depth.copy()
|
||||
d[d < 0] = np.NaN
|
||||
ax.imshow(d, *args, **kwargs)
|
38
co/plt3d.py
Normal file
38
co/plt3d.py
Normal file
@ -0,0 +1,38 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from mpl_toolkits.mplot3d import Axes3D
|
||||
|
||||
from . import geometry
|
||||
|
||||
def ax3d(fig=None):
|
||||
if fig is None:
|
||||
fig = plt.gcf()
|
||||
return fig.add_subplot(111, projection='3d')
|
||||
|
||||
def plot_camera(ax=None, R=np.eye(3), t=np.zeros((3,)), size=25, marker_C='.', color='b', linestyle='-', linewidth=0.1, label=None, **kwargs):
|
||||
if ax is None:
|
||||
ax = plt.gca()
|
||||
C0 = geometry.translation_to_cameracenter(R, t).ravel()
|
||||
C1 = C0 + R.T.dot( np.array([[-size],[-size],[3*size]], dtype=np.float32) ).ravel()
|
||||
C2 = C0 + R.T.dot( np.array([[-size],[+size],[3*size]], dtype=np.float32) ).ravel()
|
||||
C3 = C0 + R.T.dot( np.array([[+size],[+size],[3*size]], dtype=np.float32) ).ravel()
|
||||
C4 = C0 + R.T.dot( np.array([[+size],[-size],[3*size]], dtype=np.float32) ).ravel()
|
||||
|
||||
if marker_C != '':
|
||||
ax.plot([C0[0]], [C0[1]], [C0[2]], marker=marker_C, color=color, label=label, **kwargs)
|
||||
ax.plot([C0[0], C1[0]], [C0[1], C1[1]], [C0[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
|
||||
ax.plot([C0[0], C2[0]], [C0[1], C2[1]], [C0[2], C2[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
|
||||
ax.plot([C0[0], C3[0]], [C0[1], C3[1]], [C0[2], C3[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
|
||||
ax.plot([C0[0], C4[0]], [C0[1], C4[1]], [C0[2], C4[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
|
||||
ax.plot([C1[0], C2[0], C3[0], C4[0], C1[0]], [C1[1], C2[1], C3[1], C4[1], C1[1]], [C1[2], C2[2], C3[2], C4[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
|
||||
|
||||
def axis_equal(ax=None):
|
||||
if ax is None:
|
||||
ax = plt.gca()
|
||||
extents = np.array([getattr(ax, 'get_{}lim'.format(dim))() for dim in 'xyz'])
|
||||
sz = extents[:,1] - extents[:,0]
|
||||
centers = np.mean(extents, axis=1)
|
||||
maxsize = max(abs(sz))
|
||||
r = maxsize/2
|
||||
for ctr, dim in zip(centers, 'xyz'):
|
||||
getattr(ax, 'set_{}lim'.format(dim))(ctr - r, ctr + r)
|
445
co/table.py
Normal file
445
co/table.py
Normal file
@ -0,0 +1,445 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import enum
|
||||
import itertools
|
||||
|
||||
class Table(object):
|
||||
def __init__(self, n_cols):
|
||||
self.n_cols = n_cols
|
||||
self.rows = []
|
||||
self.aligns = ['r' for c in range(n_cols)]
|
||||
|
||||
def get_cell_align(self, r, c):
|
||||
align = self.rows[r].cells[c].align
|
||||
if align is None:
|
||||
return self.aligns[c]
|
||||
else:
|
||||
return align
|
||||
|
||||
def add_row(self, row):
|
||||
if row.ncols() != self.n_cols:
|
||||
raise Exception(f'row has invalid number of cols, {row.ncols()} vs. {self.n_cols}')
|
||||
self.rows.append(row)
|
||||
|
||||
def empty_row(self):
|
||||
return Row.Empty(self.n_cols)
|
||||
|
||||
def expand_rows(self, n_add_cols=1):
|
||||
if n_add_cols < 0: raise Exception('n_add_cols has to be positive')
|
||||
self.n_cols += n_add_cols
|
||||
for row in self.rows:
|
||||
row.cells.extend([Cell() for cidx in range(n_add_cols)])
|
||||
|
||||
def add_block(self, data, row=-1, col=0, fmt=None, expand=False):
|
||||
if row < 0: row = len(self.rows)
|
||||
while len(self.rows) < row + len(data):
|
||||
self.add_row(self.empty_row())
|
||||
for r in range(len(data)):
|
||||
cols = data[r]
|
||||
if col + len(cols) > self.n_cols:
|
||||
if expand:
|
||||
self.expand_rows(col + len(cols) - self.n_cols)
|
||||
else:
|
||||
raise Exception('number of cols does not fit in table')
|
||||
for c in range(len(cols)):
|
||||
self.rows[row+r].cells[col+c] = Cell(data[r][c], fmt)
|
||||
|
||||
class Row(object):
|
||||
def __init__(self, cells, pre_separator=None, post_separator=None):
|
||||
self.cells = cells
|
||||
self.pre_separator = pre_separator
|
||||
self.post_separator = post_separator
|
||||
|
||||
@classmethod
|
||||
def Empty(cls, n_cols):
|
||||
return Row([Cell() for c in range(n_cols)])
|
||||
|
||||
def add_cell(self, cell):
|
||||
self.cells.append(cell)
|
||||
|
||||
def ncols(self):
|
||||
return sum([c.span for c in self.cells])
|
||||
|
||||
|
||||
|
||||
class Color(object):
|
||||
def __init__(self, color=(0,0,0), fmt='rgb'):
|
||||
if fmt == 'rgb':
|
||||
self.color = color
|
||||
elif fmt == 'RGB':
|
||||
self.color = tuple(c / 255 for c in color)
|
||||
else:
|
||||
return Exception('invalid color format')
|
||||
|
||||
def as_rgb(self):
|
||||
return self.color
|
||||
|
||||
def as_RGB(self):
|
||||
return tuple(int(c * 255) for c in self.color)
|
||||
|
||||
@classmethod
|
||||
def rgb(cls, r, g, b):
|
||||
return Color(color=(r,g,b), fmt='rgb')
|
||||
|
||||
@classmethod
|
||||
def RGB(cls, r, g, b):
|
||||
return Color(color=(r,g,b), fmt='RGB')
|
||||
|
||||
|
||||
class CellFormat(object):
|
||||
def __init__(self, fmt='%s', fgcolor=None, bgcolor=None, bold=False):
|
||||
self.fmt = fmt
|
||||
self.fgcolor = fgcolor
|
||||
self.bgcolor = bgcolor
|
||||
self.bold = bold
|
||||
|
||||
class Cell(object):
|
||||
def __init__(self, data=None, fmt=None, span=1, align=None):
|
||||
self.data = data
|
||||
if fmt is None:
|
||||
fmt = CellFormat()
|
||||
self.fmt = fmt
|
||||
self.span = span
|
||||
self.align = align
|
||||
|
||||
def __str__(self):
|
||||
return self.fmt.fmt % self.data
|
||||
|
||||
class Separator(enum.Enum):
|
||||
HEAD = 1
|
||||
BOTTOM = 2
|
||||
INNER = 3
|
||||
|
||||
|
||||
class Renderer(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def cell_str_len(self, cell):
|
||||
return len(str(cell))
|
||||
|
||||
def col_widths(self, table):
|
||||
widths = [0 for c in range(table.n_cols)]
|
||||
for row in table.rows:
|
||||
cidx = 0
|
||||
for cell in row.cells:
|
||||
if cell.span == 1:
|
||||
strlen = self.cell_str_len(cell)
|
||||
widths[cidx] = max(widths[cidx], strlen)
|
||||
cidx += cell.span
|
||||
return widths
|
||||
|
||||
def render(self, table):
|
||||
raise NotImplementedError('not implemented')
|
||||
|
||||
def __call__(self, table):
|
||||
return self.render(table)
|
||||
|
||||
def render_to_file_comment(self):
|
||||
return ''
|
||||
|
||||
def render_to_file(self, path, table):
|
||||
txt = self.render(table)
|
||||
with open(path, 'w') as fp:
|
||||
fp.write(txt)
|
||||
|
||||
class TerminalRenderer(Renderer):
|
||||
def __init__(self, col_sep=' '):
|
||||
super().__init__()
|
||||
self.col_sep = col_sep
|
||||
|
||||
def render_cell(self, table, row, col, widths):
|
||||
cell = table.rows[row].cells[col]
|
||||
str = cell.fmt.fmt % cell.data
|
||||
str_width = len(str)
|
||||
cell_width = sum([widths[idx] for idx in range(col, col+cell.span)])
|
||||
cell_width += len(self.col_sep) * (cell.span - 1)
|
||||
if len(str) > cell_width:
|
||||
str = str[:cell_width]
|
||||
if cell.fmt.bold:
|
||||
# str = sty.ef.bold + str + sty.rs.bold_dim
|
||||
# str = sty.ef.bold + str + sty.rs.bold
|
||||
pass
|
||||
if cell.fmt.fgcolor is not None:
|
||||
# color = cell.fmt.fgcolor.as_RGB()
|
||||
# str = sty.fg(*color) + str + sty.rs.fg
|
||||
pass
|
||||
if str_width < cell_width:
|
||||
n_ws = (cell_width - str_width)
|
||||
if table.get_cell_align(row, col) == 'r':
|
||||
str = ' '*n_ws + str
|
||||
elif table.get_cell_align(row, col) == 'l':
|
||||
str = str + ' '*n_ws
|
||||
elif table.get_cell_align(row, col) == 'c':
|
||||
n_ws1 = n_ws // 2
|
||||
n_ws0 = n_ws - n_ws1
|
||||
str = ' '*n_ws0 + str + ' '*n_ws1
|
||||
if cell.fmt.bgcolor is not None:
|
||||
# color = cell.fmt.bgcolor.as_RGB()
|
||||
# str = sty.bg(*color) + str + sty.rs.bg
|
||||
pass
|
||||
return str
|
||||
|
||||
def render_separator(self, separator, tab, col_widths, total_width):
|
||||
if separator == Separator.HEAD:
|
||||
return '='*total_width
|
||||
elif separator == Separator.INNER:
|
||||
return '-'*total_width
|
||||
elif separator == Separator.BOTTOM:
|
||||
return '='*total_width
|
||||
|
||||
def render(self, table):
|
||||
widths = self.col_widths(table)
|
||||
total_width = sum(widths) + len(self.col_sep) * (table.n_cols - 1)
|
||||
lines = []
|
||||
for ridx, row in enumerate(table.rows):
|
||||
if row.pre_separator is not None:
|
||||
sepline = self.render_separator(row.pre_separator, table, widths, total_width)
|
||||
if len(sepline) > 0:
|
||||
lines.append(sepline)
|
||||
line = []
|
||||
for cidx, cell in enumerate(row.cells):
|
||||
line.append(self.render_cell(table, ridx, cidx, widths))
|
||||
lines.append(self.col_sep.join(line))
|
||||
if row.post_separator is not None:
|
||||
sepline = self.render_separator(row.post_separator, table, widths, total_width)
|
||||
if len(sepline) > 0:
|
||||
lines.append(sepline)
|
||||
return '\n'.join(lines)
|
||||
|
||||
class MarkdownRenderer(TerminalRenderer):
|
||||
def __init__(self):
|
||||
super().__init__(col_sep='|')
|
||||
self.printed_color_warning = False
|
||||
|
||||
def print_color_warning(self):
|
||||
if not self.printed_color_warning:
|
||||
print('[WARNING] MarkdownRenderer does not support color yet')
|
||||
self.printed_color_warning = True
|
||||
|
||||
def cell_str_len(self, cell):
|
||||
strlen = len(str(cell))
|
||||
if cell.fmt.bold:
|
||||
strlen += 4
|
||||
strlen = max(5, strlen)
|
||||
return strlen
|
||||
|
||||
def render_cell(self, table, row, col, widths):
|
||||
cell = table.rows[row].cells[col]
|
||||
str = cell.fmt.fmt % cell.data
|
||||
if cell.fmt.bold:
|
||||
str = f'**{str}**'
|
||||
|
||||
str_width = len(str)
|
||||
cell_width = sum([widths[idx] for idx in range(col, col+cell.span)])
|
||||
cell_width += len(self.col_sep) * (cell.span - 1)
|
||||
if len(str) > cell_width:
|
||||
str = str[:cell_width]
|
||||
else:
|
||||
n_ws = (cell_width - str_width)
|
||||
if table.get_cell_align(row, col) == 'r':
|
||||
str = ' '*n_ws + str
|
||||
elif table.get_cell_align(row, col) == 'l':
|
||||
str = str + ' '*n_ws
|
||||
elif table.get_cell_align(row, col) == 'c':
|
||||
n_ws1 = n_ws // 2
|
||||
n_ws0 = n_ws - n_ws1
|
||||
str = ' '*n_ws0 + str + ' '*n_ws1
|
||||
|
||||
if col == 0: str = self.col_sep + str
|
||||
if col == table.n_cols - 1: str += self.col_sep
|
||||
|
||||
if cell.fmt.fgcolor is not None:
|
||||
self.print_color_warning()
|
||||
if cell.fmt.bgcolor is not None:
|
||||
self.print_color_warning()
|
||||
return str
|
||||
|
||||
def render_separator(self, separator, tab, widths, total_width):
|
||||
sep = ''
|
||||
if separator == Separator.INNER:
|
||||
sep = self.col_sep
|
||||
for idx, width in enumerate(widths):
|
||||
csep = '-' * (width - 2)
|
||||
if tab.get_cell_align(1, idx) == 'r':
|
||||
csep = '-' + csep + ':'
|
||||
elif tab.get_cell_align(1, idx) == 'l':
|
||||
csep = ':' + csep + '-'
|
||||
elif tab.get_cell_align(1, idx) == 'c':
|
||||
csep = ':' + csep + ':'
|
||||
sep += csep + self.col_sep
|
||||
return sep
|
||||
|
||||
|
||||
class LatexRenderer(Renderer):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def render_cell(self, table, row, col):
|
||||
cell = table.rows[row].cells[col]
|
||||
str = cell.fmt.fmt % cell.data
|
||||
if cell.fmt.bold:
|
||||
str = '{\\bf '+ str + '}'
|
||||
if cell.fmt.fgcolor is not None:
|
||||
color = cell.fmt.fgcolor.as_rgb()
|
||||
str = f'{{\\color[rgb]{{{color[0]},{color[1]},{color[2]}}} ' + str + '}'
|
||||
if cell.fmt.bgcolor is not None:
|
||||
color = cell.fmt.bgcolor.as_rgb()
|
||||
str = f'\\cellcolor[rgb]{{{color[0]},{color[1]},{color[2]}}} ' + str
|
||||
align = table.get_cell_align(row, col)
|
||||
if cell.span != 1 or align != table.aligns[col]:
|
||||
str = f'\\multicolumn{{{cell.span}}}{{{align}}}{{{str}}}'
|
||||
return str
|
||||
|
||||
def render_separator(self, separator):
|
||||
if separator == Separator.HEAD:
|
||||
return '\\toprule'
|
||||
elif separator == Separator.INNER:
|
||||
return '\\midrule'
|
||||
elif separator == Separator.BOTTOM:
|
||||
return '\\bottomrule'
|
||||
|
||||
def render(self, table):
|
||||
lines = ['\\begin{tabular}{' + ''.join(table.aligns) + '}']
|
||||
for ridx, row in enumerate(table.rows):
|
||||
if row.pre_separator is not None:
|
||||
lines.append(self.render_separator(row.pre_separator))
|
||||
line = []
|
||||
for cidx, cell in enumerate(row.cells):
|
||||
line.append(self.render_cell(table, ridx, cidx))
|
||||
lines.append(' & '.join(line) + ' \\\\')
|
||||
if row.post_separator is not None:
|
||||
lines.append(self.render_separator(row.post_separator))
|
||||
lines.append('\\end{tabular}')
|
||||
return '\n'.join(lines)
|
||||
|
||||
class HtmlRenderer(Renderer):
|
||||
def __init__(self, html_class='result_table'):
|
||||
super().__init__()
|
||||
self.html_class = html_class
|
||||
|
||||
def render_cell(self, table, row, col):
|
||||
cell = table.rows[row].cells[col]
|
||||
str = cell.fmt.fmt % cell.data
|
||||
styles = []
|
||||
if cell.fmt.bold:
|
||||
styles.append('font-weight: bold;')
|
||||
if cell.fmt.fgcolor is not None:
|
||||
color = cell.fmt.fgcolor.as_RGB()
|
||||
styles.append(f'color: rgb({color[0]},{color[1]},{color[2]});')
|
||||
if cell.fmt.bgcolor is not None:
|
||||
color = cell.fmt.bgcolor.as_RGB()
|
||||
styles.append(f'background-color: rgb({color[0]},{color[1]},{color[2]});')
|
||||
align = table.get_cell_align(row, col)
|
||||
if align == 'l': align = 'left'
|
||||
elif align == 'r': align = 'right'
|
||||
elif align == 'c': align = 'center'
|
||||
else: raise Exception('invalid align')
|
||||
styles.append(f'text-align: {align};')
|
||||
row = table.rows[row]
|
||||
if row.pre_separator is not None:
|
||||
styles.append(f'border-top: {self.render_separator(row.pre_separator)};')
|
||||
if row.post_separator is not None:
|
||||
styles.append(f'border-bottom: {self.render_separator(row.post_separator)};')
|
||||
style = ' '.join(styles)
|
||||
str = f' <td style="{style}" colspan="{cell.span}">{str}</td>\n'
|
||||
return str
|
||||
|
||||
def render_separator(self, separator):
|
||||
if separator == Separator.HEAD:
|
||||
return '1.5pt solid black'
|
||||
elif separator == Separator.INNER:
|
||||
return '0.75pt solid black'
|
||||
elif separator == Separator.BOTTOM:
|
||||
return '1.5pt solid black'
|
||||
|
||||
def render(self, table):
|
||||
lines = [f'<table width="100%" style="border-collapse: collapse" class={self.html_class}>']
|
||||
for ridx, row in enumerate(table.rows):
|
||||
line = [f' <tr>\n']
|
||||
for cidx, cell in enumerate(row.cells):
|
||||
line.append(self.render_cell(table, ridx, cidx))
|
||||
line.append(' </tr>\n')
|
||||
lines.append(' '.join(line))
|
||||
lines.append('</table>')
|
||||
return '\n'.join(lines)
|
||||
|
||||
|
||||
def pandas_to_table(rowname, colname, valname, data, val_cell_fmt=CellFormat(fmt='%.4f'), best_val_cell_fmt=CellFormat(fmt='%.4f', bold=True), best_is_max=[]):
|
||||
rnames = data[rowname].unique()
|
||||
cnames = data[colname].unique()
|
||||
tab = Table(1+len(cnames))
|
||||
|
||||
header = [Cell('', align='r')]
|
||||
header.extend([Cell(h, align='r') for h in cnames])
|
||||
header = Row(header, pre_separator=Separator.HEAD, post_separator=Separator.INNER)
|
||||
tab.add_row(header)
|
||||
|
||||
for rname in rnames:
|
||||
cells = [Cell(rname, align='l')]
|
||||
for cname in cnames:
|
||||
cdata = data[data[colname] == cname]
|
||||
if cname in best_is_max:
|
||||
bestval = cdata[valname].max()
|
||||
val = cdata[cdata[rowname] == rname][valname].max()
|
||||
else:
|
||||
bestval = cdata[valname].min()
|
||||
val = cdata[cdata[rowname] == rname][valname].min()
|
||||
if val == bestval:
|
||||
fmt = best_val_cell_fmt
|
||||
else:
|
||||
fmt = val_cell_fmt
|
||||
cells.append(Cell(val, align='r', fmt=fmt))
|
||||
tab.add_row(Row(cells))
|
||||
tab.rows[-1].post_separator = Separator.BOTTOM
|
||||
return tab
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# df = pd.read_pickle('full.df')
|
||||
# best_is_max = ['movF0.5', 'movF1.0']
|
||||
# tab = pandas_to_table(rowname='method', colname='metric', valname='val', data=df, best_is_max=best_is_max)
|
||||
|
||||
# renderer = TerminalRenderer()
|
||||
# print(renderer(tab))
|
||||
|
||||
tab = Table(7)
|
||||
# header = Row([Cell('header', span=7, align='c')], pre_separator=Separator.HEAD, post_separator=Separator.INNER)
|
||||
# tab.add_row(header)
|
||||
# header2 = Row([Cell('thisisaverylongheader', span=4, align='c'), Cell('vals2', span=3, align='c')], post_separator=Separator.INNER)
|
||||
# tab.add_row(header2)
|
||||
tab.add_row(Row([Cell(f'c{c}') for c in range(7)]))
|
||||
tab.rows[-1].post_separator = Separator.INNER
|
||||
tab.add_block(np.arange(15*7).reshape(15,7))
|
||||
tab.rows[4].cells[2].fmt = CellFormat(bold=True)
|
||||
tab.rows[2].cells[1].fmt = CellFormat(fgcolor=Color.rgb(0.2,0.6,0.1))
|
||||
tab.rows[2].cells[2].fmt = CellFormat(bgcolor=Color.rgb(0.7,0.1,0.5))
|
||||
tab.rows[5].cells[3].fmt = CellFormat(bold=True,bgcolor=Color.rgb(0.7,0.1,0.5),fgcolor=Color.rgb(0.1,0.1,0.1))
|
||||
tab.rows[-1].post_separator = Separator.BOTTOM
|
||||
|
||||
renderer = TerminalRenderer()
|
||||
print(renderer(tab))
|
||||
renderer = MarkdownRenderer()
|
||||
print(renderer(tab))
|
||||
|
||||
# renderer = HtmlRenderer()
|
||||
# html_tab = renderer(tab)
|
||||
# print(html_tab)
|
||||
# with open('test.html', 'w') as fp:
|
||||
# fp.write(html_tab)
|
||||
|
||||
# import latex
|
||||
|
||||
# renderer = LatexRenderer()
|
||||
# ltx_tab = renderer(tab)
|
||||
# print(ltx_tab)
|
||||
|
||||
# with open('test.tex', 'w') as fp:
|
||||
# latex.write_doc_prefix(fp, document_class='article')
|
||||
# fp.write('this is text that should appear before the table and should be long enough to wrap around.\n'*40)
|
||||
# fp.write('\\begin{table}')
|
||||
# fp.write(ltx_tab)
|
||||
# fp.write('\\end{table}')
|
||||
# fp.write('this is text that should appear after the table and should be long enough to wrap around.\n'*40)
|
||||
# latex.write_doc_suffix(fp)
|
86
co/utils.py
Normal file
86
co/utils.py
Normal file
@ -0,0 +1,86 @@
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import pickle
|
||||
import subprocess
|
||||
|
||||
def str2bool(v):
|
||||
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
||||
return True
|
||||
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
||||
return False
|
||||
else:
|
||||
raise argparse.ArgumentTypeError('Boolean value expected.')
|
||||
|
||||
class StopWatch(object):
|
||||
def __init__(self):
|
||||
self.timings = OrderedDict()
|
||||
self.starts = {}
|
||||
|
||||
def start(self, name):
|
||||
self.starts[name] = time.time()
|
||||
|
||||
def stop(self, name):
|
||||
if name not in self.timings:
|
||||
self.timings[name] = []
|
||||
self.timings[name].append(time.time() - self.starts[name])
|
||||
|
||||
def get(self, name=None, reduce=np.sum):
|
||||
if name is not None:
|
||||
return reduce(self.timings[name])
|
||||
else:
|
||||
ret = {}
|
||||
for k in self.timings:
|
||||
ret[k] = reduce(self.timings[k])
|
||||
return ret
|
||||
|
||||
def __repr__(self):
|
||||
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
|
||||
def __str__(self):
|
||||
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
|
||||
|
||||
class ETA(object):
|
||||
def __init__(self, length):
|
||||
self.length = length
|
||||
self.start_time = time.time()
|
||||
self.current_idx = 0
|
||||
self.current_time = time.time()
|
||||
|
||||
def update(self, idx):
|
||||
self.current_idx = idx
|
||||
self.current_time = time.time()
|
||||
|
||||
def get_elapsed_time(self):
|
||||
return self.current_time - self.start_time
|
||||
|
||||
def get_item_time(self):
|
||||
return self.get_elapsed_time() / (self.current_idx + 1)
|
||||
|
||||
def get_remaining_time(self):
|
||||
return self.get_item_time() * (self.length - self.current_idx + 1)
|
||||
|
||||
def format_time(self, seconds):
|
||||
minutes, seconds = divmod(seconds, 60)
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
hours = int(hours)
|
||||
minutes = int(minutes)
|
||||
return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}'
|
||||
|
||||
def get_elapsed_time_str(self):
|
||||
return self.format_time(self.get_elapsed_time())
|
||||
|
||||
def get_remaining_time_str(self):
|
||||
return self.format_time(self.get_remaining_time())
|
||||
|
||||
def git_hash(cwd=None):
|
||||
ret = subprocess.run(['git', 'describe', '--always'], cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
|
||||
hash = ret.stdout
|
||||
if hash is not None and 'fatal' not in hash.decode():
|
||||
return hash.decode().strip()
|
||||
else:
|
||||
return None
|
||||
|
5
config.json
Normal file
5
config.json
Normal file
@ -0,0 +1,5 @@
|
||||
{
|
||||
"CUDA_LIBRARY_DIR": "/usr/local/cuda/lib64",
|
||||
"DATA_ROOT": "/path/to/where/to/store/data/",
|
||||
"SHAPENET_ROOT": "/path/to/shapenet/dataset"
|
||||
}
|
7
create_syn_data.sh
Executable file
7
create_syn_data.sh
Executable file
@ -0,0 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
cd data/lcn
|
||||
python setup.py build_ext --inplace
|
||||
|
||||
cd ../
|
||||
python create_syn_data.py
|
0
data/__init__.py
Normal file
0
data/__init__.py
Normal file
110
data/commons.py
Normal file
110
data/commons.py
Normal file
@ -0,0 +1,110 @@
|
||||
import co
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
|
||||
def get_patterns(path='syn', imsizes=[], crop=True):
|
||||
pattern_size = imsizes[0]
|
||||
if path == 'syn':
|
||||
np.random.seed(42)
|
||||
pattern = np.random.uniform(0,1, size=pattern_size)
|
||||
pattern = (pattern < 0.1).astype(np.float32)
|
||||
pattern.reshape(*imsizes[0])
|
||||
else:
|
||||
pattern = cv2.imread(path)
|
||||
pattern = pattern.astype(np.float32)
|
||||
pattern /= 255
|
||||
|
||||
if pattern.ndim == 2:
|
||||
pattern = np.stack([pattern for idx in range(3)], axis=2)
|
||||
|
||||
if crop and pattern.shape[0] > pattern_size[0] and pattern.shape[1] > pattern_size[1]:
|
||||
r0 = (pattern.shape[0] - pattern_size[0]) // 2
|
||||
c0 = (pattern.shape[1] - pattern_size[1]) // 2
|
||||
pattern = pattern[r0:r0+imsizes[0][0], c0:c0+imsizes[0][1]]
|
||||
|
||||
patterns = []
|
||||
for imsize in imsizes:
|
||||
pat = cv2.resize(pattern, (imsize[1],imsize[0]), interpolation=cv2.INTER_LINEAR)
|
||||
patterns.append(pat)
|
||||
|
||||
return patterns
|
||||
|
||||
def get_rotation_matrix(v0, v1):
|
||||
v0 = v0/np.linalg.norm(v0)
|
||||
v1 = v1/np.linalg.norm(v1)
|
||||
v = np.cross(v0,v1)
|
||||
c = np.dot(v0,v1)
|
||||
s = np.linalg.norm(v)
|
||||
I = np.eye(3)
|
||||
vXStr = '{} {} {}; {} {} {}; {} {} {}'.format(0, -v[2], v[1], v[2], 0, -v[0], -v[1], v[0], 0)
|
||||
k = np.matrix(vXStr)
|
||||
r = I + k + k @ k * ((1 -c)/(s**2))
|
||||
return np.asarray(r.astype(np.float32))
|
||||
|
||||
|
||||
def augment_image(img,rng,disp=None,grad=None,max_shift=64,max_blur=1.5,max_noise=10.0,max_sp_noise=0.001):
|
||||
|
||||
# get min/max values of image
|
||||
min_val = np.min(img)
|
||||
max_val = np.max(img)
|
||||
|
||||
# init augmented image
|
||||
img_aug = img
|
||||
|
||||
# init disparity correction map
|
||||
disp_aug = disp
|
||||
grad_aug = grad
|
||||
|
||||
# apply affine transformation
|
||||
if max_shift>1:
|
||||
|
||||
# affine parameters
|
||||
rows,cols = img.shape
|
||||
shear = 0
|
||||
shift = 0
|
||||
shear_correction = 0
|
||||
if rng.uniform(0,1)<0.75: shear = rng.uniform(-max_shift,max_shift) # shear with 75% probability
|
||||
else: shift = rng.uniform(0,max_shift) # shift with 25% probability
|
||||
if shear<0: shear_correction = -shear
|
||||
|
||||
# affine transformation
|
||||
a = shear/float(rows)
|
||||
b = shift+shear_correction
|
||||
|
||||
# warp image
|
||||
T = np.float32([[1,a,b],[0,1,0]])
|
||||
img_aug = cv2.warpAffine(img_aug,T,(cols,rows))
|
||||
if grad is not None:
|
||||
grad_aug = cv2.warpAffine(grad,T,(cols,rows))
|
||||
|
||||
# disparity correction map
|
||||
col = a*np.array(range(rows))+b
|
||||
disp_delta = np.tile(col,(cols,1)).transpose()
|
||||
if disp is not None:
|
||||
disp_aug = cv2.warpAffine(disp+disp_delta,T,(cols,rows))
|
||||
|
||||
# gaussian smoothing
|
||||
if rng.uniform(0,1)<0.5:
|
||||
img_aug = cv2.GaussianBlur(img_aug,(5,5),rng.uniform(0.2,max_blur))
|
||||
|
||||
# per-pixel gaussian noise
|
||||
img_aug = img_aug + rng.randn(*img_aug.shape)*rng.uniform(0.0,max_noise)/255.0
|
||||
|
||||
# salt-and-pepper noise
|
||||
if rng.uniform(0,1)<0.5:
|
||||
ratio=rng.uniform(0.0,max_sp_noise)
|
||||
img_shape = img_aug.shape
|
||||
img_aug = img_aug.flatten()
|
||||
coord = rng.choice(np.size(img_aug), int(np.size(img_aug)*ratio))
|
||||
img_aug[coord] = max_val
|
||||
coord = rng.choice(np.size(img_aug), int(np.size(img_aug)*ratio))
|
||||
img_aug[coord] = min_val
|
||||
img_aug = np.reshape(img_aug, img_shape)
|
||||
|
||||
# clip intensities back to [0,1]
|
||||
img_aug = np.maximum(img_aug,0.0)
|
||||
img_aug = np.minimum(img_aug,1.0)
|
||||
|
||||
# return image
|
||||
return img_aug, disp_aug, grad_aug
|
259
data/create_syn_data.py
Normal file
259
data/create_syn_data.py
Normal file
@ -0,0 +1,259 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import itertools
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
import multiprocessing
|
||||
import time
|
||||
import json
|
||||
import cv2
|
||||
import os
|
||||
import collections
|
||||
import sys
|
||||
sys.path.append('../')
|
||||
import renderer
|
||||
import co
|
||||
from commons import get_patterns,get_rotation_matrix
|
||||
from lcn import lcn
|
||||
|
||||
def get_objs(shapenet_dir, obj_classes, num_perclass=100):
|
||||
|
||||
shapenet = {'chair': '03001627',
|
||||
'airplane': '02691156',
|
||||
'car': '02958343',
|
||||
'watercraft': '04530566'}
|
||||
|
||||
obj_paths = []
|
||||
for cls in obj_classes:
|
||||
if cls not in shapenet.keys():
|
||||
raise Exception('unknown class name')
|
||||
ids = shapenet[cls]
|
||||
obj_path = sorted(Path(f'{shapenet_dir}/{ids}').glob('**/models/*.obj'))
|
||||
obj_paths += obj_path[:num_perclass]
|
||||
print(f'found {len(obj_paths)} object paths')
|
||||
|
||||
objs = []
|
||||
for obj_path in obj_paths:
|
||||
print(f'load {obj_path}')
|
||||
v, f, _, n = co.io3d.read_obj(obj_path)
|
||||
diffs = v.max(axis=0) - v.min(axis=0)
|
||||
v /= (0.5 * diffs.max())
|
||||
v -= (v.min(axis=0) + 1)
|
||||
f = f.astype(np.int32)
|
||||
objs.append((v,f,n))
|
||||
print(f'loaded {len(objs)} objects')
|
||||
|
||||
return objs
|
||||
|
||||
|
||||
def get_mesh(rng, min_z=0):
|
||||
# set up background board
|
||||
verts, faces, normals, colors = [], [], [], []
|
||||
v, f, n = co.geometry.xyplane(z=0, interleaved=True)
|
||||
v[:,2] += -v[:,2].min() + rng.uniform(2,7)
|
||||
v[:,:2] *= 5e2
|
||||
v[:,2] = np.mean(v[:,2]) + (v[:,2] - np.mean(v[:,2])) * 5e2
|
||||
c = np.empty_like(v)
|
||||
c[:] = rng.uniform(0,1, size=(3,)).astype(np.float32)
|
||||
verts.append(v)
|
||||
faces.append(f)
|
||||
normals.append(n)
|
||||
colors.append(c)
|
||||
|
||||
# randomly sample 4 foreground objects for each scene
|
||||
for shape_idx in range(4):
|
||||
v, f, n = objs[rng.randint(0,len(objs))]
|
||||
v, f, n = v.copy(), f.copy(), n.copy()
|
||||
|
||||
s = rng.uniform(0.25, 1)
|
||||
v *= s
|
||||
R = co.geometry.rotm_from_quat(co.geometry.quat_random(rng=rng))
|
||||
v = v @ R.T
|
||||
n = n @ R.T
|
||||
v[:,2] += -v[:,2].min() + min_z + rng.uniform(0.5, 3)
|
||||
v[:,:2] += rng.uniform(-1, 1, size=(1,2))
|
||||
|
||||
c = np.empty_like(v)
|
||||
c[:] = rng.uniform(0,1, size=(3,)).astype(np.float32)
|
||||
|
||||
verts.append(v.astype(np.float32))
|
||||
faces.append(f)
|
||||
normals.append(n)
|
||||
colors.append(c)
|
||||
|
||||
verts, faces = co.geometry.stack_mesh(verts, faces)
|
||||
normals = np.vstack(normals).astype(np.float32)
|
||||
colors = np.vstack(colors).astype(np.float32)
|
||||
return verts, faces, colors, normals
|
||||
|
||||
|
||||
def create_data(out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length=4):
|
||||
|
||||
tic = time.time()
|
||||
rng = np.random.RandomState()
|
||||
|
||||
rng.seed(idx)
|
||||
|
||||
verts, faces, colors, normals = get_mesh(rng)
|
||||
data = renderer.PyRenderInput(verts=verts.copy(), colors=colors.copy(), normals=normals.copy(), faces=faces.copy())
|
||||
print(f'loading mesh for sample {idx+1}/{n_samples} took {time.time()-tic}[s]')
|
||||
|
||||
|
||||
# let the camera point to the center
|
||||
center = np.array([0,0,3], dtype=np.float32)
|
||||
|
||||
basevec = np.array([-baseline,0,0], dtype=np.float32)
|
||||
unit = np.array([0,0,1],dtype=np.float32)
|
||||
|
||||
cam_x_ = rng.uniform(-0.2,0.2)
|
||||
cam_y_ = rng.uniform(-0.2,0.2)
|
||||
cam_z_ = rng.uniform(-0.2,0.2)
|
||||
|
||||
ret = collections.defaultdict(list)
|
||||
blend_im_rnd = np.clip(blend_im + rng.uniform(-0.1,0.1), 0,1)
|
||||
|
||||
# capture the same static scene from different view points as a track
|
||||
for ind in range(track_length):
|
||||
|
||||
cam_x = cam_x_ + rng.uniform(-0.1,0.1)
|
||||
cam_y = cam_y_ + rng.uniform(-0.1,0.1)
|
||||
cam_z = cam_z_ + rng.uniform(-0.1,0.1)
|
||||
|
||||
tcam = np.array([cam_x, cam_y, cam_z], dtype=np.float32)
|
||||
|
||||
if np.linalg.norm(tcam[0:2])<1e-9:
|
||||
Rcam = np.eye(3, dtype=np.float32)
|
||||
else:
|
||||
Rcam = get_rotation_matrix(center, center-tcam)
|
||||
|
||||
tproj = tcam + basevec
|
||||
Rproj = Rcam
|
||||
|
||||
ret['R'].append(Rcam)
|
||||
ret['t'].append(tcam)
|
||||
|
||||
cams = []
|
||||
projs = []
|
||||
|
||||
# render the scene at multiple scales
|
||||
scales = [1, 0.5, 0.25, 0.125]
|
||||
|
||||
for scale in scales:
|
||||
fx = K[0,0] * scale
|
||||
fy = K[1,1] * scale
|
||||
px = K[0,2] * scale
|
||||
py = K[1,2] * scale
|
||||
im_height = imsize[0] * scale
|
||||
im_width = imsize[1] * scale
|
||||
cams.append( renderer.PyCamera(fx,fy,px,py, Rcam, tcam, im_width, im_height) )
|
||||
projs.append( renderer.PyCamera(fx,fy,px,py, Rproj, tproj, im_width, im_height) )
|
||||
|
||||
|
||||
for s, cam, proj, pattern in zip(itertools.count(), cams, projs, patterns):
|
||||
fl = K[0,0] / (2**s)
|
||||
|
||||
shader = renderer.PyShader(0.5,1.5,0.0,10)
|
||||
pyrenderer = renderer.PyRenderer(cam, shader, engine='gpu')
|
||||
pyrenderer.mesh_proj(data, proj, pattern, d_alpha=0, d_beta=0.35)
|
||||
|
||||
# get the reflected laser pattern $R$
|
||||
im = pyrenderer.color().copy()
|
||||
depth = pyrenderer.depth().copy()
|
||||
disp = baseline * fl / depth
|
||||
mask = depth > 0
|
||||
im = np.mean(im, axis=2)
|
||||
|
||||
# get the ambient image $A$
|
||||
ambient = pyrenderer.normal().copy()
|
||||
ambient = np.mean(ambient, axis=2)
|
||||
|
||||
# get the noise free IR image $J$
|
||||
im = blend_im_rnd * im + (1 - blend_im_rnd) * ambient
|
||||
ret[f'ambient{s}'].append( ambient[None].astype(np.float32) )
|
||||
|
||||
# get the gradient magnitude of the ambient image $|\nabla A|$
|
||||
ambient = ambient.astype(np.float32)
|
||||
sobelx = cv2.Sobel(ambient,cv2.CV_32F,1,0,ksize=5)
|
||||
sobely = cv2.Sobel(ambient,cv2.CV_32F,0,1,ksize=5)
|
||||
grad = np.sqrt(sobelx**2 + sobely**2)
|
||||
grad = np.maximum(grad-0.8,0.0) # parameter
|
||||
|
||||
# get the local contract normalized grad LCN($|\nabla A|$)
|
||||
grad_lcn, grad_std = lcn.normalize(grad,5,0.1)
|
||||
grad_lcn = np.clip(grad_lcn,0.0,1.0) # parameter
|
||||
ret[f'grad{s}'].append( grad_lcn[None].astype(np.float32))
|
||||
|
||||
ret[f'im{s}'].append( im[None].astype(np.float32))
|
||||
ret[f'mask{s}'].append(mask[None].astype(np.float32))
|
||||
ret[f'disp{s}'].append(disp[None].astype(np.float32))
|
||||
|
||||
for key in ret.keys():
|
||||
ret[key] = np.stack(ret[key], axis=0)
|
||||
|
||||
# save to files
|
||||
out_dir = out_root / f'{idx:08d}'
|
||||
out_dir.mkdir(exist_ok=True, parents=True)
|
||||
for k,val in ret.items():
|
||||
for tidx in range(track_length):
|
||||
v = val[tidx]
|
||||
out_path = out_dir / f'{k}_{tidx}.npy'
|
||||
np.save(out_path, v)
|
||||
np.save( str(out_dir /'blend_im.npy'), blend_im_rnd)
|
||||
|
||||
print(f'create sample {idx+1}/{n_samples} took {time.time()-tic}[s]')
|
||||
|
||||
|
||||
|
||||
if __name__=='__main__':
|
||||
|
||||
np.random.seed(42)
|
||||
|
||||
# output directory
|
||||
with open('../config.json') as fp:
|
||||
config = json.load(fp)
|
||||
data_root = Path(config['DATA_ROOT'])
|
||||
shapenet_root = config['SHAPENET_ROOT']
|
||||
|
||||
data_type = 'syn'
|
||||
out_root = data_root / f'{data_type}'
|
||||
out_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# load shapenet models
|
||||
obj_classes = ['chair']
|
||||
objs = get_objs(shapenet_root, obj_classes)
|
||||
|
||||
# camera parameters
|
||||
imsize = (480, 640)
|
||||
imsizes = [(imsize[0]//(2**s), imsize[1]//(2**s)) for s in range(4)]
|
||||
K = np.array([[567.6, 0, 324.7], [0, 570.2, 250.1], [0 ,0, 1]], dtype=np.float32)
|
||||
focal_lengths = [K[0,0]/(2**s) for s in range(4)]
|
||||
baseline=0.075
|
||||
blend_im = 0.6
|
||||
noise = 0
|
||||
|
||||
# capture the same static scene from different view points as a track
|
||||
track_length = 4
|
||||
|
||||
# load pattern image
|
||||
pattern_path = './kinect_pattern.png'
|
||||
pattern_crop = True
|
||||
patterns = get_patterns(pattern_path, imsizes, pattern_crop)
|
||||
|
||||
# write settings to file
|
||||
settings = {
|
||||
'imsizes': imsizes,
|
||||
'patterns': patterns,
|
||||
'focal_lengths': focal_lengths,
|
||||
'baseline': baseline,
|
||||
'K': K,
|
||||
}
|
||||
out_path = out_root / f'settings.pkl'
|
||||
print(f'write settings to {out_path}')
|
||||
with open(str(out_path), 'wb') as f:
|
||||
pickle.dump(settings, f, pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
# start the job
|
||||
n_samples = 2**10 + 2**13
|
||||
for idx in range(n_samples):
|
||||
args = (out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length)
|
||||
create_data(*args)
|
148
data/dataset.py
Normal file
148
data/dataset.py
Normal file
@ -0,0 +1,148 @@
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import itertools
|
||||
import pickle
|
||||
import json
|
||||
import time
|
||||
from pathlib import Path
|
||||
import collections
|
||||
import cv2
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import glob
|
||||
|
||||
import torchext
|
||||
import renderer
|
||||
import co
|
||||
from .commons import get_patterns, augment_image
|
||||
|
||||
from mpl_toolkits.mplot3d import Axes3D
|
||||
|
||||
class TrackSynDataset(torchext.BaseDataset):
|
||||
'''
|
||||
Load locally saved synthetic dataset
|
||||
Please run ./create_syn_data.sh to generate the dataset
|
||||
'''
|
||||
def __init__(self, settings_path, sample_paths, track_length=2, train=True, data_aug=False):
|
||||
super().__init__(train=train)
|
||||
|
||||
self.settings_path = settings_path
|
||||
self.sample_paths = sample_paths
|
||||
self.data_aug = data_aug
|
||||
self.train = train
|
||||
self.track_length=track_length
|
||||
assert(track_length<=4)
|
||||
|
||||
with open(str(settings_path), 'rb') as f:
|
||||
settings = pickle.load(f)
|
||||
self.imsizes = settings['imsizes']
|
||||
self.patterns = settings['patterns']
|
||||
self.focal_lengths = settings['focal_lengths']
|
||||
self.baseline = settings['baseline']
|
||||
self.K = settings['K']
|
||||
|
||||
self.scale = len(self.imsizes)
|
||||
|
||||
self.max_shift=0
|
||||
self.max_blur=0.5
|
||||
self.max_noise=3.0
|
||||
self.max_sp_noise=0.0005
|
||||
|
||||
def __len__(self):
|
||||
return len(self.sample_paths)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if not self.train:
|
||||
rng = self.get_rng(idx)
|
||||
else:
|
||||
rng = np.random.RandomState()
|
||||
sample_path = self.sample_paths[idx]
|
||||
|
||||
if self.train:
|
||||
track_ind = np.random.permutation(4)[0:self.track_length]
|
||||
else:
|
||||
track_ind = [0]
|
||||
|
||||
ret = {}
|
||||
ret['id'] = idx
|
||||
|
||||
# load imgs, at all scales
|
||||
for sidx in range(len(self.imsizes)):
|
||||
imgs = []
|
||||
ambs = []
|
||||
grads = []
|
||||
for tidx in track_ind:
|
||||
imgs.append(np.load(os.path.join(sample_path,f'im{sidx}_{tidx}.npy')))
|
||||
ambs.append(np.load(os.path.join(sample_path,f'ambient{sidx}_{tidx}.npy')))
|
||||
grads.append(np.load(os.path.join(sample_path,f'grad{sidx}_{tidx}.npy')))
|
||||
ret[f'im{sidx}'] = np.stack(imgs, axis=0)
|
||||
ret[f'ambient{sidx}'] = np.stack(ambs, axis=0)
|
||||
ret[f'grad{sidx}'] = np.stack(grads, axis=0)
|
||||
|
||||
# load disp and grad only at full resolution
|
||||
disps = []
|
||||
R = []
|
||||
t = []
|
||||
for tidx in track_ind:
|
||||
disps.append(np.load(os.path.join(sample_path,f'disp0_{tidx}.npy')))
|
||||
R.append(np.load(os.path.join(sample_path,f'R_{tidx}.npy')))
|
||||
t.append(np.load(os.path.join(sample_path,f't_{tidx}.npy')))
|
||||
ret[f'disp0'] = np.stack(disps, axis=0)
|
||||
ret['R'] = np.stack(R, axis=0)
|
||||
ret['t'] = np.stack(t, axis=0)
|
||||
|
||||
blend_im = np.load(os.path.join(sample_path,'blend_im.npy'))
|
||||
ret['blend_im'] = blend_im.astype(np.float32)
|
||||
|
||||
#### apply data augmentation at different scales seperately, only work for max_shift=0
|
||||
if self.data_aug:
|
||||
for sidx in range(len(self.imsizes)):
|
||||
if sidx==0:
|
||||
img = ret[f'im{sidx}']
|
||||
disp = ret[f'disp{sidx}']
|
||||
grad = ret[f'grad{sidx}']
|
||||
img_aug = np.zeros_like(img)
|
||||
disp_aug = np.zeros_like(img)
|
||||
grad_aug = np.zeros_like(img)
|
||||
for i in range(img.shape[0]):
|
||||
img_aug_, disp_aug_, grad_aug_ = augment_image(img[i,0],rng,
|
||||
disp=disp[i,0],grad=grad[i,0],
|
||||
max_shift=self.max_shift, max_blur=self.max_blur,
|
||||
max_noise=self.max_noise, max_sp_noise=self.max_sp_noise)
|
||||
img_aug[i] = img_aug_[None].astype(np.float32)
|
||||
disp_aug[i] = disp_aug_[None].astype(np.float32)
|
||||
grad_aug[i] = grad_aug_[None].astype(np.float32)
|
||||
ret[f'im{sidx}'] = img_aug
|
||||
ret[f'disp{sidx}'] = disp_aug
|
||||
ret[f'grad{sidx}'] = grad_aug
|
||||
else:
|
||||
img = ret[f'im{sidx}']
|
||||
img_aug = np.zeros_like(img)
|
||||
for i in range(img.shape[0]):
|
||||
img_aug_, _, _ = augment_image(img[i,0],rng,
|
||||
max_shift=self.max_shift, max_blur=self.max_blur,
|
||||
max_noise=self.max_noise, max_sp_noise=self.max_sp_noise)
|
||||
img_aug[i] = img_aug_[None].astype(np.float32)
|
||||
ret[f'im{sidx}'] = img_aug
|
||||
|
||||
if len(track_ind)==1:
|
||||
for key, val in ret.items():
|
||||
if key!='blend_im' and key!='id':
|
||||
ret[key] = val[0]
|
||||
|
||||
|
||||
return ret
|
||||
|
||||
def getK(self, sidx=0):
|
||||
K = self.K.copy() / (2**sidx)
|
||||
K[2,2] = 1
|
||||
return K
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
||||
|
BIN
data/kinect_pattern.png
Normal file
BIN
data/kinect_pattern.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 585 KiB |
BIN
data/lcn/img.png
Normal file
BIN
data/lcn/img.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 750 KiB |
20968
data/lcn/lcn.c
Normal file
20968
data/lcn/lcn.c
Normal file
File diff suppressed because it is too large
Load Diff
720
data/lcn/lcn.html
Normal file
720
data/lcn/lcn.html
Normal file
@ -0,0 +1,720 @@
|
||||
<!DOCTYPE html>
|
||||
<!-- Generated by Cython 0.29 -->
|
||||
<html>
|
||||
<head>
|
||||
<meta http-equiv="Content-Type" content="text/html; charset=utf-8" />
|
||||
<title>Cython: lcn.pyx</title>
|
||||
<style type="text/css">
|
||||
|
||||
body.cython { font-family: courier; font-size: 12; }
|
||||
|
||||
.cython.tag { }
|
||||
.cython.line { margin: 0em }
|
||||
.cython.code { font-size: 9; color: #444444; display: none; margin: 0px 0px 0px 8px; border-left: 8px none; }
|
||||
|
||||
.cython.line .run { background-color: #B0FFB0; }
|
||||
.cython.line .mis { background-color: #FFB0B0; }
|
||||
.cython.code.run { border-left: 8px solid #B0FFB0; }
|
||||
.cython.code.mis { border-left: 8px solid #FFB0B0; }
|
||||
|
||||
.cython.code .py_c_api { color: red; }
|
||||
.cython.code .py_macro_api { color: #FF7000; }
|
||||
.cython.code .pyx_c_api { color: #FF3000; }
|
||||
.cython.code .pyx_macro_api { color: #FF7000; }
|
||||
.cython.code .refnanny { color: #FFA000; }
|
||||
.cython.code .trace { color: #FFA000; }
|
||||
.cython.code .error_goto { color: #FFA000; }
|
||||
|
||||
.cython.code .coerce { color: #008000; border: 1px dotted #008000 }
|
||||
.cython.code .py_attr { color: #FF0000; font-weight: bold; }
|
||||
.cython.code .c_attr { color: #0000FF; }
|
||||
.cython.code .py_call { color: #FF0000; font-weight: bold; }
|
||||
.cython.code .c_call { color: #0000FF; }
|
||||
|
||||
.cython.score-0 {background-color: #FFFFff;}
|
||||
.cython.score-1 {background-color: #FFFFe7;}
|
||||
.cython.score-2 {background-color: #FFFFd4;}
|
||||
.cython.score-3 {background-color: #FFFFc4;}
|
||||
.cython.score-4 {background-color: #FFFFb6;}
|
||||
.cython.score-5 {background-color: #FFFFaa;}
|
||||
.cython.score-6 {background-color: #FFFF9f;}
|
||||
.cython.score-7 {background-color: #FFFF96;}
|
||||
.cython.score-8 {background-color: #FFFF8d;}
|
||||
.cython.score-9 {background-color: #FFFF86;}
|
||||
.cython.score-10 {background-color: #FFFF7f;}
|
||||
.cython.score-11 {background-color: #FFFF79;}
|
||||
.cython.score-12 {background-color: #FFFF73;}
|
||||
.cython.score-13 {background-color: #FFFF6e;}
|
||||
.cython.score-14 {background-color: #FFFF6a;}
|
||||
.cython.score-15 {background-color: #FFFF66;}
|
||||
.cython.score-16 {background-color: #FFFF62;}
|
||||
.cython.score-17 {background-color: #FFFF5e;}
|
||||
.cython.score-18 {background-color: #FFFF5b;}
|
||||
.cython.score-19 {background-color: #FFFF57;}
|
||||
.cython.score-20 {background-color: #FFFF55;}
|
||||
.cython.score-21 {background-color: #FFFF52;}
|
||||
.cython.score-22 {background-color: #FFFF4f;}
|
||||
.cython.score-23 {background-color: #FFFF4d;}
|
||||
.cython.score-24 {background-color: #FFFF4b;}
|
||||
.cython.score-25 {background-color: #FFFF48;}
|
||||
.cython.score-26 {background-color: #FFFF46;}
|
||||
.cython.score-27 {background-color: #FFFF44;}
|
||||
.cython.score-28 {background-color: #FFFF43;}
|
||||
.cython.score-29 {background-color: #FFFF41;}
|
||||
.cython.score-30 {background-color: #FFFF3f;}
|
||||
.cython.score-31 {background-color: #FFFF3e;}
|
||||
.cython.score-32 {background-color: #FFFF3c;}
|
||||
.cython.score-33 {background-color: #FFFF3b;}
|
||||
.cython.score-34 {background-color: #FFFF39;}
|
||||
.cython.score-35 {background-color: #FFFF38;}
|
||||
.cython.score-36 {background-color: #FFFF37;}
|
||||
.cython.score-37 {background-color: #FFFF36;}
|
||||
.cython.score-38 {background-color: #FFFF35;}
|
||||
.cython.score-39 {background-color: #FFFF34;}
|
||||
.cython.score-40 {background-color: #FFFF33;}
|
||||
.cython.score-41 {background-color: #FFFF32;}
|
||||
.cython.score-42 {background-color: #FFFF31;}
|
||||
.cython.score-43 {background-color: #FFFF30;}
|
||||
.cython.score-44 {background-color: #FFFF2f;}
|
||||
.cython.score-45 {background-color: #FFFF2e;}
|
||||
.cython.score-46 {background-color: #FFFF2d;}
|
||||
.cython.score-47 {background-color: #FFFF2c;}
|
||||
.cython.score-48 {background-color: #FFFF2b;}
|
||||
.cython.score-49 {background-color: #FFFF2b;}
|
||||
.cython.score-50 {background-color: #FFFF2a;}
|
||||
.cython.score-51 {background-color: #FFFF29;}
|
||||
.cython.score-52 {background-color: #FFFF29;}
|
||||
.cython.score-53 {background-color: #FFFF28;}
|
||||
.cython.score-54 {background-color: #FFFF27;}
|
||||
.cython.score-55 {background-color: #FFFF27;}
|
||||
.cython.score-56 {background-color: #FFFF26;}
|
||||
.cython.score-57 {background-color: #FFFF26;}
|
||||
.cython.score-58 {background-color: #FFFF25;}
|
||||
.cython.score-59 {background-color: #FFFF24;}
|
||||
.cython.score-60 {background-color: #FFFF24;}
|
||||
.cython.score-61 {background-color: #FFFF23;}
|
||||
.cython.score-62 {background-color: #FFFF23;}
|
||||
.cython.score-63 {background-color: #FFFF22;}
|
||||
.cython.score-64 {background-color: #FFFF22;}
|
||||
.cython.score-65 {background-color: #FFFF22;}
|
||||
.cython.score-66 {background-color: #FFFF21;}
|
||||
.cython.score-67 {background-color: #FFFF21;}
|
||||
.cython.score-68 {background-color: #FFFF20;}
|
||||
.cython.score-69 {background-color: #FFFF20;}
|
||||
.cython.score-70 {background-color: #FFFF1f;}
|
||||
.cython.score-71 {background-color: #FFFF1f;}
|
||||
.cython.score-72 {background-color: #FFFF1f;}
|
||||
.cython.score-73 {background-color: #FFFF1e;}
|
||||
.cython.score-74 {background-color: #FFFF1e;}
|
||||
.cython.score-75 {background-color: #FFFF1e;}
|
||||
.cython.score-76 {background-color: #FFFF1d;}
|
||||
.cython.score-77 {background-color: #FFFF1d;}
|
||||
.cython.score-78 {background-color: #FFFF1c;}
|
||||
.cython.score-79 {background-color: #FFFF1c;}
|
||||
.cython.score-80 {background-color: #FFFF1c;}
|
||||
.cython.score-81 {background-color: #FFFF1c;}
|
||||
.cython.score-82 {background-color: #FFFF1b;}
|
||||
.cython.score-83 {background-color: #FFFF1b;}
|
||||
.cython.score-84 {background-color: #FFFF1b;}
|
||||
.cython.score-85 {background-color: #FFFF1a;}
|
||||
.cython.score-86 {background-color: #FFFF1a;}
|
||||
.cython.score-87 {background-color: #FFFF1a;}
|
||||
.cython.score-88 {background-color: #FFFF1a;}
|
||||
.cython.score-89 {background-color: #FFFF19;}
|
||||
.cython.score-90 {background-color: #FFFF19;}
|
||||
.cython.score-91 {background-color: #FFFF19;}
|
||||
.cython.score-92 {background-color: #FFFF19;}
|
||||
.cython.score-93 {background-color: #FFFF18;}
|
||||
.cython.score-94 {background-color: #FFFF18;}
|
||||
.cython.score-95 {background-color: #FFFF18;}
|
||||
.cython.score-96 {background-color: #FFFF18;}
|
||||
.cython.score-97 {background-color: #FFFF17;}
|
||||
.cython.score-98 {background-color: #FFFF17;}
|
||||
.cython.score-99 {background-color: #FFFF17;}
|
||||
.cython.score-100 {background-color: #FFFF17;}
|
||||
.cython.score-101 {background-color: #FFFF16;}
|
||||
.cython.score-102 {background-color: #FFFF16;}
|
||||
.cython.score-103 {background-color: #FFFF16;}
|
||||
.cython.score-104 {background-color: #FFFF16;}
|
||||
.cython.score-105 {background-color: #FFFF16;}
|
||||
.cython.score-106 {background-color: #FFFF15;}
|
||||
.cython.score-107 {background-color: #FFFF15;}
|
||||
.cython.score-108 {background-color: #FFFF15;}
|
||||
.cython.score-109 {background-color: #FFFF15;}
|
||||
.cython.score-110 {background-color: #FFFF15;}
|
||||
.cython.score-111 {background-color: #FFFF15;}
|
||||
.cython.score-112 {background-color: #FFFF14;}
|
||||
.cython.score-113 {background-color: #FFFF14;}
|
||||
.cython.score-114 {background-color: #FFFF14;}
|
||||
.cython.score-115 {background-color: #FFFF14;}
|
||||
.cython.score-116 {background-color: #FFFF14;}
|
||||
.cython.score-117 {background-color: #FFFF14;}
|
||||
.cython.score-118 {background-color: #FFFF13;}
|
||||
.cython.score-119 {background-color: #FFFF13;}
|
||||
.cython.score-120 {background-color: #FFFF13;}
|
||||
.cython.score-121 {background-color: #FFFF13;}
|
||||
.cython.score-122 {background-color: #FFFF13;}
|
||||
.cython.score-123 {background-color: #FFFF13;}
|
||||
.cython.score-124 {background-color: #FFFF13;}
|
||||
.cython.score-125 {background-color: #FFFF12;}
|
||||
.cython.score-126 {background-color: #FFFF12;}
|
||||
.cython.score-127 {background-color: #FFFF12;}
|
||||
.cython.score-128 {background-color: #FFFF12;}
|
||||
.cython.score-129 {background-color: #FFFF12;}
|
||||
.cython.score-130 {background-color: #FFFF12;}
|
||||
.cython.score-131 {background-color: #FFFF12;}
|
||||
.cython.score-132 {background-color: #FFFF11;}
|
||||
.cython.score-133 {background-color: #FFFF11;}
|
||||
.cython.score-134 {background-color: #FFFF11;}
|
||||
.cython.score-135 {background-color: #FFFF11;}
|
||||
.cython.score-136 {background-color: #FFFF11;}
|
||||
.cython.score-137 {background-color: #FFFF11;}
|
||||
.cython.score-138 {background-color: #FFFF11;}
|
||||
.cython.score-139 {background-color: #FFFF11;}
|
||||
.cython.score-140 {background-color: #FFFF11;}
|
||||
.cython.score-141 {background-color: #FFFF10;}
|
||||
.cython.score-142 {background-color: #FFFF10;}
|
||||
.cython.score-143 {background-color: #FFFF10;}
|
||||
.cython.score-144 {background-color: #FFFF10;}
|
||||
.cython.score-145 {background-color: #FFFF10;}
|
||||
.cython.score-146 {background-color: #FFFF10;}
|
||||
.cython.score-147 {background-color: #FFFF10;}
|
||||
.cython.score-148 {background-color: #FFFF10;}
|
||||
.cython.score-149 {background-color: #FFFF10;}
|
||||
.cython.score-150 {background-color: #FFFF0f;}
|
||||
.cython.score-151 {background-color: #FFFF0f;}
|
||||
.cython.score-152 {background-color: #FFFF0f;}
|
||||
.cython.score-153 {background-color: #FFFF0f;}
|
||||
.cython.score-154 {background-color: #FFFF0f;}
|
||||
.cython.score-155 {background-color: #FFFF0f;}
|
||||
.cython.score-156 {background-color: #FFFF0f;}
|
||||
.cython.score-157 {background-color: #FFFF0f;}
|
||||
.cython.score-158 {background-color: #FFFF0f;}
|
||||
.cython.score-159 {background-color: #FFFF0f;}
|
||||
.cython.score-160 {background-color: #FFFF0f;}
|
||||
.cython.score-161 {background-color: #FFFF0e;}
|
||||
.cython.score-162 {background-color: #FFFF0e;}
|
||||
.cython.score-163 {background-color: #FFFF0e;}
|
||||
.cython.score-164 {background-color: #FFFF0e;}
|
||||
.cython.score-165 {background-color: #FFFF0e;}
|
||||
.cython.score-166 {background-color: #FFFF0e;}
|
||||
.cython.score-167 {background-color: #FFFF0e;}
|
||||
.cython.score-168 {background-color: #FFFF0e;}
|
||||
.cython.score-169 {background-color: #FFFF0e;}
|
||||
.cython.score-170 {background-color: #FFFF0e;}
|
||||
.cython.score-171 {background-color: #FFFF0e;}
|
||||
.cython.score-172 {background-color: #FFFF0e;}
|
||||
.cython.score-173 {background-color: #FFFF0d;}
|
||||
.cython.score-174 {background-color: #FFFF0d;}
|
||||
.cython.score-175 {background-color: #FFFF0d;}
|
||||
.cython.score-176 {background-color: #FFFF0d;}
|
||||
.cython.score-177 {background-color: #FFFF0d;}
|
||||
.cython.score-178 {background-color: #FFFF0d;}
|
||||
.cython.score-179 {background-color: #FFFF0d;}
|
||||
.cython.score-180 {background-color: #FFFF0d;}
|
||||
.cython.score-181 {background-color: #FFFF0d;}
|
||||
.cython.score-182 {background-color: #FFFF0d;}
|
||||
.cython.score-183 {background-color: #FFFF0d;}
|
||||
.cython.score-184 {background-color: #FFFF0d;}
|
||||
.cython.score-185 {background-color: #FFFF0d;}
|
||||
.cython.score-186 {background-color: #FFFF0d;}
|
||||
.cython.score-187 {background-color: #FFFF0c;}
|
||||
.cython.score-188 {background-color: #FFFF0c;}
|
||||
.cython.score-189 {background-color: #FFFF0c;}
|
||||
.cython.score-190 {background-color: #FFFF0c;}
|
||||
.cython.score-191 {background-color: #FFFF0c;}
|
||||
.cython.score-192 {background-color: #FFFF0c;}
|
||||
.cython.score-193 {background-color: #FFFF0c;}
|
||||
.cython.score-194 {background-color: #FFFF0c;}
|
||||
.cython.score-195 {background-color: #FFFF0c;}
|
||||
.cython.score-196 {background-color: #FFFF0c;}
|
||||
.cython.score-197 {background-color: #FFFF0c;}
|
||||
.cython.score-198 {background-color: #FFFF0c;}
|
||||
.cython.score-199 {background-color: #FFFF0c;}
|
||||
.cython.score-200 {background-color: #FFFF0c;}
|
||||
.cython.score-201 {background-color: #FFFF0c;}
|
||||
.cython.score-202 {background-color: #FFFF0c;}
|
||||
.cython.score-203 {background-color: #FFFF0b;}
|
||||
.cython.score-204 {background-color: #FFFF0b;}
|
||||
.cython.score-205 {background-color: #FFFF0b;}
|
||||
.cython.score-206 {background-color: #FFFF0b;}
|
||||
.cython.score-207 {background-color: #FFFF0b;}
|
||||
.cython.score-208 {background-color: #FFFF0b;}
|
||||
.cython.score-209 {background-color: #FFFF0b;}
|
||||
.cython.score-210 {background-color: #FFFF0b;}
|
||||
.cython.score-211 {background-color: #FFFF0b;}
|
||||
.cython.score-212 {background-color: #FFFF0b;}
|
||||
.cython.score-213 {background-color: #FFFF0b;}
|
||||
.cython.score-214 {background-color: #FFFF0b;}
|
||||
.cython.score-215 {background-color: #FFFF0b;}
|
||||
.cython.score-216 {background-color: #FFFF0b;}
|
||||
.cython.score-217 {background-color: #FFFF0b;}
|
||||
.cython.score-218 {background-color: #FFFF0b;}
|
||||
.cython.score-219 {background-color: #FFFF0b;}
|
||||
.cython.score-220 {background-color: #FFFF0b;}
|
||||
.cython.score-221 {background-color: #FFFF0b;}
|
||||
.cython.score-222 {background-color: #FFFF0a;}
|
||||
.cython.score-223 {background-color: #FFFF0a;}
|
||||
.cython.score-224 {background-color: #FFFF0a;}
|
||||
.cython.score-225 {background-color: #FFFF0a;}
|
||||
.cython.score-226 {background-color: #FFFF0a;}
|
||||
.cython.score-227 {background-color: #FFFF0a;}
|
||||
.cython.score-228 {background-color: #FFFF0a;}
|
||||
.cython.score-229 {background-color: #FFFF0a;}
|
||||
.cython.score-230 {background-color: #FFFF0a;}
|
||||
.cython.score-231 {background-color: #FFFF0a;}
|
||||
.cython.score-232 {background-color: #FFFF0a;}
|
||||
.cython.score-233 {background-color: #FFFF0a;}
|
||||
.cython.score-234 {background-color: #FFFF0a;}
|
||||
.cython.score-235 {background-color: #FFFF0a;}
|
||||
.cython.score-236 {background-color: #FFFF0a;}
|
||||
.cython.score-237 {background-color: #FFFF0a;}
|
||||
.cython.score-238 {background-color: #FFFF0a;}
|
||||
.cython.score-239 {background-color: #FFFF0a;}
|
||||
.cython.score-240 {background-color: #FFFF0a;}
|
||||
.cython.score-241 {background-color: #FFFF0a;}
|
||||
.cython.score-242 {background-color: #FFFF0a;}
|
||||
.cython.score-243 {background-color: #FFFF0a;}
|
||||
.cython.score-244 {background-color: #FFFF0a;}
|
||||
.cython.score-245 {background-color: #FFFF0a;}
|
||||
.cython.score-246 {background-color: #FFFF09;}
|
||||
.cython.score-247 {background-color: #FFFF09;}
|
||||
.cython.score-248 {background-color: #FFFF09;}
|
||||
.cython.score-249 {background-color: #FFFF09;}
|
||||
.cython.score-250 {background-color: #FFFF09;}
|
||||
.cython.score-251 {background-color: #FFFF09;}
|
||||
.cython.score-252 {background-color: #FFFF09;}
|
||||
.cython.score-253 {background-color: #FFFF09;}
|
||||
.cython.score-254 {background-color: #FFFF09;}
|
||||
.cython .hll { background-color: #ffffcc }
|
||||
.cython { background: #f8f8f8; }
|
||||
.cython .c { color: #408080; font-style: italic } /* Comment */
|
||||
.cython .err { border: 1px solid #FF0000 } /* Error */
|
||||
.cython .k { color: #008000; font-weight: bold } /* Keyword */
|
||||
.cython .o { color: #666666 } /* Operator */
|
||||
.cython .ch { color: #408080; font-style: italic } /* Comment.Hashbang */
|
||||
.cython .cm { color: #408080; font-style: italic } /* Comment.Multiline */
|
||||
.cython .cp { color: #BC7A00 } /* Comment.Preproc */
|
||||
.cython .cpf { color: #408080; font-style: italic } /* Comment.PreprocFile */
|
||||
.cython .c1 { color: #408080; font-style: italic } /* Comment.Single */
|
||||
.cython .cs { color: #408080; font-style: italic } /* Comment.Special */
|
||||
.cython .gd { color: #A00000 } /* Generic.Deleted */
|
||||
.cython .ge { font-style: italic } /* Generic.Emph */
|
||||
.cython .gr { color: #FF0000 } /* Generic.Error */
|
||||
.cython .gh { color: #000080; font-weight: bold } /* Generic.Heading */
|
||||
.cython .gi { color: #00A000 } /* Generic.Inserted */
|
||||
.cython .go { color: #888888 } /* Generic.Output */
|
||||
.cython .gp { color: #000080; font-weight: bold } /* Generic.Prompt */
|
||||
.cython .gs { font-weight: bold } /* Generic.Strong */
|
||||
.cython .gu { color: #800080; font-weight: bold } /* Generic.Subheading */
|
||||
.cython .gt { color: #0044DD } /* Generic.Traceback */
|
||||
.cython .kc { color: #008000; font-weight: bold } /* Keyword.Constant */
|
||||
.cython .kd { color: #008000; font-weight: bold } /* Keyword.Declaration */
|
||||
.cython .kn { color: #008000; font-weight: bold } /* Keyword.Namespace */
|
||||
.cython .kp { color: #008000 } /* Keyword.Pseudo */
|
||||
.cython .kr { color: #008000; font-weight: bold } /* Keyword.Reserved */
|
||||
.cython .kt { color: #B00040 } /* Keyword.Type */
|
||||
.cython .m { color: #666666 } /* Literal.Number */
|
||||
.cython .s { color: #BA2121 } /* Literal.String */
|
||||
.cython .na { color: #7D9029 } /* Name.Attribute */
|
||||
.cython .nb { color: #008000 } /* Name.Builtin */
|
||||
.cython .nc { color: #0000FF; font-weight: bold } /* Name.Class */
|
||||
.cython .no { color: #880000 } /* Name.Constant */
|
||||
.cython .nd { color: #AA22FF } /* Name.Decorator */
|
||||
.cython .ni { color: #999999; font-weight: bold } /* Name.Entity */
|
||||
.cython .ne { color: #D2413A; font-weight: bold } /* Name.Exception */
|
||||
.cython .nf { color: #0000FF } /* Name.Function */
|
||||
.cython .nl { color: #A0A000 } /* Name.Label */
|
||||
.cython .nn { color: #0000FF; font-weight: bold } /* Name.Namespace */
|
||||
.cython .nt { color: #008000; font-weight: bold } /* Name.Tag */
|
||||
.cython .nv { color: #19177C } /* Name.Variable */
|
||||
.cython .ow { color: #AA22FF; font-weight: bold } /* Operator.Word */
|
||||
.cython .w { color: #bbbbbb } /* Text.Whitespace */
|
||||
.cython .mb { color: #666666 } /* Literal.Number.Bin */
|
||||
.cython .mf { color: #666666 } /* Literal.Number.Float */
|
||||
.cython .mh { color: #666666 } /* Literal.Number.Hex */
|
||||
.cython .mi { color: #666666 } /* Literal.Number.Integer */
|
||||
.cython .mo { color: #666666 } /* Literal.Number.Oct */
|
||||
.cython .sa { color: #BA2121 } /* Literal.String.Affix */
|
||||
.cython .sb { color: #BA2121 } /* Literal.String.Backtick */
|
||||
.cython .sc { color: #BA2121 } /* Literal.String.Char */
|
||||
.cython .dl { color: #BA2121 } /* Literal.String.Delimiter */
|
||||
.cython .sd { color: #BA2121; font-style: italic } /* Literal.String.Doc */
|
||||
.cython .s2 { color: #BA2121 } /* Literal.String.Double */
|
||||
.cython .se { color: #BB6622; font-weight: bold } /* Literal.String.Escape */
|
||||
.cython .sh { color: #BA2121 } /* Literal.String.Heredoc */
|
||||
.cython .si { color: #BB6688; font-weight: bold } /* Literal.String.Interpol */
|
||||
.cython .sx { color: #008000 } /* Literal.String.Other */
|
||||
.cython .sr { color: #BB6688 } /* Literal.String.Regex */
|
||||
.cython .s1 { color: #BA2121 } /* Literal.String.Single */
|
||||
.cython .ss { color: #19177C } /* Literal.String.Symbol */
|
||||
.cython .bp { color: #008000 } /* Name.Builtin.Pseudo */
|
||||
.cython .fm { color: #0000FF } /* Name.Function.Magic */
|
||||
.cython .vc { color: #19177C } /* Name.Variable.Class */
|
||||
.cython .vg { color: #19177C } /* Name.Variable.Global */
|
||||
.cython .vi { color: #19177C } /* Name.Variable.Instance */
|
||||
.cython .vm { color: #19177C } /* Name.Variable.Magic */
|
||||
.cython .il { color: #666666 } /* Literal.Number.Integer.Long */
|
||||
</style>
|
||||
</head>
|
||||
<body class="cython">
|
||||
<p><span style="border-bottom: solid 1px grey;">Generated by Cython 0.29</span></p>
|
||||
<p>
|
||||
<span style="background-color: #FFFF00">Yellow lines</span> hint at Python interaction.<br />
|
||||
Click on a line that starts with a "<code>+</code>" to see the C code that Cython generated for it.
|
||||
</p>
|
||||
<p>Raw output: <a href="lcn.c">lcn.c</a></p>
|
||||
<div class="cython"><pre class="cython line score-16" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">01</span>: <span class="k">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span></pre>
|
||||
<pre class='cython code score-16 '> __pyx_t_1 = <span class='pyx_c_api'>__Pyx_Import</span>(__pyx_n_s_numpy, 0, -1);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1, __pyx_L1_error)</span>
|
||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
||||
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_np, __pyx_t_1) < 0) <span class='error_goto'>__PYX_ERR(0, 1, __pyx_L1_error)</span>
|
||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
||||
/* … */
|
||||
__pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyDict_NewPresized</span>(0);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1, __pyx_L1_error)</span>
|
||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
||||
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_test, __pyx_t_1) < 0) <span class='error_goto'>__PYX_ERR(0, 1, __pyx_L1_error)</span>
|
||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
||||
</pre><pre class="cython line score-0"> <span class="">02</span>: <span class="k">cimport</span> <span class="nn">cython</span></pre>
|
||||
<pre class="cython line score-0"> <span class="">03</span>: </pre>
|
||||
<pre class="cython line score-0"> <span class="">04</span>: <span class="c"># use c square root function</span></pre>
|
||||
<pre class="cython line score-0"> <span class="">05</span>: <span class="k">cdef</span> <span class="kr">extern</span> <span class="k">from</span> <span class="s">"math.h"</span><span class="p">:</span></pre>
|
||||
<pre class="cython line score-0"> <span class="">06</span>: <span class="nb">float</span> <span class="n">sqrt</span><span class="p">(</span><span class="nb">float</span> <span class="n">x</span><span class="p">)</span></pre>
|
||||
<pre class="cython line score-0"> <span class="">07</span>: </pre>
|
||||
<pre class="cython line score-0"> <span class="">08</span>: <span class="nd">@cython</span><span class="o">.</span><span class="n">boundscheck</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span></pre>
|
||||
<pre class="cython line score-0"> <span class="">09</span>: <span class="nd">@cython</span><span class="o">.</span><span class="n">wraparound</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span></pre>
|
||||
<pre class="cython line score-0"> <span class="">10</span>: <span class="nd">@cython</span><span class="o">.</span><span class="n">cdivision</span><span class="p">(</span><span class="bp">True</span><span class="p">)</span></pre>
|
||||
<pre class="cython line score-0"> <span class="">11</span>: </pre>
|
||||
<pre class="cython line score-0"> <span class="">12</span>: <span class="c"># 3 parameters:</span></pre>
|
||||
<pre class="cython line score-0"> <span class="">13</span>: <span class="c"># - float image</span></pre>
|
||||
<pre class="cython line score-0"> <span class="">14</span>: <span class="c"># - kernel size (actually this is the radius, kernel is 2*k+1)</span></pre>
|
||||
<pre class="cython line score-0"> <span class="">15</span>: <span class="c"># - small constant epsilon that is used to avoid division by zero</span></pre>
|
||||
<pre class="cython line score-67" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">16</span>: <span class="k">def</span> <span class="nf">normalize</span><span class="p">(</span><span class="nb">float</span><span class="p">[:,</span> <span class="p">:]</span> <span class="n">img</span><span class="p">,</span> <span class="nb">int</span> <span class="n">kernel_size</span> <span class="o">=</span> <span class="mf">4</span><span class="p">,</span> <span class="nb">float</span> <span class="n">epsilon</span> <span class="o">=</span> <span class="mf">0.01</span><span class="p">):</span></pre>
|
||||
<pre class='cython code score-67 '>/* Python wrapper */
|
||||
static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/
|
||||
static PyMethodDef __pyx_mdef_3lcn_1normalize = {"normalize", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_3lcn_1normalize, METH_VARARGS|METH_KEYWORDS, 0};
|
||||
static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) {
|
||||
__Pyx_memviewslice __pyx_v_img = { 0, 0, { 0 }, { 0 }, { 0 } };
|
||||
int __pyx_v_kernel_size;
|
||||
float __pyx_v_epsilon;
|
||||
PyObject *__pyx_r = 0;
|
||||
<span class='refnanny'>__Pyx_RefNannyDeclarations</span>
|
||||
<span class='refnanny'>__Pyx_RefNannySetupContext</span>("normalize (wrapper)", 0);
|
||||
{
|
||||
static PyObject **__pyx_pyargnames[] = {&__pyx_n_s_img,&__pyx_n_s_kernel_size,&__pyx_n_s_epsilon,0};
|
||||
PyObject* values[3] = {0,0,0};
|
||||
if (unlikely(__pyx_kwds)) {
|
||||
Py_ssize_t kw_args;
|
||||
const Py_ssize_t pos_args = <span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args);
|
||||
switch (pos_args) {
|
||||
case 3: values[2] = <span class='py_macro_api'>PyTuple_GET_ITEM</span>(__pyx_args, 2);
|
||||
CYTHON_FALLTHROUGH;
|
||||
case 2: values[1] = <span class='py_macro_api'>PyTuple_GET_ITEM</span>(__pyx_args, 1);
|
||||
CYTHON_FALLTHROUGH;
|
||||
case 1: values[0] = <span class='py_macro_api'>PyTuple_GET_ITEM</span>(__pyx_args, 0);
|
||||
CYTHON_FALLTHROUGH;
|
||||
case 0: break;
|
||||
default: goto __pyx_L5_argtuple_error;
|
||||
}
|
||||
kw_args = <span class='py_c_api'>PyDict_Size</span>(__pyx_kwds);
|
||||
switch (pos_args) {
|
||||
case 0:
|
||||
if (likely((values[0] = <span class='pyx_c_api'>__Pyx_PyDict_GetItemStr</span>(__pyx_kwds, __pyx_n_s_img)) != 0)) kw_args--;
|
||||
else goto __pyx_L5_argtuple_error;
|
||||
CYTHON_FALLTHROUGH;
|
||||
case 1:
|
||||
if (kw_args > 0) {
|
||||
PyObject* value = <span class='pyx_c_api'>__Pyx_PyDict_GetItemStr</span>(__pyx_kwds, __pyx_n_s_kernel_size);
|
||||
if (value) { values[1] = value; kw_args--; }
|
||||
}
|
||||
CYTHON_FALLTHROUGH;
|
||||
case 2:
|
||||
if (kw_args > 0) {
|
||||
PyObject* value = <span class='pyx_c_api'>__Pyx_PyDict_GetItemStr</span>(__pyx_kwds, __pyx_n_s_epsilon);
|
||||
if (value) { values[2] = value; kw_args--; }
|
||||
}
|
||||
}
|
||||
if (unlikely(kw_args > 0)) {
|
||||
if (unlikely(<span class='pyx_c_api'>__Pyx_ParseOptionalKeywords</span>(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "normalize") < 0)) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
|
||||
}
|
||||
} else {
|
||||
switch (<span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)) {
|
||||
case 3: values[2] = <span class='py_macro_api'>PyTuple_GET_ITEM</span>(__pyx_args, 2);
|
||||
CYTHON_FALLTHROUGH;
|
||||
case 2: values[1] = <span class='py_macro_api'>PyTuple_GET_ITEM</span>(__pyx_args, 1);
|
||||
CYTHON_FALLTHROUGH;
|
||||
case 1: values[0] = <span class='py_macro_api'>PyTuple_GET_ITEM</span>(__pyx_args, 0);
|
||||
break;
|
||||
default: goto __pyx_L5_argtuple_error;
|
||||
}
|
||||
}
|
||||
__pyx_v_img = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(values[0], PyBUF_WRITABLE);<span class='error_goto'> if (unlikely(!__pyx_v_img.memview)) __PYX_ERR(0, 16, __pyx_L3_error)</span>
|
||||
if (values[1]) {
|
||||
__pyx_v_kernel_size = <span class='pyx_c_api'>__Pyx_PyInt_As_int</span>(values[1]); if (unlikely((__pyx_v_kernel_size == (int)-1) && <span class='py_c_api'>PyErr_Occurred</span>())) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
|
||||
} else {
|
||||
__pyx_v_kernel_size = ((int)4);
|
||||
}
|
||||
if (values[2]) {
|
||||
__pyx_v_epsilon = __pyx_<span class='py_c_api'>PyFloat_AsFloat</span>(values[2]); if (unlikely((__pyx_v_epsilon == (float)-1) && <span class='py_c_api'>PyErr_Occurred</span>())) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
|
||||
} else {
|
||||
__pyx_v_epsilon = ((float)0.01);
|
||||
}
|
||||
}
|
||||
goto __pyx_L4_argument_unpacking_done;
|
||||
__pyx_L5_argtuple_error:;
|
||||
<span class='pyx_c_api'>__Pyx_RaiseArgtupleInvalid</span>("normalize", 0, 1, 3, <span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)); <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
|
||||
__pyx_L3_error:;
|
||||
<span class='pyx_c_api'>__Pyx_AddTraceback</span>("lcn.normalize", __pyx_clineno, __pyx_lineno, __pyx_filename);
|
||||
<span class='refnanny'>__Pyx_RefNannyFinishContext</span>();
|
||||
return NULL;
|
||||
__pyx_L4_argument_unpacking_done:;
|
||||
__pyx_r = __pyx_pf_3lcn_normalize(__pyx_self, __pyx_v_img, __pyx_v_kernel_size, __pyx_v_epsilon);
|
||||
|
||||
/* function exit code */
|
||||
<span class='refnanny'>__Pyx_RefNannyFinishContext</span>();
|
||||
return __pyx_r;
|
||||
}
|
||||
|
||||
static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __Pyx_memviewslice __pyx_v_img, int __pyx_v_kernel_size, float __pyx_v_epsilon) {
|
||||
Py_ssize_t __pyx_v_M;
|
||||
Py_ssize_t __pyx_v_N;
|
||||
PyObject *__pyx_v_img_lcn = NULL;
|
||||
PyObject *__pyx_v_img_std = NULL;
|
||||
__Pyx_memviewslice __pyx_v_img_lcn_view = { 0, 0, { 0 }, { 0 }, { 0 } };
|
||||
__Pyx_memviewslice __pyx_v_img_std_view = { 0, 0, { 0 }, { 0 }, { 0 } };
|
||||
float __pyx_v_mean;
|
||||
float __pyx_v_stddev;
|
||||
Py_ssize_t __pyx_v_m;
|
||||
Py_ssize_t __pyx_v_n;
|
||||
Py_ssize_t __pyx_v_i;
|
||||
Py_ssize_t __pyx_v_j;
|
||||
Py_ssize_t __pyx_v_ks;
|
||||
float __pyx_v_eps;
|
||||
float __pyx_v_num;
|
||||
PyObject *__pyx_r = NULL;
|
||||
<span class='refnanny'>__Pyx_RefNannyDeclarations</span>
|
||||
<span class='refnanny'>__Pyx_RefNannySetupContext</span>("normalize", 0);
|
||||
/* … */
|
||||
/* function exit code */
|
||||
__pyx_L1_error:;
|
||||
<span class='pyx_macro_api'>__Pyx_XDECREF</span>(__pyx_t_1);
|
||||
<span class='pyx_macro_api'>__Pyx_XDECREF</span>(__pyx_t_2);
|
||||
<span class='pyx_macro_api'>__Pyx_XDECREF</span>(__pyx_t_3);
|
||||
<span class='pyx_macro_api'>__Pyx_XDECREF</span>(__pyx_t_4);
|
||||
<span class='pyx_macro_api'>__Pyx_XDECREF</span>(__pyx_t_5);
|
||||
__PYX_XDEC_MEMVIEW(&__pyx_t_6, 1);
|
||||
<span class='pyx_c_api'>__Pyx_AddTraceback</span>("lcn.normalize", __pyx_clineno, __pyx_lineno, __pyx_filename);
|
||||
__pyx_r = NULL;
|
||||
__pyx_L0:;
|
||||
<span class='pyx_macro_api'>__Pyx_XDECREF</span>(__pyx_v_img_lcn);
|
||||
<span class='pyx_macro_api'>__Pyx_XDECREF</span>(__pyx_v_img_std);
|
||||
__PYX_XDEC_MEMVIEW(&__pyx_v_img_lcn_view, 1);
|
||||
__PYX_XDEC_MEMVIEW(&__pyx_v_img_std_view, 1);
|
||||
__PYX_XDEC_MEMVIEW(&__pyx_v_img, 1);
|
||||
<span class='refnanny'>__Pyx_XGIVEREF</span>(__pyx_r);
|
||||
<span class='refnanny'>__Pyx_RefNannyFinishContext</span>();
|
||||
return __pyx_r;
|
||||
}
|
||||
/* … */
|
||||
__pyx_tuple__19 = <span class='py_c_api'>PyTuple_Pack</span>(19, __pyx_n_s_img, __pyx_n_s_kernel_size, __pyx_n_s_epsilon, __pyx_n_s_M, __pyx_n_s_N, __pyx_n_s_img_lcn, __pyx_n_s_img_std, __pyx_n_s_img_lcn_view, __pyx_n_s_img_std_view, __pyx_n_s_tmp, __pyx_n_s_mean, __pyx_n_s_stddev, __pyx_n_s_m, __pyx_n_s_n, __pyx_n_s_i, __pyx_n_s_j, __pyx_n_s_ks, __pyx_n_s_eps, __pyx_n_s_num);<span class='error_goto'> if (unlikely(!__pyx_tuple__19)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
|
||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_tuple__19);
|
||||
<span class='refnanny'>__Pyx_GIVEREF</span>(__pyx_tuple__19);
|
||||
/* … */
|
||||
__pyx_t_1 = PyCFunction_NewEx(&__pyx_mdef_3lcn_1normalize, NULL, __pyx_n_s_lcn);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
|
||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
||||
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_normalize, __pyx_t_1) < 0) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L1_error)</span>
|
||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
||||
__pyx_codeobj__20 = (PyObject*)<span class='pyx_c_api'>__Pyx_PyCode_New</span>(3, 0, 19, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__19, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_lcn_pyx, __pyx_n_s_normalize, 16, __pyx_empty_bytes);<span class='error_goto'> if (unlikely(!__pyx_codeobj__20)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
|
||||
</pre><pre class="cython line score-0"> <span class="">17</span>: </pre>
|
||||
<pre class="cython line score-0"> <span class="">18</span>: <span class="c"># image dimensions</span></pre>
|
||||
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">19</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">M</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mf">0</span><span class="p">]</span></pre>
|
||||
<pre class='cython code score-0 '> __pyx_v_M = (__pyx_v_img.shape[0]);
|
||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">20</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">N</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mf">1</span><span class="p">]</span></pre>
|
||||
<pre class='cython code score-0 '> __pyx_v_N = (__pyx_v_img.shape[1]);
|
||||
</pre><pre class="cython line score-0"> <span class="">21</span>: </pre>
|
||||
<pre class="cython line score-0"> <span class="">22</span>: <span class="c"># create outputs and output views</span></pre>
|
||||
<pre class="cython line score-46" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">23</span>: <span class="n">img_lcn</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span></pre>
|
||||
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
||||
__pyx_t_2 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_zeros);<span class='error_goto'> if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_2);
|
||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
||||
__pyx_t_1 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_M);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
||||
__pyx_t_3 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_N);<span class='error_goto'> if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_3);
|
||||
__pyx_t_4 = <span class='py_c_api'>PyTuple_New</span>(2);<span class='error_goto'> if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_4);
|
||||
<span class='refnanny'>__Pyx_GIVEREF</span>(__pyx_t_1);
|
||||
<span class='py_macro_api'>PyTuple_SET_ITEM</span>(__pyx_t_4, 0, __pyx_t_1);
|
||||
<span class='refnanny'>__Pyx_GIVEREF</span>(__pyx_t_3);
|
||||
<span class='py_macro_api'>PyTuple_SET_ITEM</span>(__pyx_t_4, 1, __pyx_t_3);
|
||||
__pyx_t_1 = 0;
|
||||
__pyx_t_3 = 0;
|
||||
__pyx_t_3 = <span class='py_c_api'>PyTuple_New</span>(1);<span class='error_goto'> if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_3);
|
||||
<span class='refnanny'>__Pyx_GIVEREF</span>(__pyx_t_4);
|
||||
<span class='py_macro_api'>PyTuple_SET_ITEM</span>(__pyx_t_3, 0, __pyx_t_4);
|
||||
__pyx_t_4 = 0;
|
||||
__pyx_t_4 = <span class='pyx_c_api'>__Pyx_PyDict_NewPresized</span>(1);<span class='error_goto'> if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_4);
|
||||
<span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
||||
__pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_float32);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
|
||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
||||
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_4, __pyx_n_s_dtype, __pyx_t_5) < 0) <span class='error_goto'>__PYX_ERR(0, 23, __pyx_L1_error)</span>
|
||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
|
||||
__pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_2, __pyx_t_3, __pyx_t_4);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
|
||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
|
||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_2); __pyx_t_2 = 0;
|
||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_3); __pyx_t_3 = 0;
|
||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_4); __pyx_t_4 = 0;
|
||||
__pyx_v_img_lcn = __pyx_t_5;
|
||||
__pyx_t_5 = 0;
|
||||
</pre><pre class="cython line score-46" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">24</span>: <span class="n">img_std</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span></pre>
|
||||
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
|
||||
__pyx_t_4 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_zeros);<span class='error_goto'> if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_4);
|
||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
|
||||
__pyx_t_5 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_M);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
|
||||
__pyx_t_3 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_N);<span class='error_goto'> if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_3);
|
||||
__pyx_t_2 = <span class='py_c_api'>PyTuple_New</span>(2);<span class='error_goto'> if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_2);
|
||||
<span class='refnanny'>__Pyx_GIVEREF</span>(__pyx_t_5);
|
||||
<span class='py_macro_api'>PyTuple_SET_ITEM</span>(__pyx_t_2, 0, __pyx_t_5);
|
||||
<span class='refnanny'>__Pyx_GIVEREF</span>(__pyx_t_3);
|
||||
<span class='py_macro_api'>PyTuple_SET_ITEM</span>(__pyx_t_2, 1, __pyx_t_3);
|
||||
__pyx_t_5 = 0;
|
||||
__pyx_t_3 = 0;
|
||||
__pyx_t_3 = <span class='py_c_api'>PyTuple_New</span>(1);<span class='error_goto'> if (unlikely(!__pyx_t_3)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_3);
|
||||
<span class='refnanny'>__Pyx_GIVEREF</span>(__pyx_t_2);
|
||||
<span class='py_macro_api'>PyTuple_SET_ITEM</span>(__pyx_t_3, 0, __pyx_t_2);
|
||||
__pyx_t_2 = 0;
|
||||
__pyx_t_2 = <span class='pyx_c_api'>__Pyx_PyDict_NewPresized</span>(1);<span class='error_goto'> if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_2);
|
||||
<span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
|
||||
__pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_float32);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
|
||||
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_2, __pyx_n_s_dtype, __pyx_t_1) < 0) <span class='error_goto'>__PYX_ERR(0, 24, __pyx_L1_error)</span>
|
||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
|
||||
__pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_4, __pyx_t_3, __pyx_t_2);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
|
||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_4); __pyx_t_4 = 0;
|
||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_3); __pyx_t_3 = 0;
|
||||
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_2); __pyx_t_2 = 0;
|
||||
__pyx_v_img_std = __pyx_t_1;
|
||||
__pyx_t_1 = 0;
|
||||
</pre><pre class="cython line score-2" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">25</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span class="p">:]</span> <span class="n">img_lcn_view</span> <span class="o">=</span> <span class="n">img_lcn</span></pre>
|
||||
<pre class='cython code score-2 '> __pyx_t_6 = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_lcn, PyBUF_WRITABLE);<span class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 25, __pyx_L1_error)</span>
|
||||
__pyx_v_img_lcn_view = __pyx_t_6;
|
||||
__pyx_t_6.memview = NULL;
|
||||
__pyx_t_6.data = NULL;
|
||||
</pre><pre class="cython line score-2" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">26</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span class="p">:]</span> <span class="n">img_std_view</span> <span class="o">=</span> <span class="n">img_std</span></pre>
|
||||
<pre class='cython code score-2 '> __pyx_t_6 = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_std, PyBUF_WRITABLE);<span class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 26, __pyx_L1_error)</span>
|
||||
__pyx_v_img_std_view = __pyx_t_6;
|
||||
__pyx_t_6.memview = NULL;
|
||||
__pyx_t_6.data = NULL;
|
||||
</pre><pre class="cython line score-0"> <span class="">27</span>: </pre>
|
||||
<pre class="cython line score-0"> <span class="">28</span>: <span class="c"># temporary c variables</span></pre>
|
||||
<pre class="cython line score-0"> <span class="">29</span>: <span class="k">cdef</span> <span class="kt">float</span> <span class="nf">tmp</span><span class="p">,</span> <span class="nf">mean</span><span class="p">,</span> <span class="nf">stddev</span></pre>
|
||||
<pre class="cython line score-0"> <span class="">30</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">m</span><span class="p">,</span> <span class="nf">n</span><span class="p">,</span> <span class="nf">i</span><span class="p">,</span> <span class="nf">j</span></pre>
|
||||
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">31</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">ks</span> <span class="o">=</span> <span class="n">kernel_size</span></pre>
|
||||
<pre class='cython code score-0 '> __pyx_v_ks = __pyx_v_kernel_size;
|
||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">32</span>: <span class="k">cdef</span> <span class="kt">float</span> <span class="nf">eps</span> <span class="o">=</span> <span class="n">epsilon</span></pre>
|
||||
<pre class='cython code score-0 '> __pyx_v_eps = __pyx_v_epsilon;
|
||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">33</span>: <span class="k">cdef</span> <span class="kt">float</span> <span class="nf">num</span> <span class="o">=</span> <span class="p">(</span><span class="n">ks</span><span class="o">*</span><span class="mf">2</span><span class="o">+</span><span class="mf">1</span><span class="p">)</span><span class="o">**</span><span class="mf">2</span></pre>
|
||||
<pre class='cython code score-0 '> __pyx_v_num = __Pyx_pow_Py_ssize_t(((__pyx_v_ks * 2) + 1), 2);
|
||||
</pre><pre class="cython line score-0"> <span class="">34</span>: </pre>
|
||||
<pre class="cython line score-0"> <span class="">35</span>: <span class="c"># for all pixels do</span></pre>
|
||||
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">36</span>: <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ks</span><span class="p">,</span><span class="n">M</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre>
|
||||
<pre class='cython code score-0 '> __pyx_t_7 = (__pyx_v_M - __pyx_v_ks);
|
||||
__pyx_t_8 = __pyx_t_7;
|
||||
for (__pyx_t_9 = __pyx_v_ks; __pyx_t_9 < __pyx_t_8; __pyx_t_9+=1) {
|
||||
__pyx_v_m = __pyx_t_9;
|
||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">37</span>: <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ks</span><span class="p">,</span><span class="n">N</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre>
|
||||
<pre class='cython code score-0 '> __pyx_t_10 = (__pyx_v_N - __pyx_v_ks);
|
||||
__pyx_t_11 = __pyx_t_10;
|
||||
for (__pyx_t_12 = __pyx_v_ks; __pyx_t_12 < __pyx_t_11; __pyx_t_12+=1) {
|
||||
__pyx_v_n = __pyx_t_12;
|
||||
</pre><pre class="cython line score-0"> <span class="">38</span>: </pre>
|
||||
<pre class="cython line score-0"> <span class="">39</span>: <span class="c"># calculate mean</span></pre>
|
||||
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">40</span>: <span class="n">mean</span> <span class="o">=</span> <span class="mf">0</span><span class="p">;</span></pre>
|
||||
<pre class='cython code score-0 '> __pyx_v_mean = 0.0;
|
||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">41</span>: <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre>
|
||||
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1);
|
||||
__pyx_t_14 = __pyx_t_13;
|
||||
for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 < __pyx_t_14; __pyx_t_15+=1) {
|
||||
__pyx_v_i = __pyx_t_15;
|
||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">42</span>: <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre>
|
||||
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1);
|
||||
__pyx_t_17 = __pyx_t_16;
|
||||
for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 < __pyx_t_17; __pyx_t_18+=1) {
|
||||
__pyx_v_j = __pyx_t_18;
|
||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">43</span>: <span class="n">mean</span> <span class="o">+=</span> <span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span class="p">]</span></pre>
|
||||
<pre class='cython code score-0 '> __pyx_t_19 = (__pyx_v_m + __pyx_v_i);
|
||||
__pyx_t_20 = (__pyx_v_n + __pyx_v_j);
|
||||
__pyx_v_mean = (__pyx_v_mean + (*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_19 * __pyx_v_img.strides[0]) ) + __pyx_t_20 * __pyx_v_img.strides[1]) ))));
|
||||
}
|
||||
}
|
||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">44</span>: <span class="n">mean</span> <span class="o">=</span> <span class="n">mean</span><span class="o">/</span><span class="n">num</span></pre>
|
||||
<pre class='cython code score-0 '> __pyx_v_mean = (__pyx_v_mean / __pyx_v_num);
|
||||
</pre><pre class="cython line score-0"> <span class="">45</span>: </pre>
|
||||
<pre class="cython line score-0"> <span class="">46</span>: <span class="c"># calculate std dev</span></pre>
|
||||
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">47</span>: <span class="n">stddev</span> <span class="o">=</span> <span class="mf">0</span><span class="p">;</span></pre>
|
||||
<pre class='cython code score-0 '> __pyx_v_stddev = 0.0;
|
||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">48</span>: <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre>
|
||||
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1);
|
||||
__pyx_t_14 = __pyx_t_13;
|
||||
for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 < __pyx_t_14; __pyx_t_15+=1) {
|
||||
__pyx_v_i = __pyx_t_15;
|
||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">49</span>: <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre>
|
||||
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1);
|
||||
__pyx_t_17 = __pyx_t_16;
|
||||
for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 < __pyx_t_17; __pyx_t_18+=1) {
|
||||
__pyx_v_j = __pyx_t_18;
|
||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">50</span>: <span class="n">stddev</span> <span class="o">=</span> <span class="n">stddev</span> <span class="o">+</span> <span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span><span class="o">*</span><span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span></pre>
|
||||
<pre class='cython code score-0 '> __pyx_t_21 = (__pyx_v_m + __pyx_v_i);
|
||||
__pyx_t_22 = (__pyx_v_n + __pyx_v_j);
|
||||
__pyx_t_23 = (__pyx_v_m + __pyx_v_i);
|
||||
__pyx_t_24 = (__pyx_v_n + __pyx_v_j);
|
||||
__pyx_v_stddev = (__pyx_v_stddev + (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_21 * __pyx_v_img.strides[0]) ) + __pyx_t_22 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) * ((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_23 * __pyx_v_img.strides[0]) ) + __pyx_t_24 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean)));
|
||||
}
|
||||
}
|
||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">51</span>: <span class="n">stddev</span> <span class="o">=</span> <span class="n">sqrt</span><span class="p">(</span><span class="n">stddev</span><span class="o">/</span><span class="n">num</span><span class="p">)</span></pre>
|
||||
<pre class='cython code score-0 '> __pyx_v_stddev = sqrt((__pyx_v_stddev / __pyx_v_num));
|
||||
</pre><pre class="cython line score-0"> <span class="">52</span>: </pre>
|
||||
<pre class="cython line score-0"> <span class="">53</span>: <span class="c"># compute normalized image (add epsilon) and std dev image</span></pre>
|
||||
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">54</span>: <span class="n">img_lcn_view</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span><span class="o">/</span><span class="p">(</span><span class="n">stddev</span><span class="o">+</span><span class="n">eps</span><span class="p">)</span></pre>
|
||||
<pre class='cython code score-0 '> __pyx_t_25 = __pyx_v_m;
|
||||
__pyx_t_26 = __pyx_v_n;
|
||||
__pyx_t_27 = __pyx_v_m;
|
||||
__pyx_t_28 = __pyx_v_n;
|
||||
*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_lcn_view.data + __pyx_t_27 * __pyx_v_img_lcn_view.strides[0]) ) + __pyx_t_28 * __pyx_v_img_lcn_view.strides[1]) )) = (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_25 * __pyx_v_img.strides[0]) ) + __pyx_t_26 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) / (__pyx_v_stddev + __pyx_v_eps));
|
||||
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">55</span>: <span class="n">img_std_view</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span class="n">stddev</span></pre>
|
||||
<pre class='cython code score-0 '> __pyx_t_29 = __pyx_v_m;
|
||||
__pyx_t_30 = __pyx_v_n;
|
||||
*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_std_view.data + __pyx_t_29 * __pyx_v_img_std_view.strides[0]) ) + __pyx_t_30 * __pyx_v_img_std_view.strides[1]) )) = __pyx_v_stddev;
|
||||
}
|
||||
}
|
||||
</pre><pre class="cython line score-0"> <span class="">56</span>: </pre>
|
||||
<pre class="cython line score-0"> <span class="">57</span>: <span class="c"># return both</span></pre>
|
||||
<pre class="cython line score-10" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">58</span>: <span class="k">return</span> <span class="n">img_lcn</span><span class="p">,</span> <span class="n">img_std</span></pre>
|
||||
<pre class='cython code score-10 '> <span class='pyx_macro_api'>__Pyx_XDECREF</span>(__pyx_r);
|
||||
__pyx_t_1 = <span class='py_c_api'>PyTuple_New</span>(2);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 58, __pyx_L1_error)</span>
|
||||
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
|
||||
<span class='pyx_macro_api'>__Pyx_INCREF</span>(__pyx_v_img_lcn);
|
||||
<span class='refnanny'>__Pyx_GIVEREF</span>(__pyx_v_img_lcn);
|
||||
<span class='py_macro_api'>PyTuple_SET_ITEM</span>(__pyx_t_1, 0, __pyx_v_img_lcn);
|
||||
<span class='pyx_macro_api'>__Pyx_INCREF</span>(__pyx_v_img_std);
|
||||
<span class='refnanny'>__Pyx_GIVEREF</span>(__pyx_v_img_std);
|
||||
<span class='py_macro_api'>PyTuple_SET_ITEM</span>(__pyx_t_1, 1, __pyx_v_img_std);
|
||||
__pyx_r = __pyx_t_1;
|
||||
__pyx_t_1 = 0;
|
||||
goto __pyx_L0;
|
||||
</pre></div></body></html>
|
58
data/lcn/lcn.pyx
Normal file
58
data/lcn/lcn.pyx
Normal file
@ -0,0 +1,58 @@
|
||||
import numpy as np
|
||||
cimport cython
|
||||
|
||||
# use c square root function
|
||||
cdef extern from "math.h":
|
||||
float sqrt(float x)
|
||||
|
||||
@cython.boundscheck(False)
|
||||
@cython.wraparound(False)
|
||||
@cython.cdivision(True)
|
||||
|
||||
# 3 parameters:
|
||||
# - float image
|
||||
# - kernel size (actually this is the radius, kernel is 2*k+1)
|
||||
# - small constant epsilon that is used to avoid division by zero
|
||||
def normalize(float[:, :] img, int kernel_size = 4, float epsilon = 0.01):
|
||||
|
||||
# image dimensions
|
||||
cdef Py_ssize_t M = img.shape[0]
|
||||
cdef Py_ssize_t N = img.shape[1]
|
||||
|
||||
# create outputs and output views
|
||||
img_lcn = np.zeros((M, N), dtype=np.float32)
|
||||
img_std = np.zeros((M, N), dtype=np.float32)
|
||||
cdef float[:, :] img_lcn_view = img_lcn
|
||||
cdef float[:, :] img_std_view = img_std
|
||||
|
||||
# temporary c variables
|
||||
cdef float tmp, mean, stddev
|
||||
cdef Py_ssize_t m, n, i, j
|
||||
cdef Py_ssize_t ks = kernel_size
|
||||
cdef float eps = epsilon
|
||||
cdef float num = (ks*2+1)**2
|
||||
|
||||
# for all pixels do
|
||||
for m in range(ks,M-ks):
|
||||
for n in range(ks,N-ks):
|
||||
|
||||
# calculate mean
|
||||
mean = 0;
|
||||
for i in range(-ks,ks+1):
|
||||
for j in range(-ks,ks+1):
|
||||
mean += img[m+i, n+j]
|
||||
mean = mean/num
|
||||
|
||||
# calculate std dev
|
||||
stddev = 0;
|
||||
for i in range(-ks,ks+1):
|
||||
for j in range(-ks,ks+1):
|
||||
stddev = stddev + (img[m+i, n+j]-mean)*(img[m+i, n+j]-mean)
|
||||
stddev = sqrt(stddev/num)
|
||||
|
||||
# compute normalized image (add epsilon) and std dev image
|
||||
img_lcn_view[m, n] = (img[m, n]-mean)/(stddev+eps)
|
||||
img_std_view[m, n] = stddev
|
||||
|
||||
# return both
|
||||
return img_lcn, img_std
|
5
data/lcn/readme.txt
Normal file
5
data/lcn/readme.txt
Normal file
@ -0,0 +1,5 @@
|
||||
compile:
|
||||
python setup.py build_ext --inplace
|
||||
|
||||
run:
|
||||
python test_lcn.py
|
6
data/lcn/setup.py
Normal file
6
data/lcn/setup.py
Normal file
@ -0,0 +1,6 @@
|
||||
from distutils.core import setup
|
||||
from Cython.Build import cythonize
|
||||
|
||||
setup(
|
||||
ext_modules = cythonize("lcn.pyx",annotate=True)
|
||||
)
|
47
data/lcn/test_lcn.py
Normal file
47
data/lcn/test_lcn.py
Normal file
@ -0,0 +1,47 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import lcn
|
||||
from scipy import misc
|
||||
|
||||
# load and convert to float
|
||||
img = misc.imread('img.png')
|
||||
img = img.astype(np.float32)/255.0
|
||||
|
||||
# normalize
|
||||
img_lcn, img_std = lcn.normalize(img,5,0.05)
|
||||
|
||||
# normalize to reasonable range between 0 and 1
|
||||
#img_lcn = img_lcn/3.0
|
||||
#img_lcn = np.maximum(img_lcn,0.0)
|
||||
#img_lcn = np.minimum(img_lcn,1.0)
|
||||
|
||||
# save to file
|
||||
#misc.imsave('lcn2.png',img_lcn)
|
||||
|
||||
print ("Orig Image: %d x %d (%s), Min: %f, Max: %f" % \
|
||||
(img.shape[0], img.shape[1], img.dtype, img.min(), img.max()))
|
||||
print ("Norm Image: %d x %d (%s), Min: %f, Max: %f" % \
|
||||
(img_lcn.shape[0], img_lcn.shape[1], img_lcn.dtype, img_lcn.min(), img_lcn.max()))
|
||||
|
||||
# plot original image
|
||||
plt.figure(1)
|
||||
img_plot = plt.imshow(img)
|
||||
img_plot.set_cmap('gray')
|
||||
plt.clim(0, 1) # fix range
|
||||
plt.tight_layout()
|
||||
|
||||
# plot normalized image
|
||||
plt.figure(2)
|
||||
img_lcn_plot = plt.imshow(img_lcn)
|
||||
img_lcn_plot.set_cmap('gray')
|
||||
#plt.clim(0, 1) # fix range
|
||||
plt.tight_layout()
|
||||
|
||||
# plot stddev image
|
||||
plt.figure(3)
|
||||
img_std_plot = plt.imshow(img_std)
|
||||
img_std_plot.set_cmap('gray')
|
||||
#plt.clim(0, 0.1) # fix range
|
||||
plt.tight_layout()
|
||||
|
||||
plt.show()
|
46
hyperdepth/eval.cpp
Normal file
46
hyperdepth/eval.cpp
Normal file
@ -0,0 +1,46 @@
|
||||
#include "hyperdepth.h"
|
||||
|
||||
|
||||
int main() {
|
||||
cv::Mat_<uint8_t> im = read_im(0);
|
||||
cv::Mat_<uint16_t> disp = read_disp(0);
|
||||
int im_rows = im.rows;
|
||||
int im_cols = im.cols;
|
||||
std::cout << im.rows << "/" << im.cols << std::endl;
|
||||
std::cout << disp.rows << "/" << disp.cols << std::endl;
|
||||
|
||||
cv::Mat_<uint16_t> ta_disp(im_rows, im_cols);
|
||||
cv::Mat_<uint16_t> es_disp(im_rows, im_cols);
|
||||
|
||||
int n_disp_bins = 16;
|
||||
|
||||
for(int row = 0; row < im_rows; ++row) {
|
||||
std::vector<TrainDatum> data;
|
||||
extract_row_samples(im, disp, row, data, false, n_disp_bins);
|
||||
|
||||
std::ostringstream forest_path;
|
||||
forest_path << "cforest_" << row << ".bin";
|
||||
BinarySerializationIn fin(forest_path.str());
|
||||
HDForest forest;
|
||||
forest.Load(fin);
|
||||
|
||||
auto res = forest.inferencemt(data, 18);
|
||||
for(int col = 0; col < im_cols; ++col) {
|
||||
auto fcn = res[col];
|
||||
auto target = std::static_pointer_cast<ClassificationTarget>(data[col].target);
|
||||
|
||||
float ta = col - float(target->cl()) / n_disp_bins;
|
||||
float es = col - float(fcn->argmax()) / n_disp_bins;
|
||||
es = std::max(0.f, es);
|
||||
|
||||
ta_disp(row, col) = int(ta * 16);
|
||||
es_disp(row, col) = int(es * 16);
|
||||
}
|
||||
}
|
||||
|
||||
cv::imwrite("disp_orig.png", disp);
|
||||
cv::imwrite("disp_ta.png", ta_disp);
|
||||
cv::imwrite("disp_es.png", es_disp);
|
||||
}
|
||||
|
||||
|
287
hyperdepth/hyperdepth.h
Normal file
287
hyperdepth/hyperdepth.h
Normal file
@ -0,0 +1,287 @@
|
||||
#include <sstream>
|
||||
#include <iomanip>
|
||||
|
||||
#include "rf/forest.h"
|
||||
#include "rf/spliteval.h"
|
||||
|
||||
class HyperdepthSplitEvaluator : public SplitEvaluator {
|
||||
public:
|
||||
HyperdepthSplitEvaluator(bool normalize, int n_classes, int n_disp_bins, int depth_switch)
|
||||
: SplitEvaluator(normalize), n_classes_(n_classes), n_disp_bins_(n_disp_bins), depth_switch_(depth_switch) {}
|
||||
virtual ~HyperdepthSplitEvaluator() {}
|
||||
|
||||
protected:
|
||||
virtual float Purity(const std::vector<TrainDatum>& targets, int depth) const {
|
||||
if(targets.size() == 0) return 0;
|
||||
|
||||
int n_classes = n_classes_;
|
||||
if(depth >= depth_switch_) {
|
||||
n_classes *= n_disp_bins_;
|
||||
}
|
||||
|
||||
std::vector<int> ps;
|
||||
ps.resize(n_classes, 0);
|
||||
for(auto target : targets) {
|
||||
auto ctarget = std::static_pointer_cast<ClassificationTarget>(target.optimize_target);
|
||||
int cl = ctarget->cl();
|
||||
if(depth < depth_switch_) {
|
||||
cl /= n_disp_bins_;
|
||||
}
|
||||
ps[cl] += 1;
|
||||
}
|
||||
|
||||
float h = 0;
|
||||
for(int cl = 0; cl < n_classes; ++cl) {
|
||||
float fi = float(ps[cl]) / float(targets.size());
|
||||
if(fi > 0) {
|
||||
h = h - fi * std::log(fi);
|
||||
}
|
||||
}
|
||||
|
||||
return h;
|
||||
}
|
||||
|
||||
private:
|
||||
int n_classes_;
|
||||
int n_disp_bins_;
|
||||
int depth_switch_;
|
||||
};
|
||||
|
||||
|
||||
class HyperdepthLeafFunction {
|
||||
public:
|
||||
HyperdepthLeafFunction() : n_classes_(-1) {}
|
||||
HyperdepthLeafFunction(int n_classes) : n_classes_(n_classes) {}
|
||||
virtual ~HyperdepthLeafFunction() {}
|
||||
|
||||
virtual std::shared_ptr<HyperdepthLeafFunction> Copy() const {
|
||||
auto fcn = std::make_shared<HyperdepthLeafFunction>();
|
||||
fcn->n_classes_ = n_classes_;
|
||||
fcn->counts_.resize(counts_.size());
|
||||
for(size_t idx = 0; idx < counts_.size(); ++idx) {
|
||||
fcn->counts_[idx] = counts_[idx];
|
||||
}
|
||||
fcn->sum_counts_ = sum_counts_;
|
||||
|
||||
return fcn;
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<HyperdepthLeafFunction> Create(const std::vector<TrainDatum>& samples) {
|
||||
auto stat = std::make_shared<HyperdepthLeafFunction>();
|
||||
|
||||
stat->counts_.resize(n_classes_, 0);
|
||||
for(auto sample : samples) {
|
||||
auto ctarget = std::static_pointer_cast<ClassificationTarget>(sample.target);
|
||||
stat->counts_[ctarget->cl()] += 1;
|
||||
}
|
||||
stat->sum_counts_ = samples.size();
|
||||
|
||||
return stat;
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<HyperdepthLeafFunction> Reduce(const std::vector<std::shared_ptr<HyperdepthLeafFunction>>& fcns) const {
|
||||
auto stat = std::make_shared<HyperdepthLeafFunction>();
|
||||
auto cfcn0 = std::static_pointer_cast<HyperdepthLeafFunction>(fcns[0]);
|
||||
stat->counts_.resize(cfcn0->counts_.size(), 0);
|
||||
stat->sum_counts_ = 0;
|
||||
|
||||
for(auto fcn : fcns) {
|
||||
auto cfcn = std::static_pointer_cast<HyperdepthLeafFunction>(fcn);
|
||||
for(size_t cl = 0; cl < stat->counts_.size(); ++cl) {
|
||||
stat->counts_[cl] += cfcn->counts_[cl];
|
||||
}
|
||||
stat->sum_counts_ += cfcn->sum_counts_;
|
||||
}
|
||||
|
||||
return stat;
|
||||
}
|
||||
|
||||
virtual std::tuple<int,int> argmax() const {
|
||||
int max_idx = 0;
|
||||
int max_count = counts_[0];
|
||||
int max2_idx = -1;
|
||||
int max2_count = -1;
|
||||
for(size_t idx = 1; idx < counts_.size(); ++idx) {
|
||||
if(counts_[idx] > max_count) {
|
||||
max2_count = max_count;
|
||||
max2_idx = max_idx;
|
||||
max_count = counts_[idx];
|
||||
max_idx = idx;
|
||||
}
|
||||
else if(counts_[idx] > max2_count) {
|
||||
max2_count = counts_[idx];
|
||||
max2_idx = idx;
|
||||
}
|
||||
}
|
||||
return std::make_tuple(max_idx, max2_idx);
|
||||
}
|
||||
|
||||
virtual std::vector<float> prob_vec() const {
|
||||
std::vector<float> probs(counts_.size(), 0.f);
|
||||
int sum = 0;
|
||||
for(int cnt : counts_) {
|
||||
sum += cnt;
|
||||
}
|
||||
for(size_t idx = 0; idx < counts_.size(); ++idx) {
|
||||
probs[idx] = float(counts_[idx]) / sum;
|
||||
}
|
||||
return probs;
|
||||
}
|
||||
|
||||
virtual void Save(SerializationOut& ar) const {
|
||||
ar << n_classes_;
|
||||
int n_counts = counts_.size();
|
||||
ar << n_counts;
|
||||
for(int idx = 0; idx < n_counts; ++idx) {
|
||||
ar << counts_[idx];
|
||||
}
|
||||
ar << sum_counts_;
|
||||
}
|
||||
|
||||
virtual void Load(SerializationIn& ar) {
|
||||
ar >> n_classes_;
|
||||
int n_counts;
|
||||
ar >> n_counts;
|
||||
counts_.resize(n_counts);
|
||||
for(int idx = 0; idx < n_counts; ++idx) {
|
||||
ar >> counts_[idx];
|
||||
}
|
||||
ar >> sum_counts_;
|
||||
}
|
||||
|
||||
public:
|
||||
int n_classes_;
|
||||
|
||||
std::vector<int> counts_;
|
||||
int sum_counts_;
|
||||
|
||||
DISABLE_COPY_AND_ASSIGN(HyperdepthLeafFunction);
|
||||
};
|
||||
|
||||
|
||||
typedef SplitFunctionPixelDifference HDSplitFunctionT;
|
||||
typedef HyperdepthLeafFunction HDLeafFunctionT;
|
||||
typedef HyperdepthSplitEvaluator HDSplitEvaluatorT;
|
||||
typedef Forest<HDSplitFunctionT, HDLeafFunctionT> HDForest;
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
class Raw {
|
||||
public:
|
||||
const T* raw;
|
||||
const int nsamples;
|
||||
const int rows;
|
||||
const int cols;
|
||||
Raw(const T* raw, int nsamples, int rows, int cols)
|
||||
: raw(raw), nsamples(nsamples), rows(rows), cols(cols) {}
|
||||
|
||||
T operator()(int n, int r, int c) const {
|
||||
return raw[(n * rows + r) * cols + c];
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class RawSample : public Sample {
|
||||
public:
|
||||
RawSample(const Raw<uint8_t>& raw, int n, int rc, int cc, int patch_height, int patch_width)
|
||||
: Sample(1, patch_height, patch_width), raw(raw), n(n), rc(rc), cc(cc) {}
|
||||
|
||||
virtual float at(int ch, int r, int c) const {
|
||||
r += rc - height_ / 2;
|
||||
c += cc - width_ / 2;
|
||||
r = std::max(0, std::min(raw.rows-1, r));
|
||||
c = std::max(0, std::min(raw.cols-1, c));
|
||||
return raw(n, r, c);
|
||||
}
|
||||
|
||||
protected:
|
||||
const Raw<uint8_t>& raw;
|
||||
int n;
|
||||
int rc;
|
||||
int cc;
|
||||
};
|
||||
|
||||
void extract_row_samples(const Raw<uint8_t>& im, const Raw<float>& disp, int row, int n_disp_bins, bool only_valid, std::vector<TrainDatum>& data) {
|
||||
for(int n = 0; n < im.nsamples; ++n) {
|
||||
for(int col = 0; col < im.cols; ++col) {
|
||||
float d = disp(n, row, col);
|
||||
float pos = col - d;
|
||||
int cl = pos * n_disp_bins;
|
||||
if((d < 0 || cl < 0) && only_valid) continue;
|
||||
|
||||
auto sample = std::make_shared<RawSample>(im, n, row, col, 32, 32);
|
||||
auto target = std::make_shared<ClassificationTarget>(cl);
|
||||
auto datum = TrainDatum(sample, target);
|
||||
data.push_back(datum);
|
||||
}
|
||||
}
|
||||
std::cout << "extracted " << data.size() << " train samples" << std::endl;
|
||||
std::cout << "n_classes (" << im.cols << ") * n_disp_bins (" << n_disp_bins << ") = " << (im.cols * n_disp_bins) << std::endl;
|
||||
}
|
||||
|
||||
|
||||
void train(int row_from, int row_to, TrainParameters params, const uint8_t* ims, const float* disps, int n, int h, int w, int n_disp_bins, int depth_switch, int n_threads, std::string forest_prefix) {
|
||||
Raw<uint8_t> raw_ims(ims, n, h, w);
|
||||
Raw<float> raw_disps(disps, n, h, w);
|
||||
|
||||
int n_classes = w;
|
||||
|
||||
auto gen_split_fcn = std::make_shared<HDSplitFunctionT>();
|
||||
auto gen_leaf_fcn = std::make_shared<HDLeafFunctionT>(n_classes * n_disp_bins);
|
||||
auto split_eval = std::make_shared<HDSplitEvaluatorT>(true, n_classes, n_disp_bins, depth_switch);
|
||||
|
||||
for(int row = row_from; row < row_to; ++row) {
|
||||
std::cout << "train row " << row << std::endl;
|
||||
|
||||
std::vector<TrainDatum> data;
|
||||
extract_row_samples(raw_ims, raw_disps, row, n_disp_bins, true, data);
|
||||
|
||||
TrainForestQueued<HDSplitFunctionT, HDLeafFunctionT, HDSplitEvaluatorT> train(params, gen_split_fcn, gen_leaf_fcn, split_eval, n_threads, true);
|
||||
|
||||
auto forest = train.Train(data, TrainType::TRAIN, nullptr);
|
||||
|
||||
std::ostringstream forest_path;
|
||||
forest_path << forest_prefix << row << ".bin";
|
||||
std::cout << "save forest of row " << row << " to " << forest_path.str() << std::endl;
|
||||
BinarySerializationOut fout(forest_path.str());
|
||||
forest->Save(fout);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
void eval(int row_from, int row_to, const uint8_t* ims, const float* disps, int n, int h, int w, int n_disp_bins, int depth_switch, int n_threads, std::string forest_prefix, float* out) {
|
||||
Raw<uint8_t> raw_ims(ims, n, h, w);
|
||||
Raw<float> raw_disps(disps, n, h, w);
|
||||
|
||||
for(int row = row_from; row < row_to; ++row) {
|
||||
std::vector<TrainDatum> data;
|
||||
extract_row_samples(raw_ims, raw_disps, row, n_disp_bins, false, data);
|
||||
|
||||
std::ostringstream forest_path;
|
||||
forest_path << forest_prefix << row << ".bin";
|
||||
std::cout << "eval row " << row << " - " << forest_path.str() << std::endl;
|
||||
|
||||
BinarySerializationIn fin(forest_path.str());
|
||||
HDForest forest;
|
||||
forest.Load(fin);
|
||||
|
||||
auto res = forest.inferencemt(data, n_threads);
|
||||
|
||||
for(int nidx = 0; nidx < n; ++nidx) {
|
||||
for(int col = 0; col < w; ++col) {
|
||||
auto fcn = res[nidx * w + col];
|
||||
int pos, pos2;
|
||||
std::tie(pos, pos2) = fcn->argmax();
|
||||
float disp = col - float(pos) / n_disp_bins;
|
||||
float disp2 = col - float(pos2) / n_disp_bins;
|
||||
|
||||
float prob = fcn->prob_vec()[pos];
|
||||
|
||||
out[((nidx * h + row) * w + col) * 3 + 0] = disp;
|
||||
out[((nidx * h + row) * w + col) * 3 + 1] = prob;
|
||||
out[((nidx * h + row) * w + col) * 3 + 2] = std::abs(disp - disp2);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
86
hyperdepth/hyperdepth.pyx
Normal file
86
hyperdepth/hyperdepth.pyx
Normal file
@ -0,0 +1,86 @@
|
||||
cimport cython
|
||||
import numpy as np
|
||||
cimport numpy as np
|
||||
|
||||
from libc.stdlib cimport free, malloc
|
||||
from libcpp cimport bool
|
||||
from libcpp.string cimport string
|
||||
from cpython cimport PyObject, Py_INCREF
|
||||
|
||||
CREATE_INIT = True # workaround, so cython builds a init function
|
||||
|
||||
np.import_array()
|
||||
|
||||
|
||||
ctypedef unsigned char uint8_t
|
||||
|
||||
cdef extern from "rf/train.h":
|
||||
cdef cppclass TrainParameters:
|
||||
int n_trees;
|
||||
int max_tree_depth;
|
||||
int n_test_split_functions;
|
||||
int n_test_thresholds;
|
||||
int n_test_samples;
|
||||
int min_samples_to_split;
|
||||
int min_samples_for_leaf;
|
||||
int print_node_info;
|
||||
TrainParameters();
|
||||
|
||||
|
||||
cdef extern from "hyperdepth.h":
|
||||
void train(int row_from, int row_to, TrainParameters params, const uint8_t* ims, const float* disps, int n, int h, int w, int n_disp_bins, int depth_switch, int n_threads, string forest_prefix);
|
||||
void eval(int row_from, int row_to, const uint8_t* ims, const float* disps, int n, int h, int w, int n_disp_bins, int depth_switch, int n_threads, string forest_prefix, float* out);
|
||||
|
||||
|
||||
|
||||
|
||||
cdef class TrainParams:
|
||||
cdef TrainParameters params;
|
||||
|
||||
def __cinit__(self, int n_trees=6, int max_tree_depth=8, int n_test_split_functions=50, int n_test_thresholds=10, int n_test_samples=4096, int min_samples_to_split=16, int min_samples_for_leaf=8, int print_node_info=100):
|
||||
self.params.n_trees = n_trees
|
||||
self.params.max_tree_depth = max_tree_depth
|
||||
self.params.n_test_split_functions = n_test_split_functions
|
||||
self.params.n_test_thresholds = n_test_thresholds
|
||||
self.params.n_test_samples = n_test_samples
|
||||
self.params.min_samples_to_split = min_samples_to_split
|
||||
self.params.min_samples_for_leaf = min_samples_for_leaf
|
||||
self.params.print_node_info = print_node_info
|
||||
|
||||
def __str__(self):
|
||||
return f'n_trees={self.params.n_trees}, max_tree_depth={self.params.max_tree_depth}, n_test_split_functions={self.params.n_test_split_functions}, n_test_thresholds={self.params.n_test_thresholds}, n_test_samples={self.params.n_test_samples}, min_samples_to_split={self.params.min_samples_to_split}, min_samples_for_leaf={self.params.min_samples_for_leaf}'
|
||||
|
||||
|
||||
def train_forest(TrainParams params, uint8_t[:,:,::1] ims, float[:,:,::1] disps, int n_disp_bins=10, int depth_switch=0, int n_threads=18, str forest_prefix='forest', int row_from=-1, int row_to=-1):
|
||||
cdef int n = ims.shape[0]
|
||||
cdef int h = ims.shape[1]
|
||||
cdef int w = ims.shape[2]
|
||||
|
||||
if row_from < 0:
|
||||
row_from = 0
|
||||
if row_to > h or row_to < 0:
|
||||
row_to = h
|
||||
|
||||
if n != disps.shape[0] or h != disps.shape[1] or w != disps.shape[2]:
|
||||
raise Exception('ims.shape != disps.shape')
|
||||
|
||||
train(row_from, row_to, params.params, &ims[0,0,0], &disps[0,0,0], n, h, w, n_disp_bins, depth_switch, n_threads, forest_prefix.encode())
|
||||
|
||||
|
||||
def eval_forest(uint8_t[:,:,::1] ims, float[:,:,::1] disps, int n_disp_bins=10, int depth_switch=0, int n_threads=18, str forest_prefix='forest', int row_from=-1, int row_to=-1):
|
||||
cdef int n = ims.shape[0]
|
||||
cdef int h = ims.shape[1]
|
||||
cdef int w = ims.shape[2]
|
||||
|
||||
if n != disps.shape[0] or h != disps.shape[1] or w != disps.shape[2]:
|
||||
raise Exception('ims.shape != disps.shape')
|
||||
|
||||
if row_from < 0:
|
||||
row_from = 0
|
||||
if row_to > h or row_to < 0:
|
||||
row_to = h
|
||||
|
||||
out = np.empty((n, h, w, 3), dtype=np.float32)
|
||||
cdef float[:,:,:,::1] out_view = out
|
||||
eval(row_from, row_to, &ims[0,0,0], &disps[0,0,0], n, h, w, n_disp_bins, depth_switch, n_threads, forest_prefix.encode(), &out_view[0,0,0,0])
|
||||
return out
|
65
hyperdepth/hyperparam_search.py
Normal file
65
hyperdepth/hyperparam_search.py
Normal file
@ -0,0 +1,65 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import cv2
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
import hyperdepth as hd
|
||||
|
||||
sys.path.append('../')
|
||||
import dataset
|
||||
|
||||
|
||||
def get_data(n, row_from, row_to, train):
|
||||
imsizes = [(256,384)]
|
||||
focal_lengths = [160]
|
||||
dset = dataset.SynDataset(n, imsizes=imsizes, focal_lengths=focal_lengths, train=train)
|
||||
ims = np.empty((n, row_to-row_from, imsizes[0][1]), dtype=np.uint8)
|
||||
disps = np.empty((n, row_to-row_from, imsizes[0][1]), dtype=np.float32)
|
||||
for idx in range(n):
|
||||
print(f'load sample {idx} train={train}')
|
||||
sample = dset[idx]
|
||||
ims[idx] = (sample['im0'][0,row_from:row_to] * 255).astype(np.uint8)
|
||||
disps[idx] = sample['disp0'][0,row_from:row_to]
|
||||
return ims, disps
|
||||
|
||||
|
||||
|
||||
params = hd.TrainParams(
|
||||
n_trees=4,
|
||||
max_tree_depth=,
|
||||
n_test_split_functions=50,
|
||||
n_test_thresholds=10,
|
||||
n_test_samples=4096,
|
||||
min_samples_to_split=16,
|
||||
min_samples_for_leaf=8)
|
||||
|
||||
n_disp_bins = 20
|
||||
depth_switch = 0
|
||||
|
||||
row_from = 100
|
||||
row_to = 108
|
||||
n_train_samples = 1024
|
||||
n_test_samples = 32
|
||||
|
||||
train_ims, train_disps = get_data(n_train_samples, row_from, row_to, True)
|
||||
test_ims, test_disps = get_data(n_test_samples, row_from, row_to, False)
|
||||
|
||||
for tree_depth in [8,10,12,14,16]:
|
||||
depth_switch = tree_depth - 4
|
||||
|
||||
prefix = f'td{tree_depth}_ds{depth_switch}'
|
||||
prefix = Path(f'./forests/{prefix}/')
|
||||
prefix.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
hd.train_forest(params, train_ims, train_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch, forest_prefix=str(prefix / 'fr'))
|
||||
|
||||
es = hd.eval_forest(test_ims, test_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch, forest_prefix=str(prefix / 'fr'))
|
||||
|
||||
np.save(str(prefix / 'ta.npy'), test_disps)
|
||||
np.save(str(prefix / 'es.npy'), es)
|
||||
|
||||
# plt.figure();
|
||||
# plt.subplot(2,1,1); plt.imshow(test_disps[0], vmin=0, vmax=4);
|
||||
# plt.subplot(2,1,2); plt.imshow(es[0], vmin=0, vmax=4);
|
||||
# plt.show()
|
18
hyperdepth/rf/common.h
Normal file
18
hyperdepth/rf/common.h
Normal file
@ -0,0 +1,18 @@
|
||||
#ifndef COMMON_H
|
||||
#define COMMON_H
|
||||
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
|
||||
#define DISABLE_COPY_AND_ASSIGN(classname) \
|
||||
private:\
|
||||
classname(const classname&) = delete;\
|
||||
classname& operator=(const classname&) = delete;
|
||||
|
||||
|
||||
#endif
|
72
hyperdepth/rf/data.h
Normal file
72
hyperdepth/rf/data.h
Normal file
@ -0,0 +1,72 @@
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
|
||||
|
||||
class Sample {
|
||||
public:
|
||||
Sample(int channels, int height, int width)
|
||||
: channels_(channels), height_(height), width_(width) {}
|
||||
|
||||
virtual ~Sample() {}
|
||||
|
||||
virtual float at(int c, int h, int w) const = 0;
|
||||
|
||||
virtual float operator()(int c, int h, int w) const {
|
||||
return at(c,h,w);
|
||||
}
|
||||
|
||||
virtual int channels() const { return channels_; }
|
||||
virtual int height() const { return height_; }
|
||||
virtual int width() const { return width_; }
|
||||
|
||||
protected:
|
||||
int channels_;
|
||||
int height_;
|
||||
int width_;
|
||||
};
|
||||
|
||||
typedef std::shared_ptr<Sample> SamplePtr;
|
||||
|
||||
|
||||
|
||||
|
||||
class Target {
|
||||
public:
|
||||
Target() {}
|
||||
virtual ~Target() {}
|
||||
};
|
||||
|
||||
typedef std::shared_ptr<Target> TargetPtr;
|
||||
typedef std::vector<TargetPtr> VecTargetPtr;
|
||||
typedef std::shared_ptr<VecTargetPtr> VecPtrTargetPtr;
|
||||
|
||||
|
||||
class ClassificationTarget : public Target {
|
||||
public:
|
||||
ClassificationTarget(int cl) : cl_(cl) {}
|
||||
virtual ~ClassificationTarget() {}
|
||||
int cl() const { return cl_; }
|
||||
|
||||
private:
|
||||
int cl_;
|
||||
};
|
||||
|
||||
typedef std::shared_ptr<ClassificationTarget> ClassificationTargetPtr;
|
||||
|
||||
|
||||
|
||||
|
||||
struct TrainDatum {
|
||||
SamplePtr sample;
|
||||
TargetPtr target;
|
||||
TargetPtr optimize_target;
|
||||
|
||||
TrainDatum() : sample(nullptr), target(nullptr), optimize_target(nullptr) {}
|
||||
|
||||
TrainDatum(SamplePtr sample, TargetPtr target)
|
||||
: sample(sample), target(target), optimize_target(target) {}
|
||||
|
||||
TrainDatum(SamplePtr sample, TargetPtr target, TargetPtr optimize_target)
|
||||
: sample(sample), target(target), optimize_target(optimize_target) {}
|
||||
};
|
92
hyperdepth/rf/forest.h
Normal file
92
hyperdepth/rf/forest.h
Normal file
@ -0,0 +1,92 @@
|
||||
#pragma once
|
||||
|
||||
#include "tree.h"
|
||||
|
||||
template <typename SplitFunctionT, typename LeafFunctionT>
|
||||
class Forest {
|
||||
public:
|
||||
Forest() {}
|
||||
virtual ~Forest() {}
|
||||
|
||||
std::shared_ptr<LeafFunctionT> inferencest(const SamplePtr& sample) const {
|
||||
int n_trees = trees_.size();
|
||||
|
||||
std::vector<std::shared_ptr<LeafFunctionT>> fcns;
|
||||
|
||||
//inference of individual trees
|
||||
for(int tree_idx = 0; tree_idx < n_trees; ++tree_idx) {
|
||||
std::shared_ptr<LeafFunctionT> tree_fcn = trees_[tree_idx]->inference(sample);
|
||||
fcns.push_back(tree_fcn);
|
||||
}
|
||||
|
||||
//combine tree fcns/results and collect all results
|
||||
return fcns[0]->Reduce(fcns);
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<LeafFunctionT>> inferencemt(const std::vector<SamplePtr>& samples, int n_threads) const {
|
||||
std::vector<std::shared_ptr<LeafFunctionT>> targets(samples.size());
|
||||
|
||||
omp_set_num_threads(n_threads);
|
||||
#pragma omp parallel for
|
||||
for(size_t sample_idx = 0; sample_idx < samples.size(); ++sample_idx) {
|
||||
targets[sample_idx] = inferencest(samples[sample_idx]);
|
||||
}
|
||||
|
||||
return targets;
|
||||
}
|
||||
|
||||
std::vector<std::shared_ptr<LeafFunctionT>> inferencemt(const std::vector<TrainDatum>& samples, int n_threads) const {
|
||||
std::vector<std::shared_ptr<LeafFunctionT>> targets(samples.size());
|
||||
|
||||
omp_set_num_threads(n_threads);
|
||||
#pragma omp parallel for
|
||||
for(size_t sample_idx = 0; sample_idx < samples.size(); ++sample_idx) {
|
||||
targets[sample_idx] = inferencest(samples[sample_idx].sample);
|
||||
}
|
||||
|
||||
return targets;
|
||||
}
|
||||
|
||||
void AddTree(std::shared_ptr<Tree<SplitFunctionT, LeafFunctionT>> tree) {
|
||||
trees_.push_back(tree);
|
||||
}
|
||||
|
||||
size_t trees_size() const { return trees_.size(); }
|
||||
// TreePtr trees(int idx) const { return trees_[idx]; }
|
||||
|
||||
virtual void Save(SerializationOut& ar) const {
|
||||
size_t n_trees = trees_.size();
|
||||
std::cout << "[DEBUG] write " << n_trees << " trees" << std::endl;
|
||||
ar << n_trees;
|
||||
|
||||
if(true) std::cout << "[Forest][write] write number of trees " << n_trees << std::endl;
|
||||
|
||||
for(size_t tree_idx = 0; tree_idx < trees_.size(); ++tree_idx) {
|
||||
if(true) std::cout << "[Forest][write] write tree nb. " << tree_idx << std::endl;
|
||||
trees_[tree_idx]->Save(ar);
|
||||
}
|
||||
}
|
||||
|
||||
virtual void Load(SerializationIn& ar) {
|
||||
size_t n_trees;
|
||||
ar >> n_trees;
|
||||
|
||||
if(true) std::cout << "[Forest][read] nTrees: " << n_trees << std::endl;
|
||||
|
||||
trees_.clear();
|
||||
for(size_t i = 0; i < n_trees; ++i) {
|
||||
if(true) std::cout << "[Forest][read] read tree " << (i+1) << " of " << n_trees << " - " << std::endl;
|
||||
|
||||
auto tree = std::make_shared<Tree<SplitFunctionT, LeafFunctionT>>();
|
||||
tree->Load(ar);
|
||||
trees_.push_back(tree);
|
||||
|
||||
if(true) std::cout << "[Forest][read] finished read tree " << (i+1) << " of " << n_trees << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private:
|
||||
std::vector<std::shared_ptr<Tree<SplitFunctionT, LeafFunctionT>>> trees_;
|
||||
};
|
||||
|
99
hyperdepth/rf/leaffcn.h
Normal file
99
hyperdepth/rf/leaffcn.h
Normal file
@ -0,0 +1,99 @@
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "common.h"
|
||||
#include "data.h"
|
||||
|
||||
|
||||
class ClassProbabilitiesLeafFunction {
|
||||
public:
|
||||
ClassProbabilitiesLeafFunction() : n_classes_(-1) {}
|
||||
ClassProbabilitiesLeafFunction(int n_classes) : n_classes_(n_classes) {}
|
||||
virtual ~ClassProbabilitiesLeafFunction() {}
|
||||
|
||||
virtual std::shared_ptr<ClassProbabilitiesLeafFunction> Copy() const {
|
||||
auto fcn = std::make_shared<ClassProbabilitiesLeafFunction>();
|
||||
fcn->n_classes_ = n_classes_;
|
||||
fcn->counts_.resize(counts_.size());
|
||||
for(size_t idx = 0; idx < counts_.size(); ++idx) {
|
||||
fcn->counts_[idx] = counts_[idx];
|
||||
}
|
||||
fcn->sum_counts_ = sum_counts_;
|
||||
|
||||
return fcn;
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<ClassProbabilitiesLeafFunction> Create(const std::vector<TrainDatum>& samples) {
|
||||
auto stat = std::make_shared<ClassProbabilitiesLeafFunction>();
|
||||
|
||||
stat->counts_.resize(n_classes_, 0);
|
||||
for(auto sample : samples) {
|
||||
auto ctarget = std::static_pointer_cast<ClassificationTarget>(sample.target);
|
||||
stat->counts_[ctarget->cl()] += 1;
|
||||
}
|
||||
stat->sum_counts_ = samples.size();
|
||||
|
||||
return stat;
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<ClassProbabilitiesLeafFunction> Reduce(const std::vector<std::shared_ptr<ClassProbabilitiesLeafFunction>>& fcns) const {
|
||||
auto stat = std::make_shared<ClassProbabilitiesLeafFunction>();
|
||||
auto cfcn0 = std::static_pointer_cast<ClassProbabilitiesLeafFunction>(fcns[0]);
|
||||
stat->counts_.resize(cfcn0->counts_.size(), 0);
|
||||
stat->sum_counts_ = 0;
|
||||
|
||||
for(auto fcn : fcns) {
|
||||
auto cfcn = std::static_pointer_cast<ClassProbabilitiesLeafFunction>(fcn);
|
||||
for(size_t cl = 0; cl < stat->counts_.size(); ++cl) {
|
||||
stat->counts_[cl] += cfcn->counts_[cl];
|
||||
}
|
||||
stat->sum_counts_ += cfcn->sum_counts_;
|
||||
}
|
||||
|
||||
return stat;
|
||||
}
|
||||
|
||||
virtual int argmax() const {
|
||||
int max_idx = 0;
|
||||
int max_count = counts_[0];
|
||||
for(size_t idx = 1; idx < counts_.size(); ++idx) {
|
||||
if(counts_[idx] > max_count) {
|
||||
max_count = counts_[idx];
|
||||
max_idx = idx;
|
||||
}
|
||||
}
|
||||
return max_idx;
|
||||
}
|
||||
|
||||
virtual void Save(SerializationOut& ar) const {
|
||||
ar << n_classes_;
|
||||
int n_counts = counts_.size();
|
||||
ar << n_counts;
|
||||
for(int idx = 0; idx < n_counts; ++idx) {
|
||||
ar << counts_[idx];
|
||||
}
|
||||
ar << sum_counts_;
|
||||
}
|
||||
|
||||
virtual void Load(SerializationIn& ar) {
|
||||
ar >> n_classes_;
|
||||
int n_counts;
|
||||
ar >> n_counts;
|
||||
counts_.resize(n_counts);
|
||||
for(int idx = 0; idx < n_counts; ++idx) {
|
||||
ar >> counts_[idx];
|
||||
}
|
||||
ar >> sum_counts_;
|
||||
}
|
||||
|
||||
public:
|
||||
int n_classes_;
|
||||
|
||||
std::vector<int> counts_;
|
||||
int sum_counts_;
|
||||
|
||||
DISABLE_COPY_AND_ASSIGN(ClassProbabilitiesLeafFunction);
|
||||
};
|
||||
|
||||
|
158
hyperdepth/rf/node.h
Normal file
158
hyperdepth/rf/node.h
Normal file
@ -0,0 +1,158 @@
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "serialization.h"
|
||||
#include "leaffcn.h"
|
||||
#include "splitfcn.h"
|
||||
|
||||
class Node {
|
||||
public:
|
||||
Node() {}
|
||||
virtual ~Node() {}
|
||||
|
||||
virtual std::shared_ptr<Node> Copy() const = 0;
|
||||
|
||||
virtual int type() const = 0;
|
||||
|
||||
virtual void Save(SerializationOut& ar) const = 0;
|
||||
virtual void Load(SerializationIn& ar) = 0;
|
||||
|
||||
};
|
||||
|
||||
typedef std::shared_ptr<Node> NodePtr;
|
||||
|
||||
|
||||
template <typename LeafFunctionT>
|
||||
class LeafNode : public Node {
|
||||
public:
|
||||
static const int TYPE = 0;
|
||||
|
||||
LeafNode() {}
|
||||
LeafNode(std::shared_ptr<LeafFunctionT> leaf_node_fcn) : leaf_node_fcn_(leaf_node_fcn) {}
|
||||
|
||||
virtual ~LeafNode() {}
|
||||
|
||||
virtual NodePtr Copy() const {
|
||||
auto node = std::make_shared<LeafNode>();
|
||||
node->leaf_node_fcn_ = leaf_node_fcn_->Copy();
|
||||
return node;
|
||||
}
|
||||
|
||||
virtual void Save(SerializationOut& ar) const {
|
||||
leaf_node_fcn_->Save(ar);
|
||||
}
|
||||
|
||||
virtual void Load(SerializationIn& ar) {
|
||||
leaf_node_fcn_ = std::make_shared<LeafFunctionT>();
|
||||
leaf_node_fcn_->Load(ar);
|
||||
}
|
||||
|
||||
virtual int type() const { return TYPE; };
|
||||
std::shared_ptr<LeafFunctionT> leaf_node_fcn() const { return leaf_node_fcn_; }
|
||||
|
||||
private:
|
||||
std::shared_ptr<LeafFunctionT> leaf_node_fcn_;
|
||||
|
||||
DISABLE_COPY_AND_ASSIGN(LeafNode);
|
||||
};
|
||||
|
||||
|
||||
template <typename SplitFunctionT, typename LeafFunctionT>
|
||||
class SplitNode : public Node {
|
||||
public:
|
||||
static const int TYPE = 1;
|
||||
|
||||
SplitNode() {}
|
||||
|
||||
SplitNode(NodePtr left, NodePtr right, std::shared_ptr<SplitFunctionT> split_fcn) :
|
||||
left_(left), right_(right), split_fcn_(split_fcn)
|
||||
{}
|
||||
|
||||
virtual ~SplitNode() {}
|
||||
|
||||
virtual std::shared_ptr<Node> Copy() const {
|
||||
std::shared_ptr<SplitNode> node = std::make_shared<SplitNode>();
|
||||
node->left_ = left_->Copy();
|
||||
node->right_ = right_->Copy();
|
||||
node->split_fcn_ = split_fcn_->Copy();
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
bool Split(SamplePtr sample) {
|
||||
return split_fcn_->Split(sample);
|
||||
}
|
||||
|
||||
virtual void Save(SerializationOut& ar) const {
|
||||
split_fcn_->Save(ar);
|
||||
|
||||
//left
|
||||
int type = left_->type();
|
||||
ar << type;
|
||||
left_->Save(ar);
|
||||
|
||||
//right
|
||||
type = right_->type();
|
||||
ar << type;
|
||||
right_->Save(ar);
|
||||
}
|
||||
|
||||
virtual void Load(SerializationIn& ar);
|
||||
|
||||
|
||||
virtual int type() const { return TYPE; }
|
||||
|
||||
NodePtr left() const { return left_; }
|
||||
NodePtr right() const { return right_; }
|
||||
std::shared_ptr<SplitFunctionT> split_fcn() const { return split_fcn_; }
|
||||
|
||||
void set_left(NodePtr left) { left_ = left; }
|
||||
void set_right(NodePtr right) { right_ = right; }
|
||||
void set_split_fcn(std::shared_ptr<SplitFunctionT> split_fcn) { split_fcn_ = split_fcn; }
|
||||
|
||||
public:
|
||||
NodePtr left_;
|
||||
NodePtr right_;
|
||||
std::shared_ptr<SplitFunctionT> split_fcn_;
|
||||
|
||||
DISABLE_COPY_AND_ASSIGN(SplitNode);
|
||||
};
|
||||
|
||||
|
||||
template <typename SplitFunctionT, typename LeafFunctionT>
|
||||
NodePtr MakeNode(int type) {
|
||||
NodePtr node;
|
||||
if(type == LeafNode<LeafFunctionT>::TYPE) {
|
||||
node = std::make_shared<LeafNode<LeafFunctionT>>();
|
||||
}
|
||||
else if(type == SplitNode<SplitFunctionT, LeafFunctionT>::TYPE) {
|
||||
node = std::make_shared<SplitNode<SplitFunctionT, LeafFunctionT>>();
|
||||
}
|
||||
else {
|
||||
std::cout << "[ERROR] unknown node type" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
|
||||
template <typename SplitFunctionT, typename LeafFunctionT>
|
||||
void SplitNode<SplitFunctionT, LeafFunctionT>::Load(SerializationIn& ar) {
|
||||
|
||||
split_fcn_ = std::make_shared<SplitFunctionT>();
|
||||
split_fcn_->Load(ar);
|
||||
|
||||
//left
|
||||
int left_type;
|
||||
ar >> left_type;
|
||||
left_ = MakeNode<SplitFunctionT, LeafFunctionT>(left_type);
|
||||
left_->Load(ar);
|
||||
|
||||
//right
|
||||
int right_type;
|
||||
ar >> right_type;
|
||||
right_ = MakeNode<SplitFunctionT, LeafFunctionT>(right_type);
|
||||
right_->Load(ar);
|
||||
}
|
256
hyperdepth/rf/serialization.h
Normal file
256
hyperdepth/rf/serialization.h
Normal file
@ -0,0 +1,256 @@
|
||||
#pragma once
|
||||
|
||||
#include <fstream>
|
||||
|
||||
class SerializationOut {
|
||||
public:
|
||||
SerializationOut(const std::string& path) : path_(path) {}
|
||||
virtual ~SerializationOut() {}
|
||||
|
||||
virtual SerializationOut& operator<<(const bool& v) = 0;
|
||||
virtual SerializationOut& operator<<(const char& v) = 0;
|
||||
virtual SerializationOut& operator<<(const int& v) = 0;
|
||||
virtual SerializationOut& operator<<(const unsigned int& v) = 0;
|
||||
virtual SerializationOut& operator<<(const long int& v) = 0;
|
||||
virtual SerializationOut& operator<<(const unsigned long int& v) = 0;
|
||||
virtual SerializationOut& operator<<(const long long int& v) = 0;
|
||||
virtual SerializationOut& operator<<(const unsigned long long int& v) = 0;
|
||||
virtual SerializationOut& operator<<(const float& v) = 0;
|
||||
virtual SerializationOut& operator<<(const double& v) = 0;
|
||||
|
||||
protected:
|
||||
const std::string& path_;
|
||||
};
|
||||
|
||||
class SerializationIn {
|
||||
public:
|
||||
SerializationIn(const std::string& path) : path_(path) {}
|
||||
virtual ~SerializationIn() {}
|
||||
|
||||
virtual SerializationIn& operator>>(bool& v) = 0;
|
||||
virtual SerializationIn& operator>>(char& v) = 0;
|
||||
virtual SerializationIn& operator>>(int& v) = 0;
|
||||
virtual SerializationIn& operator>>(unsigned int& v) = 0;
|
||||
virtual SerializationIn& operator>>(long int& v) = 0;
|
||||
virtual SerializationIn& operator>>(unsigned long int& v) = 0;
|
||||
virtual SerializationIn& operator>>(long long int& v) = 0;
|
||||
virtual SerializationIn& operator>>(unsigned long long int& v) = 0;
|
||||
virtual SerializationIn& operator>>(float& v) = 0;
|
||||
virtual SerializationIn& operator>>(double& v) = 0;
|
||||
|
||||
protected:
|
||||
const std::string& path_;
|
||||
};
|
||||
|
||||
class TextSerializationOut : public SerializationOut {
|
||||
public:
|
||||
TextSerializationOut(const std::string& path) : SerializationOut(path),
|
||||
f_(path.c_str()) {}
|
||||
virtual ~TextSerializationOut() {
|
||||
f_.close();
|
||||
}
|
||||
|
||||
virtual SerializationOut& operator<<(const bool& v) {
|
||||
f_ << v << std::endl;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const char& v) {
|
||||
f_ << v << std::endl;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const int& v) {
|
||||
f_ << v << std::endl;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const unsigned int& v) {
|
||||
f_ << v << std::endl;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const long int& v) {
|
||||
f_ << v << std::endl;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const unsigned long int& v) {
|
||||
f_ << v << std::endl;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const long long int& v) {
|
||||
f_ << v << std::endl;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const unsigned long long int& v) {
|
||||
f_ << v << std::endl;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const float& v) {
|
||||
f_ << v << std::endl;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const double& v) {
|
||||
f_ << v << std::endl;
|
||||
return (*this);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::ofstream f_;
|
||||
};
|
||||
|
||||
class TextSerializationIn : public SerializationIn {
|
||||
public:
|
||||
TextSerializationIn(const std::string& path) : SerializationIn(path),
|
||||
f_(path.c_str()) {}
|
||||
virtual ~TextSerializationIn() {
|
||||
f_.close();
|
||||
}
|
||||
|
||||
virtual SerializationIn& operator>>(bool& v) {
|
||||
f_ >> v;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(char& v) {
|
||||
f_ >> v;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(int& v) {
|
||||
f_ >> v;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(unsigned int& v) {
|
||||
f_ >> v;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(long int& v) {
|
||||
f_ >> v;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(unsigned long int& v) {
|
||||
f_ >> v;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(long long int& v) {
|
||||
f_ >> v;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(unsigned long long int& v) {
|
||||
f_ >> v;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(float& v) {
|
||||
f_ >> v;
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(double& v) {
|
||||
f_ >> v;
|
||||
return (*this);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::ifstream f_;
|
||||
};
|
||||
|
||||
class BinarySerializationOut : public SerializationOut {
|
||||
public:
|
||||
BinarySerializationOut(const std::string& path) : SerializationOut(path),
|
||||
f_(path.c_str(), std::ios::binary) {}
|
||||
virtual ~BinarySerializationOut() {
|
||||
f_.close();
|
||||
}
|
||||
|
||||
virtual SerializationOut& operator<<(const bool& v) {
|
||||
f_.write((char*)&v, sizeof(bool));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const char& v) {
|
||||
f_.write((char*)&v, sizeof(char));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const int& v) {
|
||||
f_.write((char*)&v, sizeof(int));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const unsigned int& v) {
|
||||
f_.write((char*)&v, sizeof(unsigned int));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const long int& v) {
|
||||
f_.write((char*)&v, sizeof(long int));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const unsigned long int& v) {
|
||||
f_.write((char*)&v, sizeof(unsigned long int));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const long long int& v) {
|
||||
f_.write((char*)&v, sizeof(long long int));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const unsigned long long int& v) {
|
||||
f_.write((char*)&v, sizeof(unsigned long long int));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const float& v) {
|
||||
f_.write((char*)&v, sizeof(float));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationOut& operator<<(const double& v) {
|
||||
f_.write((char*)&v, sizeof(double));
|
||||
return (*this);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::ofstream f_;
|
||||
};
|
||||
|
||||
class BinarySerializationIn : public SerializationIn {
|
||||
public:
|
||||
BinarySerializationIn(const std::string& path) : SerializationIn(path),
|
||||
f_(path.c_str(), std::ios::binary) {}
|
||||
virtual ~BinarySerializationIn() {
|
||||
f_.close();
|
||||
}
|
||||
|
||||
virtual SerializationIn& operator>>(bool& v) {
|
||||
f_.read((char*)&v, sizeof(bool));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(char& v) {
|
||||
f_.read((char*)&v, sizeof(char));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(int& v) {
|
||||
f_.read((char*)&v, sizeof(int));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(unsigned int& v) {
|
||||
f_.read((char*)&v, sizeof(unsigned int));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(long int& v) {
|
||||
f_.read((char*)&v, sizeof(long int));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(unsigned long int& v) {
|
||||
f_.read((char*)&v, sizeof(unsigned long int));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(long long int& v) {
|
||||
f_.read((char*)&v, sizeof(long long int));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(unsigned long long int& v) {
|
||||
f_.read((char*)&v, sizeof(unsigned long long int));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(float& v) {
|
||||
f_.read((char*)&v, sizeof(float));
|
||||
return (*this);
|
||||
}
|
||||
virtual SerializationIn& operator>>(double& v) {
|
||||
f_.read((char*)&v, sizeof(double));
|
||||
return (*this);
|
||||
}
|
||||
|
||||
protected:
|
||||
std::ifstream f_;
|
||||
};
|
||||
|
71
hyperdepth/rf/spliteval.h
Normal file
71
hyperdepth/rf/spliteval.h
Normal file
@ -0,0 +1,71 @@
|
||||
#pragma once
|
||||
|
||||
class SplitEvaluator {
|
||||
public:
|
||||
SplitEvaluator(bool normalize)
|
||||
: normalize_(normalize) {}
|
||||
|
||||
virtual ~SplitEvaluator() {}
|
||||
|
||||
virtual float Eval(const std::vector<TrainDatum>& lefttargets, const std::vector<TrainDatum>& righttargets, int depth) const {
|
||||
float purity_left = Purity(lefttargets, depth);
|
||||
float purity_right = Purity(righttargets, depth);
|
||||
|
||||
float normalize_left = 1.0;
|
||||
float normalize_right = 1.0;
|
||||
|
||||
if(normalize_) {
|
||||
unsigned int n_left = lefttargets.size();
|
||||
unsigned int n_right = righttargets.size();
|
||||
unsigned int n_total = n_left + n_right;
|
||||
|
||||
normalize_left = float(n_left) / float(n_total);
|
||||
normalize_right = float(n_right) / float(n_total);
|
||||
}
|
||||
|
||||
float purity = purity_left * normalize_left + purity_right * normalize_right;
|
||||
|
||||
return purity;
|
||||
}
|
||||
|
||||
protected:
|
||||
virtual float Purity(const std::vector<TrainDatum>& targets, int depth) const = 0;
|
||||
|
||||
protected:
|
||||
bool normalize_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class ClassificationIGSplitEvaluator : public SplitEvaluator {
|
||||
public:
|
||||
ClassificationIGSplitEvaluator(bool normalize, int n_classes)
|
||||
: SplitEvaluator(normalize), n_classes_(n_classes) {}
|
||||
virtual ~ClassificationIGSplitEvaluator() {}
|
||||
|
||||
protected:
|
||||
virtual float Purity(const std::vector<TrainDatum>& targets, int depth) const {
|
||||
if(targets.size() == 0) return 0;
|
||||
|
||||
std::vector<int> ps;
|
||||
ps.resize(n_classes_, 0);
|
||||
for(auto target : targets) {
|
||||
auto ctarget = std::static_pointer_cast<ClassificationTarget>(target.optimize_target);
|
||||
ps[ctarget->cl()] += 1;
|
||||
}
|
||||
|
||||
float h = 0;
|
||||
for(int cl = 0; cl < n_classes_; ++cl) {
|
||||
float fi = float(ps[cl]) / float(targets.size());
|
||||
if(fi > 0) {
|
||||
h = h - fi * std::log(fi);
|
||||
}
|
||||
}
|
||||
|
||||
return h;
|
||||
}
|
||||
|
||||
private:
|
||||
int n_classes_;
|
||||
};
|
||||
|
106
hyperdepth/rf/splitfcn.h
Normal file
106
hyperdepth/rf/splitfcn.h
Normal file
@ -0,0 +1,106 @@
|
||||
#pragma once
|
||||
|
||||
#include <random>
|
||||
|
||||
class SplitFunction {
|
||||
public:
|
||||
SplitFunction() {}
|
||||
virtual ~SplitFunction() {}
|
||||
|
||||
virtual float Compute(SamplePtr sample) const = 0;
|
||||
|
||||
virtual bool Split(SamplePtr sample) const {
|
||||
return Compute(sample) < threshold_;
|
||||
}
|
||||
|
||||
virtual void Save(SerializationOut& ar) const {
|
||||
ar << threshold_;
|
||||
}
|
||||
|
||||
virtual void Load(SerializationIn& ar) {
|
||||
ar >> threshold_;
|
||||
}
|
||||
|
||||
virtual float threshold() const { return threshold_; }
|
||||
virtual void set_threshold(float threshold) { threshold_ = threshold; }
|
||||
|
||||
protected:
|
||||
float threshold_;
|
||||
};
|
||||
|
||||
|
||||
|
||||
class SplitFunctionPixelDifference : public SplitFunction {
|
||||
public:
|
||||
|
||||
SplitFunctionPixelDifference() {}
|
||||
virtual ~SplitFunctionPixelDifference() {}
|
||||
|
||||
virtual std::shared_ptr<SplitFunctionPixelDifference> Copy() const {
|
||||
std::shared_ptr<SplitFunctionPixelDifference> split_fcn = std::make_shared<SplitFunctionPixelDifference>();
|
||||
split_fcn->threshold_ = threshold_;
|
||||
split_fcn->c0_ = c0_;
|
||||
split_fcn->c1_ = c1_;
|
||||
split_fcn->h0_ = h0_;
|
||||
split_fcn->h1_ = h1_;
|
||||
split_fcn->w0_ = w0_;
|
||||
split_fcn->w1_ = w1_;
|
||||
|
||||
return split_fcn;
|
||||
}
|
||||
|
||||
virtual std::shared_ptr<SplitFunctionPixelDifference> Generate(std::mt19937& rng, const SamplePtr sample) const {
|
||||
std::shared_ptr<SplitFunctionPixelDifference> split_fcn = std::make_shared<SplitFunctionPixelDifference>();
|
||||
|
||||
std::uniform_int_distribution<int> cdist(0, sample->channels()-1);
|
||||
split_fcn->c0_ = cdist(rng);
|
||||
split_fcn->c1_ = cdist(rng);
|
||||
|
||||
std::uniform_int_distribution<int> hdist(0, sample->height()-1);
|
||||
split_fcn->h0_ = hdist(rng);
|
||||
split_fcn->h1_ = hdist(rng);
|
||||
|
||||
std::uniform_int_distribution<int> wdist(0, sample->width()-1);
|
||||
split_fcn->w0_ = wdist(rng);
|
||||
split_fcn->w1_ = wdist(rng);
|
||||
|
||||
return split_fcn;
|
||||
}
|
||||
|
||||
virtual float Compute(SamplePtr sample) const {
|
||||
return (*sample)(c0_, h0_, w0_) - (*sample)(c1_, h1_, w1_);
|
||||
}
|
||||
|
||||
virtual void Save(SerializationOut& ar) const {
|
||||
SplitFunction::Save(ar);
|
||||
ar << c0_;
|
||||
ar << c1_;
|
||||
ar << h0_;
|
||||
ar << h1_;
|
||||
ar << w0_;
|
||||
ar << w1_;
|
||||
}
|
||||
|
||||
virtual void Load(SerializationIn& ar) {
|
||||
SplitFunction::Load(ar);
|
||||
|
||||
ar >> c0_;
|
||||
ar >> c1_;
|
||||
ar >> h0_;
|
||||
ar >> h1_;
|
||||
ar >> w0_;
|
||||
ar >> w1_;
|
||||
}
|
||||
|
||||
private:
|
||||
int c0_;
|
||||
int c1_;
|
||||
int h0_;
|
||||
int h1_;
|
||||
int w0_;
|
||||
int w1_;
|
||||
|
||||
DISABLE_COPY_AND_ASSIGN(SplitFunctionPixelDifference);
|
||||
};
|
||||
|
||||
|
112
hyperdepth/rf/threadpool.h
Normal file
112
hyperdepth/rf/threadpool.h
Normal file
@ -0,0 +1,112 @@
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include <future>
|
||||
#include <functional>
|
||||
#include <stdexcept>
|
||||
|
||||
class ThreadPool {
|
||||
public:
|
||||
ThreadPool(size_t);
|
||||
template<class F, class... Args>
|
||||
auto enqueue(F&& f, Args&&... args)
|
||||
-> std::future<typename std::result_of<F(Args...)>::type>;
|
||||
~ThreadPool();
|
||||
|
||||
bool has_running_tasks() {
|
||||
std::unique_lock<std::mutex> lock(running_tasks_mutex);
|
||||
return n_running_tasks > 0;
|
||||
}
|
||||
|
||||
private:
|
||||
// need to keep track of threads so we can join them
|
||||
std::vector< std::thread > workers;
|
||||
// the task queue
|
||||
std::queue< std::function<void()> > tasks;
|
||||
|
||||
int n_running_tasks;
|
||||
|
||||
// synchronization
|
||||
std::mutex queue_mutex;
|
||||
std::mutex running_tasks_mutex;
|
||||
std::condition_variable condition;
|
||||
bool stop;
|
||||
};
|
||||
|
||||
// the constructor just launches some amount of workers
|
||||
inline ThreadPool::ThreadPool(size_t threads)
|
||||
: n_running_tasks(0), stop(false)
|
||||
{
|
||||
for(size_t i = 0;i<threads;++i)
|
||||
workers.emplace_back(
|
||||
[this]
|
||||
{
|
||||
for(;;)
|
||||
{
|
||||
std::function<void()> task;
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(this->queue_mutex);
|
||||
this->condition.wait(lock,
|
||||
[this]{ return this->stop || !this->tasks.empty(); });
|
||||
if(this->stop && this->tasks.empty())
|
||||
return;
|
||||
task = std::move(this->tasks.front());
|
||||
this->tasks.pop();
|
||||
}
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(this->running_tasks_mutex);
|
||||
n_running_tasks++;
|
||||
}
|
||||
task();
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(this->running_tasks_mutex);
|
||||
n_running_tasks--;
|
||||
}
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
// add new work item to the pool
|
||||
template<class F, class... Args>
|
||||
auto ThreadPool::enqueue(F&& f, Args&&... args)
|
||||
-> std::future<typename std::result_of<F(Args...)>::type>
|
||||
{
|
||||
using return_type = typename std::result_of<F(Args...)>::type;
|
||||
|
||||
auto task = std::make_shared< std::packaged_task<return_type()> >(
|
||||
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
|
||||
);
|
||||
|
||||
std::future<return_type> res = task->get_future();
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
|
||||
// don't allow enqueueing after stopping the pool
|
||||
if(stop)
|
||||
throw std::runtime_error("enqueue on stopped ThreadPool");
|
||||
|
||||
tasks.emplace([task](){ (*task)(); });
|
||||
}
|
||||
condition.notify_one();
|
||||
return res;
|
||||
}
|
||||
|
||||
// the destructor joins all threads
|
||||
inline ThreadPool::~ThreadPool()
|
||||
{
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
stop = true;
|
||||
}
|
||||
condition.notify_all();
|
||||
for(std::thread &worker: workers)
|
||||
worker.join();
|
||||
}
|
423
hyperdepth/rf/train.h
Normal file
423
hyperdepth/rf/train.h
Normal file
@ -0,0 +1,423 @@
|
||||
#pragma once
|
||||
|
||||
#include <chrono>
|
||||
#include <set>
|
||||
#include <queue>
|
||||
|
||||
#include "threadpool.h"
|
||||
|
||||
#include "forest.h"
|
||||
#include "spliteval.h"
|
||||
|
||||
|
||||
enum class TrainType : int {
|
||||
TRAIN = 0,
|
||||
RETRAIN = 1,
|
||||
RETRAIN_WITH_REPLACEMENT = 2
|
||||
};
|
||||
|
||||
struct TrainParameters {
|
||||
TrainType train_type;
|
||||
int n_trees;
|
||||
int max_tree_depth;
|
||||
int n_test_split_functions;
|
||||
int n_test_thresholds;
|
||||
int n_test_samples;
|
||||
int min_samples_to_split;
|
||||
int min_samples_for_leaf;
|
||||
int print_node_info;
|
||||
|
||||
TrainParameters() :
|
||||
train_type(TrainType::TRAIN),
|
||||
n_trees(5),
|
||||
max_tree_depth(7),
|
||||
n_test_split_functions(50),
|
||||
n_test_thresholds(10),
|
||||
n_test_samples(100),
|
||||
min_samples_to_split(14),
|
||||
min_samples_for_leaf(7),
|
||||
print_node_info(100)
|
||||
{}
|
||||
};
|
||||
|
||||
|
||||
template <typename SplitFunctionT, typename LeafFunctionT, typename SplitEvaluatorT>
|
||||
class TrainForest {
|
||||
public:
|
||||
TrainForest(const TrainParameters& params, const std::shared_ptr<SplitFunctionT> gen_split_fcn, const std::shared_ptr<LeafFunctionT> gen_leaf_fcn, const std::shared_ptr<SplitEvaluatorT> split_eval, int n_threads, bool verbose)
|
||||
: params_(params), gen_split_fcn_(gen_split_fcn), gen_leaf_fcn_(gen_leaf_fcn), split_eval_(split_eval), n_threads(n_threads), verbose_(verbose) {
|
||||
|
||||
n_created_nodes_ = 0;
|
||||
n_max_nodes_ = 1;
|
||||
unsigned long n_nodes_d = 1;
|
||||
for(int depth = 0; depth < params.max_tree_depth; ++depth) {
|
||||
n_nodes_d *= 2;
|
||||
n_max_nodes_ += n_nodes_d;
|
||||
}
|
||||
n_max_nodes_ *= params.n_trees;
|
||||
}
|
||||
|
||||
virtual ~TrainForest() {}
|
||||
|
||||
virtual std::shared_ptr<Forest<SplitFunctionT, LeafFunctionT>> Train(const std::vector<TrainDatum>& samples, TrainType train_type, const std::shared_ptr<Forest<SplitFunctionT, LeafFunctionT>>& old_forest) = 0;
|
||||
|
||||
protected:
|
||||
virtual void PrintParams() {
|
||||
if(verbose_){
|
||||
#pragma omp critical (TrainForest_train)
|
||||
{
|
||||
std::cout << "[TRAIN] training forest " << std::endl;
|
||||
std::cout << "[TRAIN] n_trees : " << params_.n_trees << std::endl;
|
||||
std::cout << "[TRAIN] max_tree_depth : " << params_.max_tree_depth << std::endl;
|
||||
std::cout << "[TRAIN] n_test_split_functions: " << params_.n_test_split_functions << std::endl;
|
||||
std::cout << "[TRAIN] n_test_thresholds : " << params_.n_test_thresholds << std::endl;
|
||||
std::cout << "[TRAIN] n_test_samples : " << params_.n_test_samples << std::endl;
|
||||
std::cout << "[TRAIN] min_samples_to_split : " << params_.min_samples_to_split << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
virtual void UpdateNodeInfo(unsigned int depth, bool leaf) {
|
||||
if(verbose_) {
|
||||
n_created_nodes_ += 1;
|
||||
|
||||
if(leaf) {
|
||||
unsigned long n_nodes_d = 1;
|
||||
unsigned int n_remove_max_nodes = 0;
|
||||
for(int d = depth; d < params_.max_tree_depth; ++d) {
|
||||
n_nodes_d *= 2;
|
||||
n_remove_max_nodes += n_nodes_d;
|
||||
}
|
||||
n_max_nodes_ -= n_remove_max_nodes;
|
||||
}
|
||||
|
||||
if(n_created_nodes_ % params_.print_node_info == 0 || n_created_nodes_ == n_max_nodes_) {
|
||||
std::cout << "[Forest]"
|
||||
<< " created node number " << n_created_nodes_
|
||||
<< " @ depth " << depth
|
||||
<< ", max. " << n_max_nodes_ << " left"
|
||||
<< " => " << (double(n_created_nodes_) / double(n_max_nodes_))
|
||||
<< " done" << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
virtual void SampleData(const std::vector<TrainDatum>& all, std::vector<TrainDatum>& sampled, std::mt19937& rng) {
|
||||
unsigned int n = all.size();
|
||||
unsigned int k = params_.n_test_samples;
|
||||
k = n < k ? n : k;
|
||||
|
||||
std::set<int> indices;
|
||||
std::uniform_int_distribution<int> udist(0, all.size()-1);
|
||||
while(indices.size() < k) {
|
||||
int idx = udist(rng);
|
||||
indices.insert(idx);
|
||||
}
|
||||
|
||||
sampled.resize(k);
|
||||
int sidx = 0;
|
||||
for(int idx : indices) {
|
||||
sampled[sidx] = all[idx];
|
||||
sidx += 1;
|
||||
}
|
||||
}
|
||||
|
||||
virtual void Split(const std::shared_ptr<SplitFunctionT>& split_function, const std::vector<TrainDatum>& samples, std::vector<TrainDatum>& left, std::vector<TrainDatum>& right) {
|
||||
for(auto sample : samples) {
|
||||
if(split_function->Split(sample.sample)) {
|
||||
left.push_back(sample);
|
||||
}
|
||||
else {
|
||||
right.push_back(sample);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
virtual std::shared_ptr<SplitFunctionT> OptimizeSplitFunction(const std::vector<TrainDatum>& samples, int depth, std::mt19937& rng) {
|
||||
std::vector<TrainDatum> split_samples;
|
||||
SampleData(samples, split_samples, rng);
|
||||
|
||||
unsigned int min_samples_for_leaf = params_.min_samples_for_leaf;
|
||||
|
||||
float min_cost = std::numeric_limits<float>::max();
|
||||
std::shared_ptr<SplitFunctionT> best_split_fcn;
|
||||
float best_threshold = 0;
|
||||
|
||||
for(int split_fcn_idx = 0; split_fcn_idx < params_.n_test_split_functions; ++split_fcn_idx) {
|
||||
auto split_fcn = gen_split_fcn_->Generate(rng, samples[0].sample);
|
||||
|
||||
for(int threshold_idx = 0; threshold_idx < params_.n_test_thresholds; ++threshold_idx) {
|
||||
std::uniform_int_distribution<int> udist(0, split_samples.size()-1);
|
||||
int rand_split_sample_idx = udist(rng);
|
||||
float threshold = split_fcn->Compute(split_samples[rand_split_sample_idx].sample);
|
||||
split_fcn->set_threshold(threshold);
|
||||
|
||||
std::vector<TrainDatum> left;
|
||||
std::vector<TrainDatum> right;
|
||||
Split(split_fcn, split_samples, left, right);
|
||||
if(left.size() < min_samples_for_leaf || right.size() < min_samples_for_leaf) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// std::cout << "split done " << left.size() << "," << right.size() << std::endl;
|
||||
float split_cost = split_eval_->Eval(left, right, depth);
|
||||
// std::cout << ", " << split_cost << ", " << threshold << "; " << std::endl;
|
||||
|
||||
if(split_cost < min_cost) {
|
||||
min_cost = split_cost;
|
||||
best_split_fcn = split_fcn;
|
||||
best_threshold = threshold; //need theshold extra because of pointer
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(best_split_fcn != nullptr) {
|
||||
best_split_fcn->set_threshold(best_threshold);
|
||||
}
|
||||
|
||||
return best_split_fcn;
|
||||
}
|
||||
|
||||
|
||||
virtual NodePtr CreateLeafNode(const std::vector<TrainDatum>& samples, unsigned int depth) {
|
||||
auto leaf_fct = gen_leaf_fcn_->Create(samples);
|
||||
auto node = std::make_shared<LeafNode<LeafFunctionT>>(leaf_fct);
|
||||
|
||||
UpdateNodeInfo(depth, true);
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
protected:
|
||||
const TrainParameters& params_;
|
||||
const std::shared_ptr<SplitFunctionT> gen_split_fcn_;
|
||||
const std::shared_ptr<LeafFunctionT> gen_leaf_fcn_;
|
||||
const std::shared_ptr<SplitEvaluatorT> split_eval_;
|
||||
int n_threads;
|
||||
bool verbose_;
|
||||
|
||||
unsigned long n_created_nodes_;
|
||||
unsigned long n_max_nodes_;
|
||||
};
|
||||
|
||||
|
||||
template <typename SplitFunctionT, typename LeafFunctionT, typename SplitEvaluatorT>
|
||||
class TrainForestRecursive : public TrainForest<SplitFunctionT, LeafFunctionT, SplitEvaluatorT> {
|
||||
public:
|
||||
TrainForestRecursive(const TrainParameters& params, const std::shared_ptr<SplitFunctionT> gen_split_fcn, const std::shared_ptr<LeafFunctionT> gen_leaf_fcn, const std::shared_ptr<SplitEvaluatorT> split_eval, int n_threads, bool verbose)
|
||||
: TrainForest<SplitFunctionT, LeafFunctionT, SplitEvaluatorT>(params, gen_split_fcn, gen_leaf_fcn, split_eval, n_threads, verbose) {}
|
||||
|
||||
virtual ~TrainForestRecursive() {}
|
||||
|
||||
virtual std::shared_ptr<Forest<SplitFunctionT, LeafFunctionT>> Train(const std::vector<TrainDatum>& samples, TrainType train_type, const std::shared_ptr<Forest<SplitFunctionT, LeafFunctionT>>& old_forest) {
|
||||
|
||||
this->PrintParams();
|
||||
|
||||
auto tim = std::chrono::system_clock::now();
|
||||
auto forest = std::make_shared<Forest<SplitFunctionT, LeafFunctionT>>();
|
||||
|
||||
omp_set_num_threads(this->n_threads);
|
||||
#pragma omp parallel for ordered
|
||||
for(size_t treeIdx = 0; treeIdx < this->params_.n_trees; ++treeIdx) {
|
||||
auto treetim = std::chrono::system_clock::now();
|
||||
|
||||
#pragma omp critical (TrainForest_train)
|
||||
{
|
||||
if(this->verbose_){
|
||||
std::cout << "[TRAIN][START] training tree " << treeIdx << " of " << this->params_.n_trees << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
std::shared_ptr<Tree<SplitFunctionT, LeafFunctionT>> old_tree;
|
||||
if(old_forest != 0 && treeIdx < old_forest->trees_size()) {
|
||||
old_tree = old_forest->trees(treeIdx);
|
||||
}
|
||||
|
||||
std::random_device rd;
|
||||
std::mt19937 rng(rd());
|
||||
|
||||
auto tree = Train(samples, train_type, old_tree,rng);
|
||||
|
||||
#pragma omp critical (TrainForest_train)
|
||||
{
|
||||
forest->AddTree(tree);
|
||||
if(this->verbose_){
|
||||
auto now = std::chrono::system_clock::now();
|
||||
auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(now - treetim);
|
||||
std::cout << "[TRAIN][FINISHED] training tree " << treeIdx << " of " << this->params_.n_trees << " - took " << (ms.count() * 1e-3) << "[s]" << std::endl;
|
||||
std::cout << "[TRAIN][FINISHED] " << (this->params_.n_trees - forest->trees_size()) << " left for training" << std::endl;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(this->verbose_){
|
||||
auto now = std::chrono::system_clock::now();
|
||||
auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(now - tim);
|
||||
std::cout << "[TRAIN][FINISHED] training forest - took " << (ms.count() * 1e-3) << "[s]" << std::endl;
|
||||
}
|
||||
|
||||
return forest;
|
||||
}
|
||||
|
||||
private:
|
||||
virtual std::shared_ptr<Tree<SplitFunctionT, LeafFunctionT>> Train(const std::vector<TrainDatum>& samples, TrainType train_type, const std::shared_ptr<Tree<SplitFunctionT, LeafFunctionT>>& old_tree, std::mt19937& rng) {
|
||||
NodePtr old_root;
|
||||
if(old_tree != nullptr) {
|
||||
old_root = old_tree->root();
|
||||
}
|
||||
|
||||
NodePtr root = Train(samples, train_type, old_root, 0, rng);
|
||||
return std::make_shared<Tree<SplitFunctionT, LeafFunctionT>>(root);
|
||||
}
|
||||
|
||||
virtual NodePtr Train(const std::vector<TrainDatum>& samples, TrainType train_type, const NodePtr& old_node, unsigned int depth, std::mt19937& rng) {
|
||||
|
||||
if(depth < this->params_.max_tree_depth && samples.size() > this->params_.min_samples_to_split) {
|
||||
std::shared_ptr<SplitFunctionT> best_split_fcn;
|
||||
bool was_split_node = false;
|
||||
if(old_node == nullptr || old_node->type() == LeafNode<LeafFunctionT>::TYPE) {
|
||||
best_split_fcn = this->OptimizeSplitFunction(samples, depth, rng);
|
||||
was_split_node = false;
|
||||
}
|
||||
else if(old_node->type() == SplitNode<SplitFunctionT, LeafFunctionT>::TYPE) {
|
||||
auto split_node = std::static_pointer_cast<SplitNode<SplitFunctionT, LeafFunctionT>>(old_node);
|
||||
best_split_fcn = split_node->split_fcn()->Copy();
|
||||
was_split_node = true;
|
||||
}
|
||||
|
||||
if(best_split_fcn == nullptr) {
|
||||
if(old_node == nullptr || train_type == TrainType::TRAIN || train_type == TrainType::RETRAIN_WITH_REPLACEMENT) {
|
||||
return this->CreateLeafNode(samples, depth);
|
||||
}
|
||||
else if(train_type == TrainType::RETRAIN) {
|
||||
return old_node->Copy();
|
||||
}
|
||||
else {
|
||||
std::cout << "[ERROR] unknown train type" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
|
||||
// (1) split samples
|
||||
std::vector<TrainDatum> leftsamples, rightsamples;
|
||||
this->Split(best_split_fcn, samples, leftsamples, rightsamples);
|
||||
|
||||
//output node information
|
||||
this->UpdateNodeInfo(depth, false);
|
||||
|
||||
//create split node - recursively train the siblings
|
||||
if(was_split_node) {
|
||||
auto split_node = std::static_pointer_cast<SplitNode<SplitFunctionT, LeafFunctionT>>(old_node);
|
||||
NodePtr left = this->Train(leftsamples, train_type, split_node->left(), depth + 1, rng);
|
||||
NodePtr right = this->Train(rightsamples, train_type, split_node->right(), depth + 1, rng);
|
||||
auto new_node = std::make_shared<SplitNode<SplitFunctionT, LeafFunctionT>>(left, right, best_split_fcn);
|
||||
return new_node;
|
||||
}
|
||||
else {
|
||||
NodePtr left = this->Train(leftsamples, train_type, nullptr, depth + 1, rng);
|
||||
NodePtr right = this->Train(rightsamples, train_type, nullptr, depth + 1, rng);
|
||||
auto new_node = std::make_shared<SplitNode<SplitFunctionT, LeafFunctionT>>(left, right, best_split_fcn);
|
||||
return new_node;
|
||||
}
|
||||
} // if samples < min_samples || depth >= max_depth then make leaf node
|
||||
else {
|
||||
if(old_node == 0 || train_type == TrainType::TRAIN || train_type == TrainType::RETRAIN_WITH_REPLACEMENT) {
|
||||
return this->CreateLeafNode(samples, depth);
|
||||
}
|
||||
else if(train_type == TrainType::RETRAIN) {
|
||||
return old_node->Copy();
|
||||
}
|
||||
else {
|
||||
std::cout << "[ERROR] unknown train type" << std::endl;
|
||||
exit(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
|
||||
struct QueueTuple {
|
||||
int depth;
|
||||
std::vector<TrainDatum> train_data;
|
||||
NodePtr* parent;
|
||||
|
||||
QueueTuple() : depth(-1), train_data(), parent(nullptr) {}
|
||||
QueueTuple(int depth, std::vector<TrainDatum> train_data, NodePtr* parent) :
|
||||
depth(depth), train_data(train_data), parent(parent) {}
|
||||
};
|
||||
|
||||
template <typename SplitFunctionT, typename LeafFunctionT, typename SplitEvaluatorT>
|
||||
class TrainForestQueued : public TrainForest<SplitFunctionT, LeafFunctionT, SplitEvaluatorT> {
|
||||
public:
|
||||
TrainForestQueued(const TrainParameters& params, const std::shared_ptr<SplitFunctionT> gen_split_fcn, const std::shared_ptr<LeafFunctionT> gen_leaf_fcn, const std::shared_ptr<SplitEvaluatorT> split_eval, int n_threads, bool verbose)
|
||||
: TrainForest<SplitFunctionT, LeafFunctionT, SplitEvaluatorT>(params, gen_split_fcn, gen_leaf_fcn, split_eval, n_threads, verbose) {}
|
||||
|
||||
virtual ~TrainForestQueued() {}
|
||||
|
||||
virtual std::shared_ptr<Forest<SplitFunctionT, LeafFunctionT>> Train(const std::vector<TrainDatum>& samples, TrainType train_type, const std::shared_ptr<Forest<SplitFunctionT, LeafFunctionT>>& old_forest) {
|
||||
this->PrintParams();
|
||||
|
||||
auto tim = std::chrono::system_clock::now();
|
||||
auto forest = std::make_shared<Forest<SplitFunctionT, LeafFunctionT>>();
|
||||
|
||||
std::cout << "[TRAIN] create pool with " << this->n_threads << " threads" << std::endl;
|
||||
auto pool = std::make_shared<ThreadPool>(this->n_threads);
|
||||
for(int treeidx = 0; treeidx < this->params_.n_trees; ++treeidx) {
|
||||
auto tree = std::make_shared<Tree<SplitFunctionT, LeafFunctionT>>();
|
||||
forest->AddTree(tree);
|
||||
AddJob(pool, QueueTuple(0, samples, &(tree->root_)));
|
||||
}
|
||||
|
||||
while(pool->has_running_tasks()) {
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds(100));
|
||||
}
|
||||
|
||||
if(this->verbose_){
|
||||
auto now = std::chrono::system_clock::now();
|
||||
auto ms = std::chrono::duration_cast<std::chrono::milliseconds>(now - tim);
|
||||
std::cout << "[TRAIN][FINISHED] training forest - took " << (ms.count() * 1e-3) << "[s]" << std::endl;
|
||||
}
|
||||
|
||||
return forest;
|
||||
}
|
||||
|
||||
private:
|
||||
virtual void AddJob(std::shared_ptr<ThreadPool> pool, QueueTuple data) {
|
||||
pool->enqueue([this](std::shared_ptr<ThreadPool> pool, QueueTuple data) {
|
||||
std::random_device rd;
|
||||
std::mt19937 rng(rd());
|
||||
|
||||
std::shared_ptr<SplitFunctionT> best_split_fcn = nullptr;
|
||||
|
||||
if(data.depth < this->params_.max_tree_depth && int(data.train_data.size()) > this->params_.min_samples_to_split) {
|
||||
best_split_fcn = this->OptimizeSplitFunction(data.train_data, data.depth, rng);
|
||||
}
|
||||
|
||||
if(best_split_fcn == nullptr) {
|
||||
auto node = this->CreateLeafNode(data.train_data, data.depth);
|
||||
*(data.parent) = node;
|
||||
}
|
||||
else {
|
||||
this->UpdateNodeInfo(data.depth, false);
|
||||
auto node = std::make_shared<SplitNode<SplitFunctionT, LeafFunctionT>>();
|
||||
node->split_fcn_ = best_split_fcn;
|
||||
*(data.parent) = node;
|
||||
|
||||
QueueTuple left;
|
||||
QueueTuple right;
|
||||
this->Split(best_split_fcn, data.train_data, left.train_data, right.train_data);
|
||||
|
||||
left.depth = data.depth + 1;
|
||||
right.depth = data.depth + 1;
|
||||
|
||||
left.parent = &(node->left_);
|
||||
right.parent = &(node->right_);
|
||||
|
||||
this->AddJob(pool, left);
|
||||
this->AddJob(pool, right);
|
||||
}
|
||||
}, pool, data);
|
||||
}
|
||||
};
|
55
hyperdepth/rf/tree.h
Normal file
55
hyperdepth/rf/tree.h
Normal file
@ -0,0 +1,55 @@
|
||||
#pragma once
|
||||
|
||||
#include "node.h"
|
||||
|
||||
template <typename SplitFunctionT, typename LeafFunctionT>
|
||||
class Tree {
|
||||
public:
|
||||
Tree() : root_(nullptr) {}
|
||||
Tree(NodePtr root) : root_(root) {}
|
||||
|
||||
virtual ~Tree() {}
|
||||
|
||||
std::shared_ptr<LeafFunctionT> inference(const SamplePtr sample) const {
|
||||
if(root_ == nullptr) {
|
||||
std::cout << "[ERROR] tree inference root node is NULL";
|
||||
exit(-1);
|
||||
}
|
||||
|
||||
NodePtr node = root_;
|
||||
while(node->type() == SplitNode<SplitFunctionT, LeafFunctionT>::TYPE) {
|
||||
auto splitNode = std::static_pointer_cast<SplitNode<SplitFunctionT, LeafFunctionT>>(node);
|
||||
bool left = splitNode->Split(sample);
|
||||
if(left) {
|
||||
node = splitNode->left();
|
||||
}
|
||||
else {
|
||||
node = splitNode->right();
|
||||
}
|
||||
}
|
||||
|
||||
auto leaf_node = std::static_pointer_cast<LeafNode<LeafFunctionT>>(node);
|
||||
return leaf_node->leaf_node_fcn();
|
||||
}
|
||||
|
||||
NodePtr root() const { return root_; }
|
||||
void set_root(NodePtr root) { root_ = root; }
|
||||
|
||||
virtual void Save(SerializationOut& ar) const {
|
||||
int type = root_->type();
|
||||
ar << type;
|
||||
root_->Save(ar);
|
||||
}
|
||||
|
||||
virtual void Load(SerializationIn& ar) {
|
||||
int type;
|
||||
ar >> type;
|
||||
root_ = MakeNode<SplitFunctionT, LeafFunctionT>(type);
|
||||
root_->Load(ar);
|
||||
}
|
||||
|
||||
|
||||
public:
|
||||
NodePtr root_;
|
||||
};
|
||||
|
45
hyperdepth/setup.py
Normal file
45
hyperdepth/setup.py
Normal file
@ -0,0 +1,45 @@
|
||||
from distutils.core import setup
|
||||
from Cython.Build import cythonize
|
||||
from distutils.extension import Extension
|
||||
from Cython.Distutils import build_ext
|
||||
import numpy as np
|
||||
import platform
|
||||
import os
|
||||
|
||||
this_dir = os.path.dirname(__file__)
|
||||
|
||||
|
||||
extra_compile_args = ['-O3', '-std=c++11']
|
||||
extra_link_args = []
|
||||
|
||||
print('using openmp')
|
||||
extra_compile_args.append('-fopenmp')
|
||||
extra_link_args.append('-fopenmp')
|
||||
|
||||
sources = ['hyperdepth.pyx']
|
||||
extra_objects = []
|
||||
library_dirs = []
|
||||
libraries = ['m']
|
||||
|
||||
setup(
|
||||
name="hyperdepth",
|
||||
cmdclass= {'build_ext': build_ext},
|
||||
ext_modules=[
|
||||
Extension('hyperdepth',
|
||||
sources,
|
||||
extra_objects=extra_objects,
|
||||
language='c++',
|
||||
library_dirs=library_dirs,
|
||||
libraries=libraries,
|
||||
include_dirs=[
|
||||
np.get_include(),
|
||||
],
|
||||
extra_compile_args=extra_compile_args,
|
||||
extra_link_args=extra_link_args
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
52
hyperdepth/train.cpp
Normal file
52
hyperdepth/train.cpp
Normal file
@ -0,0 +1,52 @@
|
||||
#include "hyperdepth.h"
|
||||
#include "rf/train.h"
|
||||
|
||||
int main() {
|
||||
cv::Mat_<uint8_t> im = read_im(0);
|
||||
cv::Mat_<uint16_t> disp = read_disp(0);
|
||||
int im_rows = im.rows;
|
||||
int im_cols = im.cols;
|
||||
std::cout << im.rows << "/" << im.cols << std::endl;
|
||||
std::cout << disp.rows << "/" << disp.cols << std::endl;
|
||||
|
||||
TrainParameters params;
|
||||
params.n_trees = 6;
|
||||
params.n_test_samples = 2048;
|
||||
params.min_samples_to_split = 16;
|
||||
params.min_samples_for_leaf = 8;
|
||||
params.n_test_split_functions = 50;
|
||||
params.n_test_thresholds = 10;
|
||||
params.max_tree_depth = 8;
|
||||
|
||||
int n_classes = im_cols;
|
||||
int n_disp_bins = 16;
|
||||
int depth_switch = 4;
|
||||
|
||||
auto gen_split_fcn = std::make_shared<HDSplitFunctionT>();
|
||||
auto gen_leaf_fcn = std::make_shared<HDLeafFunctionT>(n_classes * n_disp_bins);
|
||||
auto split_eval = std::make_shared<HDSplitEvaluatorT>(true, n_classes, n_disp_bins, depth_switch);
|
||||
|
||||
for(int row = 0; row < im_rows; ++row) {
|
||||
std::vector<TrainDatum> train_data;
|
||||
for(int idx = 0; idx < 12; ++idx) {
|
||||
std::cout << "read sample " << idx << std::endl;
|
||||
im = read_im(idx);
|
||||
disp = read_disp(idx);
|
||||
|
||||
extract_row_samples(im, disp, row, train_data, true, n_disp_bins);
|
||||
}
|
||||
std::cout << "extracted " << train_data.size() << " train samples" << std::endl;
|
||||
std::cout << "n_classes (" << n_classes << ") * n_disp_bins (" << n_disp_bins << ") = " << (n_classes * n_disp_bins) << std::endl;
|
||||
|
||||
TrainForestQueued<HDSplitFunctionT, HDLeafFunctionT, HDSplitEvaluatorT> train(params, gen_split_fcn, gen_leaf_fcn, split_eval, true);
|
||||
|
||||
auto forest = train.Train(train_data, TrainType::TRAIN, nullptr);
|
||||
std::cout << "training done" << std::endl;
|
||||
|
||||
std::ostringstream forest_path;
|
||||
forest_path << "cforest_" << row << ".bin";
|
||||
BinarySerializationOut fout(forest_path.str());
|
||||
forest->Save(fout);
|
||||
}
|
||||
}
|
||||
|
15
hyperdepth/vis_eval.py
Normal file
15
hyperdepth/vis_eval.py
Normal file
@ -0,0 +1,15 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
import cv2
|
||||
|
||||
orig = cv2.imread('disp_orig.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
|
||||
ta = cv2.imread('disp_ta.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
|
||||
es = cv2.imread('disp_es.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
|
||||
|
||||
|
||||
plt.figure()
|
||||
plt.subplot(2,2,1); plt.imshow(orig / 16, vmin=0, vmax=4, cmap='magma')
|
||||
plt.subplot(2,2,2); plt.imshow(ta / 16, vmin=0, vmax=4, cmap='magma')
|
||||
plt.subplot(2,2,3); plt.imshow(es / 16, vmin=0, vmax=4, cmap='magma')
|
||||
plt.subplot(2,2,4); plt.imshow(np.abs(es - ta) / 16, vmin=0, vmax=1, cmap='magma')
|
||||
plt.show()
|
BIN
img/img.png
Normal file
BIN
img/img.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.0 MiB |
0
model/__init__.py
Normal file
0
model/__init__.py
Normal file
237
model/exp_synph.py
Normal file
237
model/exp_synph.py
Normal file
@ -0,0 +1,237 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import time
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import sys
|
||||
import itertools
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
import co
|
||||
import torchext
|
||||
from model import networks
|
||||
from data import dataset
|
||||
|
||||
class Worker(torchext.Worker):
|
||||
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
|
||||
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, train_batch_size=train_batch_size, test_batch_size=test_batch_size, save_frequency=save_frequency, **kwargs)
|
||||
|
||||
self.ms = args.ms
|
||||
self.pattern_path = args.pattern_path
|
||||
self.lcn_radius = args.lcn_radius
|
||||
self.dp_weight = args.dp_weight
|
||||
self.data_type = args.data_type
|
||||
|
||||
self.imsizes = [(480,640)]
|
||||
for iter in range(3):
|
||||
self.imsizes.append((int(self.imsizes[-1][0]/2), int(self.imsizes[-1][1]/2)))
|
||||
|
||||
with open('config.json') as fp:
|
||||
config = json.load(fp)
|
||||
data_root = Path(config['DATA_ROOT'])
|
||||
self.settings_path = data_root / self.data_type / 'settings.pkl'
|
||||
sample_paths = sorted((data_root / self.data_type).glob('0*/'))
|
||||
|
||||
self.train_paths = sample_paths[2**10:]
|
||||
self.test_paths = sample_paths[:2**8]
|
||||
|
||||
# supervise the edge encoder with only 2**8 samples
|
||||
self.train_edge = len(self.train_paths) - 2**8
|
||||
|
||||
self.lcn_in = networks.LCN(self.lcn_radius, 0.05)
|
||||
self.disparity_loss = networks.DisparityLoss()
|
||||
self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device))
|
||||
|
||||
# evaluate in the region where opencv Block Matching has valid values
|
||||
self.eval_mask = np.zeros(self.imsizes[0])
|
||||
self.eval_mask[13:self.imsizes[0][0]-13, 140:self.imsizes[0][1]-13]=1
|
||||
self.eval_mask = self.eval_mask.astype(np.bool)
|
||||
self.eval_h = self.imsizes[0][0]-2*13
|
||||
self.eval_w = self.imsizes[0][1]-13-140
|
||||
|
||||
def get_train_set(self):
|
||||
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, track_length=1)
|
||||
|
||||
return train_set
|
||||
|
||||
def get_test_sets(self):
|
||||
test_sets = torchext.TestSets()
|
||||
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True, track_length=1)
|
||||
test_sets.append('simple', test_set, test_frequency=1)
|
||||
|
||||
# initialize photometric loss modules according to image sizes
|
||||
self.losses = []
|
||||
for imsize, pat in zip(test_set.imsizes, test_set.patterns):
|
||||
pat = pat.mean(axis=2)
|
||||
pat = torch.from_numpy(pat[None][None].astype(np.float32))
|
||||
pat = pat.to(self.train_device)
|
||||
self.lcn_in = self.lcn_in.to(self.train_device)
|
||||
pat,_ = self.lcn_in(pat)
|
||||
pat = torch.cat([pat for idx in range(3)], dim=1)
|
||||
self.losses.append( networks.RectifiedPatternSimilarityLoss(imsize[0],imsize[1], pattern=pat) )
|
||||
|
||||
return test_sets
|
||||
|
||||
def copy_data(self, data, device, requires_grad, train):
|
||||
self.lcn_in = self.lcn_in.to(device)
|
||||
|
||||
self.data = {}
|
||||
for key, val in data.items():
|
||||
grad = 'im' in key and requires_grad
|
||||
self.data[key] = val.to(device).requires_grad_(requires_grad=grad)
|
||||
|
||||
# apply lcn to IR input
|
||||
# concatenate the normalized IR input and the original IR image
|
||||
if 'im' in key and 'blend' not in key:
|
||||
im = self.data[key]
|
||||
im_lcn,im_std = self.lcn_in(im)
|
||||
im_cat = torch.cat((im_lcn, im), dim=1)
|
||||
key_std = key.replace('im','std')
|
||||
self.data[key]=im_cat
|
||||
self.data[key_std] = im_std.to(device).detach()
|
||||
|
||||
def net_forward(self, net, train):
|
||||
out = net(self.data['im0'])
|
||||
return out
|
||||
|
||||
def loss_forward(self, out, train):
|
||||
out, edge = out
|
||||
if not(isinstance(out, tuple) or isinstance(out, list)):
|
||||
out = [out]
|
||||
if not(isinstance(edge, tuple) or isinstance(edge, list)):
|
||||
edge = [edge]
|
||||
|
||||
vals = []
|
||||
|
||||
# apply photometric loss
|
||||
for s,l,o in zip(itertools.count(), self.losses, out):
|
||||
val, pattern_proj = l(o, self.data[f'im{s}'][:,0:1,...], self.data[f'std{s}'])
|
||||
if s == 0:
|
||||
self.pattern_proj = pattern_proj.detach()
|
||||
vals.append(val)
|
||||
|
||||
# apply disparity loss
|
||||
# 1-edge as ground truth edge if inversed
|
||||
edge0 = 1-torch.sigmoid(edge[0])
|
||||
val = self.disparity_loss(out[0], edge0)
|
||||
if self.dp_weight>0:
|
||||
vals.append(val * self.dp_weight)
|
||||
|
||||
# apply edge loss on a subset of training samples
|
||||
for s,e in zip(itertools.count(), edge):
|
||||
# inversed ground truth edge where 0 means edge
|
||||
grad = self.data[f'grad{s}']<0.2
|
||||
grad = grad.to(torch.float32)
|
||||
ids = self.data['id']
|
||||
mask = ids>self.train_edge
|
||||
if mask.sum()>0:
|
||||
val = self.edge_loss(e[mask], grad[mask])
|
||||
else:
|
||||
val = torch.zeros_like(vals[0])
|
||||
if s == 0:
|
||||
self.edge = e.detach()
|
||||
self.edge = torch.sigmoid(self.edge)
|
||||
self.edge_gt = grad.detach()
|
||||
vals.append(val)
|
||||
|
||||
return vals
|
||||
|
||||
def numpy_in_out(self, output):
|
||||
output, edge = output
|
||||
if not(isinstance(output, tuple) or isinstance(output, list)):
|
||||
output = [output]
|
||||
es = output[0].detach().to('cpu').numpy()
|
||||
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
|
||||
im = self.data['im0'][:,0:1,...].detach().to('cpu').numpy()
|
||||
|
||||
ma = gt>0
|
||||
return es, gt, im, ma
|
||||
|
||||
def write_img(self, out_path, es, gt, im, ma):
|
||||
logging.info(f'write img {out_path}')
|
||||
u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0]))
|
||||
|
||||
diff = np.abs(es - gt)
|
||||
|
||||
vmin, vmax = np.nanmin(gt), np.nanmax(gt)
|
||||
vmin = vmin - 0.2*(vmax-vmin)
|
||||
vmax = vmax + 0.2*(vmax-vmin)
|
||||
|
||||
pattern_proj = self.pattern_proj.to('cpu').numpy()[0,0]
|
||||
im_orig = self.data['im0'].detach().to('cpu').numpy()[0,0]
|
||||
pattern_diff = np.abs(im_orig - pattern_proj)
|
||||
|
||||
|
||||
fig = plt.figure(figsize=(16,16))
|
||||
es_ = co.cmap.color_depth_map(es, scale=vmax)
|
||||
gt_ = co.cmap.color_depth_map(gt, scale=vmax)
|
||||
diff_ = co.cmap.color_error_image(diff, BGR=True)
|
||||
|
||||
# plot disparities, ground truth disparity is shown only for reference
|
||||
ax = plt.subplot(3,3,1); plt.imshow(es_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity Est. {es.min():.4f}/{es.max():.4f}')
|
||||
ax = plt.subplot(3,3,2); plt.imshow(gt_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity GT {np.nanmin(gt):.4f}/{np.nanmax(gt):.4f}')
|
||||
ax = plt.subplot(3,3,3); plt.imshow(diff_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity Err. {diff.mean():.5f}')
|
||||
|
||||
# plot edges
|
||||
edge = self.edge.to('cpu').numpy()[0,0]
|
||||
edge_gt = self.edge_gt.to('cpu').numpy()[0,0]
|
||||
edge_err = np.abs(edge - edge_gt)
|
||||
ax = plt.subplot(3,3,4); plt.imshow(edge, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge Est. {edge.min():.5f}/{edge.max():.5f}')
|
||||
ax = plt.subplot(3,3,5); plt.imshow(edge_gt, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge GT {edge_gt.min():.5f}/{edge_gt.max():.5f}')
|
||||
ax = plt.subplot(3,3,6); plt.imshow(edge_err, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge Err. {edge_err.mean():.5f}')
|
||||
|
||||
# plot normalized IR input and warped pattern
|
||||
ax = plt.subplot(3,3,7); plt.imshow(im, vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'IR input {im.mean():.5f}/{im.std():.5f}')
|
||||
ax = plt.subplot(3,3,8); plt.imshow(pattern_proj, vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Warped Pattern {pattern_proj.mean():.5f}/{pattern_proj.std():.5f}')
|
||||
im_std = self.data['std0'].to('cpu').numpy()[0,0]
|
||||
ax = plt.subplot(3,3,9); plt.imshow(im_std, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'IR std {im_std.min():.5f}/{im_std.max():.5f}')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(str(out_path))
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks=[]):
|
||||
if batch_idx % 512 == 0:
|
||||
out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
|
||||
es, gt, im, ma = self.numpy_in_out(output)
|
||||
self.write_img(out_path, es[0,0], gt[0,0], im[0,0], ma[0,0])
|
||||
|
||||
|
||||
def callback_test_start(self, epoch, set_idx):
|
||||
self.metric = co.metric.MultipleMetric(
|
||||
co.metric.DistanceMetric(vec_length=1),
|
||||
co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5])
|
||||
)
|
||||
|
||||
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks=[]):
|
||||
es, gt, im, ma = self.numpy_in_out(output)
|
||||
|
||||
if batch_idx % 8 == 0:
|
||||
out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png'
|
||||
self.write_img(out_path, es[0,0], gt[0,0], im[0,0], ma[0,0])
|
||||
|
||||
es, gt, im, ma = self.crop_output(es, gt, im, ma)
|
||||
|
||||
es = es.reshape(-1,1)
|
||||
gt = gt.reshape(-1,1)
|
||||
ma = ma.ravel()
|
||||
self.metric.add(es, gt, ma)
|
||||
|
||||
def callback_test_stop(self, epoch, set_idx, loss):
|
||||
logging.info(f'{self.metric}')
|
||||
for k, v in self.metric.items():
|
||||
self.metric_add_test(epoch, set_idx, k, v)
|
||||
|
||||
def crop_output(self, es, gt, im, ma):
|
||||
bs = es.shape[0]
|
||||
es = np.reshape(es[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
|
||||
gt = np.reshape(gt[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
|
||||
im = np.reshape(im[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
|
||||
ma = np.reshape(ma[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
|
||||
return es, gt, im, ma
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
298
model/exp_synphge.py
Normal file
298
model/exp_synphge.py
Normal file
@ -0,0 +1,298 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import time
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import sys
|
||||
import itertools
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
import co
|
||||
import torchext
|
||||
from model import networks
|
||||
from data import dataset
|
||||
|
||||
class Worker(torchext.Worker):
|
||||
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
|
||||
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, train_batch_size=train_batch_size, test_batch_size=test_batch_size, save_frequency=save_frequency, **kwargs)
|
||||
|
||||
self.ms = args.ms
|
||||
self.pattern_path = args.pattern_path
|
||||
self.lcn_radius = args.lcn_radius
|
||||
self.dp_weight = args.dp_weight
|
||||
self.ge_weight = args.ge_weight
|
||||
self.track_length = args.track_length
|
||||
self.data_type = args.data_type
|
||||
assert(self.track_length>1)
|
||||
|
||||
self.imsizes = [(480,640)]
|
||||
for iter in range(3):
|
||||
self.imsizes.append((int(self.imsizes[-1][0]/2), int(self.imsizes[-1][1]/2)))
|
||||
|
||||
with open('config.json') as fp:
|
||||
config = json.load(fp)
|
||||
data_root = Path(config['DATA_ROOT'])
|
||||
self.settings_path = data_root / self.data_type / 'settings.pkl'
|
||||
sample_paths = sorted((data_root / self.data_type).glob('0*/'))
|
||||
|
||||
self.train_paths = sample_paths[2**10:]
|
||||
self.test_paths = sample_paths[:2**8]
|
||||
|
||||
# supervise the edge encoder with only 2**8 samples
|
||||
self.train_edge = len(self.train_paths) - 2**8
|
||||
|
||||
self.lcn_in = networks.LCN(self.lcn_radius, 0.05)
|
||||
self.disparity_loss = networks.DisparityLoss()
|
||||
self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device))
|
||||
|
||||
# evaluate in the region where opencv Block Matching has valid values
|
||||
self.eval_mask = np.zeros(self.imsizes[0])
|
||||
self.eval_mask[13:self.imsizes[0][0]-13, 140:self.imsizes[0][1]-13]=1
|
||||
self.eval_mask = self.eval_mask.astype(np.bool)
|
||||
self.eval_h = self.imsizes[0][0]-2*13
|
||||
self.eval_w = self.imsizes[0][1]-13-140
|
||||
|
||||
|
||||
def get_train_set(self):
|
||||
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, track_length=self.track_length)
|
||||
return train_set
|
||||
|
||||
def get_test_sets(self):
|
||||
test_sets = torchext.TestSets()
|
||||
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True, track_length=1)
|
||||
test_sets.append('simple', test_set, test_frequency=1)
|
||||
|
||||
self.ph_losses = []
|
||||
self.ge_losses = []
|
||||
self.d2ds = []
|
||||
|
||||
self.lcn_in = self.lcn_in.to('cuda')
|
||||
for sidx in range(len(test_set.imsizes)):
|
||||
imsize = test_set.imsizes[sidx]
|
||||
pat = test_set.patterns[sidx]
|
||||
pat = pat.mean(axis=2)
|
||||
pat = torch.from_numpy(pat[None][None].astype(np.float32)).to('cuda')
|
||||
pat,_ = self.lcn_in(pat)
|
||||
pat = torch.cat([pat for idx in range(3)], dim=1)
|
||||
ph_loss = networks.RectifiedPatternSimilarityLoss(imsize[0],imsize[1], pattern=pat)
|
||||
|
||||
K = test_set.getK(sidx)
|
||||
Ki = np.linalg.inv(K)
|
||||
K = torch.from_numpy(K)
|
||||
Ki = torch.from_numpy(Ki)
|
||||
ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.1)
|
||||
|
||||
self.ph_losses.append( ph_loss )
|
||||
self.ge_losses.append( ge_loss )
|
||||
|
||||
d2d = networks.DispToDepth(float(test_set.focal_lengths[sidx]), float(test_set.baseline))
|
||||
self.d2ds.append( d2d )
|
||||
|
||||
return test_sets
|
||||
|
||||
def copy_data(self, data, device, requires_grad, train):
|
||||
self.data = {}
|
||||
|
||||
self.lcn_in = self.lcn_in.to(device)
|
||||
for key, val in data.items():
|
||||
# from
|
||||
# batch_size x track_length x ...
|
||||
# to
|
||||
# track_length x batch_size x ...
|
||||
if len(val.shape)>2:
|
||||
if train:
|
||||
val = val.transpose(0,1)
|
||||
else:
|
||||
val = val.unsqueeze(0)
|
||||
grad = 'im' in key and requires_grad
|
||||
self.data[key] = val.to(device).requires_grad_(requires_grad=grad)
|
||||
if 'im' in key and 'blend' not in key:
|
||||
im = self.data[key]
|
||||
tl = im.shape[0]
|
||||
bs = im.shape[1]
|
||||
im_lcn,im_std = self.lcn_in(im.contiguous().view(-1, *im.shape[2:]))
|
||||
key_std = key.replace('im','std')
|
||||
self.data[key_std] = im_std.view(tl, bs, *im.shape[2:]).to(device)
|
||||
im_cat = torch.cat((im_lcn.view(tl, bs, *im.shape[2:]), im), dim=2)
|
||||
self.data[key] = im_cat
|
||||
|
||||
def net_forward(self, net, train):
|
||||
im0 = self.data['im0']
|
||||
tl = im0.shape[0]
|
||||
bs = im0.shape[1]
|
||||
im0 = im0.view(-1, *im0.shape[2:])
|
||||
out, edge = net(im0)
|
||||
if not(isinstance(out, tuple) or isinstance(out, list)):
|
||||
out = out.view(tl, bs, *out.shape[1:])
|
||||
edge = edge.view(tl, bs, *out.shape[1:])
|
||||
else:
|
||||
out = [o.view(tl, bs, *o.shape[1:]) for o in out]
|
||||
edge = [e.view(tl, bs, *e.shape[1:]) for e in edge]
|
||||
return out, edge
|
||||
|
||||
def loss_forward(self, out, train):
|
||||
out, edge = out
|
||||
if not(isinstance(out, tuple) or isinstance(out, list)):
|
||||
out = [out]
|
||||
vals = []
|
||||
diffs = []
|
||||
|
||||
# apply photometric loss
|
||||
for s,l,o in zip(itertools.count(), self.ph_losses, out):
|
||||
im = self.data[f'im{s}']
|
||||
im = im.view(-1, *im.shape[2:])
|
||||
o = o.view(-1, *o.shape[2:])
|
||||
std = self.data[f'std{s}']
|
||||
std = std.view(-1, *std.shape[2:])
|
||||
val, pattern_proj = l(o, im[:,0:1,...], std)
|
||||
vals.append(val)
|
||||
if s == 0:
|
||||
self.pattern_proj = pattern_proj.detach()
|
||||
|
||||
# apply disparity loss
|
||||
# 1-edge as ground truth edge if inversed
|
||||
edge0 = 1-torch.sigmoid(edge[0])
|
||||
edge0 = edge0.view(-1, *edge0.shape[2:])
|
||||
out0 = out[0].view(-1, *out[0].shape[2:])
|
||||
val = self.disparity_loss(out0, edge0)
|
||||
if self.dp_weight>0:
|
||||
vals.append(val * self.dp_weight)
|
||||
|
||||
# apply edge loss on a subset of training samples
|
||||
for s,e in zip(itertools.count(), edge):
|
||||
# inversed ground truth edge where 0 means edge
|
||||
grad = self.data[f'grad{s}']<0.2
|
||||
grad = grad.to(torch.float32)
|
||||
ids = self.data['id']
|
||||
mask = ids>self.train_edge
|
||||
if mask.sum()>0:
|
||||
e = e[:,mask,:]
|
||||
grad = grad[:,mask,:]
|
||||
e = e.view(-1, *e.shape[2:])
|
||||
grad = grad.view(-1, *grad.shape[2:])
|
||||
val = self.edge_loss(e, grad)
|
||||
else:
|
||||
val = torch.zeros_like(vals[0])
|
||||
vals.append(val)
|
||||
|
||||
if train is False:
|
||||
return vals
|
||||
|
||||
# apply geometric loss
|
||||
R = self.data['R']
|
||||
t = self.data['t']
|
||||
ge_num = self.track_length * (self.track_length-1) / 2
|
||||
for sidx in range(len(out)):
|
||||
d2d = self.d2ds[sidx]
|
||||
depth = d2d(out[sidx])
|
||||
ge_loss = self.ge_losses[sidx]
|
||||
imsize = self.imsizes[sidx]
|
||||
for tidx0 in range(depth.shape[0]):
|
||||
for tidx1 in range(tidx0+1, depth.shape[0]):
|
||||
depth0 = depth[tidx0]
|
||||
R0 = R[tidx0]
|
||||
t0 = t[tidx0]
|
||||
depth1 = depth[tidx1]
|
||||
R1 = R[tidx1]
|
||||
t1 = t[tidx1]
|
||||
|
||||
val = ge_loss(depth0, depth1, R0, t0, R1, t1)
|
||||
vals.append(val * self.ge_weight / ge_num)
|
||||
|
||||
return vals
|
||||
|
||||
def numpy_in_out(self, output):
|
||||
output, edge = output
|
||||
if not(isinstance(output, tuple) or isinstance(output, list)):
|
||||
output = [output]
|
||||
es = output[0].detach().to('cpu').numpy()
|
||||
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
|
||||
im = self.data['im0'][:,:,0:1,...].detach().to('cpu').numpy()
|
||||
ma = gt>0
|
||||
return es, gt, im, ma
|
||||
|
||||
def write_img(self, out_path, es, gt, im, ma):
|
||||
logging.info(f'write img {out_path}')
|
||||
u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0]))
|
||||
|
||||
diff = np.abs(es - gt)
|
||||
|
||||
vmin, vmax = np.nanmin(gt), np.nanmax(gt)
|
||||
vmin = vmin - 0.2*(vmax-vmin)
|
||||
vmax = vmax + 0.2*(vmax-vmin)
|
||||
|
||||
pattern_proj = self.pattern_proj.to('cpu').numpy()[0,0]
|
||||
im_orig = self.data['im0'].detach().to('cpu').numpy()[0,0,0]
|
||||
pattern_diff = np.abs(im_orig - pattern_proj)
|
||||
|
||||
fig = plt.figure(figsize=(16,16))
|
||||
es0 = co.cmap.color_depth_map(es[0], scale=vmax)
|
||||
gt0 = co.cmap.color_depth_map(gt[0], scale=vmax)
|
||||
diff0 = co.cmap.color_error_image(diff[0], BGR=True)
|
||||
|
||||
# plot disparities, ground truth disparity is shown only for reference
|
||||
ax = plt.subplot(3,3,1); plt.imshow(es0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Est. {es0.min():.4f}/{es0.max():.4f}')
|
||||
ax = plt.subplot(3,3,2); plt.imshow(gt0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity GT {np.nanmin(gt0):.4f}/{np.nanmax(gt0):.4f}')
|
||||
ax = plt.subplot(3,3,3); plt.imshow(diff0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Err. {diff0.mean():.5f}')
|
||||
|
||||
# plot disparities of the second frame in the track if exists
|
||||
if es.shape[0]>=2:
|
||||
es1 = co.cmap.color_depth_map(es[1], scale=vmax)
|
||||
gt1 = co.cmap.color_depth_map(gt[1], scale=vmax)
|
||||
diff1 = co.cmap.color_error_image(diff[1], BGR=True)
|
||||
ax = plt.subplot(3,3,4); plt.imshow(es1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity Est. {es1.min():.4f}/{es1.max():.4f}')
|
||||
ax = plt.subplot(3,3,5); plt.imshow(gt1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity GT {np.nanmin(gt1):.4f}/{np.nanmax(gt1):.4f}')
|
||||
ax = plt.subplot(3,3,6); plt.imshow(diff1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity Err. {diff1.mean():.5f}')
|
||||
|
||||
# plot normalized IR inputs
|
||||
ax = plt.subplot(3,3,7); plt.imshow(im[0], vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 IR input {im[0].mean():.5f}/{im[0].std():.5f}')
|
||||
if es.shape[0]>=2:
|
||||
ax = plt.subplot(3,3,8); plt.imshow(im[1], vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 IR input {im[1].mean():.5f}/{im[1].std():.5f}')
|
||||
|
||||
plt.tight_layout()
|
||||
plt.savefig(str(out_path))
|
||||
plt.close(fig)
|
||||
|
||||
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
|
||||
if batch_idx % 512 == 0:
|
||||
out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
|
||||
es, gt, im, ma = self.numpy_in_out(output)
|
||||
masks = [ m.detach().to('cpu').numpy() for m in masks ]
|
||||
self.write_img(out_path, es[:,0,0], gt[:,0,0], im[:,0,0], ma[:,0,0])
|
||||
|
||||
def callback_test_start(self, epoch, set_idx):
|
||||
self.metric = co.metric.MultipleMetric(
|
||||
co.metric.DistanceMetric(vec_length=1),
|
||||
co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5])
|
||||
)
|
||||
|
||||
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
|
||||
es, gt, im, ma = self.numpy_in_out(output)
|
||||
|
||||
if batch_idx % 8 == 0:
|
||||
out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png'
|
||||
self.write_img(out_path, es[:,0,0], gt[:,0,0], im[:,0,0], ma[:,0,0])
|
||||
|
||||
es, gt, im, ma = self.crop_output(es, gt, im, ma)
|
||||
|
||||
es = es.reshape(-1,1)
|
||||
gt = gt.reshape(-1,1)
|
||||
ma = ma.ravel()
|
||||
self.metric.add(es, gt, ma)
|
||||
|
||||
def callback_test_stop(self, epoch, set_idx, loss):
|
||||
logging.info(f'{self.metric}')
|
||||
for k, v in self.metric.items():
|
||||
self.metric_add_test(epoch, set_idx, k, v)
|
||||
|
||||
def crop_output(self, es, gt, im, ma):
|
||||
tl = es.shape[0]
|
||||
bs = es.shape[1]
|
||||
es = np.reshape(es[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
|
||||
gt = np.reshape(gt[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
|
||||
im = np.reshape(im[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
|
||||
ma = np.reshape(ma[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
|
||||
return es, gt, im, ma
|
||||
|
||||
if __name__ == '__main__':
|
||||
pass
|
566
model/networks.py
Normal file
566
model/networks.py
Normal file
@ -0,0 +1,566 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
import torchext
|
||||
import co
|
||||
|
||||
|
||||
class TimedModule(torch.nn.Module):
|
||||
def __init__(self, mod_name):
|
||||
super().__init__()
|
||||
self.mod_name = mod_name
|
||||
|
||||
def tforward(self, *args, **kwargs):
|
||||
raise Exception('not implemented')
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
torch.cuda.synchronize()
|
||||
with co.gtimer.Ctx(self.mod_name):
|
||||
x = self.tforward(*args, **kwargs)
|
||||
torch.cuda.synchronize()
|
||||
return x
|
||||
|
||||
|
||||
class PosOutput(TimedModule):
|
||||
def __init__(self, channels_in, type, im_height, im_width, alpha=1, beta=0, gamma=1, offset=0):
|
||||
super().__init__(mod_name='PosOutput')
|
||||
self.im_width = im_width
|
||||
self.im_width = im_width
|
||||
|
||||
if type == 'pos':
|
||||
self.layer = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1),
|
||||
SigmoidAffine(alpha=alpha, beta=beta, gamma=gamma, offset=offset)
|
||||
)
|
||||
elif type == 'pos_row':
|
||||
self.layer = torch.nn.Sequential(
|
||||
MultiLinear(im_height, channels_in, 1),
|
||||
SigmoidAffine(alpha=alpha, beta=beta, gamma=gamma, offset=offset)
|
||||
)
|
||||
|
||||
self.u_pos = None
|
||||
|
||||
def tforward(self, x):
|
||||
if self.u_pos is None:
|
||||
self.u_pos = torch.arange(x.shape[3], dtype=torch.float32).view(1,1,1,-1)
|
||||
self.u_pos = self.u_pos.to(x.device)
|
||||
pos = self.layer(x)
|
||||
disp = self.u_pos - pos
|
||||
return disp
|
||||
|
||||
|
||||
class OutputLayerFactory(object):
|
||||
'''
|
||||
Define type of output
|
||||
type options:
|
||||
linear: apply only conv channel, used for the edge decoder
|
||||
disp: estimate the disparity
|
||||
disp_row: independently estimate the disparity per row
|
||||
pos: estimate the absolute location
|
||||
pos_row: independently estimate the absolute location per row
|
||||
'''
|
||||
def __init__(self, type='disp', params={}):
|
||||
self.type = type
|
||||
self.params = params
|
||||
|
||||
def __call__(self, channels_in, imsize):
|
||||
|
||||
if self.type == 'linear':
|
||||
return torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1)
|
||||
|
||||
elif self.type == 'disp':
|
||||
return torch.nn.Sequential(
|
||||
torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1),
|
||||
SigmoidAffine(**self.params)
|
||||
)
|
||||
|
||||
elif self.type == 'disp_row':
|
||||
return torch.nn.Sequential(
|
||||
MultiLinear(imsize[0], channels_in, 1),
|
||||
SigmoidAffine(**self.params)
|
||||
)
|
||||
|
||||
elif self.type == 'pos' or self.type == 'pos_row':
|
||||
return PosOutput(channels_in, **self.params)
|
||||
|
||||
else:
|
||||
raise Exception('unknown output layer type')
|
||||
|
||||
|
||||
class SigmoidAffine(TimedModule):
|
||||
def __init__(self, alpha=1, beta=0, gamma=1, offset=0):
|
||||
super().__init__(mod_name='SigmoidAffine')
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.gamma = gamma
|
||||
self.offset = offset
|
||||
|
||||
def tforward(self, x):
|
||||
return torch.sigmoid(x/self.gamma - self.offset) * self.alpha + self.beta
|
||||
|
||||
|
||||
class MultiLinear(TimedModule):
|
||||
def __init__(self, n, channels_in, channels_out):
|
||||
super().__init__(mod_name='MultiLinear')
|
||||
self.channels_out = channels_out
|
||||
self.mods = torch.nn.ModuleList()
|
||||
for idx in range(n):
|
||||
self.mods.append(torch.nn.Linear(channels_in, channels_out))
|
||||
|
||||
def tforward(self, x):
|
||||
x = x.permute(2,0,3,1) # BxCxHxW => HxBxWxC
|
||||
y = x.new_empty(*x.shape[:-1], self.channels_out)
|
||||
for hidx in range(x.shape[0]):
|
||||
y[hidx] = self.mods[hidx](x[hidx])
|
||||
y = y.permute(1,3,0,2) # HxBxWxC => BxCxHxW
|
||||
return y
|
||||
|
||||
|
||||
|
||||
class DispNetS(TimedModule):
|
||||
'''
|
||||
Disparity Decoder based on DispNetS
|
||||
'''
|
||||
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False, channel_multiplier=1):
|
||||
super(DispNetS, self).__init__(mod_name='DispNetS')
|
||||
|
||||
self.output_ms = output_ms
|
||||
self.coordconv = coordconv
|
||||
|
||||
conv_planes = channel_multiplier * np.array( [32, 64, 128, 256, 512, 512, 512] )
|
||||
self.conv1 = self.downsample_conv(channels_in, conv_planes[0], kernel_size=7)
|
||||
self.conv2 = self.downsample_conv(conv_planes[0], conv_planes[1], kernel_size=5)
|
||||
self.conv3 = self.downsample_conv(conv_planes[1], conv_planes[2])
|
||||
self.conv4 = self.downsample_conv(conv_planes[2], conv_planes[3])
|
||||
self.conv5 = self.downsample_conv(conv_planes[3], conv_planes[4])
|
||||
self.conv6 = self.downsample_conv(conv_planes[4], conv_planes[5])
|
||||
self.conv7 = self.downsample_conv(conv_planes[5], conv_planes[6])
|
||||
|
||||
upconv_planes = channel_multiplier * np.array( [512, 512, 256, 128, 64, 32, 16] )
|
||||
self.upconv7 = self.upconv(conv_planes[6], upconv_planes[0])
|
||||
self.upconv6 = self.upconv(upconv_planes[0], upconv_planes[1])
|
||||
self.upconv5 = self.upconv(upconv_planes[1], upconv_planes[2])
|
||||
self.upconv4 = self.upconv(upconv_planes[2], upconv_planes[3])
|
||||
self.upconv3 = self.upconv(upconv_planes[3], upconv_planes[4])
|
||||
self.upconv2 = self.upconv(upconv_planes[4], upconv_planes[5])
|
||||
self.upconv1 = self.upconv(upconv_planes[5], upconv_planes[6])
|
||||
|
||||
self.iconv7 = self.conv(upconv_planes[0] + conv_planes[5], upconv_planes[0])
|
||||
self.iconv6 = self.conv(upconv_planes[1] + conv_planes[4], upconv_planes[1])
|
||||
self.iconv5 = self.conv(upconv_planes[2] + conv_planes[3], upconv_planes[2])
|
||||
self.iconv4 = self.conv(upconv_planes[3] + conv_planes[2], upconv_planes[3])
|
||||
self.iconv3 = self.conv(1 + upconv_planes[4] + conv_planes[1], upconv_planes[4])
|
||||
self.iconv2 = self.conv(1 + upconv_planes[5] + conv_planes[0], upconv_planes[5])
|
||||
self.iconv1 = self.conv(1 + upconv_planes[6], upconv_planes[6])
|
||||
|
||||
if isinstance(output_facs, list):
|
||||
self.predict_disp4 = output_facs[3](upconv_planes[3], imsizes[3])
|
||||
self.predict_disp3 = output_facs[2](upconv_planes[4], imsizes[2])
|
||||
self.predict_disp2 = output_facs[1](upconv_planes[5], imsizes[1])
|
||||
self.predict_disp1 = output_facs[0](upconv_planes[6], imsizes[0])
|
||||
else:
|
||||
self.predict_disp4 = output_facs(upconv_planes[3], imsizes[3])
|
||||
self.predict_disp3 = output_facs(upconv_planes[4], imsizes[2])
|
||||
self.predict_disp2 = output_facs(upconv_planes[5], imsizes[1])
|
||||
self.predict_disp1 = output_facs(upconv_planes[6], imsizes[0])
|
||||
|
||||
|
||||
def init_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.ConvTranspose2d):
|
||||
torch.nn.init.xavier_uniform_(m.weight, gain=0.1)
|
||||
if m.bias is not None:
|
||||
torch.nn.init.zeros_(m.bias)
|
||||
|
||||
def downsample_conv(self, in_planes, out_planes, kernel_size=3):
|
||||
if self.coordconv:
|
||||
conv = torchext.CoordConv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2)
|
||||
else:
|
||||
conv = torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2)
|
||||
return torch.nn.Sequential(
|
||||
conv,
|
||||
torch.nn.ReLU(inplace=True),
|
||||
torch.nn.Conv2d(out_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size-1)//2),
|
||||
torch.nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
def conv(self, in_planes, out_planes):
|
||||
return torch.nn.Sequential(
|
||||
torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1),
|
||||
torch.nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
def upconv(self, in_planes, out_planes):
|
||||
return torch.nn.Sequential(
|
||||
torch.nn.ConvTranspose2d(in_planes, out_planes, kernel_size=3, stride=2, padding=1, output_padding=1),
|
||||
torch.nn.ReLU(inplace=True)
|
||||
)
|
||||
|
||||
def crop_like(self, input, ref):
|
||||
assert(input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3))
|
||||
return input[:, :, :ref.size(2), :ref.size(3)]
|
||||
|
||||
def tforward(self, x):
|
||||
out_conv1 = self.conv1(x)
|
||||
out_conv2 = self.conv2(out_conv1)
|
||||
out_conv3 = self.conv3(out_conv2)
|
||||
out_conv4 = self.conv4(out_conv3)
|
||||
out_conv5 = self.conv5(out_conv4)
|
||||
out_conv6 = self.conv6(out_conv5)
|
||||
out_conv7 = self.conv7(out_conv6)
|
||||
|
||||
out_upconv7 = self.crop_like(self.upconv7(out_conv7), out_conv6)
|
||||
concat7 = torch.cat((out_upconv7, out_conv6), 1)
|
||||
out_iconv7 = self.iconv7(concat7)
|
||||
|
||||
out_upconv6 = self.crop_like(self.upconv6(out_iconv7), out_conv5)
|
||||
concat6 = torch.cat((out_upconv6, out_conv5), 1)
|
||||
out_iconv6 = self.iconv6(concat6)
|
||||
|
||||
out_upconv5 = self.crop_like(self.upconv5(out_iconv6), out_conv4)
|
||||
concat5 = torch.cat((out_upconv5, out_conv4), 1)
|
||||
out_iconv5 = self.iconv5(concat5)
|
||||
|
||||
out_upconv4 = self.crop_like(self.upconv4(out_iconv5), out_conv3)
|
||||
concat4 = torch.cat((out_upconv4, out_conv3), 1)
|
||||
out_iconv4 = self.iconv4(concat4)
|
||||
disp4 = self.predict_disp4(out_iconv4)
|
||||
|
||||
out_upconv3 = self.crop_like(self.upconv3(out_iconv4), out_conv2)
|
||||
disp4_up = self.crop_like(torch.nn.functional.interpolate(disp4, scale_factor=2, mode='bilinear', align_corners=False), out_conv2)
|
||||
concat3 = torch.cat((out_upconv3, out_conv2, disp4_up), 1)
|
||||
out_iconv3 = self.iconv3(concat3)
|
||||
disp3 = self.predict_disp3(out_iconv3)
|
||||
|
||||
out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1)
|
||||
disp3_up = self.crop_like(torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
|
||||
concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
|
||||
out_iconv2 = self.iconv2(concat2)
|
||||
disp2 = self.predict_disp2(out_iconv2)
|
||||
|
||||
out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x)
|
||||
disp2_up = self.crop_like(torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
|
||||
concat1 = torch.cat((out_upconv1, disp2_up), 1)
|
||||
out_iconv1 = self.iconv1(concat1)
|
||||
disp1 = self.predict_disp1(out_iconv1)
|
||||
|
||||
if self.output_ms:
|
||||
return disp1, disp2, disp3, disp4
|
||||
else:
|
||||
return disp1
|
||||
|
||||
|
||||
class DispNetShallow(DispNetS):
|
||||
'''
|
||||
Edge Decoder based on DispNetS with fewer layers
|
||||
'''
|
||||
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False):
|
||||
super(DispNetShallow, self).__init__(channels_in, imsizes, output_facs, output_ms, coordconv, weight_init)
|
||||
self.mod_name = 'DispNetShallow'
|
||||
conv_planes = [32, 64, 128, 256, 512, 512, 512]
|
||||
upconv_planes = [512, 512, 256, 128, 64, 32, 16]
|
||||
self.iconv3 = self.conv(upconv_planes[4] + conv_planes[1], upconv_planes[4])
|
||||
|
||||
def tforward(self, x):
|
||||
out_conv1 = self.conv1(x)
|
||||
out_conv2 = self.conv2(out_conv1)
|
||||
out_conv3 = self.conv3(out_conv2)
|
||||
|
||||
out_upconv3 = self.crop_like(self.upconv3(out_conv3), out_conv2)
|
||||
concat3 = torch.cat((out_upconv3, out_conv2), 1)
|
||||
out_iconv3 = self.iconv3(concat3)
|
||||
disp3 = self.predict_disp3(out_iconv3)
|
||||
|
||||
out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1)
|
||||
disp3_up = self.crop_like(torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
|
||||
concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
|
||||
out_iconv2 = self.iconv2(concat2)
|
||||
disp2 = self.predict_disp2(out_iconv2)
|
||||
|
||||
out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x)
|
||||
disp2_up = self.crop_like(torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
|
||||
concat1 = torch.cat((out_upconv1, disp2_up), 1)
|
||||
out_iconv1 = self.iconv1(concat1)
|
||||
disp1 = self.predict_disp1(out_iconv1)
|
||||
|
||||
if self.output_ms:
|
||||
return disp1, disp2, disp3
|
||||
else:
|
||||
return disp1
|
||||
|
||||
|
||||
class DispEdgeDecoders(TimedModule):
|
||||
'''
|
||||
Disparity Decoder and Edge Decoder
|
||||
'''
|
||||
def __init__(self, *args, max_disp=128, **kwargs):
|
||||
super(DispEdgeDecoders, self).__init__(mod_name='DispEdgeDecoders')
|
||||
|
||||
output_facs = [OutputLayerFactory( type='disp', params={ 'alpha': max_disp/(2**s), 'beta': 0, 'gamma': 1, 'offset': 3}) for s in range(4)]
|
||||
self.disp_decoder = DispNetS(*args, output_facs=output_facs, **kwargs)
|
||||
|
||||
output_facs = [OutputLayerFactory( type='linear' ) for s in range(4)]
|
||||
self.edge_decoder = DispNetShallow(*args, output_facs=output_facs, **kwargs)
|
||||
|
||||
def tforward(self, x):
|
||||
disp = self.disp_decoder(x)
|
||||
edge = self.edge_decoder(x)
|
||||
return disp, edge
|
||||
|
||||
|
||||
class DispToDepth(TimedModule):
|
||||
def __init__(self, focal_length, baseline):
|
||||
super().__init__(mod_name='DispToDepth')
|
||||
self.baseline_focal_length = baseline * focal_length
|
||||
|
||||
def tforward(self, disp):
|
||||
disp = torch.nn.functional.relu(disp) + 1e-12
|
||||
depth = self.baseline_focal_length / disp
|
||||
return depth
|
||||
|
||||
|
||||
class PosToDepth(DispToDepth):
|
||||
def __init__(self, focal_length, baseline, im_height, im_width):
|
||||
super().__init__(focal_length, baseline)
|
||||
self.mod_name = 'PosToDepth'
|
||||
|
||||
self.im_height = im_height
|
||||
self.im_width = im_width
|
||||
self.u_pos = torch.arange(im_width, dtype=torch.float32).view(1,1,1,-1)
|
||||
|
||||
def tforward(self, pos):
|
||||
self.u_pos = self.u_pos.to(pos.device)
|
||||
disp = self.u_pos - pos
|
||||
return super().forward(disp)
|
||||
|
||||
|
||||
|
||||
class RectifiedPatternSimilarityLoss(TimedModule):
|
||||
'''
|
||||
Photometric Loss
|
||||
'''
|
||||
def __init__(self, im_height, im_width, pattern, loss_type='census_sad', loss_eps=0.5):
|
||||
super().__init__(mod_name='RectifiedPatternSimilarityLoss')
|
||||
self.im_height = im_height
|
||||
self.im_width = im_width
|
||||
self.pattern = pattern.mean(dim=1, keepdim=True).contiguous()
|
||||
|
||||
u, v = np.meshgrid(range(im_width), range(im_height))
|
||||
uv0 = np.stack((u,v), axis=2).reshape(-1,1)
|
||||
uv0 = uv0.astype(np.float32).reshape(1,-1,2)
|
||||
self.uv0 = torch.from_numpy(uv0)
|
||||
|
||||
self.loss_type = loss_type
|
||||
self.loss_eps = loss_eps
|
||||
|
||||
def tforward(self, disp0, im, std=None):
|
||||
self.pattern = self.pattern.to(disp0.device)
|
||||
self.uv0 = self.uv0.to(disp0.device)
|
||||
|
||||
uv0 = self.uv0.expand(disp0.shape[0], *self.uv0.shape[1:])
|
||||
uv1 = torch.empty_like(uv0)
|
||||
uv1[...,0] = uv0[...,0] - disp0.contiguous().view(disp0.shape[0],-1)
|
||||
uv1[...,1] = uv0[...,1]
|
||||
|
||||
uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width-1) - 0.5)
|
||||
uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height-1) - 0.5)
|
||||
uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone()
|
||||
pattern = self.pattern.expand(disp0.shape[0], *self.pattern.shape[1:])
|
||||
pattern_proj = torch.nn.functional.grid_sample(pattern, uv1, padding_mode='border')
|
||||
mask = torch.ones_like(im)
|
||||
if std is not None:
|
||||
mask = mask*std
|
||||
|
||||
diff = torchext.photometric_loss(pattern_proj.contiguous(), im.contiguous(), 9, self.loss_type, self.loss_eps)
|
||||
val = (mask*diff).sum() / mask.sum()
|
||||
return val, pattern_proj
|
||||
|
||||
class DisparityLoss(TimedModule):
|
||||
'''
|
||||
Disparity Loss
|
||||
'''
|
||||
def __init__(self):
|
||||
super().__init__(mod_name='DisparityLoss')
|
||||
self.sobel = SobelFilter(norm=False)
|
||||
|
||||
#if not edge_gt:
|
||||
self.b0=0.0503428816795
|
||||
self.b1=1.07274045944
|
||||
#else:
|
||||
# self.b0=0.0587115108967
|
||||
# self.b1=1.51931190491
|
||||
|
||||
def tforward(self, disp, edge=None):
|
||||
self.sobel=self.sobel.to(disp.device)
|
||||
|
||||
if edge is not None:
|
||||
grad = self.sobel(disp)
|
||||
grad = torch.sqrt(grad[:,0:1,...]**2 + grad[:,1:2,...]**2 + 1e-8)
|
||||
pdf = (1-edge)/self.b0 * torch.exp(-torch.abs(grad)/self.b0) + \
|
||||
edge/self.b1 * torch.exp(-torch.abs(grad)/self.b1)
|
||||
val = torch.mean(-torch.log(pdf.clamp(min=1e-4)))
|
||||
else:
|
||||
# on qifeng's data we don't have ambient info
|
||||
# therefore we supress edge everywhere
|
||||
grad = self.sobel(disp)
|
||||
grad = torch.sqrt(grad[:,0:1,...]**2 + grad[:,1:2,...]**2 + 1e-8)
|
||||
grad= torch.clamp(grad, 0, 1.0)
|
||||
val = torch.mean(grad)
|
||||
|
||||
return val
|
||||
|
||||
|
||||
|
||||
class ProjectionBaseLoss(TimedModule):
|
||||
'''
|
||||
Base module of the Geometric Loss
|
||||
'''
|
||||
def __init__(self, K, Ki, im_height, im_width):
|
||||
super().__init__(mod_name='ProjectionBaseLoss')
|
||||
|
||||
self.K = K.view(-1,3,3)
|
||||
|
||||
self.im_height = im_height
|
||||
self.im_width = im_width
|
||||
|
||||
u, v = np.meshgrid(range(im_width), range(im_height))
|
||||
uv = np.stack((u,v,np.ones_like(u)), axis=2).reshape(-1,3)
|
||||
|
||||
ray = uv @ Ki.numpy().T
|
||||
|
||||
ray = ray.reshape(1,-1,3).astype(np.float32)
|
||||
self.ray = torch.from_numpy(ray)
|
||||
|
||||
def transform(self, xyz, R=None, t=None):
|
||||
if t is not None:
|
||||
bs = xyz.shape[0]
|
||||
xyz = xyz - t.reshape(bs,1,3)
|
||||
if R is not None:
|
||||
xyz = torch.bmm(xyz, R)
|
||||
return xyz
|
||||
|
||||
def unproject(self, depth, R=None, t=None):
|
||||
self.ray = self.ray.to(depth.device)
|
||||
bs = depth.shape[0]
|
||||
|
||||
xyz = depth.reshape(bs,-1,1) * self.ray
|
||||
xyz = self.transform(xyz, R, t)
|
||||
return xyz
|
||||
|
||||
def project(self, xyz, R, t):
|
||||
self.K = self.K.to(xyz.device)
|
||||
bs = xyz.shape[0]
|
||||
|
||||
xyz = torch.bmm(xyz, R.transpose(1,2))
|
||||
xyz = xyz + t.reshape(bs,1,3)
|
||||
|
||||
Kt = self.K.transpose(1,2).expand(bs,-1,-1)
|
||||
uv = torch.bmm(xyz, Kt)
|
||||
|
||||
d = uv[:,:,2:3]
|
||||
|
||||
# avoid division by zero
|
||||
uv = uv[:,:,:2] / (torch.nn.functional.relu(d) + 1e-12)
|
||||
return uv, d
|
||||
|
||||
|
||||
def tforward(self, depth0, R0, t0, R1, t1):
|
||||
xyz = self.unproject(depth0, R0, t0)
|
||||
return self.project(xyz, R1, t1)
|
||||
|
||||
|
||||
class ProjectionDepthSimilarityLoss(ProjectionBaseLoss):
|
||||
'''
|
||||
Geometric Loss
|
||||
'''
|
||||
def __init__(self, *args, clamp=-1):
|
||||
super().__init__(*args)
|
||||
self.mod_name = 'ProjectionDepthSimilarityLoss'
|
||||
self.clamp = clamp
|
||||
|
||||
def fwd(self, depth0, depth1, R0, t0, R1, t1):
|
||||
uv1, d1 = super().tforward(depth0, R0, t0, R1, t1)
|
||||
|
||||
uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width-1) - 0.5)
|
||||
uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height-1) - 0.5)
|
||||
uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone()
|
||||
|
||||
depth10 = torch.nn.functional.grid_sample(depth1, uv1, padding_mode='border')
|
||||
|
||||
diff = torch.abs(d1.view(-1) - depth10.view(-1))
|
||||
|
||||
if self.clamp > 0:
|
||||
diff = torch.clamp(diff, 0, self.clamp)
|
||||
|
||||
# return diff without clamping for debugging
|
||||
return diff.mean()
|
||||
|
||||
def tforward(self, depth0, depth1, R0, t0, R1, t1):
|
||||
l0 = self.fwd(depth0, depth1, R0, t0, R1, t1)
|
||||
l1 = self.fwd(depth1, depth0, R1, t1, R0, t0)
|
||||
return l0+l1
|
||||
|
||||
|
||||
|
||||
class LCN(TimedModule):
|
||||
'''
|
||||
Local Contract Normalization
|
||||
'''
|
||||
def __init__(self, radius, epsilon):
|
||||
super().__init__(mod_name='LCN')
|
||||
self.box_conv = torch.nn.Sequential(
|
||||
torch.nn.ReflectionPad2d(radius),
|
||||
torch.nn.Conv2d(1, 1, kernel_size=2*radius+1, bias=False)
|
||||
)
|
||||
self.box_conv[1].weight.requires_grad=False
|
||||
self.box_conv[1].weight.fill_(1.)
|
||||
|
||||
self.epsilon = epsilon
|
||||
self.radius = radius
|
||||
|
||||
def tforward(self, data):
|
||||
boxs = self.box_conv(data)
|
||||
|
||||
avgs = boxs / (2*self.radius+1)**2
|
||||
boxs_n2 = boxs**2
|
||||
boxs_2n = self.box_conv(data**2)
|
||||
|
||||
stds = torch.sqrt(boxs_2n / (2*self.radius+1)**2 - avgs**2 + 1e-6)
|
||||
stds = stds + self.epsilon
|
||||
|
||||
return (data - avgs) / stds, stds
|
||||
|
||||
|
||||
|
||||
class SobelFilter(TimedModule):
|
||||
'''
|
||||
Sobel Filter
|
||||
'''
|
||||
def __init__(self, norm=False):
|
||||
super(SobelFilter, self).__init__(mod_name='SobelFilter')
|
||||
kx = np.array([[-5, -4, 0, 4, 5],
|
||||
[-8, -10, 0, 10, 8],
|
||||
[-10, -20, 0, 20, 10],
|
||||
[-8, -10, 0, 10, 8],
|
||||
[-5, -4, 0, 4, 5]])/240.0
|
||||
ky = kx.copy().transpose(1,0)
|
||||
|
||||
self.conv_x=torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False)
|
||||
self.conv_x.weight=torch.nn.Parameter(torch.from_numpy(kx).float().unsqueeze(0).unsqueeze(0))
|
||||
|
||||
self.conv_y=torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False)
|
||||
self.conv_y.weight=torch.nn.Parameter(torch.from_numpy(ky).float().unsqueeze(0).unsqueeze(0))
|
||||
|
||||
self.norm=norm
|
||||
|
||||
def tforward(self,x):
|
||||
x = F.pad(x, (2,2,2,2), "replicate")
|
||||
gx = self.conv_x(x)
|
||||
gy = self.conv_y(x)
|
||||
if self.norm:
|
||||
return torch.sqrt(gx**2 + gy**2 + 1e-8)
|
||||
else:
|
||||
return torch.cat((gx, gy), dim=1)
|
||||
|
95
readme.md
Normal file
95
readme.md
Normal file
@ -0,0 +1,95 @@
|
||||
# Connecting the Dots: Learning Representations for Active Monocular Depth Estimation
|
||||
|
||||
![example](img/img.png)
|
||||
|
||||
This repository contains the code for the paper
|
||||
|
||||
**[Connecting the Dots: Learning Representations for Active Monocular Depth Estimation](TODO)**
|
||||
<br>
|
||||
[Gernot Riegler](https://griegler.github.io/), [Yiyi Liao](https://yiyiliao.github.io/), [Simon Donne](https://avg.is.tuebingen.mpg.de/person/sdonne), [Vladlen Koltun](http://vladlen.info/), and [Andreas Geiger](http://www.cvlibs.net/)
|
||||
<br>
|
||||
[CVPR 2019](http://cvpr2019.thecvf.com/)
|
||||
|
||||
> We propose a technique for depth estimation with a monocular structured-light camera, i.e., a calibrated stereo set-up with one camera and one laser projector. Instead of formulating the depth estimation via a correspondence search problem, we show that a simple convolutional architecture is sufficient for high-quality disparity estimates in this setting. As accurate ground-truth is hard to obtain, we train our model in a self-supervised fashion with a combination of photometric and geometric losses. Further, we demonstrate that the projected pattern of the structured light sensor can be reliably separated from the ambient information. This can then be used to improve depth boundaries in a weakly supervised fashion by modeling the joint statistics of image and depth edges. The model trained in this fashion compares favorably to the state-of-the-art on challenging synthetic and real-world datasets. In addition, we contribute a novel simulator, which allows to benchmark active depth prediction algorithms in controlled conditions.
|
||||
|
||||
|
||||
If you find this code useful for your research, please cite
|
||||
|
||||
```
|
||||
@inproceedings{Riegler2019Connecting,
|
||||
title={Connecting the Dots: Learning Representations for Active Monocular Depth Estimation},
|
||||
author={Riegler, Gernot and Liao, Yiyi and Donne, Simon and Koltun, Vladlen and Geiger, Andreas},
|
||||
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
|
||||
year={2019}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## Dependencies
|
||||
|
||||
The network training/evaluation code is based on `Pytorch`.
|
||||
```
|
||||
PyTorch>=1.1
|
||||
Cuda>=10.0
|
||||
```
|
||||
|
||||
The other python packages can be installed with `anaconda`:
|
||||
```
|
||||
conda install --file requirements.txt
|
||||
```
|
||||
|
||||
### Structured Light Renderer
|
||||
To train and evaluate our method in a controlled setting, we implemented an structured light renderer.
|
||||
It can be used to render a virtual scene (arbitrary triangle mesh) with the structured light pattern projected from a customizable projector location.
|
||||
To build it, first make sure the correct `CUDA_LIBRARY_PATH` is set in `config.json`.
|
||||
Afterwards, the renderer can be build by running `make` within the `renderer` directory.
|
||||
|
||||
### PyTorch Extensions
|
||||
The network training/evaluation code is based on `PyTorch`.
|
||||
We implemented some custom layers that need to be built in the `torchext` directory.
|
||||
Simply change into this directory and run
|
||||
|
||||
```
|
||||
python setup.py build_ext --inplace
|
||||
```
|
||||
|
||||
### Baseline HyperDepth
|
||||
As baseline we partially re-implemented the random forest based method [HyperDepth](http://openaccess.thecvf.com/content_cvpr_2016/papers/Fanello_HyperDepth_Learning_Depth_CVPR_2016_paper.pdf).
|
||||
The code resided in the `hyperdepth` directory and is implemented in `C++11` with a Python wrapper written in `Cython`.
|
||||
To build it change into the directory and run
|
||||
|
||||
```
|
||||
python setup.py build_ext --inplace
|
||||
```
|
||||
|
||||
## Running
|
||||
|
||||
|
||||
### Creating synthetic data
|
||||
To create synthetic data and save it locally, download [ShapeNet V2](https://www.shapenet.org/) and correct `SHAPENET_ROOT` in `config.json`. Then the data can be generated and saved to `DATA_ROOT` in `config.json` by running
|
||||
```
|
||||
./create_syn_data.sh
|
||||
```
|
||||
|
||||
### Training Network
|
||||
|
||||
As a first stage, it is recommended to train the disparity decoder and edge decoder without the geometric loss. To train the network on synthetic data for the first stage run
|
||||
```
|
||||
python train_val.py
|
||||
```
|
||||
|
||||
After the model is pretrained without the geometric loss, the full model can be trained from the initialized weights by running
|
||||
```
|
||||
python train_val.py --loss phge
|
||||
```
|
||||
|
||||
|
||||
### Evaluating Network
|
||||
To evaluate a specific checkpoint, e.g. the 50th epoch, one can run
|
||||
```
|
||||
python train_val.py --cmd retest --epoch 50
|
||||
```
|
||||
|
||||
|
||||
## Acknowledgement
|
||||
This work was supported by the Intel Network on Intelligent Systems.
|
32
renderer/Makefile
Normal file
32
renderer/Makefile
Normal file
@ -0,0 +1,32 @@
|
||||
INCLUDE_DIR =
|
||||
C = gcc -c
|
||||
C_FLAGS = -O3 -msse -msse2 -msse3 -msse4.2 -fPIC -Wall
|
||||
CXX = g++ -c
|
||||
CXX_FLAGS = -O3 -std=c++11 -msse -msse2 -msse3 -msse4.2 -fPIC -Wall
|
||||
CUDA = nvcc -c
|
||||
CUDA_FLAGS = -x cu -Xcompiler -fPIC -arch=sm_30 -std=c++11 --expt-extended-lambda
|
||||
|
||||
|
||||
PYRENDER_DEPENDENCIES = setup.py \
|
||||
render/render_cpu.cpp.o \
|
||||
render/stdlib_cuda_dummy.cpp.o \
|
||||
render/render_gpu_dummy.cpp.o
|
||||
|
||||
PYRENDER_DEPENDENCIES += render/render_gpu.cu.o \
|
||||
render/stdlib_cuda.cu.o
|
||||
|
||||
all: pyrender
|
||||
|
||||
clean:
|
||||
rm render/*.o
|
||||
|
||||
pyrender: $(PYRENDER_DEPENDENCIES)
|
||||
cd pyrender; \
|
||||
python setup.py build_ext --inplace
|
||||
|
||||
%.c.o: %.c
|
||||
$(C) $(C_FLAGS) -o $@ $< $(INCLUDE_DIR)
|
||||
%.cpp.o: %.cpp
|
||||
$(CXX) $(CXX_FLAGS) -o $@ $< $(INCLUDE_DIR)
|
||||
%.cu.o: %.cu
|
||||
$(CUDA) -o $@ $< $(CUDA_FLAGS) $(INCLUDE_DIR)
|
4
renderer/__init__.py
Normal file
4
renderer/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
import ctypes
|
||||
import os
|
||||
|
||||
from .cyrender import *
|
200
renderer/cyrender.pyx
Normal file
200
renderer/cyrender.pyx
Normal file
@ -0,0 +1,200 @@
|
||||
cimport cython
|
||||
import numpy as np
|
||||
cimport numpy as np
|
||||
|
||||
from libc.stdlib cimport free, malloc
|
||||
from libcpp cimport bool
|
||||
from cpython cimport PyObject, Py_INCREF
|
||||
|
||||
CREATE_INIT = True # workaround, so cython builds a init function
|
||||
|
||||
np.import_array()
|
||||
|
||||
|
||||
ctypedef unsigned char uint8_t
|
||||
|
||||
cdef extern from "render/render.h":
|
||||
cdef cppclass Camera[T]:
|
||||
const T fx;
|
||||
const T fy;
|
||||
const T px;
|
||||
const T py;
|
||||
const T R0, R1, R2, R3, R4, R5, R6, R7, R8;
|
||||
const T t0, t1, t2;
|
||||
const T C0, C1, C2;
|
||||
const int height;
|
||||
const int width;
|
||||
Camera(const T fx, const T fy, const T px, const T py, const T* R, const T* t, int width, int height)
|
||||
|
||||
cdef cppclass RenderInput[T]:
|
||||
T* verts;
|
||||
T* radii;
|
||||
T* colors;
|
||||
T* normals;
|
||||
int n_verts;
|
||||
int* faces;
|
||||
int n_faces;
|
||||
|
||||
T* tex_coords;
|
||||
T* tex;
|
||||
int tex_height;
|
||||
int tex_width;
|
||||
int tex_channels;
|
||||
|
||||
RenderInput();
|
||||
|
||||
cdef cppclass Buffer[T]:
|
||||
T* depth;
|
||||
T* color;
|
||||
T* normal;
|
||||
Buffer();
|
||||
|
||||
cdef cppclass Shader[T]:
|
||||
const T ka;
|
||||
const T kd;
|
||||
const T ks;
|
||||
const T alpha;
|
||||
Shader(T ka, T kd, T ks, T alpha)
|
||||
|
||||
cdef cppclass BaseRenderer[T]:
|
||||
const Camera[T] cam;
|
||||
const Shader[T] shader;
|
||||
Buffer[T] buffer;
|
||||
BaseRenderer(const Camera[T] cam, const Shader[T] shader, Buffer[T] buffer)
|
||||
void render_mesh(const RenderInput[T] input);
|
||||
void render_mesh_proj(const RenderInput[T] input, const Camera[T] proj, const float* pattern, float d_alpha, float d_beta);
|
||||
|
||||
|
||||
cdef extern from "render/render_cpu.h":
|
||||
cdef cppclass RendererCpu[T](BaseRenderer[T]):
|
||||
RendererCpu(const Camera[T] cam, const Shader[T] shader, Buffer[T] buffer, int n_threads)
|
||||
void render_mesh(const RenderInput[T] input);
|
||||
void render_mesh_proj(const RenderInput[T] input, const Camera[T] proj, const float* pattern, float d_alpha, float d_beta);
|
||||
|
||||
cdef extern from "render/render_gpu.h":
|
||||
cdef cppclass RendererGpu[T](BaseRenderer[T]):
|
||||
RendererGpu(const Camera[T] cam, const Shader[T] shader, Buffer[T] buffer)
|
||||
void render_mesh(const RenderInput[T] input);
|
||||
void render_mesh_proj(const RenderInput[T] input, const Camera[T] proj, const float* pattern, float d_alpha, float d_beta);
|
||||
|
||||
|
||||
cdef class PyCamera:
|
||||
cdef Camera[float]* cam;
|
||||
|
||||
def __cinit__(self, float fx, float fy, float px, float py, float[:,::1] R, float[::1] t, int width, int height):
|
||||
if R.shape[0] != 3 or R.shape[1] != 3:
|
||||
raise Exception('invalid R matrix')
|
||||
if t.shape[0] != 3:
|
||||
raise Exception('invalid t vector')
|
||||
|
||||
self.cam = new Camera[float](fx,fy, px,py, &R[0,0], &t[0], width, height)
|
||||
|
||||
def __dealloc__(self):
|
||||
del self.cam
|
||||
|
||||
|
||||
cdef class PyRenderInput:
|
||||
cdef RenderInput[float] input;
|
||||
cdef verts
|
||||
cdef colors
|
||||
cdef normals
|
||||
cdef faces
|
||||
|
||||
def __cinit__(self, float[:,::1] verts=None, float[:,::1] colors=None, float[:,::1] normals=None, int[:,::1] faces=None):
|
||||
self.input = RenderInput[float]()
|
||||
if verts is not None:
|
||||
self.set_verts(verts)
|
||||
if normals is not None:
|
||||
self.set_normals(normals)
|
||||
if colors is not None:
|
||||
self.set_colors(colors)
|
||||
if faces is not None:
|
||||
self.set_faces(faces)
|
||||
|
||||
def set_verts(self, float[:,::1] verts):
|
||||
if verts.shape[1] != 3:
|
||||
raise Exception('verts has to be a Nx3 matrix')
|
||||
self.verts = verts
|
||||
cdef float[:,::1] verts_view = self.verts
|
||||
self.input.verts = &verts_view[0,0]
|
||||
self.input.n_verts = self.verts.shape[0]
|
||||
|
||||
def set_colors(self, float[:,::1] colors):
|
||||
if colors.shape[1] != 3:
|
||||
raise Exception('colors has to be a Nx3 matrix')
|
||||
self.colors = colors
|
||||
cdef float[:,::1] colors_view = self.colors
|
||||
self.input.colors = &colors_view[0,0]
|
||||
|
||||
def set_normals(self, float[:,::1] normals):
|
||||
if normals.shape[1] != 3:
|
||||
raise Exception('normals has to be a Nx3 matrix')
|
||||
self.normals = normals
|
||||
cdef float[:,::1] normals_view = self.normals
|
||||
self.input.normals = &normals_view[0,0]
|
||||
|
||||
def set_faces(self, int[:,::1] faces):
|
||||
if faces.shape[1] != 3:
|
||||
raise Exception('faces has to be a Nx3 matrix')
|
||||
self.faces = faces
|
||||
cdef int[:,::1] faces_view = self.faces
|
||||
self.input.faces = &faces_view[0,0]
|
||||
self.input.n_faces = self.faces.shape[0]
|
||||
|
||||
cdef class PyShader:
|
||||
cdef Shader[float]* shader
|
||||
|
||||
def __cinit__(self, float ka, float kd, float ks, float alpha):
|
||||
self.shader = new Shader[float](ka, kd, ks, alpha)
|
||||
|
||||
def __dealloc__(self):
|
||||
del self.shader
|
||||
|
||||
|
||||
cdef class PyRenderer:
|
||||
cdef BaseRenderer[float]* renderer
|
||||
|
||||
cdef Buffer[float] buffer
|
||||
cdef depth_buffer
|
||||
cdef color_buffer
|
||||
cdef normal_buffer
|
||||
|
||||
def depth(self):
|
||||
return self.depth_buffer
|
||||
|
||||
def color(self):
|
||||
return self.color_buffer
|
||||
|
||||
def normal(self):
|
||||
return self.normal_buffer
|
||||
|
||||
def __cinit__(self, PyCamera cam, PyShader shader, engine='cpu', int n_threads=1):
|
||||
self.depth_buffer = np.empty((cam.cam[0].height, cam.cam[0].width), dtype=np.float32)
|
||||
self.color_buffer = np.empty((cam.cam[0].height, cam.cam[0].width, 3), dtype=np.float32)
|
||||
self.normal_buffer = np.empty((cam.cam[0].height, cam.cam[0].width, 3), dtype=np.float32)
|
||||
|
||||
cdef float[:,::1] dbv = self.depth_buffer
|
||||
cdef float[:,:,::1] cbv = self.color_buffer
|
||||
cdef float[:,:,::1] nbv = self.normal_buffer
|
||||
self.buffer.depth = &dbv[0,0]
|
||||
self.buffer.color = &cbv[0,0,0]
|
||||
self.buffer.normal = &nbv[0,0,0]
|
||||
|
||||
if engine == 'cpu':
|
||||
self.renderer = new RendererCpu[float](cam.cam[0], shader.shader[0], self.buffer, n_threads)
|
||||
elif engine == 'gpu':
|
||||
self.renderer = new RendererGpu[float](cam.cam[0], shader.shader[0], self.buffer)
|
||||
else:
|
||||
raise Exception('invalid engine')
|
||||
|
||||
def __dealloc__(self):
|
||||
del self.renderer
|
||||
|
||||
def mesh(self, PyRenderInput input):
|
||||
self.renderer.render_mesh(input.input)
|
||||
|
||||
def mesh_proj(self, PyRenderInput input, PyCamera proj, float[:,:,::1] pattern, float d_alpha=1, float d_beta=0):
|
||||
if pattern.shape[0] != proj.cam[0].height or pattern.shape[1] != proj.cam[0].width or pattern.shape[2] != 3:
|
||||
raise Exception(f'pattern has to be a {proj.cam[0].height}x{proj.cam[0].width}x3 tensor')
|
||||
self.renderer.render_mesh_proj(input.input, proj.cam[0], &pattern[0,0,0], d_alpha, d_beta)
|
||||
|
10
renderer/render/co_types.h
Normal file
10
renderer/render/co_types.h
Normal file
@ -0,0 +1,10 @@
|
||||
#ifndef TYPES_H
|
||||
#define TYPES_H
|
||||
|
||||
#ifdef __CUDA_ARCH__
|
||||
#define CPU_GPU_FUNCTION __host__ __device__
|
||||
#else
|
||||
#define CPU_GPU_FUNCTION
|
||||
#endif
|
||||
|
||||
#endif
|
135
renderer/render/common.h
Normal file
135
renderer/render/common.h
Normal file
@ -0,0 +1,135 @@
|
||||
#ifndef COMMON_H
|
||||
#define COMMON_H
|
||||
|
||||
#include "co_types.h"
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
|
||||
#define DISABLE_COPY_AND_ASSIGN(classname) \
|
||||
private:\
|
||||
classname(const classname&) = delete;\
|
||||
classname& operator=(const classname&) = delete;
|
||||
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
void fill(T* arr, int N, T val) {
|
||||
for(int idx = 0; idx < N; ++idx) {
|
||||
arr[idx] = val;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
void fill_zero(T* arr, int N) {
|
||||
for(int idx = 0; idx < N; ++idx) {
|
||||
arr[idx] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
inline T distance_euclidean(const T* q, const T* t, int N) {
|
||||
T out = 0;
|
||||
for(int idx = 0; idx < N; idx++) {
|
||||
T diff = q[idx] - t[idx];
|
||||
out += diff * diff;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
inline T distance_l2(const T* q, const T* t, int N) {
|
||||
T out = distance_euclidean(q, t, N);
|
||||
out = std::sqrt(out);
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
struct FillFunctor {
|
||||
T* arr;
|
||||
const T val;
|
||||
|
||||
FillFunctor(T* arr, const T val) : arr(arr), val(val) {}
|
||||
CPU_GPU_FUNCTION void operator()(const int idx) {
|
||||
arr[idx] = val;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
T mmin(const T& a, const T& b) {
|
||||
#ifdef __CUDA_ARCH__
|
||||
return min(a, b);
|
||||
#else
|
||||
return std::min(a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
T mmax(const T& a, const T& b) {
|
||||
#ifdef __CUDA_ARCH__
|
||||
return max(a, b);
|
||||
#else
|
||||
return std::max(a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
T mround(const T& a) {
|
||||
#ifdef __CUDA_ARCH__
|
||||
return round(a);
|
||||
#else
|
||||
return round(a);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
#ifdef __CUDA_ARCH__
|
||||
#if __CUDA_ARCH__ < 600
|
||||
__device__ double atomicAdd(double* address, double val)
|
||||
{
|
||||
unsigned long long int* address_as_ull =
|
||||
(unsigned long long int*)address;
|
||||
unsigned long long int old = *address_as_ull, assumed;
|
||||
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_ull, assumed,
|
||||
__double_as_longlong(val +
|
||||
__longlong_as_double(assumed)));
|
||||
|
||||
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
|
||||
} while (assumed != old);
|
||||
|
||||
return __longlong_as_double(old);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
void matomic_add(T* addr, T val) {
|
||||
#ifdef __CUDA_ARCH__
|
||||
atomicAdd(addr, val);
|
||||
#else
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp atomic
|
||||
#endif
|
||||
*addr += val;
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif
|
26
renderer/render/common_cpu.h
Normal file
26
renderer/render/common_cpu.h
Normal file
@ -0,0 +1,26 @@
|
||||
#ifndef COMMON_CPU
|
||||
#define COMMON_CPU
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
template <typename FunctorT>
|
||||
void iterate_cpu(FunctorT functor, int N) {
|
||||
for(int idx = 0; idx < N; ++idx) {
|
||||
functor(idx);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename FunctorT>
|
||||
void iterate_omp_cpu(FunctorT functor, int N, int n_threads) {
|
||||
#if defined(_OPENMP)
|
||||
omp_set_num_threads(n_threads);
|
||||
#pragma omp parallel for
|
||||
#endif
|
||||
for(int idx = 0; idx < N; ++idx) {
|
||||
functor(idx);
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
173
renderer/render/common_cuda.h
Normal file
173
renderer/render/common_cuda.h
Normal file
@ -0,0 +1,173 @@
|
||||
#ifndef COMMON_CUDA
|
||||
#define COMMON_CUDA
|
||||
|
||||
#include <cublas_v2.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#define DEBUG 0
|
||||
#define CUDA_DEBUG_DEVICE_SYNC 0
|
||||
|
||||
// cuda check for cudaMalloc and so on
|
||||
#define CUDA_CHECK(condition) \
|
||||
/* Code block avoids redefinition of cudaError_t error */ \
|
||||
do { \
|
||||
if(CUDA_DEBUG_DEVICE_SYNC) { cudaDeviceSynchronize(); } \
|
||||
cudaError_t error = condition; \
|
||||
if(error != cudaSuccess) { \
|
||||
printf("%s in %s at %d\n", cudaGetErrorString(error), __FILE__, __LINE__); \
|
||||
exit(-1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
/// Get error string for error code.
|
||||
/// @param error
|
||||
inline const char* cublasGetErrorString(cublasStatus_t error) {
|
||||
switch (error) {
|
||||
case CUBLAS_STATUS_SUCCESS:
|
||||
return "CUBLAS_STATUS_SUCCESS";
|
||||
case CUBLAS_STATUS_NOT_INITIALIZED:
|
||||
return "CUBLAS_STATUS_NOT_INITIALIZED";
|
||||
case CUBLAS_STATUS_ALLOC_FAILED:
|
||||
return "CUBLAS_STATUS_ALLOC_FAILED";
|
||||
case CUBLAS_STATUS_INVALID_VALUE:
|
||||
return "CUBLAS_STATUS_INVALID_VALUE";
|
||||
case CUBLAS_STATUS_ARCH_MISMATCH:
|
||||
return "CUBLAS_STATUS_ARCH_MISMATCH";
|
||||
case CUBLAS_STATUS_MAPPING_ERROR:
|
||||
return "CUBLAS_STATUS_MAPPING_ERROR";
|
||||
case CUBLAS_STATUS_EXECUTION_FAILED:
|
||||
return "CUBLAS_STATUS_EXECUTION_FAILED";
|
||||
case CUBLAS_STATUS_INTERNAL_ERROR:
|
||||
return "CUBLAS_STATUS_INTERNAL_ERROR";
|
||||
case CUBLAS_STATUS_NOT_SUPPORTED:
|
||||
return "CUBLAS_STATUS_NOT_SUPPORTED";
|
||||
case CUBLAS_STATUS_LICENSE_ERROR:
|
||||
return "CUBLAS_STATUS_LICENSE_ERROR";
|
||||
}
|
||||
return "Unknown cublas status";
|
||||
}
|
||||
|
||||
#define CUBLAS_CHECK(condition) \
|
||||
do { \
|
||||
if(CUDA_DEBUG_DEVICE_SYNC) { cudaDeviceSynchronize(); } \
|
||||
cublasStatus_t status = condition; \
|
||||
if(status != CUBLAS_STATUS_SUCCESS) { \
|
||||
printf("%s in %s at %d\n", cublasGetErrorString(status), __FILE__, __LINE__); \
|
||||
exit(-1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// check if there is a error after kernel execution
|
||||
#define CUDA_POST_KERNEL_CHECK \
|
||||
CUDA_CHECK(cudaPeekAtLastError()); \
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
#define CUDA_KERNEL_LOOP(i, n) \
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
|
||||
|
||||
const int CUDA_NUM_THREADS = 1024;
|
||||
|
||||
inline int GET_BLOCKS(const int N, const int N_THREADS=CUDA_NUM_THREADS) {
|
||||
return (N + N_THREADS - 1) / N_THREADS;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T* device_malloc(long N) {
|
||||
T* dptr;
|
||||
CUDA_CHECK(cudaMalloc(&dptr, N * sizeof(T)));
|
||||
if(DEBUG) { printf("[DEBUG] device_malloc %p, %ld\n", dptr, N); }
|
||||
return dptr;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void device_free(T* dptr) {
|
||||
if(DEBUG) { printf("[DEBUG] device_free %p\n", dptr); }
|
||||
CUDA_CHECK(cudaFree(dptr));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void host_to_device(const T* hptr, T* dptr, long N) {
|
||||
if(DEBUG) { printf("[DEBUG] host_to_device %p => %p, %ld\n", hptr, dptr, N); }
|
||||
CUDA_CHECK(cudaMemcpy(dptr, hptr, N * sizeof(T), cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T* host_to_device_malloc(const T* hptr, long N) {
|
||||
T* dptr = device_malloc<T>(N);
|
||||
host_to_device(hptr, dptr, N);
|
||||
return dptr;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void device_to_host(const T* dptr, T* hptr, long N) {
|
||||
if(DEBUG) { printf("[DEBUG] device_to_host %p => %p, %ld\n", dptr, hptr, N); }
|
||||
CUDA_CHECK(cudaMemcpy(hptr, dptr, N * sizeof(T), cudaMemcpyDeviceToHost));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T* device_to_host_malloc(const T* dptr, long N) {
|
||||
T* hptr = new T[N];
|
||||
device_to_host(dptr, hptr, N);
|
||||
return hptr;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void device_to_device(const T* dptr, T* hptr, long N) {
|
||||
if(DEBUG) { printf("[DEBUG] device_to_device %p => %p, %ld\n", dptr, hptr, N); }
|
||||
CUDA_CHECK(cudaMemcpy(hptr, dptr, N * sizeof(T), cudaMemcpyDeviceToDevice));
|
||||
}
|
||||
|
||||
// https://github.com/parallel-forall/code-samples/blob/master/posts/cuda-aware-mpi-example/src/Device.cu
|
||||
// https://github.com/treecode/Bonsai/blob/master/runtime/profiling/derived_atomic_functions.h
|
||||
__device__ __forceinline__ void atomicMaxF(float * const address, const float value) {
|
||||
if (*address >= value) {
|
||||
return;
|
||||
}
|
||||
|
||||
int * const address_as_i = (int *)address;
|
||||
int old = * address_as_i, assumed;
|
||||
|
||||
do {
|
||||
assumed = old;
|
||||
if (__int_as_float(assumed) >= value) {
|
||||
break;
|
||||
}
|
||||
|
||||
old = atomicCAS(address_as_i, assumed, __float_as_int(value));
|
||||
} while (assumed != old);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void atomicMinF(float * const address, const float value) {
|
||||
if (*address <= value) {
|
||||
return;
|
||||
}
|
||||
|
||||
int * const address_as_i = (int *)address;
|
||||
int old = * address_as_i, assumed;
|
||||
|
||||
do {
|
||||
assumed = old;
|
||||
if (__int_as_float(assumed) <= value) {
|
||||
break;
|
||||
}
|
||||
|
||||
old = atomicCAS(address_as_i, assumed, __float_as_int(value));
|
||||
} while (assumed != old);
|
||||
}
|
||||
|
||||
|
||||
template <typename FunctorT>
|
||||
__global__ void iterate_kernel(FunctorT functor, int N) {
|
||||
CUDA_KERNEL_LOOP(idx, N) {
|
||||
functor(idx);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename FunctorT>
|
||||
void iterate_cuda(FunctorT functor, int N, int N_THREADS=CUDA_NUM_THREADS) {
|
||||
iterate_kernel<<<GET_BLOCKS(N, N_THREADS), N_THREADS>>>(functor, N);
|
||||
CUDA_POST_KERNEL_CHECK;
|
||||
}
|
||||
|
||||
|
||||
#endif
|
294
renderer/render/geometry.h
Normal file
294
renderer/render/geometry.h
Normal file
@ -0,0 +1,294 @@
|
||||
#ifndef GEOMETRY_H
|
||||
#define GEOMETRY_H
|
||||
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <cmath>
|
||||
|
||||
#include "co_types.h"
|
||||
|
||||
template <typename T, int N=3>
|
||||
CPU_GPU_FUNCTION
|
||||
inline void vec_fill(T* v, const T fill) {
|
||||
for(int idx = 0; idx < N; ++idx) {
|
||||
v[idx] = fill;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
CPU_GPU_FUNCTION
|
||||
inline void vec_fill<float, 3>(float* v, const float fill) {
|
||||
v[0] = fill;
|
||||
v[1] = fill;
|
||||
v[2] = fill;
|
||||
}
|
||||
|
||||
template <typename T, int N=3>
|
||||
CPU_GPU_FUNCTION
|
||||
inline void vec_add(const T* in1, const T* in2, T* out) {
|
||||
for(int idx = 0; idx < N; ++idx) {
|
||||
out[idx] = in1[idx] + in2[idx];
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
CPU_GPU_FUNCTION
|
||||
inline void vec_add<float, 3>(const float* in1, const float* in2, float* out) {
|
||||
out[0] = in1[0] + in2[0];
|
||||
out[1] = in1[1] + in2[1];
|
||||
out[2] = in1[2] + in2[2];
|
||||
}
|
||||
|
||||
template <typename T, int N=3>
|
||||
CPU_GPU_FUNCTION
|
||||
inline void vec_add(const T lam1, const T* in1, const T lam2, const T* in2, T* out) {
|
||||
for(int idx = 0; idx < N; ++idx) {
|
||||
out[idx] = lam1 * in1[idx] + lam2 * in2[idx];
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
CPU_GPU_FUNCTION
|
||||
inline void vec_add<float, 3>(const float lam1, const float* in1, const float lam2, const float* in2, float* out) {
|
||||
out[0] = lam1 * in1[0] + lam2 * in2[0];
|
||||
out[1] = lam1 * in1[1] + lam2 * in2[1];
|
||||
out[2] = lam1 * in1[2] + lam2 * in2[2];
|
||||
}
|
||||
|
||||
template <typename T, int N=3>
|
||||
CPU_GPU_FUNCTION
|
||||
inline void vec_sub(const T* in1, const T* in2, T* out) {
|
||||
for(int idx = 0; idx < N; ++idx) {
|
||||
out[idx] = in1[idx] - in2[idx];
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
CPU_GPU_FUNCTION
|
||||
inline void vec_sub<float, 3>(const float* in1, const float* in2, float* out) {
|
||||
out[0] = in1[0] - in2[0];
|
||||
out[1] = in1[1] - in2[1];
|
||||
out[2] = in1[2] - in2[2];
|
||||
}
|
||||
|
||||
template <typename T, int N=3>
|
||||
CPU_GPU_FUNCTION
|
||||
inline void vec_add_scalar(const T* in, const T lam, T* out) {
|
||||
for(int idx = 0; idx < N; ++idx) {
|
||||
out[idx] = in[idx] + lam;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
CPU_GPU_FUNCTION
|
||||
inline void vec_add_scalar<float, 3>(const float* in, const float lam, float* out) {
|
||||
out[0] = in[0] + lam;
|
||||
out[1] = in[1] + lam;
|
||||
out[2] = in[2] + lam;
|
||||
}
|
||||
|
||||
template <typename T, int N=3>
|
||||
CPU_GPU_FUNCTION
|
||||
inline void vec_mul_scalar(const T* in, const T lam, T* out) {
|
||||
for(int idx = 0; idx < N; ++idx) {
|
||||
out[idx] = in[idx] * lam;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
CPU_GPU_FUNCTION
|
||||
inline void vec_mul_scalar<float, 3>(const float* in, const float lam, float* out) {
|
||||
out[0] = in[0] * lam;
|
||||
out[1] = in[1] * lam;
|
||||
out[2] = in[2] * lam;
|
||||
}
|
||||
|
||||
template <typename T, int N=3>
|
||||
CPU_GPU_FUNCTION
|
||||
inline void vec_div_scalar(const T* in, const T lam, T* out) {
|
||||
for(int idx = 0; idx < N; ++idx) {
|
||||
out[idx] = in[idx] / lam;
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
CPU_GPU_FUNCTION
|
||||
inline void vec_div_scalar<float, 3>(const float* in, const float lam, float* out) {
|
||||
out[0] = in[0] / lam;
|
||||
out[1] = in[1] / lam;
|
||||
out[2] = in[2] / lam;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
inline void mat_dot_vec3(const T* M, const T* v, T* w) {
|
||||
w[0] = M[0] * v[0] + M[1] * v[1] + M[2] * v[2];
|
||||
w[1] = M[3] * v[0] + M[4] * v[1] + M[5] * v[2];
|
||||
w[2] = M[6] * v[0] + M[7] * v[1] + M[8] * v[2];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
inline void matT_dot_vec3(const T* M, const T* v, T* w) {
|
||||
w[0] = M[0] * v[0] + M[3] * v[1] + M[6] * v[2];
|
||||
w[1] = M[1] * v[0] + M[4] * v[1] + M[7] * v[2];
|
||||
w[2] = M[2] * v[0] + M[5] * v[1] + M[8] * v[2];
|
||||
}
|
||||
|
||||
template <typename T, int N=3>
|
||||
CPU_GPU_FUNCTION
|
||||
inline T vec_dot(const T* in1, const T* in2) {
|
||||
T out = T(0);
|
||||
for(int idx = 0; idx < N; ++idx) {
|
||||
out += in1[idx] * in2[idx];
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
template <>
|
||||
CPU_GPU_FUNCTION
|
||||
inline float vec_dot<float, 3>(const float* in1, const float* in2) {
|
||||
return in1[0] * in2[0] + in1[1] * in2[1] + in1[2] * in2[2];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
inline void vec_cross3(const T* u, const T* v, T* out) {
|
||||
out[0] = u[1] * v[2] - u[2] * v[1];
|
||||
out[1] = u[2] * v[0] - u[0] * v[2];
|
||||
out[2] = u[0] * v[1] - u[1] * v[0];
|
||||
}
|
||||
|
||||
template <typename T, int N=3>
|
||||
CPU_GPU_FUNCTION
|
||||
inline T vec_norm(const T* u) {
|
||||
T norm = T(0);
|
||||
for(int idx = 0; idx < N; ++idx) {
|
||||
norm += u[idx] * u[idx];
|
||||
}
|
||||
return std::sqrt(norm);
|
||||
}
|
||||
|
||||
template <>
|
||||
CPU_GPU_FUNCTION
|
||||
inline float vec_norm<float, 3>(const float* u) {
|
||||
return std::sqrt(u[0] * u[0] + u[1] * u[1] + u[2] * u[2]);
|
||||
}
|
||||
|
||||
template <typename T, int N=3>
|
||||
CPU_GPU_FUNCTION
|
||||
inline void vec_normalize(const T* u, T* v) {
|
||||
T denom = vec_norm(u);
|
||||
vec_div_scalar(u, denom, v);
|
||||
}
|
||||
|
||||
template <>
|
||||
CPU_GPU_FUNCTION
|
||||
inline void vec_normalize<float, 3>(const float* u, float* v) {
|
||||
vec_div_scalar(u, vec_norm(u), v);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
void vertex_normal_3d(const T* a, const T* b, const T* c, T* no) {
|
||||
T e1[3];
|
||||
T e2[3];
|
||||
vec_sub(a, b, e1);
|
||||
vec_sub(c, b, e2);
|
||||
vec_cross3(e1, e2, no);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
bool ray_triangle_intersect_3d(const T* orig, const T* dir, const T* v0, const T* v1, const T* v2, T* t, T* u, T* v, T eps = 1e-6) {
|
||||
T v0v1[3];
|
||||
vec_sub(v1, v0, v0v1);
|
||||
T v0v2[3];
|
||||
vec_sub(v2, v0, v0v2);
|
||||
T pvec[3];
|
||||
vec_cross3(dir, v0v2, pvec);
|
||||
T det = vec_dot(v0v1, pvec);
|
||||
|
||||
if(fabs(det) < eps) return false;
|
||||
|
||||
T inv_det = 1 / det;
|
||||
|
||||
T tvec[3];
|
||||
vec_sub(orig, v0, tvec);
|
||||
*u = vec_dot(tvec, pvec) * inv_det;
|
||||
if(*u < 0 || *u > 1) return false;
|
||||
|
||||
T qvec[3];
|
||||
vec_cross3(tvec, v0v1, qvec);
|
||||
*v = vec_dot(dir, qvec) * inv_det;
|
||||
if(*v < 0 || (*u + *v) > 1) return false;
|
||||
|
||||
*t = vec_dot(v0v2, qvec) * inv_det;
|
||||
T w = 1 - *u - *v;
|
||||
*v = *u;
|
||||
*u = w;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
bool ray_triangle_mesh_intersect_3d(const T* orig, const T* dir, const int* faces, int n_faces, const T* vertices, int* face_idx, T* t, T* u, T* v) {
|
||||
#ifdef __CUDA_ARCH__
|
||||
*t = 1e9;
|
||||
#else
|
||||
*t = std::numeric_limits<T>::max();
|
||||
#endif
|
||||
bool valid = false;
|
||||
for(int fidx = 0; fidx < n_faces; ++fidx) {
|
||||
const T* v0 = vertices + faces[fidx * 3 + 0] * 3;
|
||||
const T* v1 = vertices + faces[fidx * 3 + 1] * 3;
|
||||
const T* v2 = vertices + faces[fidx * 3 + 2] * 3;
|
||||
|
||||
T ft, fu, fv;
|
||||
bool inter = ray_triangle_intersect_3d(orig, dir, v0,v1,v2, &ft,&fu,&fv);
|
||||
if(inter && ft < *t) {
|
||||
*face_idx = fidx;
|
||||
*t = ft;
|
||||
*u = fu;
|
||||
*v = fv;
|
||||
valid = true;
|
||||
}
|
||||
}
|
||||
|
||||
return valid;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
void reflectance_light_dir(const T* sp, const T* lp, T* l) {
|
||||
vec_sub(lp, sp, l);
|
||||
vec_normalize(l, l);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
T reflectance_lambartian(const T* sp, const T* lp, const T* n) {
|
||||
T l[3];
|
||||
reflectance_light_dir(sp, lp, l);
|
||||
return vec_dot(l, n);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
T reflectance_phong(const T* orig, const T* sp, const T* lp, const T* n, const T ka, const T kd, const T ks, const T alpha) {
|
||||
T l[3];
|
||||
reflectance_light_dir(sp, lp, l);
|
||||
|
||||
T r[3];
|
||||
vec_add(2 * vec_dot(l, n), n, -1.f, l, r);
|
||||
vec_normalize(r,r); //needed?
|
||||
|
||||
T v[3];
|
||||
vec_sub(orig, sp, v);
|
||||
vec_normalize(v, v);
|
||||
|
||||
return ka + kd * vec_dot(l, n) + ks * std::pow(vec_dot(r, v), alpha);
|
||||
}
|
||||
|
||||
#endif
|
369
renderer/render/render.h
Normal file
369
renderer/render/render.h
Normal file
@ -0,0 +1,369 @@
|
||||
#ifndef RENDER_H
|
||||
#define RENDER_H
|
||||
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
|
||||
#include "co_types.h"
|
||||
#include "common.h"
|
||||
#include "geometry.h"
|
||||
|
||||
|
||||
template <typename T>
|
||||
struct Camera {
|
||||
const T fx;
|
||||
const T fy;
|
||||
const T px;
|
||||
const T py;
|
||||
const T R0, R1, R2, R3, R4, R5, R6, R7, R8;
|
||||
const T t0, t1, t2;
|
||||
const T C0, C1, C2;
|
||||
const int height;
|
||||
const int width;
|
||||
|
||||
Camera(const T fx, const T fy, const T px, const T py, const T* R, const T* t, int width, int height) :
|
||||
fx(fx), fy(fy), px(px), py(py),
|
||||
R0(R[0]), R1(R[1]), R2(R[2]), R3(R[3]), R4(R[4]), R5(R[5]), R6(R[6]), R7(R[7]), R8(R[8]),
|
||||
t0(t[0]), t1(t[1]), t2(t[2]),
|
||||
C0(-(R[0] * t[0] + R[3] * t[1] + R[6] * t[2])),
|
||||
C1(-(R[1] * t[0] + R[4] * t[1] + R[7] * t[2])),
|
||||
C2(-(R[2] * t[0] + R[5] * t[1] + R[8] * t[2])),
|
||||
height(height), width(width)
|
||||
{
|
||||
}
|
||||
|
||||
CPU_GPU_FUNCTION
|
||||
inline void to_cam(const T* x, T* y) const {
|
||||
y[0] = R0 * x[0] + R1 * x[1] + R2 * x[2] + t0;
|
||||
y[1] = R3 * x[0] + R4 * x[1] + R5 * x[2] + t1;
|
||||
y[2] = R6 * x[0] + R7 * x[1] + R8 * x[2] + t2;
|
||||
}
|
||||
|
||||
CPU_GPU_FUNCTION
|
||||
inline void to_world(const T* x, T* y) const {
|
||||
y[0] = R0 * (x[0] - t0) + R3 * (x[1] - t1) + R6 * (x[2] - t2);
|
||||
y[1] = R1 * (x[0] - t0) + R4 * (x[1] - t1) + R7 * (x[2] - t2);
|
||||
y[2] = R2 * (x[0] - t0) + R5 * (x[1] - t1) + R8 * (x[2] - t2);
|
||||
}
|
||||
|
||||
CPU_GPU_FUNCTION
|
||||
inline void to_ray(const int h, const int w, T* dir) const {
|
||||
T uhat[2];
|
||||
uhat[0] = (w - px) / fx;
|
||||
uhat[1] = (h - py) / fy;
|
||||
dir[0] = R0 * (uhat[0]) + R3 * (uhat[1]) + R6;
|
||||
dir[1] = R1 * (uhat[0]) + R4 * (uhat[1]) + R7;
|
||||
dir[2] = R2 * (uhat[0]) + R5 * (uhat[1]) + R8;
|
||||
}
|
||||
|
||||
CPU_GPU_FUNCTION
|
||||
inline void to_2d(const T* xyz, T* u, T* v, T* d) const {
|
||||
T xyz_t[3];
|
||||
to_cam(xyz, xyz_t);
|
||||
*u = fx * xyz_t[0] + px * xyz_t[2];
|
||||
*v = fy * xyz_t[1] + py * xyz_t[2];
|
||||
*d = xyz_t[2];
|
||||
*u /= *d;
|
||||
*v /= *d;
|
||||
}
|
||||
|
||||
CPU_GPU_FUNCTION
|
||||
inline void get_C(T* C) const {
|
||||
C[0] = C0;
|
||||
C[1] = C1;
|
||||
C[2] = C2;
|
||||
}
|
||||
|
||||
CPU_GPU_FUNCTION
|
||||
inline int num_pixel() const {
|
||||
return height * width;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename T>
|
||||
struct RenderInput {
|
||||
T* verts;
|
||||
T* colors;
|
||||
T* normals;
|
||||
int n_verts;
|
||||
int* faces;
|
||||
int n_faces;
|
||||
|
||||
RenderInput() : verts(nullptr), colors(nullptr), normals(nullptr), n_verts(0), faces(nullptr), n_faces(0) {}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Buffer {
|
||||
T* depth;
|
||||
T* color;
|
||||
T* normal;
|
||||
|
||||
Buffer() : depth(nullptr), color(nullptr), normal(nullptr) {}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct Shader {
|
||||
const T ka;
|
||||
const T kd;
|
||||
const T ks;
|
||||
const T alpha;
|
||||
|
||||
Shader(T ka, T kd, T ks, T alpha) : ka(ka), kd(kd), ks(ks), alpha(alpha) {}
|
||||
|
||||
CPU_GPU_FUNCTION
|
||||
T operator()(const T* orig, const T* sp, const T* lp, const T* norm) const {
|
||||
return reflectance_phong(orig, sp, lp, norm, ka, kd, ks, alpha);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
class BaseRenderer {
|
||||
public:
|
||||
const Camera<T> cam;
|
||||
const Shader<T> shader;
|
||||
Buffer<T> buffer;
|
||||
|
||||
BaseRenderer(const Camera<T> cam, const Shader<T> shader, Buffer<T> buffer) : cam(cam), shader(shader), buffer(buffer) {
|
||||
}
|
||||
|
||||
virtual ~BaseRenderer() {}
|
||||
|
||||
virtual void render_mesh(const RenderInput<T> input) = 0;
|
||||
virtual void render_mesh_proj(const RenderInput<T> input, const Camera<T> proj, const float* pattern, float d_alpha, float d_beta) = 0;
|
||||
};
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
struct RenderFunctor {
|
||||
const Camera<T> cam;
|
||||
const Shader<T> shader;
|
||||
Buffer<T> buffer;
|
||||
|
||||
RenderFunctor(const Camera<T> cam, const Shader<T> shader, Buffer<T> buffer) : cam(cam), shader(shader), buffer(buffer) {}
|
||||
};
|
||||
|
||||
|
||||
template <typename T>
|
||||
struct RenderMeshFunctor : public RenderFunctor<T> {
|
||||
const RenderInput<T> input;
|
||||
|
||||
RenderMeshFunctor(const RenderInput<T> input, const Shader<T> shader, const Camera<T> cam, Buffer<T> buffer) : RenderFunctor<T>(cam, shader,buffer), input(input) {
|
||||
}
|
||||
|
||||
CPU_GPU_FUNCTION void operator()(const int idx) {
|
||||
int h = idx / this->cam.width;
|
||||
int w = idx % this->cam.width;
|
||||
|
||||
T orig[3];
|
||||
this->cam.get_C(orig);
|
||||
T dir[3];
|
||||
this->cam.to_ray(h, w, dir);
|
||||
|
||||
int face_idx;
|
||||
T t, tu, tv;
|
||||
bool valid = ray_triangle_mesh_intersect_3d(orig, dir, this->input.faces, this->input.n_faces, this->input.verts, &face_idx, &t, &tu, &tv);
|
||||
|
||||
if(this->buffer.depth != nullptr) {
|
||||
this->buffer.depth[idx] = valid ? t : -1;
|
||||
}
|
||||
|
||||
if(!valid) {
|
||||
if(this->buffer.color != nullptr) {
|
||||
this->buffer.color[idx * 3 + 0] = 0;
|
||||
this->buffer.color[idx * 3 + 1] = 0;
|
||||
this->buffer.color[idx * 3 + 2] = 0;
|
||||
}
|
||||
if(this->buffer.normal != nullptr) {
|
||||
this->buffer.normal[idx * 3 + 0] = 0;
|
||||
this->buffer.normal[idx * 3 + 1] = 0;
|
||||
this->buffer.normal[idx * 3 + 2] = 0;
|
||||
}
|
||||
}
|
||||
else if(this->buffer.normal != nullptr || this->buffer.color != nullptr) {
|
||||
const int* face = input.faces + face_idx * 3;
|
||||
T tw = 1 - tu - tv;
|
||||
|
||||
T norm[3];
|
||||
vec_fill(norm, 0.f);
|
||||
vec_add(1.f, norm, tu, this->input.normals + face[0] * 3, norm);
|
||||
vec_add(1.f, norm, tv, this->input.normals + face[1] * 3, norm);
|
||||
vec_add(1.f, norm, tw, this->input.normals + face[2] * 3, norm);
|
||||
if(vec_dot(norm, dir) > 0) {
|
||||
vec_mul_scalar(norm, -1.f, norm);
|
||||
}
|
||||
|
||||
if(this->buffer.normal != nullptr) {
|
||||
this->buffer.normal[idx * 3 + 0] = norm[0];
|
||||
this->buffer.normal[idx * 3 + 1] = norm[1];
|
||||
this->buffer.normal[idx * 3 + 2] = norm[2];
|
||||
}
|
||||
|
||||
if(this->buffer.color != nullptr) {
|
||||
T color[3];
|
||||
vec_fill(color, 0.f);
|
||||
vec_add(1.f, color, tu, this->input.colors + face[0] * 3, color);
|
||||
vec_add(1.f, color, tv, this->input.colors + face[1] * 3, color);
|
||||
vec_add(1.f, color, tw, this->input.colors + face[2] * 3, color);
|
||||
|
||||
T sp[3];
|
||||
vec_add(1.f, orig, t, dir, sp);
|
||||
T reflectance = this->shader(orig, sp, orig, norm);
|
||||
|
||||
this->buffer.color[idx * 3 + 0] = mmin(1.f, mmax(0.f, reflectance * color[0]));
|
||||
this->buffer.color[idx * 3 + 1] = mmin(1.f, mmax(0.f, reflectance * color[1]));
|
||||
this->buffer.color[idx * 3 + 2] = mmin(1.f, mmax(0.f, reflectance * color[2]));
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int n=3>
|
||||
CPU_GPU_FUNCTION
|
||||
inline void interpolate_linear(const T* im, T x, T y, int height, int width, T* out_vec) {
|
||||
int x1 = int(x);
|
||||
int y1 = int(y);
|
||||
int x2 = x1 + 1;
|
||||
int y2 = y1 + 1;
|
||||
|
||||
T denom = (x2 - x1) * (y2 - y1);
|
||||
T t11 = (x2 - x) * (y2 - y);
|
||||
T t21 = (x - x1) * (y2 - y);
|
||||
T t12 = (x2 - x) * (y - y1);
|
||||
T t22 = (x - x1) * (y - y1);
|
||||
|
||||
x1 = mmin(mmax(x1, int(0)), width-1);
|
||||
x2 = mmin(mmax(x2, int(0)), width-1);
|
||||
y1 = mmin(mmax(y1, int(0)), height-1);
|
||||
y2 = mmin(mmax(y2, int(0)), height-1);
|
||||
|
||||
for(int idx = 0; idx < n; ++idx) {
|
||||
out_vec[idx] = (im[(y1 * width + x1) * 3 + idx] * t11 +
|
||||
im[(y2 * width + x1) * 3 + idx] * t12 +
|
||||
im[(y1 * width + x2) * 3 + idx] * t21 +
|
||||
im[(y2 * width + x2) * 3 + idx] * t22) / denom;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct RenderProjectorFunctor : public RenderFunctor<T> {
|
||||
const RenderInput<T> input;
|
||||
const Camera<T> proj;
|
||||
const float* pattern;
|
||||
const float d_alpha;
|
||||
const float d_beta;
|
||||
|
||||
RenderProjectorFunctor(const RenderInput<T> input, const Shader<T> shader, const Camera<T> cam, const Camera<T> proj, const float* pattern, float d_alpha, float d_beta, Buffer<T> buffer) : RenderFunctor<T>(cam, shader, buffer), input(input), proj(proj), pattern(pattern), d_alpha(d_alpha), d_beta(d_beta) {
|
||||
}
|
||||
|
||||
CPU_GPU_FUNCTION void operator()(const int idx) {
|
||||
int h = idx / this->cam.width;
|
||||
int w = idx % this->cam.width;
|
||||
|
||||
T orig[3];
|
||||
this->cam.get_C(orig);
|
||||
T dir[3];
|
||||
this->cam.to_ray(h, w, dir);
|
||||
|
||||
int face_idx;
|
||||
T t, tu, tv;
|
||||
bool valid = ray_triangle_mesh_intersect_3d(orig, dir, this->input.faces, this->input.n_faces, this->input.verts, &face_idx, &t, &tu, &tv);
|
||||
if(this->buffer.depth != nullptr) {
|
||||
this->buffer.depth[idx] = valid ? t : -1;
|
||||
}
|
||||
|
||||
this->buffer.color[idx * 3 + 0] = 0;
|
||||
this->buffer.color[idx * 3 + 1] = 0;
|
||||
this->buffer.color[idx * 3 + 2] = 0;
|
||||
|
||||
if(valid) {
|
||||
if(this->buffer.normal != nullptr) {
|
||||
const int* face = input.faces + face_idx * 3;
|
||||
T tw = 1 - tu - tv;
|
||||
|
||||
T norm[3];
|
||||
vertex_normal_3d(
|
||||
this->input.verts + face[0] * 3,
|
||||
this->input.verts + face[1] * 3,
|
||||
this->input.verts + face[2] * 3,
|
||||
norm);
|
||||
vec_normalize(norm, norm);
|
||||
|
||||
if(vec_dot(norm, dir) > 0) {
|
||||
vec_mul_scalar(norm, -1.f, norm);
|
||||
}
|
||||
|
||||
T color[3];
|
||||
vec_fill(color, 0.f);
|
||||
vec_add(1.f, color, tu, this->input.colors + face[0] * 3, color);
|
||||
vec_add(1.f, color, tv, this->input.colors + face[1] * 3, color);
|
||||
vec_add(1.f, color, tw, this->input.colors + face[2] * 3, color);
|
||||
|
||||
T sp[3];
|
||||
vec_add(1.f, orig, t, dir, sp);
|
||||
T reflectance = this->shader(orig, sp, orig, norm);
|
||||
|
||||
this->buffer.normal[idx * 3 + 0] = mmin(1.f, mmax(0.f, reflectance * color[0]));
|
||||
this->buffer.normal[idx * 3 + 1] = mmin(1.f, mmax(0.f, reflectance * color[1]));
|
||||
this->buffer.normal[idx * 3 + 2] = mmin(1.f, mmax(0.f, reflectance * color[2]));
|
||||
}
|
||||
|
||||
// get 3D point
|
||||
T pt[3];
|
||||
vec_mul_scalar(dir, t, pt);
|
||||
vec_add(orig, pt, pt);
|
||||
|
||||
// get dir from proj
|
||||
T proj_orig[3];
|
||||
proj.get_C(proj_orig);
|
||||
T proj_dir[3];
|
||||
vec_sub(pt, proj_orig, proj_dir);
|
||||
vec_div_scalar(proj_dir, proj_dir[2], proj_dir);
|
||||
|
||||
// check if it hit same tria
|
||||
int p_face_idx;
|
||||
T p_t, p_tu, p_tv;
|
||||
valid = ray_triangle_mesh_intersect_3d(proj_orig, proj_dir, this->input.faces, this->input.n_faces, this->input.verts, &p_face_idx, &p_t, &p_tu, &p_tv);
|
||||
// if(!valid || p_face_idx != face_idx) {
|
||||
// return;
|
||||
// }
|
||||
|
||||
T p_pt[3];
|
||||
vec_mul_scalar(proj_dir, p_t, p_pt);
|
||||
vec_add(proj_orig, p_pt, p_pt);
|
||||
T diff[3];
|
||||
vec_sub(p_pt, pt, diff);
|
||||
if(!valid || vec_norm(diff) > 1e-5) {
|
||||
return;
|
||||
}
|
||||
|
||||
// get uv in proj
|
||||
T u,v,d;
|
||||
proj.to_2d(p_pt, &u,&v,&d);
|
||||
|
||||
// if valid u,v than use it to inpaint
|
||||
if(u >= 0 && v >= 0 && u < this->proj.width && v < this->proj.height) {
|
||||
// int pattern_idx = ((int(v) * this->proj.width) + int(u)) * 3;
|
||||
// this->buffer.color[idx * 3 + 0] = pattern[pattern_idx + 0];
|
||||
// this->buffer.color[idx * 3 + 1] = pattern[pattern_idx + 1];
|
||||
// this->buffer.color[idx * 3 + 2] = pattern[pattern_idx + 2];
|
||||
interpolate_linear(pattern, u, v, this->proj.height, this->proj.width, this->buffer.color + idx * 3);
|
||||
|
||||
// decay based on distance
|
||||
T decay = d_alpha + d_beta * d;
|
||||
decay *= decay;
|
||||
decay = mmax(decay, T(1));
|
||||
vec_div_scalar(this->buffer.color + idx * 3, decay, this->buffer.color + idx * 3);
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
#endif
|
22
renderer/render/render_cpu.cpp
Normal file
22
renderer/render/render_cpu.cpp
Normal file
@ -0,0 +1,22 @@
|
||||
#include <limits>
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
#include "render_cpu.h"
|
||||
#include "common_cpu.h"
|
||||
|
||||
template <typename T>
|
||||
void RendererCpu<T>::render_mesh(RenderInput<T> input) {
|
||||
RenderMeshFunctor<T> functor(input, this->shader, this->cam, this->buffer);
|
||||
iterate_omp_cpu(functor, this->cam.num_pixel(), n_threads);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void RendererCpu<T>::render_mesh_proj(const RenderInput<T> input, const Camera<T> proj, const float* pattern, float d_alpha, float d_beta) {
|
||||
RenderProjectorFunctor<T> functor(input, this->shader, this->cam, proj, pattern, d_alpha, d_beta, this->buffer);
|
||||
iterate_omp_cpu(functor, this->cam.num_pixel(), this->n_threads);
|
||||
}
|
||||
|
||||
template class RendererCpu<float>;
|
23
renderer/render/render_cpu.h
Normal file
23
renderer/render/render_cpu.h
Normal file
@ -0,0 +1,23 @@
|
||||
#ifndef RENDER_CPU_H
|
||||
#define RENDER_CPU_H
|
||||
|
||||
#include "render.h"
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
class RendererCpu : public BaseRenderer<T> {
|
||||
public:
|
||||
const int n_threads;
|
||||
|
||||
RendererCpu(const Camera<T> cam, const Shader<T> shader, Buffer<T> buffer, int n_threads) : BaseRenderer<T>(cam, shader, buffer), n_threads(n_threads) {
|
||||
}
|
||||
|
||||
virtual ~RendererCpu() {
|
||||
}
|
||||
|
||||
virtual void render_mesh(const RenderInput<T> input);
|
||||
virtual void render_mesh_proj(const RenderInput<T> input, const Camera<T> proj, const float* pattern, float d_alpha, float d_beta);
|
||||
};
|
||||
|
||||
#endif
|
100
renderer/render/render_gpu.cu
Normal file
100
renderer/render/render_gpu.cu
Normal file
@ -0,0 +1,100 @@
|
||||
#include "common_cuda.h"
|
||||
#include "render_gpu.h"
|
||||
|
||||
template <typename T>
|
||||
RendererGpu<T>::RendererGpu(const Camera<T> cam, const Shader<T> shader, Buffer<T> buffer) : BaseRenderer<T>(cam, shader, buffer) {
|
||||
if(buffer.depth != nullptr) {
|
||||
buffer_gpu.depth = device_malloc<T>(cam.num_pixel());
|
||||
}
|
||||
|
||||
if(buffer.color != nullptr) {
|
||||
buffer_gpu.color = device_malloc<T>(cam.num_pixel() * 3);
|
||||
}
|
||||
|
||||
if(buffer.normal != nullptr) {
|
||||
buffer_gpu.normal = device_malloc<T>(cam.num_pixel() * 3);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
RendererGpu<T>::~RendererGpu() {
|
||||
device_free(buffer_gpu.depth);
|
||||
device_free(buffer_gpu.color);
|
||||
device_free(buffer_gpu.normal);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void RendererGpu<T>::gpu_to_cpu() {
|
||||
if(buffer_gpu.depth != nullptr && this->buffer.depth != nullptr) {
|
||||
device_to_host(buffer_gpu.depth, this->buffer.depth, this->cam.num_pixel());
|
||||
}
|
||||
if(buffer_gpu.color != nullptr && this->buffer.color != nullptr) {
|
||||
device_to_host(buffer_gpu.color, this->buffer.color, this->cam.num_pixel() * 3);
|
||||
}
|
||||
if(buffer_gpu.normal != nullptr && this->buffer.normal != nullptr) {
|
||||
device_to_host(buffer_gpu.normal, this->buffer.normal, this->cam.num_pixel() * 3);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
RenderInput<T> RendererGpu<T>::input_to_device(const RenderInput<T> input) {
|
||||
RenderInput<T> input_gpu;
|
||||
input_gpu.n_verts = input.n_verts;
|
||||
input_gpu.n_faces = input.n_faces;
|
||||
|
||||
if(input.verts != nullptr) {
|
||||
input_gpu.verts = host_to_device_malloc(input.verts, input.n_verts * 3);
|
||||
}
|
||||
if(input.colors != nullptr) {
|
||||
input_gpu.colors = host_to_device_malloc(input.colors, input.n_verts * 3);
|
||||
}
|
||||
if(input.normals != nullptr) {
|
||||
input_gpu.normals = host_to_device_malloc(input.normals, input.n_verts * 3);
|
||||
}
|
||||
if(input.faces != nullptr) {
|
||||
input_gpu.faces = host_to_device_malloc(input.faces, input.n_faces * 3);
|
||||
}
|
||||
|
||||
return input_gpu;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void RendererGpu<T>::input_free_device(const RenderInput<T> input) {
|
||||
if(input.verts != nullptr) {
|
||||
device_free(input.verts);
|
||||
}
|
||||
if(input.colors != nullptr) {
|
||||
device_free(input.colors);
|
||||
}
|
||||
if(input.normals != nullptr) {
|
||||
device_free(input.normals);
|
||||
}
|
||||
if(input.faces != nullptr) {
|
||||
device_free(input.faces);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
void RendererGpu<T>::render_mesh(RenderInput<T> input) {
|
||||
RenderInput<T> input_gpu = this->input_to_device(input);
|
||||
RenderMeshFunctor<T> functor(input_gpu, this->shader, this->cam, this->buffer_gpu);
|
||||
iterate_cuda(functor, this->cam.num_pixel());
|
||||
gpu_to_cpu();
|
||||
this->input_free_device(input_gpu);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void RendererGpu<T>::render_mesh_proj(const RenderInput<T> input, const Camera<T> proj, const float* pattern, float d_alpha, float d_beta) {
|
||||
RenderInput<T> input_gpu = this->input_to_device(input);
|
||||
float* pattern_gpu = host_to_device_malloc(pattern, proj.num_pixel()*3);
|
||||
|
||||
RenderProjectorFunctor<T> functor(input_gpu, this->shader, this->cam, proj, pattern_gpu, d_alpha, d_beta, this->buffer_gpu);
|
||||
iterate_cuda(functor, this->cam.num_pixel());
|
||||
|
||||
gpu_to_cpu();
|
||||
this->input_free_device(input_gpu);
|
||||
device_free(pattern_gpu);
|
||||
}
|
||||
|
||||
template class RendererGpu<float>;
|
23
renderer/render/render_gpu.h
Normal file
23
renderer/render/render_gpu.h
Normal file
@ -0,0 +1,23 @@
|
||||
#ifndef RENDER_RENDER_GPU_H
|
||||
#define RENDER_RENDER_GPU_H
|
||||
|
||||
#include "render.h"
|
||||
|
||||
template <typename T>
|
||||
class RendererGpu : public BaseRenderer<T> {
|
||||
public:
|
||||
Buffer<T> buffer_gpu;
|
||||
|
||||
RendererGpu(const Camera<T> cam, const Shader<T> shader, Buffer<T> buffer);
|
||||
|
||||
virtual ~RendererGpu();
|
||||
|
||||
virtual void gpu_to_cpu();
|
||||
virtual RenderInput<T> input_to_device(const RenderInput<T> input);
|
||||
virtual void input_free_device(const RenderInput<T> input);
|
||||
|
||||
virtual void render_mesh(const RenderInput<T> input);
|
||||
virtual void render_mesh_proj(const RenderInput<T> input, const Camera<T> proj, const float* pattern, float d_alpha, float d_beta);
|
||||
};
|
||||
|
||||
#endif
|
33
renderer/render/render_gpu_dummy.cpp
Normal file
33
renderer/render/render_gpu_dummy.cpp
Normal file
@ -0,0 +1,33 @@
|
||||
#include "render_gpu.h"
|
||||
|
||||
template <typename T>
|
||||
RendererGpu<T>::RendererGpu(const Camera<T> cam, const Shader<T> shader, Buffer<T> buffer) : BaseRenderer<T>(cam, shader, buffer) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
RendererGpu<T>::~RendererGpu() {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void RendererGpu<T>::gpu_to_cpu() {}
|
||||
|
||||
template <typename T>
|
||||
RenderInput<T> RendererGpu<T>::input_to_device(const RenderInput<T> input) { return RenderInput<T>(); }
|
||||
|
||||
template <typename T>
|
||||
void RendererGpu<T>::input_free_device(const RenderInput<T> input) {
|
||||
throw std::logic_error("Not implemented");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void RendererGpu<T>::render_mesh(const RenderInput<T> input) {
|
||||
throw std::logic_error("Not implemented");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void RendererGpu<T>::render_mesh_proj(const RenderInput<T> input, const Camera<T> proj, const float* pattern, float d_alpha, float d_beta) {
|
||||
throw std::logic_error("Not implemented");
|
||||
}
|
||||
|
||||
|
||||
template class RendererGpu<float>;
|
35
renderer/render/stdlib_cuda.cu
Normal file
35
renderer/render/stdlib_cuda.cu
Normal file
@ -0,0 +1,35 @@
|
||||
#include "common_cuda.h"
|
||||
#include "stdlib_cuda.h"
|
||||
|
||||
void device_synchronize() {
|
||||
cudaDeviceSynchronize();
|
||||
}
|
||||
|
||||
float* device_malloc_f32(long N) {
|
||||
return device_malloc<float>(N);
|
||||
}
|
||||
int* device_malloc_i32(long N) {
|
||||
return device_malloc<int>(N);
|
||||
}
|
||||
|
||||
void device_free_f32(float* dptr) {
|
||||
device_free(dptr);
|
||||
}
|
||||
void device_free_i32(int* dptr) {
|
||||
device_free(dptr);
|
||||
}
|
||||
|
||||
void device_to_host_f32(const float* dptr, float* hptr, long N) {
|
||||
device_to_host(dptr, hptr, N);
|
||||
}
|
||||
void device_to_host_i32(const int* dptr, int* hptr, long N) {
|
||||
device_to_host(dptr, hptr, N);
|
||||
}
|
||||
|
||||
float* host_to_device_malloc_f32(const float* hptr, long N) {
|
||||
return host_to_device_malloc(hptr, N);
|
||||
}
|
||||
|
||||
int* host_to_device_malloc_i32(const int* hptr, long N) {
|
||||
return host_to_device_malloc(hptr, N);
|
||||
}
|
18
renderer/render/stdlib_cuda.h
Normal file
18
renderer/render/stdlib_cuda.h
Normal file
@ -0,0 +1,18 @@
|
||||
#ifndef STDLIB_CUDA
|
||||
#define STDLIB_CUDA
|
||||
|
||||
void device_synchronize();
|
||||
|
||||
float* device_malloc_f32(long N);
|
||||
int* device_malloc_i32(long N);
|
||||
|
||||
void device_free_f32(float* dptr);
|
||||
void device_free_i32(int* dptr);
|
||||
|
||||
float* host_to_device_malloc_f32(const float* hptr, long N);
|
||||
int* host_to_device_malloc_i32(const int* hptr, long N);
|
||||
|
||||
void device_to_host_f32(const float* dptr, float* hptr, long N);
|
||||
void device_to_host_i32(const int* dptr, int* hptr, long N);
|
||||
|
||||
#endif
|
10
renderer/render/stdlib_cuda_dummy.cpp
Normal file
10
renderer/render/stdlib_cuda_dummy.cpp
Normal file
@ -0,0 +1,10 @@
|
||||
#include "stdlib_cuda.h"
|
||||
|
||||
float* device_malloc_f32(long N) { return nullptr; }
|
||||
int* device_malloc_i32(long N) { return nullptr; }
|
||||
void device_free_f32(float* dptr) {}
|
||||
void device_free_i32(int* dptr) {}
|
||||
float* host_to_device_malloc_f32(const float* hptr, long N) { return nullptr; }
|
||||
int* host_to_device_malloc_i32(const int* hptr, long N) { return nullptr; }
|
||||
void device_to_host_f32(const float* dptr, float* hptr, long N) {}
|
||||
void device_to_host_i32(const int* dptr, int* hptr, long N) {}
|
49
renderer/setup.py
Normal file
49
renderer/setup.py
Normal file
@ -0,0 +1,49 @@
|
||||
from distutils.core import setup
|
||||
from Cython.Build import cythonize
|
||||
from distutils.extension import Extension
|
||||
from Cython.Distutils import build_ext
|
||||
import numpy as np
|
||||
import platform
|
||||
import os
|
||||
import json
|
||||
|
||||
this_dir = os.path.dirname(__file__)
|
||||
|
||||
with open('../config.json') as fp:
|
||||
config = json.load(fp)
|
||||
|
||||
extra_compile_args = ['-O3', '-std=c++11']
|
||||
|
||||
print('using cuda')
|
||||
cuda_lib_dir = config['CUDA_LIBRARY_DIR']
|
||||
cuda_lib = 'cudart'
|
||||
|
||||
sources = ['cyrender.pyx']
|
||||
extra_objects = [
|
||||
os.path.join(this_dir, 'render/render_cpu.cpp.o'),
|
||||
]
|
||||
library_dirs = []
|
||||
libraries = ['m']
|
||||
extra_objects.append(os.path.join(this_dir, 'render/render_gpu.cu.o'))
|
||||
extra_objects.append(os.path.join(this_dir, 'render/stdlib_cuda.cu.o'))
|
||||
library_dirs.append(cuda_lib_dir)
|
||||
libraries.append(cuda_lib)
|
||||
|
||||
setup(
|
||||
name="cyrender",
|
||||
cmdclass= {'build_ext': build_ext},
|
||||
ext_modules=[
|
||||
Extension('cyrender',
|
||||
sources,
|
||||
extra_objects=extra_objects,
|
||||
language='c++',
|
||||
library_dirs=library_dirs,
|
||||
libraries=libraries,
|
||||
include_dirs=[
|
||||
np.get_include(),
|
||||
],
|
||||
extra_compile_args=extra_compile_args,
|
||||
# extra_link_args=extra_link_args
|
||||
)
|
||||
]
|
||||
)
|
6
requirements.txt
Normal file
6
requirements.txt
Normal file
@ -0,0 +1,6 @@
|
||||
cython
|
||||
numpy
|
||||
matplotlib
|
||||
pandas
|
||||
scipy
|
||||
opencv
|
4
torchext/__init__.py
Normal file
4
torchext/__init__.py
Normal file
@ -0,0 +1,4 @@
|
||||
from .dataset import *
|
||||
from .worker import *
|
||||
from .functions import *
|
||||
from .modules import *
|
66
torchext/dataset.py
Normal file
66
torchext/dataset.py
Normal file
@ -0,0 +1,66 @@
|
||||
import torch
|
||||
import torch.utils.data
|
||||
import numpy as np
|
||||
|
||||
class TestSet(object):
|
||||
def __init__(self, name, dset, test_frequency=1):
|
||||
self.name = name
|
||||
self.dset = dset
|
||||
self.test_frequency = test_frequency
|
||||
|
||||
class TestSets(list):
|
||||
def append(self, name, dset, test_frequency=1):
|
||||
super().append(TestSet(name, dset, test_frequency))
|
||||
|
||||
|
||||
|
||||
class MultiDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, *datasets):
|
||||
self.current_epoch = 0
|
||||
|
||||
self.datasets = []
|
||||
self.cum_n_samples = [0]
|
||||
|
||||
for dataset in datasets:
|
||||
self.append(dataset)
|
||||
|
||||
def append(self, dataset):
|
||||
self.datasets.append(dataset)
|
||||
self.__update_cum_n_samples(dataset)
|
||||
|
||||
def __update_cum_n_samples(self, dataset):
|
||||
n_samples = self.cum_n_samples[-1] + len(dataset)
|
||||
self.cum_n_samples.append(n_samples)
|
||||
|
||||
def dataset_updated(self):
|
||||
self.cum_n_samples = [0]
|
||||
for dset in self.datasets:
|
||||
self.__update_cum_n_samples(dset)
|
||||
|
||||
def __len__(self):
|
||||
return self.cum_n_samples[-1]
|
||||
|
||||
def __getitem__(self, idx):
|
||||
didx = np.searchsorted(self.cum_n_samples, idx, side='right') - 1
|
||||
sidx = idx - self.cum_n_samples[didx]
|
||||
return self.datasets[didx][sidx]
|
||||
|
||||
|
||||
|
||||
class BaseDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, train=True, fix_seed_per_epoch=False):
|
||||
self.current_epoch = 0
|
||||
self.train = train
|
||||
self.fix_seed_per_epoch = fix_seed_per_epoch
|
||||
|
||||
def get_rng(self, idx):
|
||||
rng = np.random.RandomState()
|
||||
if self.train:
|
||||
if self.fix_seed_per_epoch:
|
||||
seed = 1 * len(self) + idx
|
||||
else:
|
||||
seed = (self.current_epoch + 1) * len(self) + idx
|
||||
rng.seed(seed)
|
||||
else:
|
||||
rng.seed(idx)
|
||||
return rng
|
10
torchext/ext/co_types.h
Normal file
10
torchext/ext/co_types.h
Normal file
@ -0,0 +1,10 @@
|
||||
#ifndef TYPES_H
|
||||
#define TYPES_H
|
||||
|
||||
#ifdef __CUDA_ARCH__
|
||||
#define CPU_GPU_FUNCTION __host__ __device__
|
||||
#else
|
||||
#define CPU_GPU_FUNCTION
|
||||
#endif
|
||||
|
||||
#endif
|
135
torchext/ext/common.h
Normal file
135
torchext/ext/common.h
Normal file
@ -0,0 +1,135 @@
|
||||
#ifndef COMMON_H
|
||||
#define COMMON_H
|
||||
|
||||
#include "co_types.h"
|
||||
#include <cmath>
|
||||
#include <algorithm>
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
|
||||
#define DISABLE_COPY_AND_ASSIGN(classname) \
|
||||
private:\
|
||||
classname(const classname&) = delete;\
|
||||
classname& operator=(const classname&) = delete;
|
||||
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
void fill(T* arr, int N, T val) {
|
||||
for(int idx = 0; idx < N; ++idx) {
|
||||
arr[idx] = val;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
void fill_zero(T* arr, int N) {
|
||||
for(int idx = 0; idx < N; ++idx) {
|
||||
arr[idx] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
inline T distance_euclidean(const T* q, const T* t, int N) {
|
||||
T out = 0;
|
||||
for(int idx = 0; idx < N; idx++) {
|
||||
T diff = q[idx] - t[idx];
|
||||
out += diff * diff;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
inline T distance_l2(const T* q, const T* t, int N) {
|
||||
T out = distance_euclidean(q, t, N);
|
||||
out = std::sqrt(out);
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
template <typename T>
|
||||
struct FillFunctor {
|
||||
T* arr;
|
||||
const T val;
|
||||
|
||||
FillFunctor(T* arr, const T val) : arr(arr), val(val) {}
|
||||
CPU_GPU_FUNCTION void operator()(const int idx) {
|
||||
arr[idx] = val;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
T mmin(const T& a, const T& b) {
|
||||
#ifdef __CUDA_ARCH__
|
||||
return min(a, b);
|
||||
#else
|
||||
return std::min(a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
T mmax(const T& a, const T& b) {
|
||||
#ifdef __CUDA_ARCH__
|
||||
return max(a, b);
|
||||
#else
|
||||
return std::max(a, b);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
T mround(const T& a) {
|
||||
#ifdef __CUDA_ARCH__
|
||||
return round(a);
|
||||
#else
|
||||
return round(a);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
||||
#ifdef __CUDA_ARCH__
|
||||
#if __CUDA_ARCH__ < 600
|
||||
__device__ double atomicAdd(double* address, double val)
|
||||
{
|
||||
unsigned long long int* address_as_ull =
|
||||
(unsigned long long int*)address;
|
||||
unsigned long long int old = *address_as_ull, assumed;
|
||||
|
||||
do {
|
||||
assumed = old;
|
||||
old = atomicCAS(address_as_ull, assumed,
|
||||
__double_as_longlong(val +
|
||||
__longlong_as_double(assumed)));
|
||||
|
||||
// Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
|
||||
} while (assumed != old);
|
||||
|
||||
return __longlong_as_double(old);
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
|
||||
template <typename T>
|
||||
CPU_GPU_FUNCTION
|
||||
void matomic_add(T* addr, T val) {
|
||||
#ifdef __CUDA_ARCH__
|
||||
atomicAdd(addr, val);
|
||||
#else
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp atomic
|
||||
#endif
|
||||
*addr += val;
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif
|
173
torchext/ext/common_cuda.h
Normal file
173
torchext/ext/common_cuda.h
Normal file
@ -0,0 +1,173 @@
|
||||
#ifndef COMMON_CUDA
|
||||
#define COMMON_CUDA
|
||||
|
||||
#include <cublas_v2.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#define DEBUG 0
|
||||
#define CUDA_DEBUG_DEVICE_SYNC 0
|
||||
|
||||
// cuda check for cudaMalloc and so on
|
||||
#define CUDA_CHECK(condition) \
|
||||
/* Code block avoids redefinition of cudaError_t error */ \
|
||||
do { \
|
||||
if(CUDA_DEBUG_DEVICE_SYNC) { cudaDeviceSynchronize(); } \
|
||||
cudaError_t error = condition; \
|
||||
if(error != cudaSuccess) { \
|
||||
printf("%s in %s at %d\n", cudaGetErrorString(error), __FILE__, __LINE__); \
|
||||
exit(-1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
/// Get error string for error code.
|
||||
/// @param error
|
||||
inline const char* cublasGetErrorString(cublasStatus_t error) {
|
||||
switch (error) {
|
||||
case CUBLAS_STATUS_SUCCESS:
|
||||
return "CUBLAS_STATUS_SUCCESS";
|
||||
case CUBLAS_STATUS_NOT_INITIALIZED:
|
||||
return "CUBLAS_STATUS_NOT_INITIALIZED";
|
||||
case CUBLAS_STATUS_ALLOC_FAILED:
|
||||
return "CUBLAS_STATUS_ALLOC_FAILED";
|
||||
case CUBLAS_STATUS_INVALID_VALUE:
|
||||
return "CUBLAS_STATUS_INVALID_VALUE";
|
||||
case CUBLAS_STATUS_ARCH_MISMATCH:
|
||||
return "CUBLAS_STATUS_ARCH_MISMATCH";
|
||||
case CUBLAS_STATUS_MAPPING_ERROR:
|
||||
return "CUBLAS_STATUS_MAPPING_ERROR";
|
||||
case CUBLAS_STATUS_EXECUTION_FAILED:
|
||||
return "CUBLAS_STATUS_EXECUTION_FAILED";
|
||||
case CUBLAS_STATUS_INTERNAL_ERROR:
|
||||
return "CUBLAS_STATUS_INTERNAL_ERROR";
|
||||
case CUBLAS_STATUS_NOT_SUPPORTED:
|
||||
return "CUBLAS_STATUS_NOT_SUPPORTED";
|
||||
case CUBLAS_STATUS_LICENSE_ERROR:
|
||||
return "CUBLAS_STATUS_LICENSE_ERROR";
|
||||
}
|
||||
return "Unknown cublas status";
|
||||
}
|
||||
|
||||
#define CUBLAS_CHECK(condition) \
|
||||
do { \
|
||||
if(CUDA_DEBUG_DEVICE_SYNC) { cudaDeviceSynchronize(); } \
|
||||
cublasStatus_t status = condition; \
|
||||
if(status != CUBLAS_STATUS_SUCCESS) { \
|
||||
printf("%s in %s at %d\n", cublasGetErrorString(status), __FILE__, __LINE__); \
|
||||
exit(-1); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
// check if there is a error after kernel execution
|
||||
#define CUDA_POST_KERNEL_CHECK \
|
||||
CUDA_CHECK(cudaPeekAtLastError()); \
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
#define CUDA_KERNEL_LOOP(i, n) \
|
||||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); i += blockDim.x * gridDim.x)
|
||||
|
||||
const int CUDA_NUM_THREADS = 1024;
|
||||
|
||||
inline int GET_BLOCKS(const int N, const int N_THREADS=CUDA_NUM_THREADS) {
|
||||
return (N + N_THREADS - 1) / N_THREADS;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T* device_malloc(long N) {
|
||||
T* dptr;
|
||||
CUDA_CHECK(cudaMalloc(&dptr, N * sizeof(T)));
|
||||
if(DEBUG) { printf("[DEBUG] device_malloc %p, %ld\n", dptr, N); }
|
||||
return dptr;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void device_free(T* dptr) {
|
||||
if(DEBUG) { printf("[DEBUG] device_free %p\n", dptr); }
|
||||
CUDA_CHECK(cudaFree(dptr));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void host_to_device(const T* hptr, T* dptr, long N) {
|
||||
if(DEBUG) { printf("[DEBUG] host_to_device %p => %p, %ld\n", hptr, dptr, N); }
|
||||
CUDA_CHECK(cudaMemcpy(dptr, hptr, N * sizeof(T), cudaMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T* host_to_device_malloc(const T* hptr, long N) {
|
||||
T* dptr = device_malloc<T>(N);
|
||||
host_to_device(hptr, dptr, N);
|
||||
return dptr;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void device_to_host(const T* dptr, T* hptr, long N) {
|
||||
if(DEBUG) { printf("[DEBUG] device_to_host %p => %p, %ld\n", dptr, hptr, N); }
|
||||
CUDA_CHECK(cudaMemcpy(hptr, dptr, N * sizeof(T), cudaMemcpyDeviceToHost));
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
T* device_to_host_malloc(const T* dptr, long N) {
|
||||
T* hptr = new T[N];
|
||||
device_to_host(dptr, hptr, N);
|
||||
return hptr;
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
void device_to_device(const T* dptr, T* hptr, long N) {
|
||||
if(DEBUG) { printf("[DEBUG] device_to_device %p => %p, %ld\n", dptr, hptr, N); }
|
||||
CUDA_CHECK(cudaMemcpy(hptr, dptr, N * sizeof(T), cudaMemcpyDeviceToDevice));
|
||||
}
|
||||
|
||||
// https://github.com/parallel-forall/code-samples/blob/master/posts/cuda-aware-mpi-example/src/Device.cu
|
||||
// https://github.com/treecode/Bonsai/blob/master/runtime/profiling/derived_atomic_functions.h
|
||||
__device__ __forceinline__ void atomicMaxF(float * const address, const float value) {
|
||||
if (*address >= value) {
|
||||
return;
|
||||
}
|
||||
|
||||
int * const address_as_i = (int *)address;
|
||||
int old = * address_as_i, assumed;
|
||||
|
||||
do {
|
||||
assumed = old;
|
||||
if (__int_as_float(assumed) >= value) {
|
||||
break;
|
||||
}
|
||||
|
||||
old = atomicCAS(address_as_i, assumed, __float_as_int(value));
|
||||
} while (assumed != old);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ void atomicMinF(float * const address, const float value) {
|
||||
if (*address <= value) {
|
||||
return;
|
||||
}
|
||||
|
||||
int * const address_as_i = (int *)address;
|
||||
int old = * address_as_i, assumed;
|
||||
|
||||
do {
|
||||
assumed = old;
|
||||
if (__int_as_float(assumed) <= value) {
|
||||
break;
|
||||
}
|
||||
|
||||
old = atomicCAS(address_as_i, assumed, __float_as_int(value));
|
||||
} while (assumed != old);
|
||||
}
|
||||
|
||||
|
||||
template <typename FunctorT>
|
||||
__global__ void iterate_kernel(FunctorT functor, int N) {
|
||||
CUDA_KERNEL_LOOP(idx, N) {
|
||||
functor(idx);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename FunctorT>
|
||||
void iterate_cuda(FunctorT functor, int N, int N_THREADS=CUDA_NUM_THREADS) {
|
||||
iterate_kernel<<<GET_BLOCKS(N, N_THREADS), N_THREADS>>>(functor, N);
|
||||
CUDA_POST_KERNEL_CHECK;
|
||||
}
|
||||
|
||||
|
||||
#endif
|
347
torchext/ext/ext.h
Normal file
347
torchext/ext/ext.h
Normal file
@ -0,0 +1,347 @@
|
||||
#pragma once
|
||||
|
||||
#include "common.h"
|
||||
|
||||
|
||||
#define CHECK_CUDA(x) AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
|
||||
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
|
||||
|
||||
#define CHECK_INPUT_CPU(x) CHECK_CONTIGUOUS(x)
|
||||
#define CHECK_INPUT_CUDA(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
|
||||
|
||||
|
||||
template <typename T, int dim=3>
|
||||
struct NNFunctor {
|
||||
const T* in0; // nelem0 x dim
|
||||
const T* in1; // nelem1 x dim
|
||||
const long nelem0;
|
||||
const long nelem1;
|
||||
long* out; // nelem0
|
||||
|
||||
NNFunctor(const T* in0, const T* in1, long nelem0, long nelem1, long* out) : in0(in0), in1(in1), nelem0(nelem0), nelem1(nelem1), out(out) {}
|
||||
|
||||
CPU_GPU_FUNCTION void operator()(long idx0) {
|
||||
// idx0 \in [nelem0]
|
||||
|
||||
const T* vec0 = in0 + idx0 * dim;
|
||||
|
||||
T min_dist = 1e9;
|
||||
long min_arg = -1;
|
||||
for(long idx1 = 0; idx1 < nelem1; ++idx1) {
|
||||
const T* vec1 = in1 + idx1 * dim;
|
||||
T dist = 0;
|
||||
for(long didx = 0; didx < dim; ++didx) {
|
||||
T diff = vec0[didx] - vec1[didx];
|
||||
dist += diff * diff;
|
||||
}
|
||||
|
||||
if(dist < min_dist) {
|
||||
min_dist = dist;
|
||||
min_arg = idx1;
|
||||
}
|
||||
}
|
||||
|
||||
out[idx0] = min_arg;
|
||||
}
|
||||
};
|
||||
|
||||
struct CrossCheckFunctor {
|
||||
const long* in0; // nelem0
|
||||
const long* in1; // nelem1
|
||||
const long nelem0;
|
||||
const long nelem1;
|
||||
uint8_t* out; // nelem0
|
||||
|
||||
CrossCheckFunctor(const long* in0, const long* in1, long nelem0, long nelem1, uint8_t* out) : in0(in0), in1(in1), nelem0(nelem0), nelem1(nelem1), out(out) {}
|
||||
|
||||
CPU_GPU_FUNCTION void operator()(long idx0) {
|
||||
// idx0 \in [nelem0]
|
||||
int idx1 = in0[idx0];
|
||||
out[idx0] = idx1 >=0 && in1[idx1] >= 0 && idx0 == in1[idx1];
|
||||
// out[idx0] = idx0 == in1[in0[idx0]];
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int dim=3>
|
||||
struct ProjNNFunctor {
|
||||
// xyz0, xyz1 in coord sys of 1
|
||||
const T* xyz0; // bs x height x width x 3
|
||||
const T* xyz1; // bs x height x width x 3
|
||||
const T* K; // 3 x 3
|
||||
const long batch_size;
|
||||
const long height;
|
||||
const long width;
|
||||
const long patch_size;
|
||||
long* out; // bs x height x width
|
||||
|
||||
ProjNNFunctor(const T* xyz0, const T* xyz1, const T* K, long batch_size, long height, long width, long patch_size, long* out)
|
||||
: xyz0(xyz0), xyz1(xyz1), K(K), batch_size(batch_size), height(height), width(width), patch_size(patch_size), out(out) {}
|
||||
|
||||
CPU_GPU_FUNCTION void operator()(long idx0) {
|
||||
// idx0 \in [0, bs x height x width]
|
||||
|
||||
const long bs = idx0 / (height * width);
|
||||
|
||||
const T x = xyz0[idx0 * 3 + 0];
|
||||
const T y = xyz0[idx0 * 3 + 1];
|
||||
const T z = xyz0[idx0 * 3 + 2];
|
||||
const T d = K[6] * x + K[7] * y + K[8] * z;
|
||||
const T u = (K[0] * x + K[1] * y + K[2] * z) / d;
|
||||
const T v = (K[3] * x + K[4] * y + K[5] * z) / d;
|
||||
|
||||
int u0 = u + 0.5;
|
||||
int v0 = v + 0.5;
|
||||
|
||||
long min_idx1 = -1;
|
||||
T min_dist = 1e9;
|
||||
for(int pidx = 0; pidx < patch_size*patch_size; ++pidx) {
|
||||
int pu = pidx % patch_size;
|
||||
int pv = pidx / patch_size;
|
||||
|
||||
int u1 = u0 + pu - patch_size/2;
|
||||
int v1 = v0 + pv - patch_size/2;
|
||||
|
||||
if(u1 >= 0 && v1 >= 0 && u1 < width && v1 < height) {
|
||||
const long idx1 = (bs * height + v1) * width + u1;
|
||||
const T* xyz1n = xyz1 + idx1 * 3;
|
||||
const T d = (x-xyz1n[0]) * (x-xyz1n[0]) + (y-xyz1n[1]) * (y-xyz1n[1]) + (z-xyz1n[2]) * (z-xyz1n[2]);
|
||||
if(d < min_dist) {
|
||||
min_dist = d;
|
||||
min_idx1 = idx1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out[idx0] = min_idx1;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <typename T, int dim=3>
|
||||
struct XCorrVolFunctor {
|
||||
const T* in0; // channels x height x width
|
||||
const T* in1; // channels x height x width
|
||||
const long channels;
|
||||
const long height;
|
||||
const long width;
|
||||
const long n_disps;
|
||||
const long block_size;
|
||||
T* out; // nelem0
|
||||
|
||||
XCorrVolFunctor(const T* in0, const T* in1, long channels, long height, long width, long n_disps, long block_size, T* out) : in0(in0), in1(in1), channels(channels), height(height), width(width), n_disps(n_disps), block_size(block_size), out(out) {}
|
||||
|
||||
CPU_GPU_FUNCTION void operator()(long oidx) {
|
||||
// idx0 \in [n_disps x height x width]
|
||||
|
||||
auto d = oidx / (height * width);
|
||||
auto h = (oidx / width) % height;
|
||||
auto w = oidx % width;
|
||||
|
||||
long block_size2 = block_size * block_size;
|
||||
|
||||
T val = 0;
|
||||
for(int c = 0; c < channels; ++c) {
|
||||
// compute means
|
||||
T mu0 = 0;
|
||||
T mu1 = 0;
|
||||
for(int bh = 0; bh < block_size; ++bh) {
|
||||
long h0 = h + bh - block_size / 2;
|
||||
h0 = mmax(long(0), mmin(height-1, h0));
|
||||
for(int bw = 0; bw < block_size; ++bw) {
|
||||
long w0 = w + bw - block_size / 2;
|
||||
long w1 = w0 - d;
|
||||
w0 = mmax(long(0), mmin(width-1, w0));
|
||||
w1 = mmax(long(0), mmin(width-1, w1));
|
||||
long idx0 = (c * height + h0) * width + w0;
|
||||
long idx1 = (c * height + h0) * width + w1;
|
||||
mu0 += in0[idx0] / block_size2;
|
||||
mu1 += in1[idx1] / block_size2;
|
||||
}
|
||||
}
|
||||
|
||||
// compute stds and dot product
|
||||
T sigma0 = 0;
|
||||
T sigma1 = 0;
|
||||
T dot = 0;
|
||||
for(int bh = 0; bh < block_size; ++bh) {
|
||||
long h0 = h + bh - block_size / 2;
|
||||
h0 = mmax(long(0), mmin(height-1, h0));
|
||||
for(int bw = 0; bw < block_size; ++bw) {
|
||||
long w0 = w + bw - block_size / 2;
|
||||
long w1 = w0 - d;
|
||||
w0 = mmax(long(0), mmin(width-1, w0));
|
||||
w1 = mmax(long(0), mmin(width-1, w1));
|
||||
long idx0 = (c * height + h0) * width + w0;
|
||||
long idx1 = (c * height + h0) * width + w1;
|
||||
T v0 = in0[idx0] - mu0;
|
||||
T v1 = in1[idx1] - mu1;
|
||||
|
||||
dot += v0 * v1;
|
||||
sigma0 += v0 * v0;
|
||||
sigma1 += v1 * v1;
|
||||
}
|
||||
}
|
||||
|
||||
T norm = sqrt(sigma0 * sigma1) + 1e-8;
|
||||
val += dot / norm;
|
||||
}
|
||||
|
||||
out[oidx] = val;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
const int PHOTOMETRIC_LOSS_MSE = 0;
|
||||
const int PHOTOMETRIC_LOSS_SAD = 1;
|
||||
const int PHOTOMETRIC_LOSS_CENSUS_MSE = 2;
|
||||
const int PHOTOMETRIC_LOSS_CENSUS_SAD = 3;
|
||||
|
||||
template <typename T, int type>
|
||||
struct PhotometricLossForward {
|
||||
const T* es; // batch_size x channels x height x width;
|
||||
const T* ta;
|
||||
const int block_size;
|
||||
const int block_size2;
|
||||
const T eps;
|
||||
const int batch_size;
|
||||
const int channels;
|
||||
const int height;
|
||||
const int width;
|
||||
T* out; // batch_size x channels x height x width;
|
||||
|
||||
PhotometricLossForward(const T* es, const T* ta, int block_size, T eps, int batch_size, int channels, int height, int width, T* out) :
|
||||
es(es), ta(ta), block_size(block_size), block_size2(block_size*block_size), eps(eps), batch_size(batch_size), channels(channels), height(height), width(width), out(out) {}
|
||||
|
||||
CPU_GPU_FUNCTION void operator()(int outidx) {
|
||||
// outidx \in [0, batch_size x height x width]
|
||||
|
||||
int w = outidx % width;
|
||||
int h = (outidx / width) % height;
|
||||
int n = outidx / (height * width);
|
||||
|
||||
T loss = 0;
|
||||
for(int bidx = 0; bidx < block_size2; ++bidx) {
|
||||
int bh = bidx / block_size;
|
||||
int bw = bidx % block_size;
|
||||
int h0 = h + bh - block_size / 2;
|
||||
int w0 = w + bw - block_size / 2;
|
||||
|
||||
h0 = mmin(height-1, mmax(0, h0));
|
||||
w0 = mmin(width-1, mmax(0, w0));
|
||||
|
||||
for(int c = 0; c < channels; ++c) {
|
||||
int inidx = ((n * channels + c) * height + h0) * width + w0;
|
||||
if(type == PHOTOMETRIC_LOSS_SAD || type == PHOTOMETRIC_LOSS_MSE) {
|
||||
T diff = es[inidx] - ta[inidx];
|
||||
if(type == PHOTOMETRIC_LOSS_MSE) {
|
||||
loss += diff * diff / block_size2;
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_SAD) {
|
||||
loss += fabs(diff) / block_size2;
|
||||
}
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_CENSUS_SAD || type == PHOTOMETRIC_LOSS_CENSUS_MSE) {
|
||||
int inidxc = ((n * channels + c) * height + h) * width + w;
|
||||
T des = es[inidx] - es[inidxc];
|
||||
T dta = ta[inidx] - ta[inidxc];
|
||||
T h_des = 0.5 * (1 + des / sqrt(des * des + eps));
|
||||
T h_dta = 0.5 * (1 + dta / sqrt(dta * dta + eps));
|
||||
T diff = h_des - h_dta;
|
||||
// printf("%d,%d %d,%d: des=%f, dta=%f, h_des=%f, h_dta=%f, diff=%f\n", h,w, h0,w0, des,dta, h_des,h_dta, diff);
|
||||
// printf("%d,%d %d,%d: h_des=%f = 0.5 * (1 + %f / %f); %f, %f, %f\n", h,w, h0,w0, h_des, des, sqrt(des * des + eps), des*des, des*des+eps, eps);
|
||||
if(type == PHOTOMETRIC_LOSS_CENSUS_MSE) {
|
||||
loss += diff * diff / block_size2;
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_CENSUS_SAD) {
|
||||
loss += fabs(diff) / block_size2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
out[outidx] = loss;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T, int type>
|
||||
struct PhotometricLossBackward {
|
||||
const T* es; // batch_size x channels x height x width;
|
||||
const T* ta;
|
||||
const T* grad_out;
|
||||
const int block_size;
|
||||
const int block_size2;
|
||||
const T eps;
|
||||
const int batch_size;
|
||||
const int channels;
|
||||
const int height;
|
||||
const int width;
|
||||
T* grad_in; // batch_size x channels x height x width;
|
||||
|
||||
PhotometricLossBackward(const T* es, const T* ta, const T* grad_out, int block_size, T eps, int batch_size, int channels, int height, int width, T* grad_in) :
|
||||
es(es), ta(ta), grad_out(grad_out), block_size(block_size), block_size2(block_size*block_size), eps(eps), batch_size(batch_size), channels(channels), height(height), width(width), grad_in(grad_in) {}
|
||||
|
||||
CPU_GPU_FUNCTION void operator()(int outidx) {
|
||||
// outidx \in [0, batch_size x height x width]
|
||||
|
||||
int w = outidx % width;
|
||||
int h = (outidx / width) % height;
|
||||
int n = outidx / (height * width);
|
||||
|
||||
for(int bidx = 0; bidx < block_size2; ++bidx) {
|
||||
int bh = bidx / block_size;
|
||||
int bw = bidx % block_size;
|
||||
int h0 = h + bh - block_size / 2;
|
||||
int w0 = w + bw - block_size / 2;
|
||||
|
||||
h0 = mmin(height-1, mmax(0, h0));
|
||||
w0 = mmin(width-1, mmax(0, w0));
|
||||
|
||||
const T go = grad_out[outidx];
|
||||
|
||||
for(int c = 0; c < channels; ++c) {
|
||||
int inidx = ((n * channels + c) * height + h0) * width + w0;
|
||||
if(type == PHOTOMETRIC_LOSS_SAD || type == PHOTOMETRIC_LOSS_MSE) {
|
||||
T diff = es[inidx] - ta[inidx];
|
||||
T grad = 0;
|
||||
if(type == PHOTOMETRIC_LOSS_MSE) {
|
||||
grad = 2 * diff;
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_SAD) {
|
||||
grad = diff < 0 ? -1 : (diff > 0 ? 1 : 0);
|
||||
}
|
||||
grad = grad / block_size2 * go;
|
||||
matomic_add(grad_in + inidx, grad);
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_CENSUS_SAD || type == PHOTOMETRIC_LOSS_CENSUS_MSE) {
|
||||
int inidxc = ((n * channels + c) * height + h) * width + w;
|
||||
T des = es[inidx] - es[inidxc];
|
||||
T dta = ta[inidx] - ta[inidxc];
|
||||
T h_des = 0.5 * (1 + des / sqrt(des * des + eps));
|
||||
T h_dta = 0.5 * (1 + dta / sqrt(dta * dta + eps));
|
||||
T diff = h_des - h_dta;
|
||||
|
||||
T grad_loss = 0;
|
||||
if(type == PHOTOMETRIC_LOSS_CENSUS_MSE) {
|
||||
grad_loss = 2 * diff;
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_CENSUS_SAD) {
|
||||
grad_loss = diff < 0 ? -1 : (diff > 0 ? 1 : 0);
|
||||
}
|
||||
grad_loss = grad_loss / block_size2;
|
||||
|
||||
T tmp = des * des + eps;
|
||||
T grad_heaviside = 0.5 * eps / sqrt(tmp * tmp * tmp);
|
||||
|
||||
T grad = go * grad_loss * grad_heaviside;
|
||||
matomic_add(grad_in + inidx, grad);
|
||||
matomic_add(grad_in + inidxc, -grad);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
198
torchext/ext/ext_cpu.cpp
Normal file
198
torchext/ext/ext_cpu.cpp
Normal file
@ -0,0 +1,198 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ext.h"
|
||||
|
||||
template <typename FunctorT>
|
||||
void iterate_cpu(FunctorT functor, int N) {
|
||||
for(int idx = 0; idx < N; ++idx) {
|
||||
functor(idx);
|
||||
}
|
||||
}
|
||||
|
||||
at::Tensor nn_cpu(at::Tensor in0, at::Tensor in1) {
|
||||
CHECK_INPUT_CPU(in0)
|
||||
CHECK_INPUT_CPU(in1)
|
||||
|
||||
auto nelem0 = in0.size(0);
|
||||
auto nelem1 = in1.size(0);
|
||||
auto dim = in0.size(1);
|
||||
|
||||
AT_ASSERTM(dim == in1.size(1), "in0 and in1 have to be the same shape")
|
||||
AT_ASSERTM(dim == 3, "dim hast to be 3")
|
||||
AT_ASSERTM(in0.dim() == 2, "in0 has to be N0 x 3")
|
||||
AT_ASSERTM(in1.dim() == 2, "in1 has to be N1 x 3")
|
||||
|
||||
auto out = at::empty({nelem0}, torch::CPU(at::kLong));
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(in0.scalar_type(), "nn", ([&] {
|
||||
iterate_cpu(
|
||||
NNFunctor<scalar_t>(in0.data<scalar_t>(), in1.data<scalar_t>(), nelem0, nelem1, out.data<long>()),
|
||||
nelem0);
|
||||
}));
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
at::Tensor crosscheck_cpu(at::Tensor in0, at::Tensor in1) {
|
||||
CHECK_INPUT_CPU(in0)
|
||||
CHECK_INPUT_CPU(in1)
|
||||
|
||||
AT_ASSERTM(in0.dim() == 1, "")
|
||||
AT_ASSERTM(in1.dim() == 1, "")
|
||||
|
||||
auto nelem0 = in0.size(0);
|
||||
auto nelem1 = in1.size(0);
|
||||
|
||||
auto out = at::empty({nelem0}, torch::CPU(at::kByte));
|
||||
|
||||
iterate_cpu(
|
||||
CrossCheckFunctor(in0.data<long>(), in1.data<long>(), nelem0, nelem1, out.data<uint8_t>()),
|
||||
nelem0);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
at::Tensor proj_nn_cpu(at::Tensor xyz0, at::Tensor xyz1, at::Tensor K, int patch_size) {
|
||||
CHECK_INPUT_CPU(xyz0)
|
||||
CHECK_INPUT_CPU(xyz1)
|
||||
CHECK_INPUT_CPU(K)
|
||||
|
||||
auto batch_size = xyz0.size(0);
|
||||
auto height = xyz0.size(1);
|
||||
auto width = xyz0.size(2);
|
||||
|
||||
AT_ASSERTM(xyz0.size(0) == xyz1.size(0), "")
|
||||
AT_ASSERTM(xyz0.size(1) == xyz1.size(1), "")
|
||||
AT_ASSERTM(xyz0.size(2) == xyz1.size(2), "")
|
||||
AT_ASSERTM(xyz0.size(3) == xyz1.size(3), "")
|
||||
AT_ASSERTM(xyz0.size(3) == 3, "")
|
||||
AT_ASSERTM(xyz0.dim() == 4, "")
|
||||
AT_ASSERTM(xyz1.dim() == 4, "")
|
||||
|
||||
auto out = at::empty({batch_size, height, width}, torch::CPU(at::kLong));
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(xyz0.scalar_type(), "proj_nn", ([&] {
|
||||
iterate_cpu(
|
||||
ProjNNFunctor<scalar_t>(xyz0.data<scalar_t>(), xyz1.data<scalar_t>(), K.data<scalar_t>(), batch_size, height, width, patch_size, out.data<long>()),
|
||||
batch_size * height * width);
|
||||
}));
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
at::Tensor xcorrvol_cpu(at::Tensor in0, at::Tensor in1, int n_disps, int block_size) {
|
||||
CHECK_INPUT_CPU(in0)
|
||||
CHECK_INPUT_CPU(in1)
|
||||
|
||||
auto channels = in0.size(0);
|
||||
auto height = in0.size(1);
|
||||
auto width = in0.size(2);
|
||||
|
||||
auto out = at::empty({n_disps, height, width}, in0.options());
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(in0.scalar_type(), "xcorrvol", ([&] {
|
||||
iterate_cpu(
|
||||
XCorrVolFunctor<scalar_t>(in0.data<scalar_t>(), in1.data<scalar_t>(), channels, height, width, n_disps, block_size, out.data<scalar_t>()),
|
||||
n_disps * height * width);
|
||||
}));
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
at::Tensor photometric_loss_forward(at::Tensor es, at::Tensor ta, int block_size, int type, float eps) {
|
||||
CHECK_INPUT_CPU(es)
|
||||
CHECK_INPUT_CPU(ta)
|
||||
|
||||
auto batch_size = es.size(0);
|
||||
auto channels = es.size(1);
|
||||
auto height = es.size(2);
|
||||
auto width = es.size(3);
|
||||
|
||||
auto out = at::empty({batch_size, 1, height, width}, es.options());
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(es.scalar_type(), "photometric_loss_forward_cpu", ([&] {
|
||||
if(type == PHOTOMETRIC_LOSS_MSE) {
|
||||
iterate_cpu(
|
||||
PhotometricLossForward<scalar_t, PHOTOMETRIC_LOSS_MSE>(es.data<scalar_t>(), ta.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, out.data<scalar_t>()),
|
||||
out.numel());
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_SAD) {
|
||||
iterate_cpu(
|
||||
PhotometricLossForward<scalar_t, PHOTOMETRIC_LOSS_SAD>(es.data<scalar_t>(), ta.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, out.data<scalar_t>()),
|
||||
out.numel());
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_CENSUS_MSE) {
|
||||
iterate_cpu(
|
||||
PhotometricLossForward<scalar_t, PHOTOMETRIC_LOSS_CENSUS_MSE>(es.data<scalar_t>(), ta.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, out.data<scalar_t>()),
|
||||
out.numel());
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_CENSUS_SAD) {
|
||||
iterate_cpu(
|
||||
PhotometricLossForward<scalar_t, PHOTOMETRIC_LOSS_CENSUS_SAD>(es.data<scalar_t>(), ta.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, out.data<scalar_t>()),
|
||||
out.numel());
|
||||
}
|
||||
}));
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
at::Tensor photometric_loss_backward(at::Tensor es, at::Tensor ta, at::Tensor grad_out, int block_size, int type, float eps) {
|
||||
CHECK_INPUT_CPU(es)
|
||||
CHECK_INPUT_CPU(ta)
|
||||
CHECK_INPUT_CPU(grad_out)
|
||||
|
||||
auto batch_size = es.size(0);
|
||||
auto channels = es.size(1);
|
||||
auto height = es.size(2);
|
||||
auto width = es.size(3);
|
||||
|
||||
CHECK_INPUT_CPU(ta)
|
||||
auto grad_in = at::zeros({batch_size, channels, height, width}, grad_out.options());
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(es.scalar_type(), "photometric_loss_backward_cpu", ([&] {
|
||||
if(type == PHOTOMETRIC_LOSS_MSE) {
|
||||
iterate_cpu(
|
||||
PhotometricLossBackward<scalar_t, PHOTOMETRIC_LOSS_MSE>(es.data<scalar_t>(), ta.data<scalar_t>(), grad_out.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, grad_in.data<scalar_t>()),
|
||||
grad_out.numel());
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_SAD) {
|
||||
iterate_cpu(
|
||||
PhotometricLossBackward<scalar_t, PHOTOMETRIC_LOSS_SAD>(es.data<scalar_t>(), ta.data<scalar_t>(), grad_out.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, grad_in.data<scalar_t>()),
|
||||
grad_out.numel());
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_CENSUS_MSE) {
|
||||
iterate_cpu(
|
||||
PhotometricLossBackward<scalar_t, PHOTOMETRIC_LOSS_CENSUS_MSE>(es.data<scalar_t>(), ta.data<scalar_t>(), grad_out.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, grad_in.data<scalar_t>()),
|
||||
grad_out.numel());
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_CENSUS_SAD) {
|
||||
iterate_cpu(
|
||||
PhotometricLossBackward<scalar_t, PHOTOMETRIC_LOSS_CENSUS_SAD>(es.data<scalar_t>(), ta.data<scalar_t>(), grad_out.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, grad_in.data<scalar_t>()),
|
||||
grad_out.numel());
|
||||
}
|
||||
}));
|
||||
|
||||
return grad_in;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("nn_cpu", &nn_cpu, "nn_cpu");
|
||||
m.def("crosscheck_cpu", &crosscheck_cpu, "crosscheck_cpu");
|
||||
m.def("proj_nn_cpu", &proj_nn_cpu, "proj_nn_cpu");
|
||||
|
||||
m.def("xcorrvol_cpu", &xcorrvol_cpu, "xcorrvol_cpu");
|
||||
|
||||
m.def("photometric_loss_forward", &photometric_loss_forward);
|
||||
m.def("photometric_loss_backward", &photometric_loss_backward);
|
||||
}
|
135
torchext/ext/ext_cuda.cpp
Normal file
135
torchext/ext/ext_cuda.cpp
Normal file
@ -0,0 +1,135 @@
|
||||
#include <torch/extension.h>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#include "ext.h"
|
||||
|
||||
void nn_kernel(at::Tensor in0, at::Tensor in1, at::Tensor out);
|
||||
|
||||
at::Tensor nn_cuda(at::Tensor in0, at::Tensor in1) {
|
||||
CHECK_INPUT_CUDA(in0)
|
||||
CHECK_INPUT_CUDA(in1)
|
||||
|
||||
auto nelem0 = in0.size(0);
|
||||
auto dim = in0.size(1);
|
||||
|
||||
AT_ASSERTM(dim == in1.size(1), "in0 and in1 have to be the same shape")
|
||||
AT_ASSERTM(dim == 3, "dim hast to be 3")
|
||||
AT_ASSERTM(in0.dim() == 2, "in0 has to be N0 x 3")
|
||||
AT_ASSERTM(in1.dim() == 2, "in1 has to be N1 x 3")
|
||||
|
||||
auto out = at::empty({nelem0}, torch::CUDA(at::kLong));
|
||||
|
||||
nn_kernel(in0, in1, out);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
void crosscheck_kernel(at::Tensor in0, at::Tensor in1, at::Tensor out);
|
||||
|
||||
at::Tensor crosscheck_cuda(at::Tensor in0, at::Tensor in1) {
|
||||
CHECK_INPUT_CUDA(in0)
|
||||
CHECK_INPUT_CUDA(in1)
|
||||
|
||||
AT_ASSERTM(in0.dim() == 1, "")
|
||||
AT_ASSERTM(in1.dim() == 1, "")
|
||||
|
||||
auto nelem0 = in0.size(0);
|
||||
auto out = at::empty({nelem0}, torch::CUDA(at::kByte));
|
||||
crosscheck_kernel(in0, in1, out);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void proj_nn_kernel(at::Tensor xyz0, at::Tensor xyz1, at::Tensor K, int patch_size, at::Tensor out);
|
||||
|
||||
at::Tensor proj_nn_cuda(at::Tensor xyz0, at::Tensor xyz1, at::Tensor K, int patch_size) {
|
||||
CHECK_INPUT_CUDA(xyz0)
|
||||
CHECK_INPUT_CUDA(xyz1)
|
||||
CHECK_INPUT_CUDA(K)
|
||||
|
||||
auto batch_size = xyz0.size(0);
|
||||
auto height = xyz0.size(1);
|
||||
auto width = xyz0.size(2);
|
||||
|
||||
AT_ASSERTM(xyz0.size(0) == xyz1.size(0), "")
|
||||
AT_ASSERTM(xyz0.size(1) == xyz1.size(1), "")
|
||||
AT_ASSERTM(xyz0.size(2) == xyz1.size(2), "")
|
||||
AT_ASSERTM(xyz0.size(3) == xyz1.size(3), "")
|
||||
AT_ASSERTM(xyz0.size(3) == 3, "")
|
||||
AT_ASSERTM(xyz0.dim() == 4, "")
|
||||
AT_ASSERTM(xyz1.dim() == 4, "")
|
||||
|
||||
auto out = at::empty({batch_size, height, width}, torch::CUDA(at::kLong));
|
||||
|
||||
proj_nn_kernel(xyz0, xyz1, K, patch_size, out);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
void xcorrvol_kernel(at::Tensor in0, at::Tensor in1, int n_disps, int block_size, at::Tensor out);
|
||||
|
||||
at::Tensor xcorrvol_cuda(at::Tensor in0, at::Tensor in1, int n_disps, int block_size) {
|
||||
CHECK_INPUT_CUDA(in0)
|
||||
CHECK_INPUT_CUDA(in1)
|
||||
|
||||
// auto channels = in0.size(0);
|
||||
auto height = in0.size(1);
|
||||
auto width = in0.size(2);
|
||||
|
||||
auto out = at::empty({n_disps, height, width}, in0.options());
|
||||
|
||||
xcorrvol_kernel(in0, in1, n_disps, block_size, out);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
|
||||
void photometric_loss_forward_kernel(at::Tensor es, at::Tensor ta, int block_size, int type, float eps, at::Tensor out);
|
||||
|
||||
at::Tensor photometric_loss_forward(at::Tensor es, at::Tensor ta, int block_size, int type, float eps) {
|
||||
CHECK_INPUT_CUDA(es)
|
||||
CHECK_INPUT_CUDA(ta)
|
||||
|
||||
auto batch_size = es.size(0);
|
||||
auto height = es.size(2);
|
||||
auto width = es.size(3);
|
||||
|
||||
auto out = at::empty({batch_size, 1, height, width}, es.options());
|
||||
photometric_loss_forward_kernel(es, ta, block_size, type, eps, out);
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
void photometric_loss_backward_kernel(at::Tensor es, at::Tensor ta, at::Tensor grad_out, int block_size, int type, float eps, at::Tensor grad_in);
|
||||
|
||||
at::Tensor photometric_loss_backward(at::Tensor es, at::Tensor ta, at::Tensor grad_out, int block_size, int type, float eps) {
|
||||
CHECK_INPUT_CUDA(es)
|
||||
CHECK_INPUT_CUDA(ta)
|
||||
CHECK_INPUT_CUDA(grad_out)
|
||||
|
||||
auto batch_size = es.size(0);
|
||||
auto channels = es.size(1);
|
||||
auto height = es.size(2);
|
||||
auto width = es.size(3);
|
||||
|
||||
auto grad_in = at::zeros({batch_size, channels, height, width}, grad_out.options());
|
||||
photometric_loss_backward_kernel(es, ta, grad_out, block_size, type, eps, grad_in);
|
||||
|
||||
return grad_in;
|
||||
}
|
||||
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def("nn_cuda", &nn_cuda, "nn_cuda");
|
||||
m.def("crosscheck_cuda", &crosscheck_cuda, "crosscheck_cuda");
|
||||
m.def("proj_nn_cuda", &proj_nn_cuda, "proj_nn_cuda");
|
||||
|
||||
m.def("xcorrvol_cuda", &xcorrvol_cuda, "xcorrvol_cuda");
|
||||
|
||||
m.def("photometric_loss_forward", &photometric_loss_forward);
|
||||
m.def("photometric_loss_backward", &photometric_loss_backward);
|
||||
}
|
112
torchext/ext/ext_kernel.cu
Normal file
112
torchext/ext/ext_kernel.cu
Normal file
@ -0,0 +1,112 @@
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
#include "ext.h"
|
||||
#include "common_cuda.h"
|
||||
|
||||
void nn_kernel(at::Tensor in0, at::Tensor in1, at::Tensor out) {
|
||||
auto nelem0 = in0.size(0);
|
||||
auto nelem1 = in1.size(0);
|
||||
auto dim = in0.size(1);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(in0.scalar_type(), "nn", ([&] {
|
||||
iterate_cuda(
|
||||
NNFunctor<scalar_t>(in0.data<scalar_t>(), in1.data<scalar_t>(), nelem0, nelem1, out.data<long>()),
|
||||
nelem0);
|
||||
}));
|
||||
}
|
||||
|
||||
|
||||
void crosscheck_kernel(at::Tensor in0, at::Tensor in1, at::Tensor out) {
|
||||
auto nelem0 = in0.size(0);
|
||||
auto nelem1 = in1.size(0);
|
||||
|
||||
iterate_cuda(
|
||||
CrossCheckFunctor(in0.data<long>(), in1.data<long>(), nelem0, nelem1, out.data<uint8_t>()),
|
||||
nelem0);
|
||||
}
|
||||
|
||||
void proj_nn_kernel(at::Tensor xyz0, at::Tensor xyz1, at::Tensor K, int patch_size, at::Tensor out) {
|
||||
auto batch_size = xyz0.size(0);
|
||||
auto height = xyz0.size(1);
|
||||
auto width = xyz0.size(2);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(xyz0.scalar_type(), "proj_nn", ([&] {
|
||||
iterate_cuda(
|
||||
ProjNNFunctor<scalar_t>(xyz0.data<scalar_t>(), xyz1.data<scalar_t>(), K.data<scalar_t>(), batch_size, height, width, patch_size, out.data<long>()),
|
||||
batch_size * height * width);
|
||||
}));
|
||||
}
|
||||
|
||||
void xcorrvol_kernel(at::Tensor in0, at::Tensor in1, int n_disps, int block_size, at::Tensor out) {
|
||||
auto channels = in0.size(0);
|
||||
auto height = in0.size(1);
|
||||
auto width = in0.size(2);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(in0.scalar_type(), "xcorrvol", ([&] {
|
||||
iterate_cuda(
|
||||
XCorrVolFunctor<scalar_t>(in0.data<scalar_t>(), in1.data<scalar_t>(), channels, height, width, n_disps, block_size, out.data<scalar_t>()),
|
||||
n_disps * height * width, 512);
|
||||
}));
|
||||
}
|
||||
|
||||
|
||||
|
||||
void photometric_loss_forward_kernel(at::Tensor es, at::Tensor ta, int block_size, int type, float eps, at::Tensor out) {
|
||||
auto batch_size = es.size(0);
|
||||
auto channels = es.size(1);
|
||||
auto height = es.size(2);
|
||||
auto width = es.size(3);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(es.scalar_type(), "photometric_loss_forward_cuda", ([&] {
|
||||
if(type == PHOTOMETRIC_LOSS_MSE) {
|
||||
iterate_cuda(
|
||||
PhotometricLossForward<scalar_t, PHOTOMETRIC_LOSS_MSE>(es.data<scalar_t>(), ta.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, out.data<scalar_t>()),
|
||||
out.numel());
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_SAD) {
|
||||
iterate_cuda(
|
||||
PhotometricLossForward<scalar_t, PHOTOMETRIC_LOSS_SAD>(es.data<scalar_t>(), ta.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, out.data<scalar_t>()),
|
||||
out.numel());
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_CENSUS_MSE) {
|
||||
iterate_cuda(
|
||||
PhotometricLossForward<scalar_t, PHOTOMETRIC_LOSS_CENSUS_MSE>(es.data<scalar_t>(), ta.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, out.data<scalar_t>()),
|
||||
out.numel());
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_CENSUS_SAD) {
|
||||
iterate_cuda(
|
||||
PhotometricLossForward<scalar_t, PHOTOMETRIC_LOSS_CENSUS_SAD>(es.data<scalar_t>(), ta.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, out.data<scalar_t>()),
|
||||
out.numel());
|
||||
}
|
||||
}));
|
||||
}
|
||||
|
||||
void photometric_loss_backward_kernel(at::Tensor es, at::Tensor ta, at::Tensor grad_out, int block_size, int type, float eps, at::Tensor grad_in) {
|
||||
auto batch_size = es.size(0);
|
||||
auto channels = es.size(1);
|
||||
auto height = es.size(2);
|
||||
auto width = es.size(3);
|
||||
|
||||
AT_DISPATCH_FLOATING_TYPES(es.scalar_type(), "photometric_loss_backward_cuda", ([&] {
|
||||
if(type == PHOTOMETRIC_LOSS_MSE) {
|
||||
iterate_cuda(
|
||||
PhotometricLossBackward<scalar_t, PHOTOMETRIC_LOSS_MSE>(es.data<scalar_t>(), ta.data<scalar_t>(), grad_out.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, grad_in.data<scalar_t>()),
|
||||
grad_out.numel());
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_SAD) {
|
||||
iterate_cuda(
|
||||
PhotometricLossBackward<scalar_t, PHOTOMETRIC_LOSS_SAD>(es.data<scalar_t>(), ta.data<scalar_t>(), grad_out.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, grad_in.data<scalar_t>()),
|
||||
grad_out.numel());
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_CENSUS_MSE) {
|
||||
iterate_cuda(
|
||||
PhotometricLossBackward<scalar_t, PHOTOMETRIC_LOSS_CENSUS_MSE>(es.data<scalar_t>(), ta.data<scalar_t>(), grad_out.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, grad_in.data<scalar_t>()),
|
||||
grad_out.numel());
|
||||
}
|
||||
else if(type == PHOTOMETRIC_LOSS_CENSUS_SAD) {
|
||||
iterate_cuda(
|
||||
PhotometricLossBackward<scalar_t, PHOTOMETRIC_LOSS_CENSUS_SAD>(es.data<scalar_t>(), ta.data<scalar_t>(), grad_out.data<scalar_t>(), block_size, eps, batch_size, channels, height, width, grad_in.data<scalar_t>()),
|
||||
grad_out.numel());
|
||||
}
|
||||
}));
|
||||
}
|
147
torchext/functions.py
Normal file
147
torchext/functions.py
Normal file
@ -0,0 +1,147 @@
|
||||
import torch
|
||||
from . import ext_cpu
|
||||
from . import ext_cuda
|
||||
|
||||
class NNFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, in0, in1):
|
||||
args = (in0, in1)
|
||||
if in0.is_cuda:
|
||||
out = ext_cuda.nn_cuda(*args)
|
||||
else:
|
||||
out = ext_cpu.nn_cpu(*args)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
return None, None
|
||||
|
||||
def nn(in0, in1):
|
||||
return NNFunction.apply(in0, in1)
|
||||
|
||||
|
||||
class CrossCheckFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, in0, in1):
|
||||
args = (in0, in1)
|
||||
if in0.is_cuda:
|
||||
out = ext_cuda.crosscheck_cuda(*args)
|
||||
else:
|
||||
out = ext_cpu.crosscheck_cpu(*args)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
return None, None
|
||||
|
||||
def crosscheck(in0, in1):
|
||||
return CrossCheckFunction.apply(in0, in1)
|
||||
|
||||
class ProjNNFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, xyz0, xyz1, K, patch_size):
|
||||
args = (xyz0, xyz1, K, patch_size)
|
||||
if xyz0.is_cuda:
|
||||
out = ext_cuda.proj_nn_cuda(*args)
|
||||
else:
|
||||
out = ext_cpu.proj_nn_cpu(*args)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
return None, None, None, None
|
||||
|
||||
def proj_nn(xyz0, xyz1, K, patch_size):
|
||||
return ProjNNFunction.apply(xyz0, xyz1, K, patch_size)
|
||||
|
||||
|
||||
|
||||
class XCorrVolFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, in0, in1, n_disps, block_size):
|
||||
args = (in0, in1, n_disps, block_size)
|
||||
if in0.is_cuda:
|
||||
out = ext_cuda.xcorrvol_cuda(*args)
|
||||
else:
|
||||
out = ext_cpu.xcorrvol_cpu(*args)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
return None, None, None, None
|
||||
|
||||
def xcorrvol(in0, in1, n_disps, block_size):
|
||||
return XCorrVolFunction.apply(in0, in1, n_disps, block_size)
|
||||
|
||||
|
||||
|
||||
|
||||
class PhotometricLossFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, es, ta, block_size, type, eps):
|
||||
args = (es, ta, block_size, type, eps)
|
||||
ctx.save_for_backward(es, ta)
|
||||
ctx.block_size = block_size
|
||||
ctx.type = type
|
||||
ctx.eps = eps
|
||||
if es.is_cuda:
|
||||
out = ext_cuda.photometric_loss_forward(*args)
|
||||
else:
|
||||
out = ext_cpu.photometric_loss_forward(*args)
|
||||
return out
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_out):
|
||||
es, ta = ctx.saved_tensors
|
||||
block_size = ctx.block_size
|
||||
type = ctx.type
|
||||
eps = ctx.eps
|
||||
args = (es, ta, grad_out.contiguous(), block_size, type, eps)
|
||||
if grad_out.is_cuda:
|
||||
grad_es = ext_cuda.photometric_loss_backward(*args)
|
||||
else:
|
||||
grad_es = ext_cpu.photometric_loss_backward(*args)
|
||||
return grad_es, None, None, None, None
|
||||
|
||||
def photometric_loss(es, ta, block_size, type='mse', eps=0.1):
|
||||
type = type.lower()
|
||||
if type == 'mse':
|
||||
type = 0
|
||||
elif type == 'sad':
|
||||
type = 1
|
||||
elif type == 'census_mse':
|
||||
type = 2
|
||||
elif type == 'census_sad':
|
||||
type = 3
|
||||
else:
|
||||
raise Exception('invalid loss type')
|
||||
return PhotometricLossFunction.apply(es, ta, block_size, type, eps)
|
||||
|
||||
def photometric_loss_pytorch(es, ta, block_size, type='mse', eps=0.1):
|
||||
type = type.lower()
|
||||
p = block_size // 2
|
||||
es_pad = torch.nn.functional.pad(es, (p,p,p,p), mode='replicate')
|
||||
ta_pad = torch.nn.functional.pad(ta, (p,p,p,p), mode='replicate')
|
||||
es_uf = torch.nn.functional.unfold(es_pad, kernel_size=block_size)
|
||||
ta_uf = torch.nn.functional.unfold(ta_pad, kernel_size=block_size)
|
||||
es_uf = es_uf.view(es.shape[0], es.shape[1], -1, es.shape[2], es.shape[3])
|
||||
ta_uf = ta_uf.view(ta.shape[0], ta.shape[1], -1, ta.shape[2], ta.shape[3])
|
||||
if type == 'mse':
|
||||
ref = (es_uf - ta_uf)**2
|
||||
elif type == 'sad':
|
||||
ref = torch.abs(es_uf - ta_uf)
|
||||
elif type == 'census_mse' or type == 'census_sad':
|
||||
des = es_uf - es.unsqueeze(2)
|
||||
dta = ta_uf - ta.unsqueeze(2)
|
||||
h_des = 0.5 * (1 + des / torch.sqrt(des * des + eps))
|
||||
h_dta = 0.5 * (1 + dta / torch.sqrt(dta * dta + eps))
|
||||
diff = h_des - h_dta
|
||||
if type == 'census_mse':
|
||||
ref = diff * diff
|
||||
elif type == 'census_sad':
|
||||
ref = torch.abs(diff)
|
||||
else:
|
||||
raise Exception('invalid loss type')
|
||||
ref = ref.view(es.shape[0], -1, es.shape[2], es.shape[3])
|
||||
ref = torch.sum(ref, dim=1, keepdim=True) / block_size**2
|
||||
return ref
|
27
torchext/modules.py
Normal file
27
torchext/modules.py
Normal file
@ -0,0 +1,27 @@
|
||||
import torch
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from .functions import *
|
||||
|
||||
class CoordConv2d(torch.nn.Module):
|
||||
def __init__(self, channels_in, channels_out, kernel_size, stride, padding):
|
||||
super().__init__()
|
||||
|
||||
self.conv = torch.nn.Conv2d(channels_in+2, channels_out, kernel_size=kernel_size, padding=padding, stride=stride)
|
||||
|
||||
self.uv = None
|
||||
|
||||
def forward(self, x):
|
||||
if self.uv is None:
|
||||
height, width = x.shape[2], x.shape[3]
|
||||
u, v = np.meshgrid(range(width), range(height))
|
||||
u = 2 * u / (width - 1) - 1
|
||||
v = 2 * v / (height - 1) - 1
|
||||
uv = np.stack((u, v)).reshape(1, 2, height, width)
|
||||
self.uv = torch.from_numpy( uv.astype(np.float32) )
|
||||
self.uv = self.uv.to(x.device)
|
||||
uv = self.uv.expand(x.shape[0], *self.uv.shape[1:])
|
||||
xuv = torch.cat((x, uv), dim=1)
|
||||
y = self.conv(xuv)
|
||||
return y
|
24
torchext/setup.py
Normal file
24
torchext/setup.py
Normal file
@ -0,0 +1,24 @@
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import CppExtension, CUDAExtension, BuildExtension
|
||||
import os
|
||||
|
||||
this_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
|
||||
include_dirs = [
|
||||
]
|
||||
|
||||
nvcc_args = [
|
||||
'-arch=sm_30',
|
||||
'-gencode=arch=compute_30,code=sm_30',
|
||||
'-gencode=arch=compute_35,code=sm_35',
|
||||
]
|
||||
|
||||
setup(
|
||||
name='ext',
|
||||
ext_modules=[
|
||||
CppExtension('ext_cpu', ['ext/ext_cpu.cpp']),
|
||||
CUDAExtension('ext_cuda', ['ext/ext_cuda.cpp', 'ext/ext_kernel.cu'], extra_compile_args={'cxx': [], 'nvcc': nvcc_args}),
|
||||
],
|
||||
cmdclass={'build_ext': BuildExtension},
|
||||
include_dirs=include_dirs
|
||||
)
|
528
torchext/worker.py
Normal file
528
torchext/worker.py
Normal file
@ -0,0 +1,528 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import random
|
||||
import logging
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
import argparse
|
||||
import subprocess
|
||||
import socket
|
||||
import sys
|
||||
import os
|
||||
import gc
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
|
||||
|
||||
class StopWatch(object):
|
||||
def __init__(self):
|
||||
self.timings = OrderedDict()
|
||||
self.starts = {}
|
||||
|
||||
def start(self, name):
|
||||
self.starts[name] = time.time()
|
||||
|
||||
def stop(self, name):
|
||||
if name not in self.timings:
|
||||
self.timings[name] = []
|
||||
self.timings[name].append(time.time() - self.starts[name])
|
||||
|
||||
def get(self, name=None, reduce=np.sum):
|
||||
if name is not None:
|
||||
return reduce(self.timings[name])
|
||||
else:
|
||||
ret = {}
|
||||
for k in self.timings:
|
||||
ret[k] = reduce(self.timings[k])
|
||||
return ret
|
||||
|
||||
def __repr__(self):
|
||||
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
|
||||
def __str__(self):
|
||||
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
|
||||
|
||||
|
||||
class ETA(object):
|
||||
def __init__(self, length):
|
||||
self.length = length
|
||||
self.start_time = time.time()
|
||||
self.current_idx = 0
|
||||
self.current_time = time.time()
|
||||
|
||||
def update(self, idx):
|
||||
self.current_idx = idx
|
||||
self.current_time = time.time()
|
||||
|
||||
def get_elapsed_time(self):
|
||||
return self.current_time - self.start_time
|
||||
|
||||
def get_item_time(self):
|
||||
return self.get_elapsed_time() / (self.current_idx + 1)
|
||||
|
||||
def get_remaining_time(self):
|
||||
return self.get_item_time() * (self.length - self.current_idx + 1)
|
||||
|
||||
def format_time(self, seconds):
|
||||
minutes, seconds = divmod(seconds, 60)
|
||||
hours, minutes = divmod(minutes, 60)
|
||||
hours = int(hours)
|
||||
minutes = int(minutes)
|
||||
return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}'
|
||||
|
||||
def get_elapsed_time_str(self):
|
||||
return self.format_time(self.get_elapsed_time())
|
||||
|
||||
def get_remaining_time_str(self):
|
||||
return self.format_time(self.get_remaining_time())
|
||||
|
||||
class Worker(object):
|
||||
def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16, num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1):
|
||||
self.out_root = Path(out_root)
|
||||
self.experiment_name = experiment_name
|
||||
self.epochs = epochs
|
||||
self.seed = seed
|
||||
self.train_batch_size = train_batch_size
|
||||
self.test_batch_size = test_batch_size
|
||||
self.num_workers = num_workers
|
||||
self.save_frequency = save_frequency
|
||||
self.train_device = train_device
|
||||
self.test_device = test_device
|
||||
self.max_train_iter = max_train_iter
|
||||
|
||||
self.errs_list=[]
|
||||
|
||||
self.setup_experiment()
|
||||
|
||||
def setup_experiment(self):
|
||||
self.exp_out_root = self.out_root / self.experiment_name
|
||||
self.exp_out_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if logging.root: del logging.root.handlers[:]
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
handlers=[
|
||||
logging.FileHandler( str(self.exp_out_root / 'train.log') ),
|
||||
logging.StreamHandler()
|
||||
],
|
||||
format='%(relativeCreated)d:%(levelname)s:%(process)d-%(processName)s: %(message)s'
|
||||
)
|
||||
|
||||
logging.info('='*80)
|
||||
logging.info(f'Start of experiment: {self.experiment_name}')
|
||||
logging.info(socket.gethostname())
|
||||
self.log_datetime()
|
||||
logging.info('='*80)
|
||||
|
||||
self.metric_path = self.exp_out_root / 'metrics.json'
|
||||
if self.metric_path.exists():
|
||||
with open(str(self.metric_path), 'r') as fp:
|
||||
self.metric_data = json.load(fp)
|
||||
else:
|
||||
self.metric_data = {}
|
||||
|
||||
self.init_seed()
|
||||
|
||||
def metric_add_train(self, epoch, key, val):
|
||||
epoch = str(epoch)
|
||||
key = str(key)
|
||||
if epoch not in self.metric_data:
|
||||
self.metric_data[epoch] = {}
|
||||
if 'train' not in self.metric_data[epoch]:
|
||||
self.metric_data[epoch]['train'] = {}
|
||||
self.metric_data[epoch]['train'][key] = val
|
||||
|
||||
def metric_add_test(self, epoch, set_idx, key, val):
|
||||
epoch = str(epoch)
|
||||
set_idx = str(set_idx)
|
||||
key = str(key)
|
||||
if epoch not in self.metric_data:
|
||||
self.metric_data[epoch] = {}
|
||||
if 'test' not in self.metric_data[epoch]:
|
||||
self.metric_data[epoch]['test'] = {}
|
||||
if set_idx not in self.metric_data[epoch]['test']:
|
||||
self.metric_data[epoch]['test'][set_idx] = {}
|
||||
self.metric_data[epoch]['test'][set_idx][key] = val
|
||||
|
||||
def metric_save(self):
|
||||
with open(str(self.metric_path), 'w') as fp:
|
||||
json.dump(self.metric_data, fp, indent=2)
|
||||
|
||||
def init_seed(self, seed=None):
|
||||
if seed is not None:
|
||||
self.seed = seed
|
||||
logging.info(f'Set seed to {self.seed}')
|
||||
np.random.seed(self.seed)
|
||||
random.seed(self.seed)
|
||||
torch.manual_seed(self.seed)
|
||||
torch.cuda.manual_seed(self.seed)
|
||||
|
||||
def log_datetime(self):
|
||||
logging.info(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
||||
|
||||
def mem_report(self):
|
||||
for obj in gc.get_objects():
|
||||
if torch.is_tensor(obj):
|
||||
print(type(obj), obj.shape)
|
||||
|
||||
def get_net_path(self, epoch, root=None):
|
||||
if root is None:
|
||||
root = self.exp_out_root
|
||||
return root / f'net_{epoch:04d}.params'
|
||||
|
||||
def get_do_parser_cmds(self):
|
||||
return ['retrain', 'resume', 'retest', 'test_init']
|
||||
|
||||
def get_do_parser(self):
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--cmd', type=str, default='resume', choices=self.get_do_parser_cmds())
|
||||
parser.add_argument('--epoch', type=int, default=-1)
|
||||
return parser
|
||||
|
||||
def do_cmd(self, args, net, optimizer, scheduler=None):
|
||||
if args.cmd == 'retrain':
|
||||
self.train(net, optimizer, resume=False, scheduler=scheduler)
|
||||
elif args.cmd == 'resume':
|
||||
self.train(net, optimizer, resume=True, scheduler=scheduler)
|
||||
elif args.cmd == 'retest':
|
||||
self.retest(net, epoch=args.epoch)
|
||||
elif args.cmd == 'test_init':
|
||||
test_sets = self.get_test_sets()
|
||||
self.test(-1, net, test_sets)
|
||||
else:
|
||||
raise Exception('invalid cmd')
|
||||
|
||||
def do(self, net, optimizer, load_net_optimizer=None, scheduler=None):
|
||||
parser = self.get_do_parser()
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
if load_net_optimizer is not None and args.cmd not in ['schedule']:
|
||||
net, optimizer = load_net_optimizer()
|
||||
|
||||
self.do_cmd(args, net, optimizer, scheduler=scheduler)
|
||||
|
||||
def retest(self, net, epoch=-1):
|
||||
if epoch < 0:
|
||||
epochs = range(self.epochs)
|
||||
else:
|
||||
epochs = [epoch]
|
||||
|
||||
test_sets = self.get_test_sets()
|
||||
|
||||
for epoch in epochs:
|
||||
net_path = self.get_net_path(epoch)
|
||||
if net_path.exists():
|
||||
state_dict = torch.load(str(net_path))
|
||||
net.load_state_dict(state_dict)
|
||||
self.test(epoch, net, test_sets)
|
||||
|
||||
def format_err_str(self, errs, div=1):
|
||||
err = sum(errs)
|
||||
if len(errs) > 1:
|
||||
err_str = f'{err/div:0.4f}=' + '+'.join([f'{e/div:0.4f}' for e in errs])
|
||||
else:
|
||||
err_str = f'{err/div:0.4f}'
|
||||
return err_str
|
||||
|
||||
def write_err_img(self):
|
||||
err_img_path = self.exp_out_root / 'errs.png'
|
||||
fig = plt.figure(figsize=(16,16))
|
||||
lines=[]
|
||||
for idx,errs in enumerate(self.errs_list):
|
||||
line,=plt.plot(range(len(errs)), errs, label=f'error{idx}')
|
||||
lines.append(line)
|
||||
plt.tight_layout()
|
||||
plt.legend(handles=lines)
|
||||
plt.savefig(str(err_img_path))
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def callback_train_new_epoch(self, epoch, net, optimizer):
|
||||
pass
|
||||
|
||||
def train(self, net, optimizer, resume=False, scheduler=None):
|
||||
logging.info('='*80)
|
||||
logging.info('Start training')
|
||||
self.log_datetime()
|
||||
logging.info('='*80)
|
||||
|
||||
train_set = self.get_train_set()
|
||||
test_sets = self.get_test_sets()
|
||||
|
||||
net = net.to(self.train_device)
|
||||
|
||||
epoch = 0
|
||||
min_err = {ts.name: 1e9 for ts in test_sets}
|
||||
|
||||
state_path = self.exp_out_root / 'state.dict'
|
||||
if resume and state_path.exists():
|
||||
logging.info('='*80)
|
||||
logging.info(f'Loading state from {state_path}')
|
||||
logging.info('='*80)
|
||||
state = torch.load(str(state_path))
|
||||
epoch = state['epoch'] + 1
|
||||
if 'min_err' in state:
|
||||
min_err = state['min_err']
|
||||
|
||||
curr_state = net.state_dict()
|
||||
curr_state.update(state['state_dict'])
|
||||
net.load_state_dict(curr_state)
|
||||
|
||||
|
||||
try:
|
||||
optimizer.load_state_dict(state['optimizer'])
|
||||
except:
|
||||
logging.info('Warning: cannot load optimizer from state_dict')
|
||||
pass
|
||||
if 'cpu_rng_state' in state:
|
||||
torch.set_rng_state(state['cpu_rng_state'])
|
||||
if 'gpu_rng_state' in state:
|
||||
torch.cuda.set_rng_state(state['gpu_rng_state'])
|
||||
|
||||
for epoch in range(epoch, self.epochs):
|
||||
self.callback_train_new_epoch(epoch, net, optimizer)
|
||||
|
||||
# train epoch
|
||||
self.train_epoch(epoch, net, optimizer, train_set)
|
||||
|
||||
# test epoch
|
||||
errs = self.test(epoch, net, test_sets)
|
||||
|
||||
if (epoch + 1) % self.save_frequency == 0:
|
||||
net = net.to(self.train_device)
|
||||
|
||||
# store state
|
||||
state_dict = {
|
||||
'epoch': epoch,
|
||||
'min_err': min_err,
|
||||
'state_dict': net.state_dict(),
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'cpu_rng_state': torch.get_rng_state(),
|
||||
'gpu_rng_state': torch.cuda.get_rng_state(),
|
||||
}
|
||||
logging.info(f'save state to {state_path}')
|
||||
state_path = self.exp_out_root / 'state.dict'
|
||||
torch.save(state_dict, str(state_path))
|
||||
|
||||
for test_set_name in errs:
|
||||
err = sum(errs[test_set_name])
|
||||
if err < min_err[test_set_name]:
|
||||
min_err[test_set_name] = err
|
||||
state_path = self.exp_out_root / f'state_set{test_set_name}_best.dict'
|
||||
logging.info(f'save state to {state_path}')
|
||||
torch.save(state_dict, str(state_path))
|
||||
|
||||
# store network
|
||||
net_path = self.get_net_path(epoch)
|
||||
logging.info(f'save network to {net_path}')
|
||||
torch.save(net.state_dict(), str(net_path))
|
||||
|
||||
if scheduler is not None:
|
||||
scheduler.step()
|
||||
|
||||
logging.info('='*80)
|
||||
logging.info('Finished training')
|
||||
self.log_datetime()
|
||||
logging.info('='*80)
|
||||
|
||||
def get_train_set(self):
|
||||
# returns train_set
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_test_sets(self):
|
||||
# returns test_sets
|
||||
raise NotImplementedError()
|
||||
|
||||
def copy_data(self, data, device, requires_grad, train):
|
||||
raise NotImplementedError()
|
||||
|
||||
def net_forward(self, net, train):
|
||||
raise NotImplementedError()
|
||||
|
||||
def loss_forward(self, output, train):
|
||||
raise NotImplementedError()
|
||||
|
||||
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
|
||||
# err = False
|
||||
# for name, param in net.named_parameters():
|
||||
# if not torch.isfinite(param.grad).all():
|
||||
# print(name)
|
||||
# err = True
|
||||
# if err:
|
||||
# import ipdb; ipdb.set_trace()
|
||||
pass
|
||||
|
||||
def callback_train_start(self, epoch):
|
||||
pass
|
||||
|
||||
def callback_train_stop(self, epoch, loss):
|
||||
pass
|
||||
|
||||
def train_epoch(self, epoch, net, optimizer, dset):
|
||||
self.callback_train_start(epoch)
|
||||
stopwatch = StopWatch()
|
||||
|
||||
logging.info('='*80)
|
||||
logging.info('Train epoch %d' % epoch)
|
||||
|
||||
dset.current_epoch = epoch
|
||||
train_loader = torch.utils.data.DataLoader(dset, batch_size=self.train_batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True, pin_memory=False)
|
||||
|
||||
net = net.to(self.train_device)
|
||||
net.train()
|
||||
|
||||
mean_loss = None
|
||||
|
||||
n_batches = self.max_train_iter if self.max_train_iter > 0 else len(train_loader)
|
||||
bar = ETA(length=n_batches)
|
||||
|
||||
stopwatch.start('total')
|
||||
stopwatch.start('data')
|
||||
for batch_idx, data in enumerate(train_loader):
|
||||
if self.max_train_iter > 0 and batch_idx > self.max_train_iter: break
|
||||
self.copy_data(data, device=self.train_device, requires_grad=True, train=True)
|
||||
stopwatch.stop('data')
|
||||
|
||||
optimizer.zero_grad()
|
||||
|
||||
stopwatch.start('forward')
|
||||
output = self.net_forward(net, train=True)
|
||||
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
||||
stopwatch.stop('forward')
|
||||
|
||||
stopwatch.start('loss')
|
||||
errs = self.loss_forward(output, train=True)
|
||||
if isinstance(errs, dict):
|
||||
masks = errs['masks']
|
||||
errs = errs['errs']
|
||||
else:
|
||||
masks = []
|
||||
if not isinstance(errs, list) and not isinstance(errs, tuple):
|
||||
errs = [errs]
|
||||
err = sum(errs)
|
||||
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
||||
stopwatch.stop('loss')
|
||||
|
||||
stopwatch.start('backward')
|
||||
err.backward()
|
||||
self.callback_train_post_backward(net, errs, output, epoch, batch_idx, masks)
|
||||
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
||||
stopwatch.stop('backward')
|
||||
|
||||
stopwatch.start('optimizer')
|
||||
optimizer.step()
|
||||
if 'cuda' in self.train_device: torch.cuda.synchronize()
|
||||
stopwatch.stop('optimizer')
|
||||
|
||||
bar.update(batch_idx)
|
||||
if (epoch <= 1 and batch_idx < 128) or batch_idx % 16 == 0:
|
||||
err_str = self.format_err_str(errs)
|
||||
logging.info(f'train e{epoch}: {batch_idx+1}/{len(train_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
|
||||
#self.write_err_img()
|
||||
|
||||
|
||||
if mean_loss is None:
|
||||
mean_loss = [0 for e in errs]
|
||||
for erridx, err in enumerate(errs):
|
||||
mean_loss[erridx] += err.item()
|
||||
|
||||
stopwatch.start('data')
|
||||
stopwatch.stop('total')
|
||||
logging.info('timings: %s' % stopwatch)
|
||||
|
||||
mean_loss = [l / len(train_loader) for l in mean_loss]
|
||||
self.callback_train_stop(epoch, mean_loss)
|
||||
self.metric_add_train(epoch, 'loss', mean_loss)
|
||||
|
||||
# save metrics
|
||||
self.metric_save()
|
||||
|
||||
err_str = self.format_err_str(mean_loss)
|
||||
logging.info(f'avg train_loss={err_str}')
|
||||
return mean_loss
|
||||
|
||||
def callback_test_start(self, epoch, set_idx):
|
||||
pass
|
||||
|
||||
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
|
||||
pass
|
||||
|
||||
def callback_test_stop(self, epoch, set_idx, loss):
|
||||
pass
|
||||
|
||||
def test(self, epoch, net, test_sets):
|
||||
errs = {}
|
||||
for test_set_idx, test_set in enumerate(test_sets):
|
||||
if (epoch + 1) % test_set.test_frequency == 0:
|
||||
logging.info('='*80)
|
||||
logging.info(f'testing set {test_set.name}')
|
||||
err = self.test_epoch(epoch, test_set_idx, net, test_set.dset)
|
||||
errs[test_set.name] = err
|
||||
return errs
|
||||
|
||||
def test_epoch(self, epoch, set_idx, net, dset):
|
||||
logging.info('-'*80)
|
||||
logging.info('Test epoch %d' % epoch)
|
||||
dset.current_epoch = epoch
|
||||
test_loader = torch.utils.data.DataLoader(dset, batch_size=self.test_batch_size, shuffle=False, num_workers=self.num_workers, drop_last=False, pin_memory=False)
|
||||
|
||||
net = net.to(self.test_device)
|
||||
net.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
mean_loss = None
|
||||
|
||||
self.callback_test_start(epoch, set_idx)
|
||||
|
||||
bar = ETA(length=len(test_loader))
|
||||
stopwatch = StopWatch()
|
||||
stopwatch.start('total')
|
||||
stopwatch.start('data')
|
||||
for batch_idx, data in enumerate(test_loader):
|
||||
# if batch_idx == 10: break
|
||||
self.copy_data(data, device=self.test_device, requires_grad=False, train=False)
|
||||
stopwatch.stop('data')
|
||||
|
||||
stopwatch.start('forward')
|
||||
output = self.net_forward(net, train=False)
|
||||
if 'cuda' in self.test_device: torch.cuda.synchronize()
|
||||
stopwatch.stop('forward')
|
||||
|
||||
stopwatch.start('loss')
|
||||
errs = self.loss_forward(output, train=False)
|
||||
if isinstance(errs, dict):
|
||||
masks = errs['masks']
|
||||
errs = errs['errs']
|
||||
else:
|
||||
masks = []
|
||||
if not isinstance(errs, list) and not isinstance(errs, tuple):
|
||||
errs = [errs]
|
||||
|
||||
bar.update(batch_idx)
|
||||
if batch_idx % 25 == 0:
|
||||
err_str = self.format_err_str(errs)
|
||||
logging.info(f'test e{epoch}: {batch_idx+1}/{len(test_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
|
||||
|
||||
if mean_loss is None:
|
||||
mean_loss = [0 for e in errs]
|
||||
for erridx, err in enumerate(errs):
|
||||
mean_loss[erridx] += err.item()
|
||||
stopwatch.stop('loss')
|
||||
|
||||
self.callback_test_add(epoch, set_idx, batch_idx, len(test_loader), output, masks)
|
||||
|
||||
stopwatch.start('data')
|
||||
stopwatch.stop('total')
|
||||
logging.info('timings: %s' % stopwatch)
|
||||
|
||||
mean_loss = [l / len(test_loader) for l in mean_loss]
|
||||
self.callback_test_stop(epoch, set_idx, mean_loss)
|
||||
self.metric_add_test(epoch, set_idx, 'loss', mean_loss)
|
||||
|
||||
# save metrics
|
||||
self.metric_save()
|
||||
|
||||
err_str = self.format_err_str(mean_loss)
|
||||
logging.info(f'test epoch {epoch}: avg test_loss={err_str}')
|
||||
return mean_loss
|
29
train_val.py
Normal file
29
train_val.py
Normal file
@ -0,0 +1,29 @@
|
||||
import os
|
||||
import torch
|
||||
from model import exp_synph
|
||||
from model import exp_synphge
|
||||
from model import networks
|
||||
from co.args import parse_args
|
||||
|
||||
|
||||
# parse args
|
||||
args = parse_args()
|
||||
|
||||
# loss types
|
||||
if args.loss=='ph':
|
||||
worker = exp_synph.Worker(args)
|
||||
elif args.loss=='phge':
|
||||
worker = exp_synphge.Worker(args)
|
||||
|
||||
# concatenation of original image and lcn image
|
||||
channels_in=2
|
||||
|
||||
# set up network
|
||||
net = networks.DispEdgeDecoders(channels_in=channels_in, max_disp=args.max_disp, imsizes=worker.imsizes, output_ms=worker.ms)
|
||||
|
||||
# optimizer
|
||||
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
|
||||
|
||||
# start the work
|
||||
worker.do(net, optimizer)
|
||||
|
Loading…
Reference in New Issue
Block a user