Reformat $EVERYTHING

This commit is contained in:
CptCaptain 2021-11-15 16:53:30 +01:00
parent 56f2aa7d5d
commit 43df77fb9b
32 changed files with 4171 additions and 3749 deletions

View File

@ -7,8 +7,9 @@
# set matplotlib backend depending on env
import os
import matplotlib
if os.name == 'posix' and "DISPLAY" not in os.environ:
matplotlib.use('Agg')
matplotlib.use('Agg')
from . import geometry
from . import plt

View File

@ -12,14 +12,14 @@ def parse_args():
parser.add_argument('--loss',
help='Train with \'ph\' for the first stage without geometric loss, \
train with \'phge\' for the second stage with geometric loss',
default='ph', choices=['ph','phge'], type=str)
default='ph', choices=['ph', 'phge'], type=str)
parser.add_argument('--data_type',
default='syn', choices=['syn'], type=str)
#
parser.add_argument('--cmd',
help='Start training or test',
parser.add_argument('--cmd',
help='Start training or test',
default='resume', choices=['retrain', 'resume', 'retest', 'test_init'], type=str)
parser.add_argument('--epoch',
parser.add_argument('--epoch',
help='If larger than -1, retest on the specified epoch',
default=-1, type=int)
parser.add_argument('--epochs',
@ -55,7 +55,7 @@ def parse_args():
parser.add_argument('--blend_im',
help='Parameter for adding texture',
default=0.6, type=float)
args = parser.parse_args()
args.exp_name = get_exp_name(args)
@ -66,6 +66,3 @@ def parse_args():
def get_exp_name(args):
name = f"exp_{args.data_type}"
return name

View File

@ -1,19 +1,20 @@
import numpy as np
_color_map_errors = np.array([
[149, 54, 49], #0: log2(x) = -infinity
[180, 117, 69], #0.0625: log2(x) = -4
[209, 173, 116], #0.125: log2(x) = -3
[233, 217, 171], #0.25: log2(x) = -2
[248, 243, 224], #0.5: log2(x) = -1
[144, 224, 254], #1.0: log2(x) = 0
[97, 174, 253], #2.0: log2(x) = 1
[67, 109, 244], #4.0: log2(x) = 2
[39, 48, 215], #8.0: log2(x) = 3
[38, 0, 165], #16.0: log2(x) = 4
[38, 0, 165] #inf: log2(x) = inf
[149, 54, 49], # 0: log2(x) = -infinity
[180, 117, 69], # 0.0625: log2(x) = -4
[209, 173, 116], # 0.125: log2(x) = -3
[233, 217, 171], # 0.25: log2(x) = -2
[248, 243, 224], # 0.5: log2(x) = -1
[144, 224, 254], # 1.0: log2(x) = 0
[97, 174, 253], # 2.0: log2(x) = 1
[67, 109, 244], # 4.0: log2(x) = 2
[39, 48, 215], # 8.0: log2(x) = 3
[38, 0, 165], # 16.0: log2(x) = 4
[38, 0, 165] # inf: log2(x) = inf
]).astype(float)
def color_error_image(errors, scale=1, mask=None, BGR=True):
"""
Color an input error map.
@ -27,31 +28,33 @@ def color_error_image(errors, scale=1, mask=None, BGR=True):
Returns:
colored_errors -- HxWx3 numpy array visualizing the errors
"""
errors_flat = errors.flatten()
errors_color_indices = np.clip(np.log2(errors_flat / scale + 1e-5) + 5, 0, 9)
i0 = np.floor(errors_color_indices).astype(int)
f1 = errors_color_indices - i0.astype(float)
colored_errors_flat = _color_map_errors[i0, :] * (1-f1).reshape(-1,1) + _color_map_errors[i0+1, :] * f1.reshape(-1,1)
colored_errors_flat = _color_map_errors[i0, :] * (1 - f1).reshape(-1, 1) + _color_map_errors[i0 + 1,
:] * f1.reshape(-1, 1)
if mask is not None:
colored_errors_flat[mask.flatten() == 0] = 255
if not BGR:
colored_errors_flat = colored_errors_flat[:,[2,1,0]]
colored_errors_flat = colored_errors_flat[:, [2, 1, 0]]
return colored_errors_flat.reshape(errors.shape[0], errors.shape[1], 3).astype(np.int)
_color_map_depths = np.array([
[0, 0, 0], # 0.000
[0, 0, 255], # 0.114
[255, 0, 0], # 0.299
[255, 0, 255], # 0.413
[0, 255, 0], # 0.587
[0, 255, 255], # 0.701
[255, 255, 0], # 0.886
[255, 255, 255], # 1.000
[255, 255, 255], # 1.000
[0, 0, 0], # 0.000
[0, 0, 255], # 0.114
[255, 0, 0], # 0.299
[255, 0, 255], # 0.413
[0, 255, 0], # 0.587
[0, 255, 255], # 0.701
[255, 255, 0], # 0.886
[255, 255, 255], # 1.000
[255, 255, 255], # 1.000
]).astype(float)
_color_map_bincenters = np.array([
0.0,
@ -62,9 +65,10 @@ _color_map_bincenters = np.array([
0.701,
0.886,
1.000,
2.000, # doesn't make a difference, just strictly higher than 1
2.000, # doesn't make a difference, just strictly higher than 1
])
def color_depth_map(depths, scale=None):
"""
Color an input depth map.
@ -82,12 +86,13 @@ def color_depth_map(depths, scale=None):
values = np.clip(depths.flatten() / scale, 0, 1)
# for each value, figure out where they fit in in the bincenters: what is the last bincenter smaller than this value?
lower_bin = ((values.reshape(-1, 1) >= _color_map_bincenters.reshape(1,-1)) * np.arange(0,9)).max(axis=1)
lower_bin = ((values.reshape(-1, 1) >= _color_map_bincenters.reshape(1, -1)) * np.arange(0, 9)).max(axis=1)
lower_bin_value = _color_map_bincenters[lower_bin]
higher_bin_value = _color_map_bincenters[lower_bin + 1]
alphas = (values - lower_bin_value) / (higher_bin_value - lower_bin_value)
colors = _color_map_depths[lower_bin] * (1-alphas).reshape(-1,1) + _color_map_depths[lower_bin + 1] * alphas.reshape(-1,1)
colors = _color_map_depths[lower_bin] * (1 - alphas).reshape(-1, 1) + _color_map_depths[
lower_bin + 1] * alphas.reshape(-1, 1)
return colors.reshape(depths.shape[0], depths.shape[1], 3).astype(np.uint8)
#from utils.debug import save_color_numpy
#save_color_numpy(color_depth_map(np.matmul(np.ones((100,1)), np.arange(0,1200).reshape(1,1200)), scale=1000))
# from utils.debug import save_color_numpy
# save_color_numpy(color_depth_map(np.matmul(np.ones((100,1)), np.arange(0,1200).reshape(1,1200)), scale=1000))

File diff suppressed because it is too large Load Diff

View File

@ -2,31 +2,37 @@ import numpy as np
from . import utils
class StopWatch(utils.StopWatch):
def __del__(self):
print('='*80)
print('gtimer:')
total = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.sum).items()])
print(f' [total] {total}')
mean = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.mean).items()])
print(f' [mean] {mean}')
median = ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get(reduce=np.median).items()])
print(f' [median] {median}')
print('='*80)
def __del__(self):
print('=' * 80)
print('gtimer:')
total = ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get(reduce=np.sum).items()])
print(f' [total] {total}')
mean = ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get(reduce=np.mean).items()])
print(f' [mean] {mean}')
median = ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get(reduce=np.median).items()])
print(f' [median] {median}')
print('=' * 80)
GTIMER = StopWatch()
def start(name):
GTIMER.start(name)
GTIMER.start(name)
def stop(name):
GTIMER.stop(name)
GTIMER.stop(name)
class Ctx(object):
def __init__(self, name):
self.name = name
def __init__(self, name):
self.name = name
def __enter__(self):
start(self.name)
def __enter__(self):
start(self.name)
def __exit__(self, *args):
stop(self.name)
def __exit__(self, *args):
stop(self.name)

View File

@ -2,266 +2,273 @@ import struct
import numpy as np
import collections
def _write_ply_point(fp, x,y,z, color=None, normal=None, binary=False):
args = [x,y,z]
if color is not None:
args += [int(color[0]), int(color[1]), int(color[2])]
if normal is not None:
args += [normal[0],normal[1],normal[2]]
if binary:
fmt = '<fff'
if color is not None:
fmt = fmt + 'BBB'
if normal is not None:
fmt = fmt + 'fff'
fp.write(struct.pack(fmt, *args))
else:
fmt = '%f %f %f'
if color is not None:
fmt = fmt + ' %d %d %d'
if normal is not None:
fmt = fmt + ' %f %f %f'
fmt += '\n'
fp.write(fmt % tuple(args))
def _write_ply_triangle(fp, i0,i1,i2, binary):
if binary:
fp.write(struct.pack('<Biii', 3,i0,i1,i2))
else:
fp.write('3 %d %d %d\n' % (i0,i1,i2))
def _write_ply_point(fp, x, y, z, color=None, normal=None, binary=False):
args = [x, y, z]
if color is not None:
args += [int(color[0]), int(color[1]), int(color[2])]
if normal is not None:
args += [normal[0], normal[1], normal[2]]
if binary:
fmt = '<fff'
if color is not None:
fmt = fmt + 'BBB'
if normal is not None:
fmt = fmt + 'fff'
fp.write(struct.pack(fmt, *args))
else:
fmt = '%f %f %f'
if color is not None:
fmt = fmt + ' %d %d %d'
if normal is not None:
fmt = fmt + ' %f %f %f'
fmt += '\n'
fp.write(fmt % tuple(args))
def _write_ply_triangle(fp, i0, i1, i2, binary):
if binary:
fp.write(struct.pack('<Biii', 3, i0, i1, i2))
else:
fp.write('3 %d %d %d\n' % (i0, i1, i2))
def _write_ply_header_line(fp, str, binary):
if binary:
fp.write(str.encode())
else:
fp.write(str)
if binary:
fp.write(str.encode())
else:
fp.write(str)
def write_ply(path, verts, trias=None, color=None, normals=None, binary=False):
if verts.shape[1] != 3:
raise Exception('verts has to be of shape Nx3')
if trias is not None and trias.shape[1] != 3:
raise Exception('trias has to be of shape Nx3')
if color is not None and not callable(color) and not isinstance(color, np.ndarray) and color.shape[1] != 3:
raise Exception('color has to be of shape Nx3 or a callable')
if verts.shape[1] != 3:
raise Exception('verts has to be of shape Nx3')
if trias is not None and trias.shape[1] != 3:
raise Exception('trias has to be of shape Nx3')
if color is not None and not callable(color) and not isinstance(color, np.ndarray) and color.shape[1] != 3:
raise Exception('color has to be of shape Nx3 or a callable')
mode = 'wb' if binary else 'w'
with open(path, mode) as fp:
_write_ply_header_line(fp, "ply\n", binary)
if binary:
_write_ply_header_line(fp, "format binary_little_endian 1.0\n", binary)
else:
_write_ply_header_line(fp, "format ascii 1.0\n", binary)
_write_ply_header_line(fp, "element vertex %d\n" % (verts.shape[0]), binary)
_write_ply_header_line(fp, "property float32 x\n", binary)
_write_ply_header_line(fp, "property float32 y\n", binary)
_write_ply_header_line(fp, "property float32 z\n", binary)
if color is not None:
_write_ply_header_line(fp, "property uchar red\n", binary)
_write_ply_header_line(fp, "property uchar green\n", binary)
_write_ply_header_line(fp, "property uchar blue\n", binary)
if normals is not None:
_write_ply_header_line(fp, "property float32 nx\n", binary)
_write_ply_header_line(fp, "property float32 ny\n", binary)
_write_ply_header_line(fp, "property float32 nz\n", binary)
if trias is not None:
_write_ply_header_line(fp, "element face %d\n" % (trias.shape[0]), binary)
_write_ply_header_line(fp, "property list uchar int32 vertex_indices\n", binary)
_write_ply_header_line(fp, "end_header\n", binary)
for vidx, v in enumerate(verts):
if color is not None:
if callable(color):
c = color(vidx)
elif color.shape[0] > 1:
c = color[vidx]
mode = 'wb' if binary else 'w'
with open(path, mode) as fp:
_write_ply_header_line(fp, "ply\n", binary)
if binary:
_write_ply_header_line(fp, "format binary_little_endian 1.0\n", binary)
else:
c = color[0]
else:
c = None
if normals is None:
n = None
else:
n = normals[vidx]
_write_ply_point(fp, v[0],v[1],v[2], c, n, binary)
_write_ply_header_line(fp, "format ascii 1.0\n", binary)
_write_ply_header_line(fp, "element vertex %d\n" % (verts.shape[0]), binary)
_write_ply_header_line(fp, "property float32 x\n", binary)
_write_ply_header_line(fp, "property float32 y\n", binary)
_write_ply_header_line(fp, "property float32 z\n", binary)
if color is not None:
_write_ply_header_line(fp, "property uchar red\n", binary)
_write_ply_header_line(fp, "property uchar green\n", binary)
_write_ply_header_line(fp, "property uchar blue\n", binary)
if normals is not None:
_write_ply_header_line(fp, "property float32 nx\n", binary)
_write_ply_header_line(fp, "property float32 ny\n", binary)
_write_ply_header_line(fp, "property float32 nz\n", binary)
if trias is not None:
_write_ply_header_line(fp, "element face %d\n" % (trias.shape[0]), binary)
_write_ply_header_line(fp, "property list uchar int32 vertex_indices\n", binary)
_write_ply_header_line(fp, "end_header\n", binary)
for vidx, v in enumerate(verts):
if color is not None:
if callable(color):
c = color(vidx)
elif color.shape[0] > 1:
c = color[vidx]
else:
c = color[0]
else:
c = None
if normals is None:
n = None
else:
n = normals[vidx]
_write_ply_point(fp, v[0], v[1], v[2], c, n, binary)
if trias is not None:
for t in trias:
_write_ply_triangle(fp, t[0], t[1], t[2], binary)
if trias is not None:
for t in trias:
_write_ply_triangle(fp, t[0],t[1],t[2], binary)
def faces_to_triangles(faces):
new_faces = []
for f in faces:
if f[0] == 3:
new_faces.append([f[1], f[2], f[3]])
elif f[0] == 4:
new_faces.append([f[1], f[2], f[3]])
new_faces.append([f[3], f[4], f[1]])
else:
raise Exception('unknown face count %d', f[0])
return new_faces
new_faces = []
for f in faces:
if f[0] == 3:
new_faces.append([f[1], f[2], f[3]])
elif f[0] == 4:
new_faces.append([f[1], f[2], f[3]])
new_faces.append([f[3], f[4], f[1]])
else:
raise Exception('unknown face count %d', f[0])
return new_faces
def read_ply(path):
with open(path, 'rb') as f:
# parse header
line = f.readline().decode().strip()
if line != 'ply':
raise Exception('Header error')
n_verts = 0
n_faces = 0
vert_types = {}
vert_bin_format = []
vert_bin_len = 0
vert_bin_cols = 0
line = f.readline().decode()
parse_vertex_prop = False
while line.strip() != 'end_header':
if 'format' in line:
if 'ascii' in line:
binary = False
elif 'binary_little_endian' in line:
binary = True
else:
raise Exception('invalid ply format')
if 'element face' in line:
splits = line.strip().split(' ')
n_faces = int(splits[-1])
with open(path, 'rb') as f:
# parse header
line = f.readline().decode().strip()
if line != 'ply':
raise Exception('Header error')
n_verts = 0
n_faces = 0
vert_types = {}
vert_bin_format = []
vert_bin_len = 0
vert_bin_cols = 0
line = f.readline().decode()
parse_vertex_prop = False
if 'element camera' in line:
parse_vertex_prop = False
if 'element vertex' in line:
splits = line.strip().split(' ')
n_verts = int(splits[-1])
parse_vertex_prop = True
if parse_vertex_prop and 'property' in line:
prop = line.strip().split()
if prop[1] == 'float':
vert_bin_format.append('f4')
vert_bin_len += 4
vert_bin_cols += 1
elif prop[1] == 'uchar':
vert_bin_format.append('B')
vert_bin_len += 1
vert_bin_cols += 1
while line.strip() != 'end_header':
if 'format' in line:
if 'ascii' in line:
binary = False
elif 'binary_little_endian' in line:
binary = True
else:
raise Exception('invalid ply format')
if 'element face' in line:
splits = line.strip().split(' ')
n_faces = int(splits[-1])
parse_vertex_prop = False
if 'element camera' in line:
parse_vertex_prop = False
if 'element vertex' in line:
splits = line.strip().split(' ')
n_verts = int(splits[-1])
parse_vertex_prop = True
if parse_vertex_prop and 'property' in line:
prop = line.strip().split()
if prop[1] == 'float':
vert_bin_format.append('f4')
vert_bin_len += 4
vert_bin_cols += 1
elif prop[1] == 'uchar':
vert_bin_format.append('B')
vert_bin_len += 1
vert_bin_cols += 1
else:
raise Exception('invalid property')
vert_types[prop[2]] = len(vert_types)
line = f.readline().decode()
# parse content
if binary:
sz = n_verts * vert_bin_len
fmt = ','.join(vert_bin_format)
verts = np.ndarray(shape=(1, n_verts), dtype=np.dtype(fmt), buffer=f.read(sz))
verts = verts[0].astype(vert_bin_cols * 'f4,').view(dtype='f4').reshape((n_verts, -1))
faces = []
for idx in range(n_faces):
fmt = '<Biii'
length = struct.calcsize(fmt)
dat = f.read(length)
vals = struct.unpack(fmt, dat)
faces.append(vals)
faces = faces_to_triangles(faces)
faces = np.array(faces, dtype=np.int32)
else:
raise Exception('invalid property')
vert_types[prop[2]] = len(vert_types)
line = f.readline().decode()
verts = []
for idx in range(n_verts):
vals = [float(v) for v in f.readline().decode().strip().split(' ')]
verts.append(vals)
verts = np.array(verts, dtype=np.float32)
faces = []
for idx in range(n_faces):
splits = f.readline().decode().strip().split(' ')
n_face_verts = int(splits[0])
vals = [int(v) for v in splits[0:n_face_verts + 1]]
faces.append(vals)
faces = faces_to_triangles(faces)
faces = np.array(faces, dtype=np.int32)
# parse content
if binary:
sz = n_verts * vert_bin_len
fmt = ','.join(vert_bin_format)
verts = np.ndarray(shape=(1, n_verts), dtype=np.dtype(fmt), buffer=f.read(sz))
verts = verts[0].astype(vert_bin_cols*'f4,').view(dtype='f4').reshape((n_verts,-1))
faces = []
for idx in range(n_faces):
fmt = '<Biii'
length = struct.calcsize(fmt)
dat = f.read(length)
vals = struct.unpack(fmt, dat)
faces.append(vals)
faces = faces_to_triangles(faces)
faces = np.array(faces, dtype=np.int32)
else:
verts = []
for idx in range(n_verts):
vals = [float(v) for v in f.readline().decode().strip().split(' ')]
verts.append(vals)
verts = np.array(verts, dtype=np.float32)
faces = []
for idx in range(n_faces):
splits = f.readline().decode().strip().split(' ')
n_face_verts = int(splits[0])
vals = [int(v) for v in splits[0:n_face_verts+1]]
faces.append(vals)
faces = faces_to_triangles(faces)
faces = np.array(faces, dtype=np.int32)
xyz = None
if 'x' in vert_types and 'y' in vert_types and 'z' in vert_types:
xyz = verts[:, [vert_types['x'], vert_types['y'], vert_types['z']]]
colors = None
if 'red' in vert_types and 'green' in vert_types and 'blue' in vert_types:
colors = verts[:, [vert_types['red'], vert_types['green'], vert_types['blue']]]
colors /= 255
normals = None
if 'nx' in vert_types and 'ny' in vert_types and 'nz' in vert_types:
normals = verts[:, [vert_types['nx'], vert_types['ny'], vert_types['nz']]]
xyz = None
if 'x' in vert_types and 'y' in vert_types and 'z' in vert_types:
xyz = verts[:,[vert_types['x'], vert_types['y'], vert_types['z']]]
colors = None
if 'red' in vert_types and 'green' in vert_types and 'blue' in vert_types:
colors = verts[:,[vert_types['red'], vert_types['green'], vert_types['blue']]]
colors /= 255
normals = None
if 'nx' in vert_types and 'ny' in vert_types and 'nz' in vert_types:
normals = verts[:,[vert_types['nx'], vert_types['ny'], vert_types['nz']]]
return xyz, faces, colors, normals
return xyz, faces, colors, normals
def _read_obj_split_f(s):
parts = s.split('/')
vidx = int(parts[0]) - 1
if len(parts) >= 2 and len(parts[1]) > 0:
tidx = int(parts[1]) - 1
else:
tidx = -1
if len(parts) >= 3 and len(parts[2]) > 0:
nidx = int(parts[2]) - 1
else:
nidx = -1
return vidx, tidx, nidx
parts = s.split('/')
vidx = int(parts[0]) - 1
if len(parts) >= 2 and len(parts[1]) > 0:
tidx = int(parts[1]) - 1
else:
tidx = -1
if len(parts) >= 3 and len(parts[2]) > 0:
nidx = int(parts[2]) - 1
else:
nidx = -1
return vidx, tidx, nidx
def read_obj(path):
with open(path, 'r') as fp:
lines = fp.readlines()
with open(path, 'r') as fp:
lines = fp.readlines()
verts = []
colors = []
fnorms = []
fnorm_map = collections.defaultdict(list)
faces = []
for line in lines:
line = line.strip()
if line.startswith('#') or len(line) == 0:
continue
verts = []
colors = []
fnorms = []
fnorm_map = collections.defaultdict(list)
faces = []
for line in lines:
line = line.strip()
if line.startswith('#') or len(line) == 0:
continue
parts = line.split()
if line.startswith('v '):
parts = parts[1:]
x,y,z = float(parts[0]), float(parts[1]), float(parts[2])
if len(parts) == 4 or len(parts) == 7:
w = float(parts[3])
x,y,z = x/w, y/w, z/w
verts.append((x,y,z))
if len(parts) >= 6:
r,g,b = float(parts[-3]), float(parts[-2]), float(parts[-1])
rgb.append((r,g,b))
parts = line.split()
if line.startswith('v '):
parts = parts[1:]
x, y, z = float(parts[0]), float(parts[1]), float(parts[2])
if len(parts) == 4 or len(parts) == 7:
w = float(parts[3])
x, y, z = x / w, y / w, z / w
verts.append((x, y, z))
if len(parts) >= 6:
r, g, b = float(parts[-3]), float(parts[-2]), float(parts[-1])
rgb.append((r, g, b))
elif line.startswith('vn '):
parts = parts[1:]
x,y,z = float(parts[0]), float(parts[1]), float(parts[2])
fnorms.append((x,y,z))
elif line.startswith('vn '):
parts = parts[1:]
x, y, z = float(parts[0]), float(parts[1]), float(parts[2])
fnorms.append((x, y, z))
elif line.startswith('f '):
parts = parts[1:]
if len(parts) != 3:
raise Exception('only triangle meshes supported atm')
vidx0, tidx0, nidx0 = _read_obj_split_f(parts[0])
vidx1, tidx1, nidx1 = _read_obj_split_f(parts[1])
vidx2, tidx2, nidx2 = _read_obj_split_f(parts[2])
elif line.startswith('f '):
parts = parts[1:]
if len(parts) != 3:
raise Exception('only triangle meshes supported atm')
vidx0, tidx0, nidx0 = _read_obj_split_f(parts[0])
vidx1, tidx1, nidx1 = _read_obj_split_f(parts[1])
vidx2, tidx2, nidx2 = _read_obj_split_f(parts[2])
faces.append((vidx0, vidx1, vidx2))
if nidx0 >= 0:
fnorm_map[vidx0].append( nidx0 )
if nidx1 >= 0:
fnorm_map[vidx1].append( nidx1 )
if nidx2 >= 0:
fnorm_map[vidx2].append( nidx2 )
faces.append((vidx0, vidx1, vidx2))
if nidx0 >= 0:
fnorm_map[vidx0].append(nidx0)
if nidx1 >= 0:
fnorm_map[vidx1].append(nidx1)
if nidx2 >= 0:
fnorm_map[vidx2].append(nidx2)
verts = np.array(verts)
colors = np.array(colors)
fnorms = np.array(fnorms)
faces = np.array(faces)
# face normals to vertex normals
norms = np.zeros_like(verts)
for vidx in fnorm_map.keys():
ind = fnorm_map[vidx]
norms[vidx] = fnorms[ind].sum(axis=0)
N = np.linalg.norm(norms, axis=1, keepdims=True)
np.divide(norms, N, out=norms, where=N != 0)
verts = np.array(verts)
colors = np.array(colors)
fnorms = np.array(fnorms)
faces = np.array(faces)
return verts, faces, colors, norms
# face normals to vertex normals
norms = np.zeros_like(verts)
for vidx in fnorm_map.keys():
ind = fnorm_map[vidx]
norms[vidx] = fnorms[ind].sum(axis=0)
N = np.linalg.norm(norms, axis=1, keepdims=True)
np.divide(norms, N, out=norms, where=N != 0)
return verts, faces, colors, norms

View File

@ -1,248 +1,260 @@
import numpy as np
from . import geometry
def _process_inputs(estimate, target, mask):
if estimate.shape != target.shape:
raise Exception('estimate and target have to be same shape')
if mask is None:
mask = np.ones(estimate.shape, dtype=np.bool)
else:
mask = mask != 0
if estimate.shape != mask.shape:
raise Exception('estimate and mask have to be same shape')
return estimate, target, mask
if estimate.shape != target.shape:
raise Exception('estimate and target have to be same shape')
if mask is None:
mask = np.ones(estimate.shape, dtype=np.bool)
else:
mask = mask != 0
if estimate.shape != mask.shape:
raise Exception('estimate and mask have to be same shape')
return estimate, target, mask
def mse(estimate, target, mask=None):
estimate, target, mask = _process_inputs(estimate, target, mask)
m = np.sum((estimate[mask] - target[mask])**2) / mask.sum()
return m
estimate, target, mask = _process_inputs(estimate, target, mask)
m = np.sum((estimate[mask] - target[mask]) ** 2) / mask.sum()
return m
def rmse(estimate, target, mask=None):
return np.sqrt(mse(estimate, target, mask))
return np.sqrt(mse(estimate, target, mask))
def mae(estimate, target, mask=None):
estimate, target, mask = _process_inputs(estimate, target, mask)
m = np.abs(estimate[mask] - target[mask]).sum() / mask.sum()
return m
estimate, target, mask = _process_inputs(estimate, target, mask)
m = np.abs(estimate[mask] - target[mask]).sum() / mask.sum()
return m
def outlier_fraction(estimate, target, mask=None, threshold=0):
estimate, target, mask = _process_inputs(estimate, target, mask)
diff = np.abs(estimate[mask] - target[mask])
m = (diff > threshold).sum() / mask.sum()
return m
estimate, target, mask = _process_inputs(estimate, target, mask)
diff = np.abs(estimate[mask] - target[mask])
m = (diff > threshold).sum() / mask.sum()
return m
class Metric(object):
def __init__(self, str_prefix=''):
self.str_prefix = str_prefix
self.reset()
def __init__(self, str_prefix=''):
self.str_prefix = str_prefix
self.reset()
def reset(self):
pass
def reset(self):
pass
def add(self, es, ta, ma=None):
pass
def add(self, es, ta, ma=None):
pass
def get(self):
return {}
def get(self):
return {}
def items(self):
return self.get().items()
def items(self):
return self.get().items()
def __str__(self):
return ', '.join([f'{self.str_prefix}{key}={value:.5f}' for key, value in self.get().items()])
def __str__(self):
return ', '.join([f'{self.str_prefix}{key}={value:.5f}' for key, value in self.get().items()])
class MultipleMetric(Metric):
def __init__(self, *metrics, **kwargs):
self.metrics = [*metrics]
super().__init__(**kwargs)
def __init__(self, *metrics, **kwargs):
self.metrics = [*metrics]
super().__init__(**kwargs)
def reset(self):
for m in self.metrics:
m.reset()
def reset(self):
for m in self.metrics:
m.reset()
def add(self, es, ta, ma=None):
for m in self.metrics:
m.add(es, ta, ma)
def add(self, es, ta, ma=None):
for m in self.metrics:
m.add(es, ta, ma)
def get(self):
ret = {}
for m in self.metrics:
vals = m.get()
for k in vals:
ret[k] = vals[k]
return ret
def get(self):
ret = {}
for m in self.metrics:
vals = m.get()
for k in vals:
ret[k] = vals[k]
return ret
def __str__(self):
return '\n'.join([str(m) for m in self.metrics])
def __str__(self):
return '\n'.join([str(m) for m in self.metrics])
class BaseDistanceMetric(Metric):
def __init__(self, name='', **kwargs):
super().__init__(**kwargs)
self.name = name
def __init__(self, name='', **kwargs):
super().__init__(**kwargs)
self.name = name
def reset(self):
self.dists = []
def reset(self):
self.dists = []
def add(self, es, ta, ma=None):
pass
def add(self, es, ta, ma=None):
pass
def get(self):
dists = np.hstack(self.dists)
return {
f'dist{self.name}_mean': float(np.mean(dists)),
f'dist{self.name}_std': float(np.std(dists)),
f'dist{self.name}_median': float(np.median(dists)),
f'dist{self.name}_q10': float(np.percentile(dists, 10)),
f'dist{self.name}_q90': float(np.percentile(dists, 90)),
f'dist{self.name}_min': float(np.min(dists)),
f'dist{self.name}_max': float(np.max(dists)),
}
def get(self):
dists = np.hstack(self.dists)
return {
f'dist{self.name}_mean': float(np.mean(dists)),
f'dist{self.name}_std': float(np.std(dists)),
f'dist{self.name}_median': float(np.median(dists)),
f'dist{self.name}_q10': float(np.percentile(dists, 10)),
f'dist{self.name}_q90': float(np.percentile(dists, 90)),
f'dist{self.name}_min': float(np.min(dists)),
f'dist{self.name}_max': float(np.max(dists)),
}
class DistanceMetric(BaseDistanceMetric):
def __init__(self, vec_length, p=2, **kwargs):
super().__init__(name=f'{p}', **kwargs)
self.vec_length = vec_length
self.p = p
def __init__(self, vec_length, p=2, **kwargs):
super().__init__(name=f'{p}', **kwargs)
self.vec_length = vec_length
self.p = p
def add(self, es, ta, ma=None):
if es.shape != ta.shape or es.shape[1] != self.vec_length or es.ndim != 2:
print(es.shape, ta.shape)
raise Exception('es and ta have to be of shape Nxdim')
if ma is not None:
es = es[ma != 0]
ta = ta[ma != 0]
dist = np.linalg.norm(es - ta, ord=self.p, axis=1)
self.dists.append(dist)
def add(self, es, ta, ma=None):
if es.shape != ta.shape or es.shape[1] != self.vec_length or es.ndim != 2:
print(es.shape, ta.shape)
raise Exception('es and ta have to be of shape Nxdim')
if ma is not None:
es = es[ma != 0]
ta = ta[ma != 0]
dist = np.linalg.norm(es - ta, ord=self.p, axis=1)
self.dists.append( dist )
class OutlierFractionMetric(DistanceMetric):
def __init__(self, thresholds, *args, **kwargs):
super().__init__(*args, **kwargs)
self.thresholds = thresholds
def __init__(self, thresholds, *args, **kwargs):
super().__init__(*args, **kwargs)
self.thresholds = thresholds
def get(self):
dists = np.hstack(self.dists)
ret = {}
for t in self.thresholds:
ma = dists > t
ret[f'of{t}'] = float(ma.sum() / ma.size)
return ret
def get(self):
dists = np.hstack(self.dists)
ret = {}
for t in self.thresholds:
ma = dists > t
ret[f'of{t}'] = float(ma.sum() / ma.size)
return ret
class RelativeDistanceMetric(BaseDistanceMetric):
def __init__(self, vec_length, p=2, **kwargs):
super().__init__(name=f'rel{p}', **kwargs)
self.vec_length = vec_length
self.p = p
def __init__(self, vec_length, p=2, **kwargs):
super().__init__(name=f'rel{p}', **kwargs)
self.vec_length = vec_length
self.p = p
def add(self, es, ta, ma=None):
if es.shape != ta.shape or es.shape[1] != self.vec_length or es.ndim != 2:
raise Exception('es and ta have to be of shape Nxdim')
dist = np.linalg.norm(es - ta, ord=self.p, axis=1)
denom = np.linalg.norm(ta, ord=self.p, axis=1)
dist /= denom
if ma is not None:
dist = dist[ma != 0]
self.dists.append(dist)
def add(self, es, ta, ma=None):
if es.shape != ta.shape or es.shape[1] != self.vec_length or es.ndim != 2:
raise Exception('es and ta have to be of shape Nxdim')
dist = np.linalg.norm(es - ta, ord=self.p, axis=1)
denom = np.linalg.norm(ta, ord=self.p, axis=1)
dist /= denom
if ma is not None:
dist = dist[ma != 0]
self.dists.append( dist )
class RotmDistanceMetric(BaseDistanceMetric):
def __init__(self, type='identity', **kwargs):
super().__init__(name=type, **kwargs)
self.type = type
def __init__(self, type='identity', **kwargs):
super().__init__(name=type, **kwargs)
self.type = type
def add(self, es, ta, ma=None):
if es.shape != ta.shape or es.shape[1] != 3 or es.shape[2] != 3 or es.ndim != 3:
print(es.shape, ta.shape)
raise Exception('es and ta have to be of shape Nx3x3')
if ma is not None:
raise Exception('mask is not implemented')
if self.type == 'identity':
self.dists.append(geometry.rotm_distance_identity(es, ta))
elif self.type == 'geodesic':
self.dists.append(geometry.rotm_distance_geodesic_unit_sphere(es, ta))
else:
raise Exception('invalid distance type')
def add(self, es, ta, ma=None):
if es.shape != ta.shape or es.shape[1] != 3 or es.shape[2] != 3 or es.ndim != 3:
print(es.shape, ta.shape)
raise Exception('es and ta have to be of shape Nx3x3')
if ma is not None:
raise Exception('mask is not implemented')
if self.type == 'identity':
self.dists.append( geometry.rotm_distance_identity(es, ta) )
elif self.type == 'geodesic':
self.dists.append( geometry.rotm_distance_geodesic_unit_sphere(es, ta) )
else:
raise Exception('invalid distance type')
class QuaternionDistanceMetric(BaseDistanceMetric):
def __init__(self, type='angle', **kwargs):
super().__init__(name=type, **kwargs)
self.type = type
def __init__(self, type='angle', **kwargs):
super().__init__(name=type, **kwargs)
self.type = type
def add(self, es, ta, ma=None):
if es.shape != ta.shape or es.shape[1] != 4 or es.ndim != 2:
print(es.shape, ta.shape)
raise Exception('es and ta have to be of shape Nx4')
if ma is not None:
raise Exception('mask is not implemented')
if self.type == 'angle':
self.dists.append( geometry.quat_distance_angle(es, ta) )
elif self.type == 'mineucl':
self.dists.append( geometry.quat_distance_mineucl(es, ta) )
elif self.type == 'normdiff':
self.dists.append( geometry.quat_distance_normdiff(es, ta) )
else:
raise Exception('invalid distance type')
def add(self, es, ta, ma=None):
if es.shape != ta.shape or es.shape[1] != 4 or es.ndim != 2:
print(es.shape, ta.shape)
raise Exception('es and ta have to be of shape Nx4')
if ma is not None:
raise Exception('mask is not implemented')
if self.type == 'angle':
self.dists.append(geometry.quat_distance_angle(es, ta))
elif self.type == 'mineucl':
self.dists.append(geometry.quat_distance_mineucl(es, ta))
elif self.type == 'normdiff':
self.dists.append(geometry.quat_distance_normdiff(es, ta))
else:
raise Exception('invalid distance type')
class BinaryAccuracyMetric(Metric):
def __init__(self, thresholds=np.linspace(0.0, 1.0, num=101, dtype=np.float64)[:-1], **kwargs):
self.thresholds = thresholds
super().__init__(**kwargs)
def __init__(self, thresholds=np.linspace(0.0, 1.0, num=101, dtype=np.float64)[:-1], **kwargs):
self.thresholds = thresholds
super().__init__(**kwargs)
def reset(self):
self.tps = [0 for wp in self.thresholds]
self.fps = [0 for wp in self.thresholds]
self.fns = [0 for wp in self.thresholds]
self.tns = [0 for wp in self.thresholds]
self.n_pos = 0
self.n_neg = 0
def reset(self):
self.tps = [0 for wp in self.thresholds]
self.fps = [0 for wp in self.thresholds]
self.fns = [0 for wp in self.thresholds]
self.tns = [0 for wp in self.thresholds]
self.n_pos = 0
self.n_neg = 0
def add(self, es, ta, ma=None):
if ma is not None:
raise Exception('mask is not implemented')
es = es.ravel()
ta = ta.ravel()
if es.shape[0] != ta.shape[0]:
raise Exception('invalid shape of es, or ta')
if es.min() < 0 or es.max() > 1:
raise Exception('estimate has wrong value range')
ta_p = (ta == 1)
ta_n = (ta == 0)
es_p = es[ta_p]
es_n = es[ta_n]
for idx, wp in enumerate(self.thresholds):
wp = np.asscalar(wp)
self.tps[idx] += (es_p > wp).sum()
self.fps[idx] += (es_n > wp).sum()
self.fns[idx] += (es_p <= wp).sum()
self.tns[idx] += (es_n <= wp).sum()
self.n_pos += ta_p.sum()
self.n_neg += ta_n.sum()
def add(self, es, ta, ma=None):
if ma is not None:
raise Exception('mask is not implemented')
es = es.ravel()
ta = ta.ravel()
if es.shape[0] != ta.shape[0]:
raise Exception('invalid shape of es, or ta')
if es.min() < 0 or es.max() > 1:
raise Exception('estimate has wrong value range')
ta_p = (ta == 1)
ta_n = (ta == 0)
es_p = es[ta_p]
es_n = es[ta_n]
for idx, wp in enumerate(self.thresholds):
wp = np.asscalar(wp)
self.tps[idx] += (es_p > wp).sum()
self.fps[idx] += (es_n > wp).sum()
self.fns[idx] += (es_p <= wp).sum()
self.tns[idx] += (es_n <= wp).sum()
self.n_pos += ta_p.sum()
self.n_neg += ta_n.sum()
def get(self):
tps = np.array(self.tps).astype(np.float32)
fps = np.array(self.fps).astype(np.float32)
fns = np.array(self.fns).astype(np.float32)
tns = np.array(self.tns).astype(np.float32)
wp = self.thresholds
def get(self):
tps = np.array(self.tps).astype(np.float32)
fps = np.array(self.fps).astype(np.float32)
fns = np.array(self.fns).astype(np.float32)
tns = np.array(self.tns).astype(np.float32)
wp = self.thresholds
ret = {}
ret = {}
precisions = np.divide(tps, tps + fps, out=np.zeros_like(tps), where=tps + fps != 0)
recalls = np.divide(tps, tps + fns, out=np.zeros_like(tps), where=tps + fns != 0) # tprs
fprs = np.divide(fps, fps + tns, out=np.zeros_like(tps), where=fps + tns != 0)
precisions = np.divide(tps, tps + fps, out=np.zeros_like(tps), where=tps + fps != 0)
recalls = np.divide(tps, tps + fns, out=np.zeros_like(tps), where=tps + fns != 0) # tprs
fprs = np.divide(fps, fps + tns, out=np.zeros_like(tps), where=fps + tns != 0)
precisions = np.r_[0, precisions, 1]
recalls = np.r_[1, recalls, 0]
fprs = np.r_[1, fprs, 0]
precisions = np.r_[0, precisions, 1]
recalls = np.r_[1, recalls, 0]
fprs = np.r_[1, fprs, 0]
ret['auc'] = float(-np.trapz(recalls, fprs))
ret['prauc'] = float(-np.trapz(precisions, recalls))
ret['ap'] = float(-(np.diff(recalls) * precisions[:-1]).sum())
ret['auc'] = float(-np.trapz(recalls, fprs))
ret['prauc'] = float(-np.trapz(precisions, recalls))
ret['ap'] = float(-(np.diff(recalls) * precisions[:-1]).sum())
accuracies = np.divide(tps + tns, tps + tns + fps + fns)
aacc = np.mean(accuracies)
for t in np.linspace(0,1,num=11)[1:-1]:
idx = np.argmin(np.abs(t - wp))
ret[f'acc{wp[idx]:.2f}'] = float(accuracies[idx])
accuracies = np.divide(tps + tns, tps + tns + fps + fns)
aacc = np.mean(accuracies)
for t in np.linspace(0, 1, num=11)[1:-1]:
idx = np.argmin(np.abs(t - wp))
ret[f'acc{wp[idx]:.2f}'] = float(accuracies[idx])
return ret
return ret

163
co/plt.py
View File

@ -6,94 +6,99 @@ import matplotlib.pyplot as plt
import os
import time
def save(path, remove_axis=False, dpi=300, fig=None):
if fig is None:
fig = plt.gcf()
dirname = os.path.dirname(path)
if dirname != '' and not os.path.exists(dirname):
os.makedirs(dirname)
if remove_axis:
for ax in fig.axes:
ax.axis('off')
ax.margins(0,0)
fig.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
for ax in fig.axes:
ax.xaxis.set_major_locator(plt.NullLocator())
ax.yaxis.set_major_locator(plt.NullLocator())
fig.savefig(path, dpi=dpi, bbox_inches='tight', pad_inches=0)
if fig is None:
fig = plt.gcf()
dirname = os.path.dirname(path)
if dirname != '' and not os.path.exists(dirname):
os.makedirs(dirname)
if remove_axis:
for ax in fig.axes:
ax.axis('off')
ax.margins(0, 0)
fig.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
for ax in fig.axes:
ax.xaxis.set_major_locator(plt.NullLocator())
ax.yaxis.set_major_locator(plt.NullLocator())
fig.savefig(path, dpi=dpi, bbox_inches='tight', pad_inches=0)
def color_map(im_, cmap='viridis', vmin=None, vmax=None):
cm = plt.get_cmap(cmap)
im = im_.copy()
if vmin is None:
vmin = np.nanmin(im)
if vmax is None:
vmax = np.nanmax(im)
mask = np.logical_not(np.isfinite(im))
im[mask] = vmin
im = (im.clip(vmin, vmax) - vmin) / (vmax - vmin)
im = cm(im)
im = im[...,:3]
for c in range(3):
im[mask, c] = 1
return im
cm = plt.get_cmap(cmap)
im = im_.copy()
if vmin is None:
vmin = np.nanmin(im)
if vmax is None:
vmax = np.nanmax(im)
mask = np.logical_not(np.isfinite(im))
im[mask] = vmin
im = (im.clip(vmin, vmax) - vmin) / (vmax - vmin)
im = cm(im)
im = im[..., :3]
for c in range(3):
im[mask, c] = 1
return im
def interactive_legend(leg=None, fig=None, all_axes=True):
if leg is None:
leg = plt.legend()
if fig is None:
fig = plt.gcf()
if all_axes:
axs = fig.get_axes()
else:
axs = [fig.gca()]
# lined = dict()
# lines = ax.lines
# for legline, origline in zip(leg.get_lines(), ax.lines):
# legline.set_picker(5)
# lined[legline] = origline
lined = dict()
for lidx, legline in enumerate(leg.get_lines()):
legline.set_picker(5)
lined[legline] = [ax.lines[lidx] for ax in axs]
def onpick(event):
if event.mouseevent.dblclick:
tmp = [(k,v) for k,v in lined.items()]
if leg is None:
leg = plt.legend()
if fig is None:
fig = plt.gcf()
if all_axes:
axs = fig.get_axes()
else:
tmp = [(event.artist, lined[event.artist])]
axs = [fig.gca()]
for legline, origline in tmp:
for ol in origline:
vis = not ol.get_visible()
ol.set_visible(vis)
if vis:
legline.set_alpha(1.0)
else:
legline.set_alpha(0.2)
fig.canvas.draw()
# lined = dict()
# lines = ax.lines
# for legline, origline in zip(leg.get_lines(), ax.lines):
# legline.set_picker(5)
# lined[legline] = origline
lined = dict()
for lidx, legline in enumerate(leg.get_lines()):
legline.set_picker(5)
lined[legline] = [ax.lines[lidx] for ax in axs]
def onpick(event):
if event.mouseevent.dblclick:
tmp = [(k, v) for k, v in lined.items()]
else:
tmp = [(event.artist, lined[event.artist])]
for legline, origline in tmp:
for ol in origline:
vis = not ol.get_visible()
ol.set_visible(vis)
if vis:
legline.set_alpha(1.0)
else:
legline.set_alpha(0.2)
fig.canvas.draw()
fig.canvas.mpl_connect('pick_event', onpick)
fig.canvas.mpl_connect('pick_event', onpick)
def non_annoying_pause(interval, focus_figure=False):
# https://github.com/matplotlib/matplotlib/issues/11131
backend = mpl.rcParams['backend']
if backend in _interactive_bk:
figManager = _pylab_helpers.Gcf.get_active()
if figManager is not None:
canvas = figManager.canvas
if canvas.figure.stale:
canvas.draw()
if focus_figure:
plt.show(block=False)
canvas.start_event_loop(interval)
return
time.sleep(interval)
# https://github.com/matplotlib/matplotlib/issues/11131
backend = mpl.rcParams['backend']
if backend in _interactive_bk:
figManager = _pylab_helpers.Gcf.get_active()
if figManager is not None:
canvas = figManager.canvas
if canvas.figure.stale:
canvas.draw()
if focus_figure:
plt.show(block=False)
canvas.start_event_loop(interval)
return
time.sleep(interval)
def remove_all_ticks(fig=None):
if fig is None:
fig = plt.gcf()
for ax in fig.axes:
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if fig is None:
fig = plt.gcf()
for ax in fig.axes:
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)

View File

@ -3,55 +3,60 @@ import matplotlib.pyplot as plt
from . import geometry
def image_matrix(ims, bgval=0):
n = ims.shape[0]
m = int( np.ceil(np.sqrt(n)) )
h = ims.shape[1]
w = ims.shape[2]
mat = np.empty((m*h, m*w), dtype=ims.dtype)
mat.fill(bgval)
idx = 0
for r in range(m):
for c in range(m):
if idx < n:
mat[r*h:(r+1)*h, c*w:(c+1)*w] = ims[idx]
idx += 1
return mat
n = ims.shape[0]
m = int(np.ceil(np.sqrt(n)))
h = ims.shape[1]
w = ims.shape[2]
mat = np.empty((m * h, m * w), dtype=ims.dtype)
mat.fill(bgval)
idx = 0
for r in range(m):
for c in range(m):
if idx < n:
mat[r * h:(r + 1) * h, c * w:(c + 1) * w] = ims[idx]
idx += 1
return mat
def image_cat(ims, vertical=False):
offx = [0]
offy = [0]
if vertical:
width = max([im.shape[1] for im in ims])
offx += [0 for im in ims[:-1]]
offy += [im.shape[0] for im in ims[:-1]]
height = sum([im.shape[0] for im in ims])
else:
height = max([im.shape[0] for im in ims])
offx += [im.shape[1] for im in ims[:-1]]
offy += [0 for im in ims[:-1]]
width = sum([im.shape[1] for im in ims])
offx = np.cumsum(offx)
offy = np.cumsum(offy)
offx = [0]
offy = [0]
if vertical:
width = max([im.shape[1] for im in ims])
offx += [0 for im in ims[:-1]]
offy += [im.shape[0] for im in ims[:-1]]
height = sum([im.shape[0] for im in ims])
else:
height = max([im.shape[0] for im in ims])
offx += [im.shape[1] for im in ims[:-1]]
offy += [0 for im in ims[:-1]]
width = sum([im.shape[1] for im in ims])
offx = np.cumsum(offx)
offy = np.cumsum(offy)
im = np.zeros((height,width,*ims[0].shape[2:]), dtype=ims[0].dtype)
for im0, ox, oy in zip(ims, offx, offy):
im[oy:oy + im0.shape[0], ox:ox + im0.shape[1]] = im0
im = np.zeros((height, width, *ims[0].shape[2:]), dtype=ims[0].dtype)
for im0, ox, oy in zip(ims, offx, offy):
im[oy:oy + im0.shape[0], ox:ox + im0.shape[1]] = im0
return im, offx, offy
return im, offx, offy
def line(li, h, w, ax=None, *args, **kwargs):
if ax is None:
ax = plt.gca()
xs = (-li[2] - li[1] * np.array((0, h-1))) / li[0]
ys = (-li[2] - li[0] * np.array((0, w-1))) / li[1]
pts = np.array([(0,ys[0]), (w-1, ys[1]), (xs[0], 0), (xs[1], h-1)])
pts = pts[np.logical_and(np.logical_and(pts[:,0] >= 0, pts[:,0] < w), np.logical_and(pts[:,1] >= 0, pts[:,1] < h))]
ax.plot(pts[:,0], pts[:,1], *args, **kwargs)
if ax is None:
ax = plt.gca()
xs = (-li[2] - li[1] * np.array((0, h - 1))) / li[0]
ys = (-li[2] - li[0] * np.array((0, w - 1))) / li[1]
pts = np.array([(0, ys[0]), (w - 1, ys[1]), (xs[0], 0), (xs[1], h - 1)])
pts = pts[
np.logical_and(np.logical_and(pts[:, 0] >= 0, pts[:, 0] < w), np.logical_and(pts[:, 1] >= 0, pts[:, 1] < h))]
ax.plot(pts[:, 0], pts[:, 1], *args, **kwargs)
def depthshow(depth, *args, ax=None, **kwargs):
if ax is None:
ax = plt.gca()
d = depth.copy()
d[d < 0] = np.NaN
ax.imshow(d, *args, **kwargs)
if ax is None:
ax = plt.gca()
d = depth.copy()
d[d < 0] = np.NaN
ax.imshow(d, *args, **kwargs)

View File

@ -4,35 +4,45 @@ from mpl_toolkits.mplot3d import Axes3D
from . import geometry
def ax3d(fig=None):
if fig is None:
fig = plt.gcf()
return fig.add_subplot(111, projection='3d')
if fig is None:
fig = plt.gcf()
return fig.add_subplot(111, projection='3d')
def plot_camera(ax=None, R=np.eye(3), t=np.zeros((3,)), size=25, marker_C='.', color='b', linestyle='-', linewidth=0.1, label=None, **kwargs):
if ax is None:
ax = plt.gca()
C0 = geometry.translation_to_cameracenter(R, t).ravel()
C1 = C0 + R.T.dot( np.array([[-size],[-size],[3*size]], dtype=np.float32) ).ravel()
C2 = C0 + R.T.dot( np.array([[-size],[+size],[3*size]], dtype=np.float32) ).ravel()
C3 = C0 + R.T.dot( np.array([[+size],[+size],[3*size]], dtype=np.float32) ).ravel()
C4 = C0 + R.T.dot( np.array([[+size],[-size],[3*size]], dtype=np.float32) ).ravel()
if marker_C != '':
ax.plot([C0[0]], [C0[1]], [C0[2]], marker=marker_C, color=color, label=label, **kwargs)
ax.plot([C0[0], C1[0]], [C0[1], C1[1]], [C0[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
ax.plot([C0[0], C2[0]], [C0[1], C2[1]], [C0[2], C2[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
ax.plot([C0[0], C3[0]], [C0[1], C3[1]], [C0[2], C3[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
ax.plot([C0[0], C4[0]], [C0[1], C4[1]], [C0[2], C4[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
ax.plot([C1[0], C2[0], C3[0], C4[0], C1[0]], [C1[1], C2[1], C3[1], C4[1], C1[1]], [C1[2], C2[2], C3[2], C4[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle, linewidth=linewidth, **kwargs)
def plot_camera(ax=None, R=np.eye(3), t=np.zeros((3,)), size=25, marker_C='.', color='b', linestyle='-', linewidth=0.1,
label=None, **kwargs):
if ax is None:
ax = plt.gca()
C0 = geometry.translation_to_cameracenter(R, t).ravel()
C1 = C0 + R.T.dot(np.array([[-size], [-size], [3 * size]], dtype=np.float32)).ravel()
C2 = C0 + R.T.dot(np.array([[-size], [+size], [3 * size]], dtype=np.float32)).ravel()
C3 = C0 + R.T.dot(np.array([[+size], [+size], [3 * size]], dtype=np.float32)).ravel()
C4 = C0 + R.T.dot(np.array([[+size], [-size], [3 * size]], dtype=np.float32)).ravel()
if marker_C != '':
ax.plot([C0[0]], [C0[1]], [C0[2]], marker=marker_C, color=color, label=label, **kwargs)
ax.plot([C0[0], C1[0]], [C0[1], C1[1]], [C0[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle,
linewidth=linewidth, **kwargs)
ax.plot([C0[0], C2[0]], [C0[1], C2[1]], [C0[2], C2[2]], color=color, label='_nolegend_', linestyle=linestyle,
linewidth=linewidth, **kwargs)
ax.plot([C0[0], C3[0]], [C0[1], C3[1]], [C0[2], C3[2]], color=color, label='_nolegend_', linestyle=linestyle,
linewidth=linewidth, **kwargs)
ax.plot([C0[0], C4[0]], [C0[1], C4[1]], [C0[2], C4[2]], color=color, label='_nolegend_', linestyle=linestyle,
linewidth=linewidth, **kwargs)
ax.plot([C1[0], C2[0], C3[0], C4[0], C1[0]], [C1[1], C2[1], C3[1], C4[1], C1[1]],
[C1[2], C2[2], C3[2], C4[2], C1[2]], color=color, label='_nolegend_', linestyle=linestyle,
linewidth=linewidth, **kwargs)
def axis_equal(ax=None):
if ax is None:
ax = plt.gca()
extents = np.array([getattr(ax, 'get_{}lim'.format(dim))() for dim in 'xyz'])
sz = extents[:,1] - extents[:,0]
centers = np.mean(extents, axis=1)
maxsize = max(abs(sz))
r = maxsize/2
for ctr, dim in zip(centers, 'xyz'):
getattr(ax, 'set_{}lim'.format(dim))(ctr - r, ctr + r)
if ax is None:
ax = plt.gca()
extents = np.array([getattr(ax, 'get_{}lim'.format(dim))() for dim in 'xyz'])
sz = extents[:, 1] - extents[:, 0]
centers = np.mean(extents, axis=1)
maxsize = max(abs(sz))
r = maxsize / 2
for ctr, dim in zip(centers, 'xyz'):
getattr(ax, 'set_{}lim'.format(dim))(ctr - r, ctr + r)

View File

@ -3,443 +3,453 @@ import pandas as pd
import enum
import itertools
class Table(object):
def __init__(self, n_cols):
self.n_cols = n_cols
self.rows = []
self.aligns = ['r' for c in range(n_cols)]
def __init__(self, n_cols):
self.n_cols = n_cols
self.rows = []
self.aligns = ['r' for c in range(n_cols)]
def get_cell_align(self, r, c):
align = self.rows[r].cells[c].align
if align is None:
return self.aligns[c]
else:
return align
def add_row(self, row):
if row.ncols() != self.n_cols:
raise Exception(f'row has invalid number of cols, {row.ncols()} vs. {self.n_cols}')
self.rows.append(row)
def empty_row(self):
return Row.Empty(self.n_cols)
def expand_rows(self, n_add_cols=1):
if n_add_cols < 0: raise Exception('n_add_cols has to be positive')
self.n_cols += n_add_cols
for row in self.rows:
row.cells.extend([Cell() for cidx in range(n_add_cols)])
def add_block(self, data, row=-1, col=0, fmt=None, expand=False):
if row < 0: row = len(self.rows)
while len(self.rows) < row + len(data):
self.add_row(self.empty_row())
for r in range(len(data)):
cols = data[r]
if col + len(cols) > self.n_cols:
if expand:
self.expand_rows(col + len(cols) - self.n_cols)
def get_cell_align(self, r, c):
align = self.rows[r].cells[c].align
if align is None:
return self.aligns[c]
else:
raise Exception('number of cols does not fit in table')
for c in range(len(cols)):
self.rows[row+r].cells[col+c] = Cell(data[r][c], fmt)
return align
def add_row(self, row):
if row.ncols() != self.n_cols:
raise Exception(f'row has invalid number of cols, {row.ncols()} vs. {self.n_cols}')
self.rows.append(row)
def empty_row(self):
return Row.Empty(self.n_cols)
def expand_rows(self, n_add_cols=1):
if n_add_cols < 0: raise Exception('n_add_cols has to be positive')
self.n_cols += n_add_cols
for row in self.rows:
row.cells.extend([Cell() for cidx in range(n_add_cols)])
def add_block(self, data, row=-1, col=0, fmt=None, expand=False):
if row < 0: row = len(self.rows)
while len(self.rows) < row + len(data):
self.add_row(self.empty_row())
for r in range(len(data)):
cols = data[r]
if col + len(cols) > self.n_cols:
if expand:
self.expand_rows(col + len(cols) - self.n_cols)
else:
raise Exception('number of cols does not fit in table')
for c in range(len(cols)):
self.rows[row + r].cells[col + c] = Cell(data[r][c], fmt)
class Row(object):
def __init__(self, cells, pre_separator=None, post_separator=None):
self.cells = cells
self.pre_separator = pre_separator
self.post_separator = post_separator
def __init__(self, cells, pre_separator=None, post_separator=None):
self.cells = cells
self.pre_separator = pre_separator
self.post_separator = post_separator
@classmethod
def Empty(cls, n_cols):
return Row([Cell() for c in range(n_cols)])
@classmethod
def Empty(cls, n_cols):
return Row([Cell() for c in range(n_cols)])
def add_cell(self, cell):
self.cells.append(cell)
def ncols(self):
return sum([c.span for c in self.cells])
def add_cell(self, cell):
self.cells.append(cell)
def ncols(self):
return sum([c.span for c in self.cells])
class Color(object):
def __init__(self, color=(0,0,0), fmt='rgb'):
if fmt == 'rgb':
self.color = color
elif fmt == 'RGB':
self.color = tuple(c / 255 for c in color)
else:
return Exception('invalid color format')
def __init__(self, color=(0, 0, 0), fmt='rgb'):
if fmt == 'rgb':
self.color = color
elif fmt == 'RGB':
self.color = tuple(c / 255 for c in color)
else:
return Exception('invalid color format')
def as_rgb(self):
return self.color
def as_rgb(self):
return self.color
def as_RGB(self):
return tuple(int(c * 255) for c in self.color)
def as_RGB(self):
return tuple(int(c * 255) for c in self.color)
@classmethod
def rgb(cls, r, g, b):
return Color(color=(r,g,b), fmt='rgb')
@classmethod
def rgb(cls, r, g, b):
return Color(color=(r, g, b), fmt='rgb')
@classmethod
def RGB(cls, r, g, b):
return Color(color=(r,g,b), fmt='RGB')
@classmethod
def RGB(cls, r, g, b):
return Color(color=(r, g, b), fmt='RGB')
class CellFormat(object):
def __init__(self, fmt='%s', fgcolor=None, bgcolor=None, bold=False):
self.fmt = fmt
self.fgcolor = fgcolor
self.bgcolor = bgcolor
self.bold = bold
def __init__(self, fmt='%s', fgcolor=None, bgcolor=None, bold=False):
self.fmt = fmt
self.fgcolor = fgcolor
self.bgcolor = bgcolor
self.bold = bold
class Cell(object):
def __init__(self, data=None, fmt=None, span=1, align=None):
self.data = data
if fmt is None:
fmt = CellFormat()
self.fmt = fmt
self.span = span
self.align = align
def __init__(self, data=None, fmt=None, span=1, align=None):
self.data = data
if fmt is None:
fmt = CellFormat()
self.fmt = fmt
self.span = span
self.align = align
def __str__(self):
return self.fmt.fmt % self.data
def __str__(self):
return self.fmt.fmt % self.data
class Separator(enum.Enum):
HEAD = 1
BOTTOM = 2
INNER = 3
HEAD = 1
BOTTOM = 2
INNER = 3
class Renderer(object):
def __init__(self):
pass
def __init__(self):
pass
def cell_str_len(self, cell):
return len(str(cell))
def cell_str_len(self, cell):
return len(str(cell))
def col_widths(self, table):
widths = [0 for c in range(table.n_cols)]
for row in table.rows:
cidx = 0
for cell in row.cells:
if cell.span == 1:
strlen = self.cell_str_len(cell)
widths[cidx] = max(widths[cidx], strlen)
cidx += cell.span
return widths
def col_widths(self, table):
widths = [0 for c in range(table.n_cols)]
for row in table.rows:
cidx = 0
for cell in row.cells:
if cell.span == 1:
strlen = self.cell_str_len(cell)
widths[cidx] = max(widths[cidx], strlen)
cidx += cell.span
return widths
def render(self, table):
raise NotImplementedError('not implemented')
def render(self, table):
raise NotImplementedError('not implemented')
def __call__(self, table):
return self.render(table)
def __call__(self, table):
return self.render(table)
def render_to_file_comment(self):
return ''
def render_to_file_comment(self):
return ''
def render_to_file(self, path, table):
txt = self.render(table)
with open(path, 'w') as fp:
fp.write(txt)
def render_to_file(self, path, table):
txt = self.render(table)
with open(path, 'w') as fp:
fp.write(txt)
class TerminalRenderer(Renderer):
def __init__(self, col_sep=' '):
super().__init__()
self.col_sep = col_sep
def __init__(self, col_sep=' '):
super().__init__()
self.col_sep = col_sep
def render_cell(self, table, row, col, widths):
cell = table.rows[row].cells[col]
str = cell.fmt.fmt % cell.data
str_width = len(str)
cell_width = sum([widths[idx] for idx in range(col, col+cell.span)])
cell_width += len(self.col_sep) * (cell.span - 1)
if len(str) > cell_width:
str = str[:cell_width]
if cell.fmt.bold:
# str = sty.ef.bold + str + sty.rs.bold_dim
# str = sty.ef.bold + str + sty.rs.bold
pass
if cell.fmt.fgcolor is not None:
# color = cell.fmt.fgcolor.as_RGB()
# str = sty.fg(*color) + str + sty.rs.fg
pass
if str_width < cell_width:
n_ws = (cell_width - str_width)
if table.get_cell_align(row, col) == 'r':
str = ' '*n_ws + str
elif table.get_cell_align(row, col) == 'l':
str = str + ' '*n_ws
elif table.get_cell_align(row, col) == 'c':
n_ws1 = n_ws // 2
n_ws0 = n_ws - n_ws1
str = ' '*n_ws0 + str + ' '*n_ws1
if cell.fmt.bgcolor is not None:
# color = cell.fmt.bgcolor.as_RGB()
# str = sty.bg(*color) + str + sty.rs.bg
pass
return str
def render_cell(self, table, row, col, widths):
cell = table.rows[row].cells[col]
str = cell.fmt.fmt % cell.data
str_width = len(str)
cell_width = sum([widths[idx] for idx in range(col, col + cell.span)])
cell_width += len(self.col_sep) * (cell.span - 1)
if len(str) > cell_width:
str = str[:cell_width]
if cell.fmt.bold:
# str = sty.ef.bold + str + sty.rs.bold_dim
# str = sty.ef.bold + str + sty.rs.bold
pass
if cell.fmt.fgcolor is not None:
# color = cell.fmt.fgcolor.as_RGB()
# str = sty.fg(*color) + str + sty.rs.fg
pass
if str_width < cell_width:
n_ws = (cell_width - str_width)
if table.get_cell_align(row, col) == 'r':
str = ' ' * n_ws + str
elif table.get_cell_align(row, col) == 'l':
str = str + ' ' * n_ws
elif table.get_cell_align(row, col) == 'c':
n_ws1 = n_ws // 2
n_ws0 = n_ws - n_ws1
str = ' ' * n_ws0 + str + ' ' * n_ws1
if cell.fmt.bgcolor is not None:
# color = cell.fmt.bgcolor.as_RGB()
# str = sty.bg(*color) + str + sty.rs.bg
pass
return str
def render_separator(self, separator, tab, col_widths, total_width):
if separator == Separator.HEAD:
return '='*total_width
elif separator == Separator.INNER:
return '-'*total_width
elif separator == Separator.BOTTOM:
return '='*total_width
def render_separator(self, separator, tab, col_widths, total_width):
if separator == Separator.HEAD:
return '=' * total_width
elif separator == Separator.INNER:
return '-' * total_width
elif separator == Separator.BOTTOM:
return '=' * total_width
def render(self, table):
widths = self.col_widths(table)
total_width = sum(widths) + len(self.col_sep) * (table.n_cols - 1)
lines = []
for ridx, row in enumerate(table.rows):
if row.pre_separator is not None:
sepline = self.render_separator(row.pre_separator, table, widths, total_width)
if len(sepline) > 0:
lines.append(sepline)
line = []
for cidx, cell in enumerate(row.cells):
line.append(self.render_cell(table, ridx, cidx, widths))
lines.append(self.col_sep.join(line))
if row.post_separator is not None:
sepline = self.render_separator(row.post_separator, table, widths, total_width)
if len(sepline) > 0:
lines.append(sepline)
return '\n'.join(lines)
def render(self, table):
widths = self.col_widths(table)
total_width = sum(widths) + len(self.col_sep) * (table.n_cols - 1)
lines = []
for ridx, row in enumerate(table.rows):
if row.pre_separator is not None:
sepline = self.render_separator(row.pre_separator, table, widths, total_width)
if len(sepline) > 0:
lines.append(sepline)
line = []
for cidx, cell in enumerate(row.cells):
line.append(self.render_cell(table, ridx, cidx, widths))
lines.append(self.col_sep.join(line))
if row.post_separator is not None:
sepline = self.render_separator(row.post_separator, table, widths, total_width)
if len(sepline) > 0:
lines.append(sepline)
return '\n'.join(lines)
class MarkdownRenderer(TerminalRenderer):
def __init__(self):
super().__init__(col_sep='|')
self.printed_color_warning = False
def __init__(self):
super().__init__(col_sep='|')
self.printed_color_warning = False
def print_color_warning(self):
if not self.printed_color_warning:
print('[WARNING] MarkdownRenderer does not support color yet')
self.printed_color_warning = True
def print_color_warning(self):
if not self.printed_color_warning:
print('[WARNING] MarkdownRenderer does not support color yet')
self.printed_color_warning = True
def cell_str_len(self, cell):
strlen = len(str(cell))
if cell.fmt.bold:
strlen += 4
strlen = max(5, strlen)
return strlen
def cell_str_len(self, cell):
strlen = len(str(cell))
if cell.fmt.bold:
strlen += 4
strlen = max(5, strlen)
return strlen
def render_cell(self, table, row, col, widths):
cell = table.rows[row].cells[col]
str = cell.fmt.fmt % cell.data
if cell.fmt.bold:
str = f'**{str}**'
def render_cell(self, table, row, col, widths):
cell = table.rows[row].cells[col]
str = cell.fmt.fmt % cell.data
if cell.fmt.bold:
str = f'**{str}**'
str_width = len(str)
cell_width = sum([widths[idx] for idx in range(col, col+cell.span)])
cell_width += len(self.col_sep) * (cell.span - 1)
if len(str) > cell_width:
str = str[:cell_width]
else:
n_ws = (cell_width - str_width)
if table.get_cell_align(row, col) == 'r':
str = ' '*n_ws + str
elif table.get_cell_align(row, col) == 'l':
str = str + ' '*n_ws
elif table.get_cell_align(row, col) == 'c':
n_ws1 = n_ws // 2
n_ws0 = n_ws - n_ws1
str = ' '*n_ws0 + str + ' '*n_ws1
str_width = len(str)
cell_width = sum([widths[idx] for idx in range(col, col + cell.span)])
cell_width += len(self.col_sep) * (cell.span - 1)
if len(str) > cell_width:
str = str[:cell_width]
else:
n_ws = (cell_width - str_width)
if table.get_cell_align(row, col) == 'r':
str = ' ' * n_ws + str
elif table.get_cell_align(row, col) == 'l':
str = str + ' ' * n_ws
elif table.get_cell_align(row, col) == 'c':
n_ws1 = n_ws // 2
n_ws0 = n_ws - n_ws1
str = ' ' * n_ws0 + str + ' ' * n_ws1
if col == 0: str = self.col_sep + str
if col == table.n_cols - 1: str += self.col_sep
if col == 0: str = self.col_sep + str
if col == table.n_cols - 1: str += self.col_sep
if cell.fmt.fgcolor is not None:
self.print_color_warning()
if cell.fmt.bgcolor is not None:
self.print_color_warning()
return str
if cell.fmt.fgcolor is not None:
self.print_color_warning()
if cell.fmt.bgcolor is not None:
self.print_color_warning()
return str
def render_separator(self, separator, tab, widths, total_width):
sep = ''
if separator == Separator.INNER:
sep = self.col_sep
for idx, width in enumerate(widths):
csep = '-' * (width - 2)
if tab.get_cell_align(1, idx) == 'r':
csep = '-' + csep + ':'
elif tab.get_cell_align(1, idx) == 'l':
csep = ':' + csep + '-'
elif tab.get_cell_align(1, idx) == 'c':
csep = ':' + csep + ':'
sep += csep + self.col_sep
return sep
def render_separator(self, separator, tab, widths, total_width):
sep = ''
if separator == Separator.INNER:
sep = self.col_sep
for idx, width in enumerate(widths):
csep = '-' * (width - 2)
if tab.get_cell_align(1, idx) == 'r':
csep = '-' + csep + ':'
elif tab.get_cell_align(1, idx) == 'l':
csep = ':' + csep + '-'
elif tab.get_cell_align(1, idx) == 'c':
csep = ':' + csep + ':'
sep += csep + self.col_sep
return sep
class LatexRenderer(Renderer):
def __init__(self):
super().__init__()
def __init__(self):
super().__init__()
def render_cell(self, table, row, col):
cell = table.rows[row].cells[col]
str = cell.fmt.fmt % cell.data
if cell.fmt.bold:
str = '{\\bf '+ str + '}'
if cell.fmt.fgcolor is not None:
color = cell.fmt.fgcolor.as_rgb()
str = f'{{\\color[rgb]{{{color[0]},{color[1]},{color[2]}}} ' + str + '}'
if cell.fmt.bgcolor is not None:
color = cell.fmt.bgcolor.as_rgb()
str = f'\\cellcolor[rgb]{{{color[0]},{color[1]},{color[2]}}} ' + str
align = table.get_cell_align(row, col)
if cell.span != 1 or align != table.aligns[col]:
str = f'\\multicolumn{{{cell.span}}}{{{align}}}{{{str}}}'
return str
def render_cell(self, table, row, col):
cell = table.rows[row].cells[col]
str = cell.fmt.fmt % cell.data
if cell.fmt.bold:
str = '{\\bf ' + str + '}'
if cell.fmt.fgcolor is not None:
color = cell.fmt.fgcolor.as_rgb()
str = f'{{\\color[rgb]{{{color[0]},{color[1]},{color[2]}}} ' + str + '}'
if cell.fmt.bgcolor is not None:
color = cell.fmt.bgcolor.as_rgb()
str = f'\\cellcolor[rgb]{{{color[0]},{color[1]},{color[2]}}} ' + str
align = table.get_cell_align(row, col)
if cell.span != 1 or align != table.aligns[col]:
str = f'\\multicolumn{{{cell.span}}}{{{align}}}{{{str}}}'
return str
def render_separator(self, separator):
if separator == Separator.HEAD:
return '\\toprule'
elif separator == Separator.INNER:
return '\\midrule'
elif separator == Separator.BOTTOM:
return '\\bottomrule'
def render_separator(self, separator):
if separator == Separator.HEAD:
return '\\toprule'
elif separator == Separator.INNER:
return '\\midrule'
elif separator == Separator.BOTTOM:
return '\\bottomrule'
def render(self, table):
lines = ['\\begin{tabular}{' + ''.join(table.aligns) + '}']
for ridx, row in enumerate(table.rows):
if row.pre_separator is not None:
lines.append(self.render_separator(row.pre_separator))
line = []
for cidx, cell in enumerate(row.cells):
line.append(self.render_cell(table, ridx, cidx))
lines.append(' & '.join(line) + ' \\\\')
if row.post_separator is not None:
lines.append(self.render_separator(row.post_separator))
lines.append('\\end{tabular}')
return '\n'.join(lines)
def render(self, table):
lines = ['\\begin{tabular}{' + ''.join(table.aligns) + '}']
for ridx, row in enumerate(table.rows):
if row.pre_separator is not None:
lines.append(self.render_separator(row.pre_separator))
line = []
for cidx, cell in enumerate(row.cells):
line.append(self.render_cell(table, ridx, cidx))
lines.append(' & '.join(line) + ' \\\\')
if row.post_separator is not None:
lines.append(self.render_separator(row.post_separator))
lines.append('\\end{tabular}')
return '\n'.join(lines)
class HtmlRenderer(Renderer):
def __init__(self, html_class='result_table'):
super().__init__()
self.html_class = html_class
def __init__(self, html_class='result_table'):
super().__init__()
self.html_class = html_class
def render_cell(self, table, row, col):
cell = table.rows[row].cells[col]
str = cell.fmt.fmt % cell.data
styles = []
if cell.fmt.bold:
styles.append('font-weight: bold;')
if cell.fmt.fgcolor is not None:
color = cell.fmt.fgcolor.as_RGB()
styles.append(f'color: rgb({color[0]},{color[1]},{color[2]});')
if cell.fmt.bgcolor is not None:
color = cell.fmt.bgcolor.as_RGB()
styles.append(f'background-color: rgb({color[0]},{color[1]},{color[2]});')
align = table.get_cell_align(row, col)
if align == 'l': align = 'left'
elif align == 'r': align = 'right'
elif align == 'c': align = 'center'
else: raise Exception('invalid align')
styles.append(f'text-align: {align};')
row = table.rows[row]
if row.pre_separator is not None:
styles.append(f'border-top: {self.render_separator(row.pre_separator)};')
if row.post_separator is not None:
styles.append(f'border-bottom: {self.render_separator(row.post_separator)};')
style = ' '.join(styles)
str = f' <td style="{style}" colspan="{cell.span}">{str}</td>\n'
return str
def render_cell(self, table, row, col):
cell = table.rows[row].cells[col]
str = cell.fmt.fmt % cell.data
styles = []
if cell.fmt.bold:
styles.append('font-weight: bold;')
if cell.fmt.fgcolor is not None:
color = cell.fmt.fgcolor.as_RGB()
styles.append(f'color: rgb({color[0]},{color[1]},{color[2]});')
if cell.fmt.bgcolor is not None:
color = cell.fmt.bgcolor.as_RGB()
styles.append(f'background-color: rgb({color[0]},{color[1]},{color[2]});')
align = table.get_cell_align(row, col)
if align == 'l':
align = 'left'
elif align == 'r':
align = 'right'
elif align == 'c':
align = 'center'
else:
raise Exception('invalid align')
styles.append(f'text-align: {align};')
row = table.rows[row]
if row.pre_separator is not None:
styles.append(f'border-top: {self.render_separator(row.pre_separator)};')
if row.post_separator is not None:
styles.append(f'border-bottom: {self.render_separator(row.post_separator)};')
style = ' '.join(styles)
str = f' <td style="{style}" colspan="{cell.span}">{str}</td>\n'
return str
def render_separator(self, separator):
if separator == Separator.HEAD:
return '1.5pt solid black'
elif separator == Separator.INNER:
return '0.75pt solid black'
elif separator == Separator.BOTTOM:
return '1.5pt solid black'
def render_separator(self, separator):
if separator == Separator.HEAD:
return '1.5pt solid black'
elif separator == Separator.INNER:
return '0.75pt solid black'
elif separator == Separator.BOTTOM:
return '1.5pt solid black'
def render(self, table):
lines = [f'<table width="100%" style="border-collapse: collapse" class={self.html_class}>']
for ridx, row in enumerate(table.rows):
line = [f' <tr>\n']
for cidx, cell in enumerate(row.cells):
line.append(self.render_cell(table, ridx, cidx))
line.append(' </tr>\n')
lines.append(' '.join(line))
lines.append('</table>')
return '\n'.join(lines)
def render(self, table):
lines = [f'<table width="100%" style="border-collapse: collapse" class={self.html_class}>']
for ridx, row in enumerate(table.rows):
line = [f' <tr>\n']
for cidx, cell in enumerate(row.cells):
line.append(self.render_cell(table, ridx, cidx))
line.append(' </tr>\n')
lines.append(' '.join(line))
lines.append('</table>')
return '\n'.join(lines)
def pandas_to_table(rowname, colname, valname, data, val_cell_fmt=CellFormat(fmt='%.4f'), best_val_cell_fmt=CellFormat(fmt='%.4f', bold=True), best_is_max=[]):
rnames = data[rowname].unique()
cnames = data[colname].unique()
tab = Table(1+len(cnames))
def pandas_to_table(rowname, colname, valname, data, val_cell_fmt=CellFormat(fmt='%.4f'),
best_val_cell_fmt=CellFormat(fmt='%.4f', bold=True), best_is_max=[]):
rnames = data[rowname].unique()
cnames = data[colname].unique()
tab = Table(1 + len(cnames))
header = [Cell('', align='r')]
header.extend([Cell(h, align='r') for h in cnames])
header = Row(header, pre_separator=Separator.HEAD, post_separator=Separator.INNER)
tab.add_row(header)
for rname in rnames:
cells = [Cell(rname, align='l')]
for cname in cnames:
cdata = data[data[colname] == cname]
if cname in best_is_max:
bestval = cdata[valname].max()
val = cdata[cdata[rowname] == rname][valname].max()
else:
bestval = cdata[valname].min()
val = cdata[cdata[rowname] == rname][valname].min()
if val == bestval:
fmt = best_val_cell_fmt
else:
fmt = val_cell_fmt
cells.append(Cell(val, align='r', fmt=fmt))
tab.add_row(Row(cells))
tab.rows[-1].post_separator = Separator.BOTTOM
return tab
header = [Cell('', align='r')]
header.extend([Cell(h, align='r') for h in cnames])
header = Row(header, pre_separator=Separator.HEAD, post_separator=Separator.INNER)
tab.add_row(header)
for rname in rnames:
cells = [Cell(rname, align='l')]
for cname in cnames:
cdata = data[data[colname] == cname]
if cname in best_is_max:
bestval = cdata[valname].max()
val = cdata[cdata[rowname] == rname][valname].max()
else:
bestval = cdata[valname].min()
val = cdata[cdata[rowname] == rname][valname].min()
if val == bestval:
fmt = best_val_cell_fmt
else:
fmt = val_cell_fmt
cells.append(Cell(val, align='r', fmt=fmt))
tab.add_row(Row(cells))
tab.rows[-1].post_separator = Separator.BOTTOM
return tab
if __name__ == '__main__':
# df = pd.read_pickle('full.df')
# best_is_max = ['movF0.5', 'movF1.0']
# tab = pandas_to_table(rowname='method', colname='metric', valname='val', data=df, best_is_max=best_is_max)
# df = pd.read_pickle('full.df')
# best_is_max = ['movF0.5', 'movF1.0']
# tab = pandas_to_table(rowname='method', colname='metric', valname='val', data=df, best_is_max=best_is_max)
# renderer = TerminalRenderer()
# print(renderer(tab))
# renderer = TerminalRenderer()
# print(renderer(tab))
tab = Table(7)
# header = Row([Cell('header', span=7, align='c')], pre_separator=Separator.HEAD, post_separator=Separator.INNER)
# tab.add_row(header)
# header2 = Row([Cell('thisisaverylongheader', span=4, align='c'), Cell('vals2', span=3, align='c')], post_separator=Separator.INNER)
# tab.add_row(header2)
tab.add_row(Row([Cell(f'c{c}') for c in range(7)]))
tab.rows[-1].post_separator = Separator.INNER
tab.add_block(np.arange(15*7).reshape(15,7))
tab.rows[4].cells[2].fmt = CellFormat(bold=True)
tab.rows[2].cells[1].fmt = CellFormat(fgcolor=Color.rgb(0.2,0.6,0.1))
tab.rows[2].cells[2].fmt = CellFormat(bgcolor=Color.rgb(0.7,0.1,0.5))
tab.rows[5].cells[3].fmt = CellFormat(bold=True,bgcolor=Color.rgb(0.7,0.1,0.5),fgcolor=Color.rgb(0.1,0.1,0.1))
tab.rows[-1].post_separator = Separator.BOTTOM
tab = Table(7)
# header = Row([Cell('header', span=7, align='c')], pre_separator=Separator.HEAD, post_separator=Separator.INNER)
# tab.add_row(header)
# header2 = Row([Cell('thisisaverylongheader', span=4, align='c'), Cell('vals2', span=3, align='c')], post_separator=Separator.INNER)
# tab.add_row(header2)
tab.add_row(Row([Cell(f'c{c}') for c in range(7)]))
tab.rows[-1].post_separator = Separator.INNER
tab.add_block(np.arange(15 * 7).reshape(15, 7))
tab.rows[4].cells[2].fmt = CellFormat(bold=True)
tab.rows[2].cells[1].fmt = CellFormat(fgcolor=Color.rgb(0.2, 0.6, 0.1))
tab.rows[2].cells[2].fmt = CellFormat(bgcolor=Color.rgb(0.7, 0.1, 0.5))
tab.rows[5].cells[3].fmt = CellFormat(bold=True, bgcolor=Color.rgb(0.7, 0.1, 0.5), fgcolor=Color.rgb(0.1, 0.1, 0.1))
tab.rows[-1].post_separator = Separator.BOTTOM
renderer = TerminalRenderer()
print(renderer(tab))
renderer = MarkdownRenderer()
print(renderer(tab))
renderer = TerminalRenderer()
print(renderer(tab))
renderer = MarkdownRenderer()
print(renderer(tab))
# renderer = HtmlRenderer()
# html_tab = renderer(tab)
# print(html_tab)
# with open('test.html', 'w') as fp:
# fp.write(html_tab)
# renderer = HtmlRenderer()
# html_tab = renderer(tab)
# print(html_tab)
# with open('test.html', 'w') as fp:
# fp.write(html_tab)
# import latex
# import latex
# renderer = LatexRenderer()
# ltx_tab = renderer(tab)
# print(ltx_tab)
# renderer = LatexRenderer()
# ltx_tab = renderer(tab)
# print(ltx_tab)
# with open('test.tex', 'w') as fp:
# latex.write_doc_prefix(fp, document_class='article')
# fp.write('this is text that should appear before the table and should be long enough to wrap around.\n'*40)
# fp.write('\\begin{table}')
# fp.write(ltx_tab)
# fp.write('\\end{table}')
# fp.write('this is text that should appear after the table and should be long enough to wrap around.\n'*40)
# latex.write_doc_suffix(fp)
# with open('test.tex', 'w') as fp:
# latex.write_doc_prefix(fp, document_class='article')
# fp.write('this is text that should appear before the table and should be long enough to wrap around.\n'*40)
# fp.write('\\begin{table}')
# fp.write(ltx_tab)
# fp.write('\\end{table}')
# fp.write('this is text that should appear after the table and should be long enough to wrap around.\n'*40)
# latex.write_doc_suffix(fp)

View File

@ -8,6 +8,7 @@ import re
import pickle
import subprocess
def str2bool(v):
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
@ -16,71 +17,74 @@ def str2bool(v):
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
class StopWatch(object):
def __init__(self):
self.timings = OrderedDict()
self.starts = {}
def __init__(self):
self.timings = OrderedDict()
self.starts = {}
def start(self, name):
self.starts[name] = time.time()
def start(self, name):
self.starts[name] = time.time()
def stop(self, name):
if name not in self.timings:
self.timings[name] = []
self.timings[name].append(time.time() - self.starts[name])
def stop(self, name):
if name not in self.timings:
self.timings[name] = []
self.timings[name].append(time.time() - self.starts[name])
def get(self, name=None, reduce=np.sum):
if name is not None:
return reduce(self.timings[name])
else:
ret = {}
for k in self.timings:
ret[k] = reduce(self.timings[k])
return ret
def get(self, name=None, reduce=np.sum):
if name is not None:
return reduce(self.timings[name])
else:
ret = {}
for k in self.timings:
ret[k] = reduce(self.timings[k])
return ret
def __repr__(self):
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
def __str__(self):
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
def __repr__(self):
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
def __str__(self):
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
class ETA(object):
def __init__(self, length):
self.length = length
self.start_time = time.time()
self.current_idx = 0
self.current_time = time.time()
def __init__(self, length):
self.length = length
self.start_time = time.time()
self.current_idx = 0
self.current_time = time.time()
def update(self, idx):
self.current_idx = idx
self.current_time = time.time()
def update(self, idx):
self.current_idx = idx
self.current_time = time.time()
def get_elapsed_time(self):
return self.current_time - self.start_time
def get_elapsed_time(self):
return self.current_time - self.start_time
def get_item_time(self):
return self.get_elapsed_time() / (self.current_idx + 1)
def get_item_time(self):
return self.get_elapsed_time() / (self.current_idx + 1)
def get_remaining_time(self):
return self.get_item_time() * (self.length - self.current_idx + 1)
def get_remaining_time(self):
return self.get_item_time() * (self.length - self.current_idx + 1)
def format_time(self, seconds):
minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(minutes, 60)
hours = int(hours)
minutes = int(minutes)
return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}'
def format_time(self, seconds):
minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(minutes, 60)
hours = int(hours)
minutes = int(minutes)
return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}'
def get_elapsed_time_str(self):
return self.format_time(self.get_elapsed_time())
def get_elapsed_time_str(self):
return self.format_time(self.get_elapsed_time())
def get_remaining_time_str(self):
return self.format_time(self.get_remaining_time())
def get_remaining_time_str(self):
return self.format_time(self.get_remaining_time())
def git_hash(cwd=None):
ret = subprocess.run(['git', 'describe', '--always'], cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
hash = ret.stdout
if hash is not None and 'fatal' not in hash.decode():
return hash.decode().strip()
else:
return None
ret = subprocess.run(['git', 'describe', '--always'], cwd=cwd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT)
hash = ret.stdout
if hash is not None and 'fatal' not in hash.decode():
return hash.decode().strip()
else:
return None

View File

@ -4,107 +4,109 @@ import cv2
def get_patterns(path='syn', imsizes=[], crop=True):
pattern_size = imsizes[0]
if path == 'syn':
np.random.seed(42)
pattern = np.random.uniform(0,1, size=pattern_size)
pattern = (pattern < 0.1).astype(np.float32)
pattern.reshape(*imsizes[0])
else:
pattern = cv2.imread(path)
pattern = pattern.astype(np.float32)
pattern /= 255
if pattern.ndim == 2:
pattern = np.stack([pattern for idx in range(3)], axis=2)
if crop and pattern.shape[0] > pattern_size[0] and pattern.shape[1] > pattern_size[1]:
r0 = (pattern.shape[0] - pattern_size[0]) // 2
c0 = (pattern.shape[1] - pattern_size[1]) // 2
pattern = pattern[r0:r0+imsizes[0][0], c0:c0+imsizes[0][1]]
patterns = []
for imsize in imsizes:
pat = cv2.resize(pattern, (imsize[1],imsize[0]), interpolation=cv2.INTER_LINEAR)
patterns.append(pat)
pattern_size = imsizes[0]
if path == 'syn':
np.random.seed(42)
pattern = np.random.uniform(0, 1, size=pattern_size)
pattern = (pattern < 0.1).astype(np.float32)
pattern.reshape(*imsizes[0])
else:
pattern = cv2.imread(path)
pattern = pattern.astype(np.float32)
pattern /= 255
if pattern.ndim == 2:
pattern = np.stack([pattern for idx in range(3)], axis=2)
if crop and pattern.shape[0] > pattern_size[0] and pattern.shape[1] > pattern_size[1]:
r0 = (pattern.shape[0] - pattern_size[0]) // 2
c0 = (pattern.shape[1] - pattern_size[1]) // 2
pattern = pattern[r0:r0 + imsizes[0][0], c0:c0 + imsizes[0][1]]
patterns = []
for imsize in imsizes:
pat = cv2.resize(pattern, (imsize[1], imsize[0]), interpolation=cv2.INTER_LINEAR)
patterns.append(pat)
return patterns
return patterns
def get_rotation_matrix(v0, v1):
v0 = v0/np.linalg.norm(v0)
v1 = v1/np.linalg.norm(v1)
v = np.cross(v0,v1)
c = np.dot(v0,v1)
s = np.linalg.norm(v)
I = np.eye(3)
vXStr = '{} {} {}; {} {} {}; {} {} {}'.format(0, -v[2], v[1], v[2], 0, -v[0], -v[1], v[0], 0)
k = np.matrix(vXStr)
r = I + k + k @ k * ((1 -c)/(s**2))
return np.asarray(r.astype(np.float32))
v0 = v0 / np.linalg.norm(v0)
v1 = v1 / np.linalg.norm(v1)
v = np.cross(v0, v1)
c = np.dot(v0, v1)
s = np.linalg.norm(v)
I = np.eye(3)
vXStr = '{} {} {}; {} {} {}; {} {} {}'.format(0, -v[2], v[1], v[2], 0, -v[0], -v[1], v[0], 0)
k = np.matrix(vXStr)
r = I + k + k @ k * ((1 - c) / (s ** 2))
return np.asarray(r.astype(np.float32))
def augment_image(img,rng,disp=None,grad=None,max_shift=64,max_blur=1.5,max_noise=10.0,max_sp_noise=0.001):
def augment_image(img, rng, disp=None, grad=None, max_shift=64, max_blur=1.5, max_noise=10.0, max_sp_noise=0.001):
# get min/max values of image
min_val = np.min(img)
max_val = np.max(img)
# init augmented image
img_aug = img
# init disparity correction map
disp_aug = disp
grad_aug = grad
# apply affine transformation
if max_shift>1:
if max_shift > 1:
# affine parameters
rows,cols = img.shape
rows, cols = img.shape
shear = 0
shift = 0
shear_correction = 0
if rng.uniform(0,1)<0.75: shear = rng.uniform(-max_shift,max_shift) # shear with 75% probability
else: shift = rng.uniform(0,max_shift) # shift with 25% probability
if shear<0: shear_correction = -shear
if rng.uniform(0, 1) < 0.75:
shear = rng.uniform(-max_shift, max_shift) # shear with 75% probability
else:
shift = rng.uniform(0, max_shift) # shift with 25% probability
if shear < 0: shear_correction = -shear
# affine transformation
a = shear/float(rows)
b = shift+shear_correction
a = shear / float(rows)
b = shift + shear_correction
# warp image
T = np.float32([[1,a,b],[0,1,0]])
img_aug = cv2.warpAffine(img_aug,T,(cols,rows))
T = np.float32([[1, a, b], [0, 1, 0]])
img_aug = cv2.warpAffine(img_aug, T, (cols, rows))
if grad is not None:
grad_aug = cv2.warpAffine(grad,T,(cols,rows))
grad_aug = cv2.warpAffine(grad, T, (cols, rows))
# disparity correction map
col = a*np.array(range(rows))+b
disp_delta = np.tile(col,(cols,1)).transpose()
col = a * np.array(range(rows)) + b
disp_delta = np.tile(col, (cols, 1)).transpose()
if disp is not None:
disp_aug = cv2.warpAffine(disp+disp_delta,T,(cols,rows))
disp_aug = cv2.warpAffine(disp + disp_delta, T, (cols, rows))
# gaussian smoothing
if rng.uniform(0,1)<0.5:
img_aug = cv2.GaussianBlur(img_aug,(5,5),rng.uniform(0.2,max_blur))
if rng.uniform(0, 1) < 0.5:
img_aug = cv2.GaussianBlur(img_aug, (5, 5), rng.uniform(0.2, max_blur))
# per-pixel gaussian noise
img_aug = img_aug + rng.randn(*img_aug.shape)*rng.uniform(0.0,max_noise)/255.0
img_aug = img_aug + rng.randn(*img_aug.shape) * rng.uniform(0.0, max_noise) / 255.0
# salt-and-pepper noise
if rng.uniform(0,1)<0.5:
ratio=rng.uniform(0.0,max_sp_noise)
if rng.uniform(0, 1) < 0.5:
ratio = rng.uniform(0.0, max_sp_noise)
img_shape = img_aug.shape
img_aug = img_aug.flatten()
coord = rng.choice(np.size(img_aug), int(np.size(img_aug)*ratio))
coord = rng.choice(np.size(img_aug), int(np.size(img_aug) * ratio))
img_aug[coord] = max_val
coord = rng.choice(np.size(img_aug), int(np.size(img_aug)*ratio))
coord = rng.choice(np.size(img_aug), int(np.size(img_aug) * ratio))
img_aug[coord] = min_val
img_aug = np.reshape(img_aug, img_shape)
# clip intensities back to [0,1]
img_aug = np.maximum(img_aug,0.0)
img_aug = np.minimum(img_aug,1.0)
img_aug = np.maximum(img_aug, 0.0)
img_aug = np.minimum(img_aug, 1.0)
# return image
return img_aug, disp_aug, grad_aug

View File

@ -10,261 +10,259 @@ import cv2
import os
import collections
import sys
sys.path.append('../')
import renderer
import co
from commons import get_patterns,get_rotation_matrix
from commons import get_patterns, get_rotation_matrix
from lcn import lcn
def get_objs(shapenet_dir, obj_classes, num_perclass=100):
shapenet = {'chair': '03001627',
'airplane': '02691156',
'car': '02958343',
'watercraft': '04530566'}
shapenet = {'chair': '03001627',
'airplane': '02691156',
'car': '02958343',
'watercraft': '04530566'}
obj_paths = []
for cls in obj_classes:
if cls not in shapenet.keys():
raise Exception('unknown class name')
ids = shapenet[cls]
obj_path = sorted(Path(f'{shapenet_dir}/{ids}').glob('**/models/*.obj'))
obj_paths += obj_path[:num_perclass]
print(f'found {len(obj_paths)} object paths')
obj_paths = []
for cls in obj_classes:
if cls not in shapenet.keys():
raise Exception('unknown class name')
ids = shapenet[cls]
obj_path = sorted(Path(f'{shapenet_dir}/{ids}').glob('**/models/*.obj'))
obj_paths += obj_path[:num_perclass]
print(f'found {len(obj_paths)} object paths')
objs = []
for obj_path in obj_paths:
print(f'load {obj_path}')
v, f, _, n = co.io3d.read_obj(obj_path)
diffs = v.max(axis=0) - v.min(axis=0)
v /= (0.5 * diffs.max())
v -= (v.min(axis=0) + 1)
f = f.astype(np.int32)
objs.append((v, f, n))
print(f'loaded {len(objs)} objects')
objs = []
for obj_path in obj_paths:
print(f'load {obj_path}')
v, f, _, n = co.io3d.read_obj(obj_path)
diffs = v.max(axis=0) - v.min(axis=0)
v /= (0.5 * diffs.max())
v -= (v.min(axis=0) + 1)
f = f.astype(np.int32)
objs.append((v,f,n))
print(f'loaded {len(objs)} objects')
return objs
return objs
def get_mesh(rng, min_z=0):
# set up background board
verts, faces, normals, colors = [], [], [], []
v, f, n = co.geometry.xyplane(z=0, interleaved=True)
v[:,2] += -v[:,2].min() + rng.uniform(2,7)
v[:,:2] *= 5e2
v[:,2] = np.mean(v[:,2]) + (v[:,2] - np.mean(v[:,2])) * 5e2
c = np.empty_like(v)
c[:] = rng.uniform(0,1, size=(3,)).astype(np.float32)
verts.append(v)
faces.append(f)
normals.append(n)
colors.append(c)
# randomly sample 4 foreground objects for each scene
for shape_idx in range(4):
v, f, n = objs[rng.randint(0,len(objs))]
v, f, n = v.copy(), f.copy(), n.copy()
s = rng.uniform(0.25, 1)
v *= s
R = co.geometry.rotm_from_quat(co.geometry.quat_random(rng=rng))
v = v @ R.T
n = n @ R.T
v[:,2] += -v[:,2].min() + min_z + rng.uniform(0.5, 3)
v[:,:2] += rng.uniform(-1, 1, size=(1,2))
# set up background board
verts, faces, normals, colors = [], [], [], []
v, f, n = co.geometry.xyplane(z=0, interleaved=True)
v[:, 2] += -v[:, 2].min() + rng.uniform(2, 7)
v[:, :2] *= 5e2
v[:, 2] = np.mean(v[:, 2]) + (v[:, 2] - np.mean(v[:, 2])) * 5e2
c = np.empty_like(v)
c[:] = rng.uniform(0,1, size=(3,)).astype(np.float32)
verts.append(v.astype(np.float32))
c[:] = rng.uniform(0, 1, size=(3,)).astype(np.float32)
verts.append(v)
faces.append(f)
normals.append(n)
colors.append(c)
verts, faces = co.geometry.stack_mesh(verts, faces)
normals = np.vstack(normals).astype(np.float32)
colors = np.vstack(colors).astype(np.float32)
return verts, faces, colors, normals
# randomly sample 4 foreground objects for each scene
for shape_idx in range(4):
v, f, n = objs[rng.randint(0, len(objs))]
v, f, n = v.copy(), f.copy(), n.copy()
s = rng.uniform(0.25, 1)
v *= s
R = co.geometry.rotm_from_quat(co.geometry.quat_random(rng=rng))
v = v @ R.T
n = n @ R.T
v[:, 2] += -v[:, 2].min() + min_z + rng.uniform(0.5, 3)
v[:, :2] += rng.uniform(-1, 1, size=(1, 2))
c = np.empty_like(v)
c[:] = rng.uniform(0, 1, size=(3,)).astype(np.float32)
verts.append(v.astype(np.float32))
faces.append(f)
normals.append(n)
colors.append(c)
verts, faces = co.geometry.stack_mesh(verts, faces)
normals = np.vstack(normals).astype(np.float32)
colors = np.vstack(colors).astype(np.float32)
return verts, faces, colors, normals
def create_data(out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length=4):
tic = time.time()
rng = np.random.RandomState()
tic = time.time()
rng = np.random.RandomState()
rng.seed(idx)
rng.seed(idx)
verts, faces, colors, normals = get_mesh(rng)
data = renderer.PyRenderInput(verts=verts.copy(), colors=colors.copy(), normals=normals.copy(), faces=faces.copy())
print(f'loading mesh for sample {idx + 1}/{n_samples} took {time.time() - tic}[s]')
verts, faces, colors, normals = get_mesh(rng)
data = renderer.PyRenderInput(verts=verts.copy(), colors=colors.copy(), normals=normals.copy(), faces=faces.copy())
print(f'loading mesh for sample {idx+1}/{n_samples} took {time.time()-tic}[s]')
# let the camera point to the center
center = np.array([0, 0, 3], dtype=np.float32)
basevec = np.array([-baseline, 0, 0], dtype=np.float32)
unit = np.array([0, 0, 1], dtype=np.float32)
cam_x_ = rng.uniform(-0.2, 0.2)
cam_y_ = rng.uniform(-0.2, 0.2)
cam_z_ = rng.uniform(-0.2, 0.2)
ret = collections.defaultdict(list)
blend_im_rnd = np.clip(blend_im + rng.uniform(-0.1, 0.1), 0, 1)
# capture the same static scene from different view points as a track
for ind in range(track_length):
cam_x = cam_x_ + rng.uniform(-0.1, 0.1)
cam_y = cam_y_ + rng.uniform(-0.1, 0.1)
cam_z = cam_z_ + rng.uniform(-0.1, 0.1)
tcam = np.array([cam_x, cam_y, cam_z], dtype=np.float32)
if np.linalg.norm(tcam[0:2]) < 1e-9:
Rcam = np.eye(3, dtype=np.float32)
else:
Rcam = get_rotation_matrix(center, center - tcam)
tproj = tcam + basevec
Rproj = Rcam
ret['R'].append(Rcam)
ret['t'].append(tcam)
cams = []
projs = []
# render the scene at multiple scales
scales = [1, 0.5, 0.25, 0.125]
for scale in scales:
fx = K[0, 0] * scale
fy = K[1, 1] * scale
px = K[0, 2] * scale
py = K[1, 2] * scale
im_height = imsize[0] * scale
im_width = imsize[1] * scale
cams.append(renderer.PyCamera(fx, fy, px, py, Rcam, tcam, im_width, im_height))
projs.append(renderer.PyCamera(fx, fy, px, py, Rproj, tproj, im_width, im_height))
for s, cam, proj, pattern in zip(itertools.count(), cams, projs, patterns):
fl = K[0, 0] / (2 ** s)
shader = renderer.PyShader(0.5, 1.5, 0.0, 10)
pyrenderer = renderer.PyRenderer(cam, shader, engine='gpu')
pyrenderer.mesh_proj(data, proj, pattern, d_alpha=0, d_beta=0.35)
# get the reflected laser pattern $R$
im = pyrenderer.color().copy()
depth = pyrenderer.depth().copy()
disp = baseline * fl / depth
mask = depth > 0
im = np.mean(im, axis=2)
# get the ambient image $A$
ambient = pyrenderer.normal().copy()
ambient = np.mean(ambient, axis=2)
# get the noise free IR image $J$
im = blend_im_rnd * im + (1 - blend_im_rnd) * ambient
ret[f'ambient{s}'].append(ambient[None].astype(np.float32))
# get the gradient magnitude of the ambient image $|\nabla A|$
ambient = ambient.astype(np.float32)
sobelx = cv2.Sobel(ambient, cv2.CV_32F, 1, 0, ksize=5)
sobely = cv2.Sobel(ambient, cv2.CV_32F, 0, 1, ksize=5)
grad = np.sqrt(sobelx ** 2 + sobely ** 2)
grad = np.maximum(grad - 0.8, 0.0) # parameter
# get the local contract normalized grad LCN($|\nabla A|$)
grad_lcn, grad_std = lcn.normalize(grad, 5, 0.1)
grad_lcn = np.clip(grad_lcn, 0.0, 1.0) # parameter
ret[f'grad{s}'].append(grad_lcn[None].astype(np.float32))
ret[f'im{s}'].append(im[None].astype(np.float32))
ret[f'mask{s}'].append(mask[None].astype(np.float32))
ret[f'disp{s}'].append(disp[None].astype(np.float32))
for key in ret.keys():
ret[key] = np.stack(ret[key], axis=0)
# save to files
out_dir = out_root / f'{idx:08d}'
out_dir.mkdir(exist_ok=True, parents=True)
for k, val in ret.items():
for tidx in range(track_length):
v = val[tidx]
out_path = out_dir / f'{k}_{tidx}.npy'
np.save(out_path, v)
np.save(str(out_dir / 'blend_im.npy'), blend_im_rnd)
print(f'create sample {idx + 1}/{n_samples} took {time.time() - tic}[s]')
# let the camera point to the center
center = np.array([0,0,3], dtype=np.float32)
if __name__ == '__main__':
basevec = np.array([-baseline,0,0], dtype=np.float32)
unit = np.array([0,0,1],dtype=np.float32)
np.random.seed(42)
cam_x_ = rng.uniform(-0.2,0.2)
cam_y_ = rng.uniform(-0.2,0.2)
cam_z_ = rng.uniform(-0.2,0.2)
# output directory
with open('../config.json') as fp:
config = json.load(fp)
data_root = Path(config['DATA_ROOT'])
shapenet_root = config['SHAPENET_ROOT']
ret = collections.defaultdict(list)
blend_im_rnd = np.clip(blend_im + rng.uniform(-0.1,0.1), 0,1)
data_type = 'syn'
out_root = data_root / f'{data_type}'
out_root.mkdir(parents=True, exist_ok=True)
# capture the same static scene from different view points as a track
for ind in range(track_length):
cam_x = cam_x_ + rng.uniform(-0.1,0.1)
cam_y = cam_y_ + rng.uniform(-0.1,0.1)
cam_z = cam_z_ + rng.uniform(-0.1,0.1)
tcam = np.array([cam_x, cam_y, cam_z], dtype=np.float32)
if np.linalg.norm(tcam[0:2])<1e-9:
Rcam = np.eye(3, dtype=np.float32)
start = 0
if len(sys.argv) >= 2 and isinstance(sys.argv[2], int):
start = sys.argv[2]
else:
Rcam = get_rotation_matrix(center, center-tcam)
if sys.argv[2] == '--resume':
try:
start = max([int(dir) for dir in os.listdir(out_root) if str.isdigit(dir)]) or 0
except:
pass
tproj = tcam + basevec
Rproj = Rcam
# load shapenet models
obj_classes = ['chair']
objs = get_objs(shapenet_root, obj_classes)
ret['R'].append(Rcam)
ret['t'].append(tcam)
# camera parameters
imsize = (488, 648)
imsizes = [(imsize[0] // (2 ** s), imsize[1] // (2 ** s)) for s in range(4)]
# K = np.array([[567.6, 0, 324.7], [0, 570.2, 250.1], [0 ,0, 1]], dtype=np.float32)
K = np.array([[1929.5936336276382, 0, 113.66561071478046], [0, 1911.2517985448746, 473.70108079885887], [0, 0, 1]],
dtype=np.float32)
focal_lengths = [K[0, 0] / (2 ** s) for s in range(4)]
baseline = 0.075
blend_im = 0.6
noise = 0
cams = []
projs = []
# capture the same static scene from different view points as a track
track_length = 4
# render the scene at multiple scales
scales = [1, 0.5, 0.25, 0.125]
# load pattern image
pattern_path = './kinect_pattern.png'
pattern_crop = True
patterns = get_patterns(pattern_path, imsizes, pattern_crop)
for scale in scales:
fx = K[0,0] * scale
fy = K[1,1] * scale
px = K[0,2] * scale
py = K[1,2] * scale
im_height = imsize[0] * scale
im_width = imsize[1] * scale
cams.append( renderer.PyCamera(fx,fy,px,py, Rcam, tcam, im_width, im_height) )
projs.append( renderer.PyCamera(fx,fy,px,py, Rproj, tproj, im_width, im_height) )
# write settings to file
settings = {
'imsizes': imsizes,
'patterns': patterns,
'focal_lengths': focal_lengths,
'baseline': baseline,
'K': K,
}
out_path = out_root / f'settings.pkl'
print(f'write settings to {out_path}')
with open(str(out_path), 'wb') as f:
pickle.dump(settings, f, pickle.HIGHEST_PROTOCOL)
for s, cam, proj, pattern in zip(itertools.count(), cams, projs, patterns):
fl = K[0,0] / (2**s)
shader = renderer.PyShader(0.5,1.5,0.0,10)
pyrenderer = renderer.PyRenderer(cam, shader, engine='gpu')
pyrenderer.mesh_proj(data, proj, pattern, d_alpha=0, d_beta=0.35)
# get the reflected laser pattern $R$
im = pyrenderer.color().copy()
depth = pyrenderer.depth().copy()
disp = baseline * fl / depth
mask = depth > 0
im = np.mean(im, axis=2)
# get the ambient image $A$
ambient = pyrenderer.normal().copy()
ambient = np.mean(ambient, axis=2)
# get the noise free IR image $J$
im = blend_im_rnd * im + (1 - blend_im_rnd) * ambient
ret[f'ambient{s}'].append( ambient[None].astype(np.float32) )
# get the gradient magnitude of the ambient image $|\nabla A|$
ambient = ambient.astype(np.float32)
sobelx = cv2.Sobel(ambient,cv2.CV_32F,1,0,ksize=5)
sobely = cv2.Sobel(ambient,cv2.CV_32F,0,1,ksize=5)
grad = np.sqrt(sobelx**2 + sobely**2)
grad = np.maximum(grad-0.8,0.0) # parameter
# get the local contract normalized grad LCN($|\nabla A|$)
grad_lcn, grad_std = lcn.normalize(grad,5,0.1)
grad_lcn = np.clip(grad_lcn,0.0,1.0) # parameter
ret[f'grad{s}'].append( grad_lcn[None].astype(np.float32))
ret[f'im{s}'].append( im[None].astype(np.float32))
ret[f'mask{s}'].append(mask[None].astype(np.float32))
ret[f'disp{s}'].append(disp[None].astype(np.float32))
for key in ret.keys():
ret[key] = np.stack(ret[key], axis=0)
# save to files
out_dir = out_root / f'{idx:08d}'
out_dir.mkdir(exist_ok=True, parents=True)
for k,val in ret.items():
for tidx in range(track_length):
v = val[tidx]
out_path = out_dir / f'{k}_{tidx}.npy'
np.save(out_path, v)
np.save( str(out_dir /'blend_im.npy'), blend_im_rnd)
print(f'create sample {idx+1}/{n_samples} took {time.time()-tic}[s]')
if __name__=='__main__':
np.random.seed(42)
# output directory
with open('../config.json') as fp:
config = json.load(fp)
data_root = Path(config['DATA_ROOT'])
shapenet_root = config['SHAPENET_ROOT']
data_type = 'syn'
out_root = data_root / f'{data_type}'
out_root.mkdir(parents=True, exist_ok=True)
start = 0
if len(sys.argv) >= 2 and isinstance(sys.argv[2], int):
start = sys.argv[2]
else:
if sys.argv[2] == '--resume':
try:
start = max([int(dir) for dir in os.listdir(out_root) if str.isdigit(dir)]) or 0
except:
pass
# load shapenet models
obj_classes = ['chair']
objs = get_objs(shapenet_root, obj_classes)
# camera parameters
imsize = (488, 648)
imsizes = [(imsize[0]//(2**s), imsize[1]//(2**s)) for s in range(4)]
# K = np.array([[567.6, 0, 324.7], [0, 570.2, 250.1], [0 ,0, 1]], dtype=np.float32)
K = np.array([[1929.5936336276382, 0, 113.66561071478046], [0, 1911.2517985448746, 473.70108079885887], [0 ,0, 1]], dtype=np.float32)
focal_lengths = [K[0,0]/(2**s) for s in range(4)]
baseline=0.075
blend_im = 0.6
noise = 0
# capture the same static scene from different view points as a track
track_length = 4
# load pattern image
pattern_path = './kinect_pattern.png'
pattern_crop = True
patterns = get_patterns(pattern_path, imsizes, pattern_crop)
# write settings to file
settings = {
'imsizes': imsizes,
'patterns': patterns,
'focal_lengths': focal_lengths,
'baseline': baseline,
'K': K,
}
out_path = out_root / f'settings.pkl'
print(f'write settings to {out_path}')
with open(str(out_path), 'wb') as f:
pickle.dump(settings, f, pickle.HIGHEST_PROTOCOL)
# start the job
n_samples = 2**10 + 2**13
for idx in range(start, n_samples):
args = (out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length)
create_data(*args)
# start the job
n_samples = 2 ** 10 + 2 ** 13
for idx in range(start, n_samples):
args = (out_root, idx, n_samples, imsize, patterns, K, baseline, blend_im, noise, track_length)
create_data(*args)

View File

@ -21,128 +21,128 @@ from .commons import get_patterns, augment_image
from mpl_toolkits.mplot3d import Axes3D
class TrackSynDataset(torchext.BaseDataset):
'''
Load locally saved synthetic dataset
Please run ./create_syn_data.sh to generate the dataset
'''
def __init__(self, settings_path, sample_paths, track_length=2, train=True, data_aug=False):
super().__init__(train=train)
'''
Load locally saved synthetic dataset
Please run ./create_syn_data.sh to generate the dataset
'''
self.settings_path = settings_path
self.sample_paths = sample_paths
self.data_aug = data_aug
self.train = train
self.track_length=track_length
assert(track_length<=4)
def __init__(self, settings_path, sample_paths, track_length=2, train=True, data_aug=False):
super().__init__(train=train)
with open(str(settings_path), 'rb') as f:
settings = pickle.load(f)
self.imsizes = settings['imsizes']
self.patterns = settings['patterns']
self.focal_lengths = settings['focal_lengths']
self.baseline = settings['baseline']
self.K = settings['K']
self.settings_path = settings_path
self.sample_paths = sample_paths
self.data_aug = data_aug
self.train = train
self.track_length = track_length
assert (track_length <= 4)
self.scale = len(self.imsizes)
with open(str(settings_path), 'rb') as f:
settings = pickle.load(f)
self.imsizes = settings['imsizes']
self.patterns = settings['patterns']
self.focal_lengths = settings['focal_lengths']
self.baseline = settings['baseline']
self.K = settings['K']
self.max_shift=0
self.max_blur=0.5
self.max_noise=3.0
self.max_sp_noise=0.0005
self.scale = len(self.imsizes)
def __len__(self):
return len(self.sample_paths)
self.max_shift = 0
self.max_blur = 0.5
self.max_noise = 3.0
self.max_sp_noise = 0.0005
def __getitem__(self, idx):
if not self.train:
rng = self.get_rng(idx)
else:
rng = np.random.RandomState()
sample_path = self.sample_paths[idx]
def __len__(self):
return len(self.sample_paths)
if self.train:
track_ind = np.random.permutation(4)[0:self.track_length]
else:
track_ind = [0]
ret = {}
ret['id'] = idx
# load imgs, at all scales
for sidx in range(len(self.imsizes)):
imgs = []
ambs = []
grads = []
for tidx in track_ind:
imgs.append(np.load(os.path.join(sample_path,f'im{sidx}_{tidx}.npy')))
ambs.append(np.load(os.path.join(sample_path,f'ambient{sidx}_{tidx}.npy')))
grads.append(np.load(os.path.join(sample_path,f'grad{sidx}_{tidx}.npy')))
ret[f'im{sidx}'] = np.stack(imgs, axis=0)
ret[f'ambient{sidx}'] = np.stack(ambs, axis=0)
ret[f'grad{sidx}'] = np.stack(grads, axis=0)
# load disp and grad only at full resolution
disps = []
R = []
t = []
for tidx in track_ind:
disps.append(np.load(os.path.join(sample_path,f'disp0_{tidx}.npy')))
R.append(np.load(os.path.join(sample_path,f'R_{tidx}.npy')))
t.append(np.load(os.path.join(sample_path,f't_{tidx}.npy')))
ret[f'disp0'] = np.stack(disps, axis=0)
ret['R'] = np.stack(R, axis=0)
ret['t'] = np.stack(t, axis=0)
blend_im = np.load(os.path.join(sample_path,'blend_im.npy'))
ret['blend_im'] = blend_im.astype(np.float32)
#### apply data augmentation at different scales seperately, only work for max_shift=0
if self.data_aug:
for sidx in range(len(self.imsizes)):
if sidx==0:
img = ret[f'im{sidx}']
disp = ret[f'disp{sidx}']
grad = ret[f'grad{sidx}']
img_aug = np.zeros_like(img)
disp_aug = np.zeros_like(img)
grad_aug = np.zeros_like(img)
for i in range(img.shape[0]):
img_aug_, disp_aug_, grad_aug_ = augment_image(img[i,0],rng,
disp=disp[i,0],grad=grad[i,0],
max_shift=self.max_shift, max_blur=self.max_blur,
max_noise=self.max_noise, max_sp_noise=self.max_sp_noise)
img_aug[i] = img_aug_[None].astype(np.float32)
disp_aug[i] = disp_aug_[None].astype(np.float32)
grad_aug[i] = grad_aug_[None].astype(np.float32)
ret[f'im{sidx}'] = img_aug
ret[f'disp{sidx}'] = disp_aug
ret[f'grad{sidx}'] = grad_aug
def __getitem__(self, idx):
if not self.train:
rng = self.get_rng(idx)
else:
img = ret[f'im{sidx}']
img_aug = np.zeros_like(img)
for i in range(img.shape[0]):
img_aug_, _, _ = augment_image(img[i,0],rng,
max_shift=self.max_shift, max_blur=self.max_blur,
max_noise=self.max_noise, max_sp_noise=self.max_sp_noise)
img_aug[i] = img_aug_[None].astype(np.float32)
ret[f'im{sidx}'] = img_aug
rng = np.random.RandomState()
sample_path = self.sample_paths[idx]
if len(track_ind)==1:
for key, val in ret.items():
if key!='blend_im' and key!='id':
ret[key] = val[0]
if self.train:
track_ind = np.random.permutation(4)[0:self.track_length]
else:
track_ind = [0]
ret = {}
ret['id'] = idx
return ret
# load imgs, at all scales
for sidx in range(len(self.imsizes)):
imgs = []
ambs = []
grads = []
for tidx in track_ind:
imgs.append(np.load(os.path.join(sample_path, f'im{sidx}_{tidx}.npy')))
ambs.append(np.load(os.path.join(sample_path, f'ambient{sidx}_{tidx}.npy')))
grads.append(np.load(os.path.join(sample_path, f'grad{sidx}_{tidx}.npy')))
ret[f'im{sidx}'] = np.stack(imgs, axis=0)
ret[f'ambient{sidx}'] = np.stack(ambs, axis=0)
ret[f'grad{sidx}'] = np.stack(grads, axis=0)
def getK(self, sidx=0):
K = self.K.copy() / (2**sidx)
K[2,2] = 1
return K
# load disp and grad only at full resolution
disps = []
R = []
t = []
for tidx in track_ind:
disps.append(np.load(os.path.join(sample_path, f'disp0_{tidx}.npy')))
R.append(np.load(os.path.join(sample_path, f'R_{tidx}.npy')))
t.append(np.load(os.path.join(sample_path, f't_{tidx}.npy')))
ret[f'disp0'] = np.stack(disps, axis=0)
ret['R'] = np.stack(R, axis=0)
ret['t'] = np.stack(t, axis=0)
blend_im = np.load(os.path.join(sample_path, 'blend_im.npy'))
ret['blend_im'] = blend_im.astype(np.float32)
#### apply data augmentation at different scales seperately, only work for max_shift=0
if self.data_aug:
for sidx in range(len(self.imsizes)):
if sidx == 0:
img = ret[f'im{sidx}']
disp = ret[f'disp{sidx}']
grad = ret[f'grad{sidx}']
img_aug = np.zeros_like(img)
disp_aug = np.zeros_like(img)
grad_aug = np.zeros_like(img)
for i in range(img.shape[0]):
img_aug_, disp_aug_, grad_aug_ = augment_image(img[i, 0], rng,
disp=disp[i, 0], grad=grad[i, 0],
max_shift=self.max_shift, max_blur=self.max_blur,
max_noise=self.max_noise,
max_sp_noise=self.max_sp_noise)
img_aug[i] = img_aug_[None].astype(np.float32)
disp_aug[i] = disp_aug_[None].astype(np.float32)
grad_aug[i] = grad_aug_[None].astype(np.float32)
ret[f'im{sidx}'] = img_aug
ret[f'disp{sidx}'] = disp_aug
ret[f'grad{sidx}'] = grad_aug
else:
img = ret[f'im{sidx}']
img_aug = np.zeros_like(img)
for i in range(img.shape[0]):
img_aug_, _, _ = augment_image(img[i, 0], rng,
max_shift=self.max_shift, max_blur=self.max_blur,
max_noise=self.max_noise, max_sp_noise=self.max_sp_noise)
img_aug[i] = img_aug_[None].astype(np.float32)
ret[f'im{sidx}'] = img_aug
if len(track_ind) == 1:
for key, val in ret.items():
if key != 'blend_im' and key != 'id':
ret[key] = val[0]
return ret
def getK(self, sidx=0):
K = self.K.copy() / (2 ** sidx)
K[2, 2] = 1
return K
if __name__ == '__main__':
pass
pass

View File

@ -2,7 +2,7 @@
<!-- Generated by Cython 0.29 -->
<html>
<head>
<meta http-equiv="Content-Type" content="text/html; charset=utf-8" />
<meta http-equiv="Content-Type" content="text/html; charset=utf-8"/>
<title>Cython: lcn.pyx</title>
<style type="text/css">
@ -355,17 +355,23 @@ body.cython { font-family: courier; font-size: 12; }
.cython .vi { color: #19177C } /* Name.Variable.Instance */
.cython .vm { color: #19177C } /* Name.Variable.Magic */
.cython .il { color: #666666 } /* Literal.Number.Integer.Long */
</style>
</head>
<body class="cython">
<p><span style="border-bottom: solid 1px grey;">Generated by Cython 0.29</span></p>
<p>
<span style="background-color: #FFFF00">Yellow lines</span> hint at Python interaction.<br />
<span style="background-color: #FFFF00">Yellow lines</span> hint at Python interaction.<br/>
Click on a line that starts with a "<code>+</code>" to see the C code that Cython generated for it.
</p>
<p>Raw output: <a href="lcn.c">lcn.c</a></p>
<div class="cython"><pre class="cython line score-16" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">01</span>: <span class="k">import</span> <span class="nn">numpy</span> <span class="k">as</span> <span class="nn">np</span></pre>
<pre class='cython code score-16 '> __pyx_t_1 = <span class='pyx_c_api'>__Pyx_Import</span>(__pyx_n_s_numpy, 0, -1);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1, __pyx_L1_error)</span>
<div class="cython">
<pre class="cython line score-16"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">01</span>: <span class="k">import</span> <span class="nn">numpy</span> <span
class="k">as</span> <span class="nn">np</span></pre>
<pre class='cython code score-16 '> __pyx_t_1 = <span class='pyx_c_api'>__Pyx_Import</span>(__pyx_n_s_numpy, 0, -1);<span
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 1, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_np, __pyx_t_1) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 1, __pyx_L1_error)</span>
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
@ -374,22 +380,39 @@ body.cython { font-family: courier; font-size: 12; }
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_test, __pyx_t_1) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 1, __pyx_L1_error)</span>
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
</pre><pre class="cython line score-0">&#xA0;<span class="">02</span>: <span class="k">cimport</span> <span class="nn">cython</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">03</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">04</span>: <span class="c"># use c square root function</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">05</span>: <span class="k">cdef</span> <span class="kr">extern</span> <span class="k">from</span> <span class="s">&quot;math.h&quot;</span><span class="p">:</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">06</span>: <span class="nb">float</span> <span class="n">sqrt</span><span class="p">(</span><span class="nb">float</span> <span class="n">x</span><span class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">07</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">08</span>: <span class="nd">@cython</span><span class="o">.</span><span class="n">boundscheck</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">09</span>: <span class="nd">@cython</span><span class="o">.</span><span class="n">wraparound</span><span class="p">(</span><span class="bp">False</span><span class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">10</span>: <span class="nd">@cython</span><span class="o">.</span><span class="n">cdivision</span><span class="p">(</span><span class="bp">True</span><span class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">11</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">12</span>: <span class="c"># 3 parameters:</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">13</span>: <span class="c"># - float image</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">14</span>: <span class="c"># - kernel size (actually this is the radius, kernel is 2*k+1)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">15</span>: <span class="c"># - small constant epsilon that is used to avoid division by zero</span></pre>
<pre class="cython line score-67" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">16</span>: <span class="k">def</span> <span class="nf">normalize</span><span class="p">(</span><span class="nb">float</span><span class="p">[:,</span> <span class="p">:]</span> <span class="n">img</span><span class="p">,</span> <span class="nb">int</span> <span class="n">kernel_size</span> <span class="o">=</span> <span class="mf">4</span><span class="p">,</span> <span class="nb">float</span> <span class="n">epsilon</span> <span class="o">=</span> <span class="mf">0.01</span><span class="p">):</span></pre>
<pre class='cython code score-67 '>/* Python wrapper */
</pre>
<pre class="cython line score-0">&#xA0;<span class="">02</span>: <span class="k">cimport</span> <span class="nn">cython</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">03</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">04</span>: <span class="c"># use c square root function</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">05</span>: <span class="k">cdef</span> <span
class="kr">extern</span> <span class="k">from</span> <span class="s">&quot;math.h&quot;</span><span
class="p">:</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">06</span>: <span class="nb">float</span> <span class="n">sqrt</span><span
class="p">(</span><span class="nb">float</span> <span class="n">x</span><span class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">07</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">08</span>: <span class="nd">@cython</span><span
class="o">.</span><span class="n">boundscheck</span><span class="p">(</span><span
class="bp">False</span><span class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">09</span>: <span class="nd">@cython</span><span
class="o">.</span><span class="n">wraparound</span><span class="p">(</span><span
class="bp">False</span><span class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">10</span>: <span class="nd">@cython</span><span
class="o">.</span><span class="n">cdivision</span><span class="p">(</span><span class="bp">True</span><span
class="p">)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">11</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">12</span>: <span class="c"># 3 parameters:</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">13</span>: <span class="c"># - float image</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">14</span>: <span class="c"># - kernel size (actually this is the radius, kernel is 2*k+1)</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">15</span>: <span class="c"># - small constant epsilon that is used to avoid division by zero</span></pre>
<pre class="cython line score-67"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">16</span>: <span class="k">def</span> <span class="nf">normalize</span><span
class="p">(</span><span class="nb">float</span><span class="p">[:,</span> <span class="p">:]</span> <span
class="n">img</span><span class="p">,</span> <span class="nb">int</span> <span class="n">kernel_size</span> <span
class="o">=</span> <span class="mf">4</span><span class="p">,</span> <span class="nb">float</span> <span
class="n">epsilon</span> <span class="o">=</span> <span class="mf">0.01</span><span
class="p">):</span></pre>
<pre class='cython code score-67 '>/* Python wrapper */
static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds); /*proto*/
static PyMethodDef __pyx_mdef_3lcn_1normalize = {"normalize", (PyCFunction)(void*)(PyCFunctionWithKeywords)__pyx_pw_3lcn_1normalize, METH_VARARGS|METH_KEYWORDS, 0};
static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_args, PyObject *__pyx_kwds) {
@ -434,7 +457,8 @@ static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_
}
}
if (unlikely(kw_args &gt; 0)) {
if (unlikely(<span class='pyx_c_api'>__Pyx_ParseOptionalKeywords</span>(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "normalize") &lt; 0)) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
if (unlikely(<span class='pyx_c_api'>__Pyx_ParseOptionalKeywords</span>(__pyx_kwds, __pyx_pyargnames, 0, values, pos_args, "normalize") &lt; 0)) <span
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
}
} else {
switch (<span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)) {
@ -447,21 +471,27 @@ static PyObject *__pyx_pw_3lcn_1normalize(PyObject *__pyx_self, PyObject *__pyx_
default: goto __pyx_L5_argtuple_error;
}
}
__pyx_v_img = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(values[0], PyBUF_WRITABLE);<span class='error_goto'> if (unlikely(!__pyx_v_img.memview)) __PYX_ERR(0, 16, __pyx_L3_error)</span>
__pyx_v_img = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(values[0], PyBUF_WRITABLE);<span
class='error_goto'> if (unlikely(!__pyx_v_img.memview)) __PYX_ERR(0, 16, __pyx_L3_error)</span>
if (values[1]) {
__pyx_v_kernel_size = <span class='pyx_c_api'>__Pyx_PyInt_As_int</span>(values[1]); if (unlikely((__pyx_v_kernel_size == (int)-1) &amp;&amp; <span class='py_c_api'>PyErr_Occurred</span>())) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
__pyx_v_kernel_size = <span class='pyx_c_api'>__Pyx_PyInt_As_int</span>(values[1]); if (unlikely((__pyx_v_kernel_size == (int)-1) &amp;&amp; <span
class='py_c_api'>PyErr_Occurred</span>())) <span
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
} else {
__pyx_v_kernel_size = ((int)4);
}
if (values[2]) {
__pyx_v_epsilon = __pyx_<span class='py_c_api'>PyFloat_AsFloat</span>(values[2]); if (unlikely((__pyx_v_epsilon == (float)-1) &amp;&amp; <span class='py_c_api'>PyErr_Occurred</span>())) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
__pyx_v_epsilon = __pyx_<span class='py_c_api'>PyFloat_AsFloat</span>(values[2]); if (unlikely((__pyx_v_epsilon == (float)-1) &amp;&amp; <span
class='py_c_api'>PyErr_Occurred</span>())) <span
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
} else {
__pyx_v_epsilon = ((float)0.01);
}
}
goto __pyx_L4_argument_unpacking_done;
__pyx_L5_argtuple_error:;
<span class='pyx_c_api'>__Pyx_RaiseArgtupleInvalid</span>("normalize", 0, 1, 3, <span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)); <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
<span class='pyx_c_api'>__Pyx_RaiseArgtupleInvalid</span>("normalize", 0, 1, 3, <span class='py_macro_api'>PyTuple_GET_SIZE</span>(__pyx_args)); <span
class='error_goto'>__PYX_ERR(0, 16, __pyx_L3_error)</span>
__pyx_L3_error:;
<span class='pyx_c_api'>__Pyx_AddTraceback</span>("lcn.normalize", __pyx_clineno, __pyx_lineno, __pyx_filename);
<span class='refnanny'>__Pyx_RefNannyFinishContext</span>();
@ -515,27 +545,49 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
return __pyx_r;
}
/* … */
__pyx_tuple__19 = <span class='py_c_api'>PyTuple_Pack</span>(19, __pyx_n_s_img, __pyx_n_s_kernel_size, __pyx_n_s_epsilon, __pyx_n_s_M, __pyx_n_s_N, __pyx_n_s_img_lcn, __pyx_n_s_img_std, __pyx_n_s_img_lcn_view, __pyx_n_s_img_std_view, __pyx_n_s_tmp, __pyx_n_s_mean, __pyx_n_s_stddev, __pyx_n_s_m, __pyx_n_s_n, __pyx_n_s_i, __pyx_n_s_j, __pyx_n_s_ks, __pyx_n_s_eps, __pyx_n_s_num);<span class='error_goto'> if (unlikely(!__pyx_tuple__19)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
__pyx_tuple__19 = <span class='py_c_api'>PyTuple_Pack</span>(19, __pyx_n_s_img, __pyx_n_s_kernel_size, __pyx_n_s_epsilon, __pyx_n_s_M, __pyx_n_s_N, __pyx_n_s_img_lcn, __pyx_n_s_img_std, __pyx_n_s_img_lcn_view, __pyx_n_s_img_std_view, __pyx_n_s_tmp, __pyx_n_s_mean, __pyx_n_s_stddev, __pyx_n_s_m, __pyx_n_s_n, __pyx_n_s_i, __pyx_n_s_j, __pyx_n_s_ks, __pyx_n_s_eps, __pyx_n_s_num);<span
class='error_goto'> if (unlikely(!__pyx_tuple__19)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_tuple__19);
<span class='refnanny'>__Pyx_GIVEREF</span>(__pyx_tuple__19);
/* … */
__pyx_t_1 = PyCFunction_NewEx(&amp;__pyx_mdef_3lcn_1normalize, NULL, __pyx_n_s_lcn);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_normalize, __pyx_t_1) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 16, __pyx_L1_error)</span>
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_d, __pyx_n_s_normalize, __pyx_t_1) &lt; 0) <span
class='error_goto'>__PYX_ERR(0, 16, __pyx_L1_error)</span>
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
__pyx_codeobj__20 = (PyObject*)<span class='pyx_c_api'>__Pyx_PyCode_New</span>(3, 0, 19, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__19, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_lcn_pyx, __pyx_n_s_normalize, 16, __pyx_empty_bytes);<span class='error_goto'> if (unlikely(!__pyx_codeobj__20)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
</pre><pre class="cython line score-0">&#xA0;<span class="">17</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">18</span>: <span class="c"># image dimensions</span></pre>
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">19</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">M</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mf">0</span><span class="p">]</span></pre>
<pre class='cython code score-0 '> __pyx_v_M = (__pyx_v_img.shape[0]);
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">20</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">N</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span class="n">shape</span><span class="p">[</span><span class="mf">1</span><span class="p">]</span></pre>
<pre class='cython code score-0 '> __pyx_v_N = (__pyx_v_img.shape[1]);
</pre><pre class="cython line score-0">&#xA0;<span class="">21</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">22</span>: <span class="c"># create outputs and output views</span></pre>
<pre class="cython line score-46" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">23</span>: <span class="n">img_lcn</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span></pre>
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
__pyx_codeobj__20 = (PyObject*)<span class='pyx_c_api'>__Pyx_PyCode_New</span>(3, 0, 19, 0, CO_OPTIMIZED|CO_NEWLOCALS, __pyx_empty_bytes, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_tuple__19, __pyx_empty_tuple, __pyx_empty_tuple, __pyx_kp_s_lcn_pyx, __pyx_n_s_normalize, 16, __pyx_empty_bytes);<span
class='error_goto'> if (unlikely(!__pyx_codeobj__20)) __PYX_ERR(0, 16, __pyx_L1_error)</span>
</pre>
<pre class="cython line score-0">&#xA0;<span class="">17</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">18</span>: <span class="c"># image dimensions</span></pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">19</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
class="nf">M</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span
class="n">shape</span><span class="p">[</span><span class="mf">0</span><span class="p">]</span></pre>
<pre class='cython code score-0 '> __pyx_v_M = (__pyx_v_img.shape[0]);
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">20</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
class="nf">N</span> <span class="o">=</span> <span class="n">img</span><span class="o">.</span><span
class="n">shape</span><span class="p">[</span><span class="mf">1</span><span class="p">]</span></pre>
<pre class='cython code score-0 '> __pyx_v_N = (__pyx_v_img.shape[1]);
</pre>
<pre class="cython line score-0">&#xA0;<span class="">21</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">22</span>: <span class="c"># create outputs and output views</span></pre>
<pre class="cython line score-46"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">23</span>: <span class="n">img_lcn</span> <span class="o">=</span> <span
class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span
class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span
class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span
class="n">float32</span><span class="p">)</span></pre>
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
__pyx_t_2 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_zeros);<span class='error_goto'> if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
__pyx_t_2 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_zeros);<span
class='error_goto'> if (unlikely(!__pyx_t_2)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_2);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
__pyx_t_1 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_M);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
@ -559,22 +611,34 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_4);
<span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_1, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
__pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_float32);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
__pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_1, __pyx_n_s_float32);<span
class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_4, __pyx_n_s_dtype, __pyx_t_5) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 23, __pyx_L1_error)</span>
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_4, __pyx_n_s_dtype, __pyx_t_5) &lt; 0) <span
class='error_goto'>__PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
__pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_2, __pyx_t_3, __pyx_t_4);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
__pyx_t_5 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_2, __pyx_t_3, __pyx_t_4);<span
class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 23, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_2); __pyx_t_2 = 0;
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_3); __pyx_t_3 = 0;
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_4); __pyx_t_4 = 0;
__pyx_v_img_lcn = __pyx_t_5;
__pyx_t_5 = 0;
</pre><pre class="cython line score-46" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">24</span>: <span class="n">img_std</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span class="n">float32</span><span class="p">)</span></pre>
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
</pre>
<pre class="cython line score-46"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">24</span>: <span class="n">img_std</span> <span class="o">=</span> <span
class="n">np</span><span class="o">.</span><span class="n">zeros</span><span class="p">((</span><span
class="n">M</span><span class="p">,</span> <span class="n">N</span><span class="p">),</span> <span
class="n">dtype</span><span class="o">=</span><span class="n">np</span><span class="o">.</span><span
class="n">float32</span><span class="p">)</span></pre>
<pre class='cython code score-46 '> <span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span
class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
__pyx_t_4 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_zeros);<span class='error_goto'> if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
__pyx_t_4 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_zeros);<span
class='error_goto'> if (unlikely(!__pyx_t_4)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_4);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
__pyx_t_5 = <span class='py_c_api'>PyInt_FromSsize_t</span>(__pyx_v_M);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
@ -598,114 +662,236 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_2);
<span class='pyx_c_api'>__Pyx_GetModuleGlobalName</span>(__pyx_t_5, __pyx_n_s_np);<span class='error_goto'> if (unlikely(!__pyx_t_5)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_5);
__pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_float32);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
__pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_GetAttrStr</span>(__pyx_t_5, __pyx_n_s_float32);<span
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_5); __pyx_t_5 = 0;
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_2, __pyx_n_s_dtype, __pyx_t_1) &lt; 0) <span class='error_goto'>__PYX_ERR(0, 24, __pyx_L1_error)</span>
if (<span class='py_c_api'>PyDict_SetItem</span>(__pyx_t_2, __pyx_n_s_dtype, __pyx_t_1) &lt; 0) <span
class='error_goto'>__PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_1); __pyx_t_1 = 0;
__pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_4, __pyx_t_3, __pyx_t_2);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
__pyx_t_1 = <span class='pyx_c_api'>__Pyx_PyObject_Call</span>(__pyx_t_4, __pyx_t_3, __pyx_t_2);<span
class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 24, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_4); __pyx_t_4 = 0;
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_3); __pyx_t_3 = 0;
<span class='pyx_macro_api'>__Pyx_DECREF</span>(__pyx_t_2); __pyx_t_2 = 0;
__pyx_v_img_std = __pyx_t_1;
__pyx_t_1 = 0;
</pre><pre class="cython line score-2" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">25</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span class="p">:]</span> <span class="n">img_lcn_view</span> <span class="o">=</span> <span class="n">img_lcn</span></pre>
<pre class='cython code score-2 '> __pyx_t_6 = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_lcn, PyBUF_WRITABLE);<span class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 25, __pyx_L1_error)</span>
</pre>
<pre class="cython line score-2"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">25</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span
class="p">:]</span> <span class="n">img_lcn_view</span> <span class="o">=</span> <span
class="n">img_lcn</span></pre>
<pre class='cython code score-2 '> __pyx_t_6 = <span
class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_lcn, PyBUF_WRITABLE);<span
class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 25, __pyx_L1_error)</span>
__pyx_v_img_lcn_view = __pyx_t_6;
__pyx_t_6.memview = NULL;
__pyx_t_6.data = NULL;
</pre><pre class="cython line score-2" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">26</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span class="p">:]</span> <span class="n">img_std_view</span> <span class="o">=</span> <span class="n">img_std</span></pre>
<pre class='cython code score-2 '> __pyx_t_6 = <span class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_std, PyBUF_WRITABLE);<span class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 26, __pyx_L1_error)</span>
</pre>
<pre class="cython line score-2"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">26</span>: <span class="k">cdef</span> <span class="kt">float</span>[<span class="p">:,</span> <span
class="p">:]</span> <span class="n">img_std_view</span> <span class="o">=</span> <span
class="n">img_std</span></pre>
<pre class='cython code score-2 '> __pyx_t_6 = <span
class='pyx_c_api'>__Pyx_PyObject_to_MemoryviewSlice_dsds_float</span>(__pyx_v_img_std, PyBUF_WRITABLE);<span
class='error_goto'> if (unlikely(!__pyx_t_6.memview)) __PYX_ERR(0, 26, __pyx_L1_error)</span>
__pyx_v_img_std_view = __pyx_t_6;
__pyx_t_6.memview = NULL;
__pyx_t_6.data = NULL;
</pre><pre class="cython line score-0">&#xA0;<span class="">27</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">28</span>: <span class="c"># temporary c variables</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">29</span>: <span class="k">cdef</span> <span class="kt">float</span> <span class="nf">tmp</span><span class="p">,</span> <span class="nf">mean</span><span class="p">,</span> <span class="nf">stddev</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">30</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">m</span><span class="p">,</span> <span class="nf">n</span><span class="p">,</span> <span class="nf">i</span><span class="p">,</span> <span class="nf">j</span></pre>
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">31</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span class="nf">ks</span> <span class="o">=</span> <span class="n">kernel_size</span></pre>
<pre class='cython code score-0 '> __pyx_v_ks = __pyx_v_kernel_size;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">32</span>: <span class="k">cdef</span> <span class="kt">float</span> <span class="nf">eps</span> <span class="o">=</span> <span class="n">epsilon</span></pre>
<pre class='cython code score-0 '> __pyx_v_eps = __pyx_v_epsilon;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">33</span>: <span class="k">cdef</span> <span class="kt">float</span> <span class="nf">num</span> <span class="o">=</span> <span class="p">(</span><span class="n">ks</span><span class="o">*</span><span class="mf">2</span><span class="o">+</span><span class="mf">1</span><span class="p">)</span><span class="o">**</span><span class="mf">2</span></pre>
<pre class='cython code score-0 '> __pyx_v_num = __Pyx_pow_Py_ssize_t(((__pyx_v_ks * 2) + 1), 2);
</pre><pre class="cython line score-0">&#xA0;<span class="">34</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">35</span>: <span class="c"># for all pixels do</span></pre>
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">36</span>: <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ks</span><span class="p">,</span><span class="n">M</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_7 = (__pyx_v_M - __pyx_v_ks);
</pre>
<pre class="cython line score-0">&#xA0;<span class="">27</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">28</span>: <span class="c"># temporary c variables</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">29</span>: <span class="k">cdef</span> <span class="kt">float</span> <span
class="nf">tmp</span><span class="p">,</span> <span class="nf">mean</span><span class="p">,</span> <span
class="nf">stddev</span></pre>
<pre class="cython line score-0">&#xA0;<span class="">30</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
class="nf">m</span><span class="p">,</span> <span class="nf">n</span><span class="p">,</span> <span
class="nf">i</span><span class="p">,</span> <span class="nf">j</span></pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">31</span>: <span class="k">cdef</span> <span class="kt">Py_ssize_t</span> <span
class="nf">ks</span> <span class="o">=</span> <span class="n">kernel_size</span></pre>
<pre class='cython code score-0 '> __pyx_v_ks = __pyx_v_kernel_size;
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">32</span>: <span class="k">cdef</span> <span class="kt">float</span> <span
class="nf">eps</span> <span class="o">=</span> <span class="n">epsilon</span></pre>
<pre class='cython code score-0 '> __pyx_v_eps = __pyx_v_epsilon;
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">33</span>: <span class="k">cdef</span> <span class="kt">float</span> <span
class="nf">num</span> <span class="o">=</span> <span class="p">(</span><span class="n">ks</span><span
class="o">*</span><span class="mf">2</span><span class="o">+</span><span class="mf">1</span><span class="p">)</span><span
class="o">**</span><span class="mf">2</span></pre>
<pre class='cython code score-0 '> __pyx_v_num = __Pyx_pow_Py_ssize_t(((__pyx_v_ks * 2) + 1), 2);
</pre>
<pre class="cython line score-0">&#xA0;<span class="">34</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">35</span>: <span
class="c"># for all pixels do</span></pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">36</span>: <span class="k">for</span> <span class="n">m</span> <span class="ow">in</span> <span
class="nb">range</span><span class="p">(</span><span class="n">ks</span><span class="p">,</span><span
class="n">M</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_7 = (__pyx_v_M - __pyx_v_ks);
__pyx_t_8 = __pyx_t_7;
for (__pyx_t_9 = __pyx_v_ks; __pyx_t_9 &lt; __pyx_t_8; __pyx_t_9+=1) {
__pyx_v_m = __pyx_t_9;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">37</span>: <span class="k">for</span> <span class="n">n</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ks</span><span class="p">,</span><span class="n">N</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_10 = (__pyx_v_N - __pyx_v_ks);
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">37</span>: <span class="k">for</span> <span class="n">n</span> <span
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="n">ks</span><span
class="p">,</span><span class="n">N</span><span class="o">-</span><span class="n">ks</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_10 = (__pyx_v_N - __pyx_v_ks);
__pyx_t_11 = __pyx_t_10;
for (__pyx_t_12 = __pyx_v_ks; __pyx_t_12 &lt; __pyx_t_11; __pyx_t_12+=1) {
__pyx_v_n = __pyx_t_12;
</pre><pre class="cython line score-0">&#xA0;<span class="">38</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">39</span>: <span class="c"># calculate mean</span></pre>
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">40</span>: <span class="n">mean</span> <span class="o">=</span> <span class="mf">0</span><span class="p">;</span></pre>
<pre class='cython code score-0 '> __pyx_v_mean = 0.0;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">41</span>: <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1);
</pre>
<pre class="cython line score-0">&#xA0;<span class="">38</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">39</span>: <span class="c"># calculate mean</span></pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">40</span>: <span class="n">mean</span> <span class="o">=</span> <span
class="mf">0</span><span class="p">;</span></pre>
<pre class='cython code score-0 '> __pyx_v_mean = 0.0;
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">41</span>: <span class="k">for</span> <span class="n">i</span> <span
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1);
__pyx_t_14 = __pyx_t_13;
for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 &lt; __pyx_t_14; __pyx_t_15+=1) {
__pyx_v_i = __pyx_t_15;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">42</span>: <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1);
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">42</span>: <span class="k">for</span> <span class="n">j</span> <span
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1);
__pyx_t_17 = __pyx_t_16;
for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 &lt; __pyx_t_17; __pyx_t_18+=1) {
__pyx_v_j = __pyx_t_18;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">43</span>: <span class="n">mean</span> <span class="o">+=</span> <span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span class="p">]</span></pre>
<pre class='cython code score-0 '> __pyx_t_19 = (__pyx_v_m + __pyx_v_i);
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">43</span>: <span class="n">mean</span> <span class="o">+=</span> <span
class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span
class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span
class="p">]</span></pre>
<pre class='cython code score-0 '> __pyx_t_19 = (__pyx_v_m + __pyx_v_i);
__pyx_t_20 = (__pyx_v_n + __pyx_v_j);
__pyx_v_mean = (__pyx_v_mean + (*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_19 * __pyx_v_img.strides[0]) ) + __pyx_t_20 * __pyx_v_img.strides[1]) ))));
}
}
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">44</span>: <span class="n">mean</span> <span class="o">=</span> <span class="n">mean</span><span class="o">/</span><span class="n">num</span></pre>
<pre class='cython code score-0 '> __pyx_v_mean = (__pyx_v_mean / __pyx_v_num);
</pre><pre class="cython line score-0">&#xA0;<span class="">45</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">46</span>: <span class="c"># calculate std dev</span></pre>
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">47</span>: <span class="n">stddev</span> <span class="o">=</span> <span class="mf">0</span><span class="p">;</span></pre>
<pre class='cython code score-0 '> __pyx_v_stddev = 0.0;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">48</span>: <span class="k">for</span> <span class="n">i</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1);
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">44</span>: <span class="n">mean</span> <span class="o">=</span> <span
class="n">mean</span><span class="o">/</span><span class="n">num</span></pre>
<pre class='cython code score-0 '> __pyx_v_mean = (__pyx_v_mean / __pyx_v_num);
</pre>
<pre class="cython line score-0">&#xA0;<span class="">45</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">46</span>: <span
class="c"># calculate std dev</span></pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">47</span>: <span class="n">stddev</span> <span class="o">=</span> <span
class="mf">0</span><span class="p">;</span></pre>
<pre class='cython code score-0 '> __pyx_v_stddev = 0.0;
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">48</span>: <span class="k">for</span> <span class="n">i</span> <span
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_13 = (__pyx_v_ks + 1);
__pyx_t_14 = __pyx_t_13;
for (__pyx_t_15 = (-__pyx_v_ks); __pyx_t_15 &lt; __pyx_t_14; __pyx_t_15+=1) {
__pyx_v_i = __pyx_t_15;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">49</span>: <span class="k">for</span> <span class="n">j</span> <span class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1);
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">49</span>: <span class="k">for</span> <span class="n">j</span> <span
class="ow">in</span> <span class="nb">range</span><span class="p">(</span><span class="o">-</span><span
class="n">ks</span><span class="p">,</span><span class="n">ks</span><span class="o">+</span><span
class="mf">1</span><span class="p">):</span></pre>
<pre class='cython code score-0 '> __pyx_t_16 = (__pyx_v_ks + 1);
__pyx_t_17 = __pyx_t_16;
for (__pyx_t_18 = (-__pyx_v_ks); __pyx_t_18 &lt; __pyx_t_17; __pyx_t_18+=1) {
__pyx_v_j = __pyx_t_18;
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">50</span>: <span class="n">stddev</span> <span class="o">=</span> <span class="n">stddev</span> <span class="o">+</span> <span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span><span class="o">*</span><span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span></pre>
<pre class='cython code score-0 '> __pyx_t_21 = (__pyx_v_m + __pyx_v_i);
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">50</span>: <span class="n">stddev</span> <span class="o">=</span> <span
class="n">stddev</span> <span class="o">+</span> <span class="p">(</span><span class="n">img</span><span
class="p">[</span><span class="n">m</span><span class="o">+</span><span class="n">i</span><span
class="p">,</span> <span class="n">n</span><span class="o">+</span><span class="n">j</span><span
class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span><span
class="o">*</span><span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span
class="o">+</span><span class="n">i</span><span class="p">,</span> <span class="n">n</span><span
class="o">+</span><span class="n">j</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span
class="p">)</span></pre>
<pre class='cython code score-0 '> __pyx_t_21 = (__pyx_v_m + __pyx_v_i);
__pyx_t_22 = (__pyx_v_n + __pyx_v_j);
__pyx_t_23 = (__pyx_v_m + __pyx_v_i);
__pyx_t_24 = (__pyx_v_n + __pyx_v_j);
__pyx_v_stddev = (__pyx_v_stddev + (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_21 * __pyx_v_img.strides[0]) ) + __pyx_t_22 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) * ((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_23 * __pyx_v_img.strides[0]) ) + __pyx_t_24 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean)));
}
}
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">51</span>: <span class="n">stddev</span> <span class="o">=</span> <span class="n">sqrt</span><span class="p">(</span><span class="n">stddev</span><span class="o">/</span><span class="n">num</span><span class="p">)</span></pre>
<pre class='cython code score-0 '> __pyx_v_stddev = sqrt((__pyx_v_stddev / __pyx_v_num));
</pre><pre class="cython line score-0">&#xA0;<span class="">52</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">53</span>: <span class="c"># compute normalized image (add epsilon) and std dev image</span></pre>
<pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">54</span>: <span class="n">img_lcn_view</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span class="p">)</span><span class="o">/</span><span class="p">(</span><span class="n">stddev</span><span class="o">+</span><span class="n">eps</span><span class="p">)</span></pre>
<pre class='cython code score-0 '> __pyx_t_25 = __pyx_v_m;
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">51</span>: <span class="n">stddev</span> <span class="o">=</span> <span
class="n">sqrt</span><span class="p">(</span><span class="n">stddev</span><span class="o">/</span><span
class="n">num</span><span class="p">)</span></pre>
<pre class='cython code score-0 '> __pyx_v_stddev = sqrt((__pyx_v_stddev / __pyx_v_num));
</pre>
<pre class="cython line score-0">&#xA0;<span class="">52</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">53</span>: <span class="c"># compute normalized image (add epsilon) and std dev image</span></pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">54</span>: <span class="n">img_lcn_view</span><span class="p">[</span><span
class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span
class="p">(</span><span class="n">img</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span
class="n">n</span><span class="p">]</span><span class="o">-</span><span class="n">mean</span><span
class="p">)</span><span class="o">/</span><span class="p">(</span><span class="n">stddev</span><span
class="o">+</span><span class="n">eps</span><span class="p">)</span></pre>
<pre class='cython code score-0 '> __pyx_t_25 = __pyx_v_m;
__pyx_t_26 = __pyx_v_n;
__pyx_t_27 = __pyx_v_m;
__pyx_t_28 = __pyx_v_n;
*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_lcn_view.data + __pyx_t_27 * __pyx_v_img_lcn_view.strides[0]) ) + __pyx_t_28 * __pyx_v_img_lcn_view.strides[1]) )) = (((*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img.data + __pyx_t_25 * __pyx_v_img.strides[0]) ) + __pyx_t_26 * __pyx_v_img.strides[1]) ))) - __pyx_v_mean) / (__pyx_v_stddev + __pyx_v_eps));
</pre><pre class="cython line score-0" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">55</span>: <span class="n">img_std_view</span><span class="p">[</span><span class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span class="n">stddev</span></pre>
<pre class='cython code score-0 '> __pyx_t_29 = __pyx_v_m;
</pre>
<pre class="cython line score-0"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">55</span>: <span class="n">img_std_view</span><span class="p">[</span><span
class="n">m</span><span class="p">,</span> <span class="n">n</span><span class="p">]</span> <span class="o">=</span> <span
class="n">stddev</span></pre>
<pre class='cython code score-0 '> __pyx_t_29 = __pyx_v_m;
__pyx_t_30 = __pyx_v_n;
*((float *) ( /* dim=1 */ (( /* dim=0 */ (__pyx_v_img_std_view.data + __pyx_t_29 * __pyx_v_img_std_view.strides[0]) ) + __pyx_t_30 * __pyx_v_img_std_view.strides[1]) )) = __pyx_v_stddev;
}
}
</pre><pre class="cython line score-0">&#xA0;<span class="">56</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">57</span>: <span class="c"># return both</span></pre>
<pre class="cython line score-10" onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span class="">58</span>: <span class="k">return</span> <span class="n">img_lcn</span><span class="p">,</span> <span class="n">img_std</span></pre>
<pre class='cython code score-10 '> <span class='pyx_macro_api'>__Pyx_XDECREF</span>(__pyx_r);
</pre>
<pre class="cython line score-0">&#xA0;<span class="">56</span>: </pre>
<pre class="cython line score-0">&#xA0;<span class="">57</span>: <span class="c"># return both</span></pre>
<pre class="cython line score-10"
onclick="(function(s){s.display=s.display==='block'?'none':'block'})(this.nextElementSibling.style)">+<span
class="">58</span>: <span class="k">return</span> <span class="n">img_lcn</span><span class="p">,</span> <span
class="n">img_std</span></pre>
<pre class='cython code score-10 '> <span class='pyx_macro_api'>__Pyx_XDECREF</span>(__pyx_r);
__pyx_t_1 = <span class='py_c_api'>PyTuple_New</span>(2);<span class='error_goto'> if (unlikely(!__pyx_t_1)) __PYX_ERR(0, 58, __pyx_L1_error)</span>
<span class='refnanny'>__Pyx_GOTREF</span>(__pyx_t_1);
<span class='pyx_macro_api'>__Pyx_INCREF</span>(__pyx_v_img_lcn);
@ -717,4 +903,7 @@ static PyObject *__pyx_pf_3lcn_normalize(CYTHON_UNUSED PyObject *__pyx_self, __P
__pyx_r = __pyx_t_1;
__pyx_t_1 = 0;
goto __pyx_L0;
</pre></div></body></html>
</pre>
</div>
</body>
</html>

View File

@ -2,5 +2,5 @@ from distutils.core import setup
from Cython.Build import cythonize
setup(
ext_modules = cythonize("lcn.pyx",annotate=True)
ext_modules=cythonize("lcn.pyx", annotate=True)
)

View File

@ -5,43 +5,43 @@ from scipy import misc
# load and convert to float
img = misc.imread('img.png')
img = img.astype(np.float32)/255.0
img = img.astype(np.float32) / 255.0
# normalize
img_lcn, img_std = lcn.normalize(img,5,0.05)
img_lcn, img_std = lcn.normalize(img, 5, 0.05)
# normalize to reasonable range between 0 and 1
#img_lcn = img_lcn/3.0
#img_lcn = np.maximum(img_lcn,0.0)
#img_lcn = np.minimum(img_lcn,1.0)
# img_lcn = img_lcn/3.0
# img_lcn = np.maximum(img_lcn,0.0)
# img_lcn = np.minimum(img_lcn,1.0)
# save to file
#misc.imsave('lcn2.png',img_lcn)
# misc.imsave('lcn2.png',img_lcn)
print ("Orig Image: %d x %d (%s), Min: %f, Max: %f" % \
(img.shape[0], img.shape[1], img.dtype, img.min(), img.max()))
print ("Norm Image: %d x %d (%s), Min: %f, Max: %f" % \
(img_lcn.shape[0], img_lcn.shape[1], img_lcn.dtype, img_lcn.min(), img_lcn.max()))
print("Orig Image: %d x %d (%s), Min: %f, Max: %f" % \
(img.shape[0], img.shape[1], img.dtype, img.min(), img.max()))
print("Norm Image: %d x %d (%s), Min: %f, Max: %f" % \
(img_lcn.shape[0], img_lcn.shape[1], img_lcn.dtype, img_lcn.min(), img_lcn.max()))
# plot original image
plt.figure(1)
img_plot = plt.imshow(img)
img_plot.set_cmap('gray')
plt.clim(0, 1) # fix range
plt.clim(0, 1) # fix range
plt.tight_layout()
# plot normalized image
plt.figure(2)
img_lcn_plot = plt.imshow(img_lcn)
img_lcn_plot.set_cmap('gray')
#plt.clim(0, 1) # fix range
# plt.clim(0, 1) # fix range
plt.tight_layout()
# plot stddev image
plt.figure(3)
img_std_plot = plt.imshow(img_std)
img_std_plot.set_cmap('gray')
#plt.clim(0, 0.1) # fix range
# plt.clim(0, 0.1) # fix range
plt.tight_layout()
plt.show()

View File

@ -11,28 +11,27 @@ import dataset
def get_data(n, row_from, row_to, train):
imsizes = [(256,384)]
focal_lengths = [160]
dset = dataset.SynDataset(n, imsizes=imsizes, focal_lengths=focal_lengths, train=train)
ims = np.empty((n, row_to-row_from, imsizes[0][1]), dtype=np.uint8)
disps = np.empty((n, row_to-row_from, imsizes[0][1]), dtype=np.float32)
for idx in range(n):
print(f'load sample {idx} train={train}')
sample = dset[idx]
ims[idx] = (sample['im0'][0,row_from:row_to] * 255).astype(np.uint8)
disps[idx] = sample['disp0'][0,row_from:row_to]
return ims, disps
imsizes = [(256, 384)]
focal_lengths = [160]
dset = dataset.SynDataset(n, imsizes=imsizes, focal_lengths=focal_lengths, train=train)
ims = np.empty((n, row_to - row_from, imsizes[0][1]), dtype=np.uint8)
disps = np.empty((n, row_to - row_from, imsizes[0][1]), dtype=np.float32)
for idx in range(n):
print(f'load sample {idx} train={train}')
sample = dset[idx]
ims[idx] = (sample['im0'][0, row_from:row_to] * 255).astype(np.uint8)
disps[idx] = sample['disp0'][0, row_from:row_to]
return ims, disps
params = hd.TrainParams(
n_trees=4,
max_tree_depth=,
n_test_split_functions=50,
n_test_thresholds=10,
n_test_samples=4096,
min_samples_to_split=16,
min_samples_for_leaf=8)
n_trees=4,
max_tree_depth=,
n_test_split_functions=50,
n_test_thresholds=10,
n_test_samples=4096,
min_samples_to_split=16,
min_samples_for_leaf=8)
n_disp_bins = 20
depth_switch = 0
@ -45,21 +44,23 @@ n_test_samples = 32
train_ims, train_disps = get_data(n_train_samples, row_from, row_to, True)
test_ims, test_disps = get_data(n_test_samples, row_from, row_to, False)
for tree_depth in [8,10,12,14,16]:
depth_switch = tree_depth - 4
for tree_depth in [8, 10, 12, 14, 16]:
depth_switch = tree_depth - 4
prefix = f'td{tree_depth}_ds{depth_switch}'
prefix = Path(f'./forests/{prefix}/')
prefix.mkdir(parents=True, exist_ok=True)
prefix = f'td{tree_depth}_ds{depth_switch}'
prefix = Path(f'./forests/{prefix}/')
prefix.mkdir(parents=True, exist_ok=True)
hd.train_forest(params, train_ims, train_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch, forest_prefix=str(prefix / 'fr'))
hd.train_forest(params, train_ims, train_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch,
forest_prefix=str(prefix / 'fr'))
es = hd.eval_forest(test_ims, test_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch, forest_prefix=str(prefix / 'fr'))
es = hd.eval_forest(test_ims, test_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch,
forest_prefix=str(prefix / 'fr'))
np.save(str(prefix / 'ta.npy'), test_disps)
np.save(str(prefix / 'es.npy'), es)
np.save(str(prefix / 'ta.npy'), test_disps)
np.save(str(prefix / 'es.npy'), es)
# plt.figure();
# plt.subplot(2,1,1); plt.imshow(test_disps[0], vmin=0, vmax=4);
# plt.subplot(2,1,2); plt.imshow(es[0], vmin=0, vmax=4);
# plt.show()
# plt.figure();
# plt.subplot(2,1,1); plt.imshow(test_disps[0], vmin=0, vmax=4);
# plt.subplot(2,1,2); plt.imshow(es[0], vmin=0, vmax=4);
# plt.show()

View File

@ -8,7 +8,6 @@ import os
this_dir = os.path.dirname(__file__)
extra_compile_args = ['-O3', '-std=c++11']
extra_link_args = []
@ -22,24 +21,20 @@ library_dirs = []
libraries = ['m']
setup(
name="hyperdepth",
cmdclass= {'build_ext': build_ext},
ext_modules=[
Extension('hyperdepth',
sources,
extra_objects=extra_objects,
language='c++',
library_dirs=library_dirs,
libraries=libraries,
include_dirs=[
np.get_include(),
],
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args
)
]
name="hyperdepth",
cmdclass={'build_ext': build_ext},
ext_modules=[
Extension('hyperdepth',
sources,
extra_objects=extra_objects,
language='c++',
library_dirs=library_dirs,
libraries=libraries,
include_dirs=[
np.get_include(),
],
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args
)
]
)

View File

@ -6,10 +6,13 @@ orig = cv2.imread('disp_orig.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
ta = cv2.imread('disp_ta.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
es = cv2.imread('disp_es.png', cv2.IMREAD_ANYDEPTH).astype(np.float32)
plt.figure()
plt.subplot(2,2,1); plt.imshow(orig / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2,2,2); plt.imshow(ta / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2,2,3); plt.imshow(es / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2,2,4); plt.imshow(np.abs(es - ta) / 16, vmin=0, vmax=1, cmap='magma')
plt.subplot(2, 2, 1);
plt.imshow(orig / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2, 2, 2);
plt.imshow(ta / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2, 2, 3);
plt.imshow(es / 16, vmin=0, vmax=4, cmap='magma')
plt.subplot(2, 2, 4);
plt.imshow(np.abs(es - ta) / 16, vmin=0, vmax=1, cmap='magma')
plt.show()

View File

@ -12,226 +12,263 @@ import torchext
from model import networks
from data import dataset
class Worker(torchext.Worker):
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, train_batch_size=train_batch_size, test_batch_size=test_batch_size, save_frequency=save_frequency, **kwargs)
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers,
train_batch_size=train_batch_size, test_batch_size=test_batch_size,
save_frequency=save_frequency, **kwargs)
self.ms = args.ms
self.pattern_path = args.pattern_path
self.lcn_radius = args.lcn_radius
self.dp_weight = args.dp_weight
self.data_type = args.data_type
self.ms = args.ms
self.pattern_path = args.pattern_path
self.lcn_radius = args.lcn_radius
self.dp_weight = args.dp_weight
self.data_type = args.data_type
self.imsizes = [(480,640)]
for iter in range(3):
self.imsizes.append((int(self.imsizes[-1][0]/2), int(self.imsizes[-1][1]/2)))
self.imsizes = [(488, 648)]
for iter in range(3):
self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2)))
with open('config.json') as fp:
config = json.load(fp)
data_root = Path(config['DATA_ROOT'])
self.settings_path = data_root / self.data_type / 'settings.pkl'
sample_paths = sorted((data_root / self.data_type).glob('0*/'))
with open('config.json') as fp:
config = json.load(fp)
data_root = Path(config['DATA_ROOT'])
self.settings_path = data_root / self.data_type / 'settings.pkl'
sample_paths = sorted((data_root / self.data_type).glob('0*/'))
self.train_paths = sample_paths[2**10:]
self.test_paths = sample_paths[:2**8]
self.train_paths = sample_paths[2 ** 10:]
self.test_paths = sample_paths[:2 ** 8]
# supervise the edge encoder with only 2**8 samples
self.train_edge = len(self.train_paths) - 2**8
# supervise the edge encoder with only 2**8 samples
self.train_edge = len(self.train_paths) - 2 ** 8
self.lcn_in = networks.LCN(self.lcn_radius, 0.05)
self.disparity_loss = networks.DisparityLoss()
self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device))
self.lcn_in = networks.LCN(self.lcn_radius, 0.05)
self.disparity_loss = networks.DisparityLoss()
self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device))
# evaluate in the region where opencv Block Matching has valid values
self.eval_mask = np.zeros(self.imsizes[0])
self.eval_mask[13:self.imsizes[0][0]-13, 140:self.imsizes[0][1]-13]=1
self.eval_mask = self.eval_mask.astype(np.bool)
self.eval_h = self.imsizes[0][0]-2*13
self.eval_w = self.imsizes[0][1]-13-140
# evaluate in the region where opencv Block Matching has valid values
self.eval_mask = np.zeros(self.imsizes[0])
self.eval_mask[13:self.imsizes[0][0] - 13, 140:self.imsizes[0][1] - 13] = 1
self.eval_mask = self.eval_mask.astype(np.bool)
self.eval_h = self.imsizes[0][0] - 2 * 13
self.eval_w = self.imsizes[0][1] - 13 - 140
def get_train_set(self):
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, track_length=1)
def get_train_set(self):
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True,
track_length=1)
return train_set
return train_set
def get_test_sets(self):
test_sets = torchext.TestSets()
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True, track_length=1)
test_sets.append('simple', test_set, test_frequency=1)
def get_test_sets(self):
test_sets = torchext.TestSets()
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True,
track_length=1)
test_sets.append('simple', test_set, test_frequency=1)
# initialize photometric loss modules according to image sizes
self.losses = []
for imsize, pat in zip(test_set.imsizes, test_set.patterns):
pat = pat.mean(axis=2)
pat = torch.from_numpy(pat[None][None].astype(np.float32))
pat = pat.to(self.train_device)
self.lcn_in = self.lcn_in.to(self.train_device)
pat,_ = self.lcn_in(pat)
pat = torch.cat([pat for idx in range(3)], dim=1)
self.losses.append( networks.RectifiedPatternSimilarityLoss(imsize[0],imsize[1], pattern=pat) )
# initialize photometric loss modules according to image sizes
self.losses = []
for imsize, pat in zip(test_set.imsizes, test_set.patterns):
pat = pat.mean(axis=2)
pat = torch.from_numpy(pat[None][None].astype(np.float32))
pat = pat.to(self.train_device)
self.lcn_in = self.lcn_in.to(self.train_device)
pat, _ = self.lcn_in(pat)
pat = torch.cat([pat for idx in range(3)], dim=1)
self.losses.append(networks.RectifiedPatternSimilarityLoss(imsize[0], imsize[1], pattern=pat))
return test_sets
return test_sets
def copy_data(self, data, device, requires_grad, train):
self.lcn_in = self.lcn_in.to(device)
def copy_data(self, data, device, requires_grad, train):
self.lcn_in = self.lcn_in.to(device)
self.data = {}
for key, val in data.items():
grad = 'im' in key and requires_grad
self.data[key] = val.to(device).requires_grad_(requires_grad=grad)
self.data = {}
for key, val in data.items():
grad = 'im' in key and requires_grad
self.data[key] = val.to(device).requires_grad_(requires_grad=grad)
# apply lcn to IR input
# concatenate the normalized IR input and the original IR image
if 'im' in key and 'blend' not in key:
im = self.data[key]
im_lcn,im_std = self.lcn_in(im)
im_cat = torch.cat((im_lcn, im), dim=1)
key_std = key.replace('im','std')
self.data[key]=im_cat
self.data[key_std] = im_std.to(device).detach()
# apply lcn to IR input
# concatenate the normalized IR input and the original IR image
if 'im' in key and 'blend' not in key:
im = self.data[key]
im_lcn, im_std = self.lcn_in(im)
im_cat = torch.cat((im_lcn, im), dim=1)
key_std = key.replace('im', 'std')
self.data[key] = im_cat
self.data[key_std] = im_std.to(device).detach()
def net_forward(self, net, train):
out = net(self.data['im0'])
return out
def net_forward(self, net, train):
out = net(self.data['im0'])
return out
def loss_forward(self, out, train):
out, edge = out
if not(isinstance(out, tuple) or isinstance(out, list)):
out = [out]
if not(isinstance(edge, tuple) or isinstance(edge, list)):
edge = [edge]
def loss_forward(self, out, train):
out, edge = out
if not (isinstance(out, tuple) or isinstance(out, list)):
out = [out]
if not (isinstance(edge, tuple) or isinstance(edge, list)):
edge = [edge]
vals = []
vals = []
# apply photometric loss
for s,l,o in zip(itertools.count(), self.losses, out):
val, pattern_proj = l(o, self.data[f'im{s}'][:,0:1,...], self.data[f'std{s}'])
if s == 0:
self.pattern_proj = pattern_proj.detach()
vals.append(val)
# apply photometric loss
for s, l, o in zip(itertools.count(), self.losses, out):
val, pattern_proj = l(o, self.data[f'im{s}'][:, 0:1, ...], self.data[f'std{s}'])
if s == 0:
self.pattern_proj = pattern_proj.detach()
vals.append(val)
# apply disparity loss
# 1-edge as ground truth edge if inversed
edge0 = 1-torch.sigmoid(edge[0])
val = self.disparity_loss(out[0], edge0)
if self.dp_weight>0:
vals.append(val * self.dp_weight)
# apply disparity loss
# 1-edge as ground truth edge if inversed
edge0 = 1 - torch.sigmoid(edge[0])
val = self.disparity_loss(out[0], edge0)
if self.dp_weight > 0:
vals.append(val * self.dp_weight)
# apply edge loss on a subset of training samples
for s,e in zip(itertools.count(), edge):
# inversed ground truth edge where 0 means edge
grad = self.data[f'grad{s}']<0.2
grad = grad.to(torch.float32)
ids = self.data['id']
mask = ids>self.train_edge
if mask.sum()>0:
val = self.edge_loss(e[mask], grad[mask])
else:
val = torch.zeros_like(vals[0])
if s == 0:
self.edge = e.detach()
self.edge = torch.sigmoid(self.edge)
self.edge_gt = grad.detach()
vals.append(val)
# apply edge loss on a subset of training samples
for s, e in zip(itertools.count(), edge):
# inversed ground truth edge where 0 means edge
grad = self.data[f'grad{s}'] < 0.2
grad = grad.to(torch.float32)
ids = self.data['id']
mask = ids > self.train_edge
if mask.sum() > 0:
val = self.edge_loss(e[mask], grad[mask])
else:
val = torch.zeros_like(vals[0])
if s == 0:
self.edge = e.detach()
self.edge = torch.sigmoid(self.edge)
self.edge_gt = grad.detach()
vals.append(val)
return vals
return vals
def numpy_in_out(self, output):
output, edge = output
if not(isinstance(output, tuple) or isinstance(output, list)):
output = [output]
es = output[0].detach().to('cpu').numpy()
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
im = self.data['im0'][:,0:1,...].detach().to('cpu').numpy()
def numpy_in_out(self, output):
output, edge = output
if not (isinstance(output, tuple) or isinstance(output, list)):
output = [output]
es = output[0].detach().to('cpu').numpy()
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
im = self.data['im0'][:, 0:1, ...].detach().to('cpu').numpy()
ma = gt>0
return es, gt, im, ma
ma = gt > 0
return es, gt, im, ma
def write_img(self, out_path, es, gt, im, ma):
logging.info(f'write img {out_path}')
u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0]))
def write_img(self, out_path, es, gt, im, ma):
logging.info(f'write img {out_path}')
u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0]))
diff = np.abs(es - gt)
diff = np.abs(es - gt)
vmin, vmax = np.nanmin(gt), np.nanmax(gt)
vmin = vmin - 0.2*(vmax-vmin)
vmax = vmax + 0.2*(vmax-vmin)
vmin, vmax = np.nanmin(gt), np.nanmax(gt)
vmin = vmin - 0.2 * (vmax - vmin)
vmax = vmax + 0.2 * (vmax - vmin)
pattern_proj = self.pattern_proj.to('cpu').numpy()[0,0]
im_orig = self.data['im0'].detach().to('cpu').numpy()[0,0]
pattern_diff = np.abs(im_orig - pattern_proj)
pattern_proj = self.pattern_proj.to('cpu').numpy()[0, 0]
im_orig = self.data['im0'].detach().to('cpu').numpy()[0, 0]
pattern_diff = np.abs(im_orig - pattern_proj)
fig = plt.figure(figsize=(16, 16))
es_ = co.cmap.color_depth_map(es, scale=vmax)
gt_ = co.cmap.color_depth_map(gt, scale=vmax)
diff_ = co.cmap.color_error_image(diff, BGR=True)
fig = plt.figure(figsize=(16,16))
es_ = co.cmap.color_depth_map(es, scale=vmax)
gt_ = co.cmap.color_depth_map(gt, scale=vmax)
diff_ = co.cmap.color_error_image(diff, BGR=True)
# plot disparities, ground truth disparity is shown only for reference
ax = plt.subplot(3, 3, 1)
plt.imshow(es_[..., [2, 1, 0]])
plt.xticks([])
plt.yticks([])
ax.set_title(f'Disparity Est. {es.min():.4f}/{es.max():.4f}')
ax = plt.subplot(3, 3, 2)
plt.imshow(gt_[..., [2, 1, 0]])
plt.xticks([])
plt.yticks([])
ax.set_title(f'Disparity GT {np.nanmin(gt):.4f}/{np.nanmax(gt):.4f}')
ax = plt.subplot(3, 3, 3)
plt.imshow(diff_[..., [2, 1, 0]])
plt.xticks([])
plt.yticks([])
ax.set_title(f'Disparity Err. {diff.mean():.5f}')
# plot disparities, ground truth disparity is shown only for reference
ax = plt.subplot(3,3,1); plt.imshow(es_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity Est. {es.min():.4f}/{es.max():.4f}')
ax = plt.subplot(3,3,2); plt.imshow(gt_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity GT {np.nanmin(gt):.4f}/{np.nanmax(gt):.4f}')
ax = plt.subplot(3,3,3); plt.imshow(diff_[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'Disparity Err. {diff.mean():.5f}')
# plot edges
edge = self.edge.to('cpu').numpy()[0, 0]
edge_gt = self.edge_gt.to('cpu').numpy()[0, 0]
edge_err = np.abs(edge - edge_gt)
ax = plt.subplot(3, 3, 4);
plt.imshow(edge, cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Edge Est. {edge.min():.5f}/{edge.max():.5f}')
ax = plt.subplot(3, 3, 5);
plt.imshow(edge_gt, cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Edge GT {edge_gt.min():.5f}/{edge_gt.max():.5f}')
ax = plt.subplot(3, 3, 6);
plt.imshow(edge_err, cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Edge Err. {edge_err.mean():.5f}')
# plot edges
edge = self.edge.to('cpu').numpy()[0,0]
edge_gt = self.edge_gt.to('cpu').numpy()[0,0]
edge_err = np.abs(edge - edge_gt)
ax = plt.subplot(3,3,4); plt.imshow(edge, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge Est. {edge.min():.5f}/{edge.max():.5f}')
ax = plt.subplot(3,3,5); plt.imshow(edge_gt, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge GT {edge_gt.min():.5f}/{edge_gt.max():.5f}')
ax = plt.subplot(3,3,6); plt.imshow(edge_err, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Edge Err. {edge_err.mean():.5f}')
# plot normalized IR input and warped pattern
ax = plt.subplot(3, 3, 7);
plt.imshow(im, vmin=im.min(), vmax=im.max(), cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'IR input {im.mean():.5f}/{im.std():.5f}')
ax = plt.subplot(3, 3, 8);
plt.imshow(pattern_proj, vmin=im.min(), vmax=im.max(), cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'Warped Pattern {pattern_proj.mean():.5f}/{pattern_proj.std():.5f}')
im_std = self.data['std0'].to('cpu').numpy()[0, 0]
ax = plt.subplot(3, 3, 9);
plt.imshow(im_std, cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'IR std {im_std.min():.5f}/{im_std.max():.5f}')
# plot normalized IR input and warped pattern
ax = plt.subplot(3,3,7); plt.imshow(im, vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'IR input {im.mean():.5f}/{im.std():.5f}')
ax = plt.subplot(3,3,8); plt.imshow(pattern_proj, vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'Warped Pattern {pattern_proj.mean():.5f}/{pattern_proj.std():.5f}')
im_std = self.data['std0'].to('cpu').numpy()[0,0]
ax = plt.subplot(3,3,9); plt.imshow(im_std, cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'IR std {im_std.min():.5f}/{im_std.max():.5f}')
plt.tight_layout()
plt.savefig(str(out_path))
plt.close(fig)
plt.tight_layout()
plt.savefig(str(out_path))
plt.close(fig)
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks=[]):
if batch_idx % 512 == 0:
out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
es, gt, im, ma = self.numpy_in_out(output)
self.write_img(out_path, es[0, 0], gt[0, 0], im[0, 0], ma[0, 0])
def callback_test_start(self, epoch, set_idx):
self.metric = co.metric.MultipleMetric(
co.metric.DistanceMetric(vec_length=1),
co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5])
)
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks=[]):
if batch_idx % 512 == 0:
out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
es, gt, im, ma = self.numpy_in_out(output)
self.write_img(out_path, es[0,0], gt[0,0], im[0,0], ma[0,0])
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks=[]):
es, gt, im, ma = self.numpy_in_out(output)
if batch_idx % 8 == 0:
out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png'
self.write_img(out_path, es[0, 0], gt[0, 0], im[0, 0], ma[0, 0])
def callback_test_start(self, epoch, set_idx):
self.metric = co.metric.MultipleMetric(
co.metric.DistanceMetric(vec_length=1),
co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5])
)
es, gt, im, ma = self.crop_output(es, gt, im, ma)
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks=[]):
es, gt, im, ma = self.numpy_in_out(output)
es = es.reshape(-1, 1)
gt = gt.reshape(-1, 1)
ma = ma.ravel()
self.metric.add(es, gt, ma)
if batch_idx % 8 == 0:
out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png'
self.write_img(out_path, es[0,0], gt[0,0], im[0,0], ma[0,0])
es, gt, im, ma = self.crop_output(es, gt, im, ma)
es = es.reshape(-1,1)
gt = gt.reshape(-1,1)
ma = ma.ravel()
self.metric.add(es, gt, ma)
def callback_test_stop(self, epoch, set_idx, loss):
logging.info(f'{self.metric}')
for k, v in self.metric.items():
self.metric_add_test(epoch, set_idx, k, v)
def crop_output(self, es, gt, im, ma):
bs = es.shape[0]
es = np.reshape(es[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
gt = np.reshape(gt[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
im = np.reshape(im[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
ma = np.reshape(ma[:,:,self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
return es, gt, im, ma
def callback_test_stop(self, epoch, set_idx, loss):
logging.info(f'{self.metric}')
for k, v in self.metric.items():
self.metric_add_test(epoch, set_idx, k, v)
def crop_output(self, es, gt, im, ma):
bs = es.shape[0]
es = np.reshape(es[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
gt = np.reshape(gt[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
im = np.reshape(im[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
ma = np.reshape(ma[:, :, self.eval_mask], [bs, 1, self.eval_h, self.eval_w])
return es, gt, im, ma
if __name__ == '__main__':
pass
pass

View File

@ -12,287 +12,324 @@ import torchext
from model import networks
from data import dataset
class Worker(torchext.Worker):
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers, train_batch_size=train_batch_size, test_batch_size=test_batch_size, save_frequency=save_frequency, **kwargs)
def __init__(self, args, num_workers=18, train_batch_size=8, test_batch_size=8, save_frequency=1, **kwargs):
super().__init__(args.output_dir, args.exp_name, epochs=args.epochs, num_workers=num_workers,
train_batch_size=train_batch_size, test_batch_size=test_batch_size,
save_frequency=save_frequency, **kwargs)
self.ms = args.ms
self.pattern_path = args.pattern_path
self.lcn_radius = args.lcn_radius
self.dp_weight = args.dp_weight
self.ge_weight = args.ge_weight
self.track_length = args.track_length
self.data_type = args.data_type
assert(self.track_length>1)
self.ms = args.ms
self.pattern_path = args.pattern_path
self.lcn_radius = args.lcn_radius
self.dp_weight = args.dp_weight
self.ge_weight = args.ge_weight
self.track_length = args.track_length
self.data_type = args.data_type
assert (self.track_length > 1)
self.imsizes = [(480,640)]
for iter in range(3):
self.imsizes.append((int(self.imsizes[-1][0]/2), int(self.imsizes[-1][1]/2)))
self.imsizes = [(480, 640)]
for iter in range(3):
self.imsizes.append((int(self.imsizes[-1][0] / 2), int(self.imsizes[-1][1] / 2)))
with open('config.json') as fp:
config = json.load(fp)
data_root = Path(config['DATA_ROOT'])
self.settings_path = data_root / self.data_type / 'settings.pkl'
sample_paths = sorted((data_root / self.data_type).glob('0*/'))
with open('config.json') as fp:
config = json.load(fp)
data_root = Path(config['DATA_ROOT'])
self.settings_path = data_root / self.data_type / 'settings.pkl'
sample_paths = sorted((data_root / self.data_type).glob('0*/'))
self.train_paths = sample_paths[2**10:]
self.test_paths = sample_paths[:2**8]
self.train_paths = sample_paths[2 ** 10:]
self.test_paths = sample_paths[:2 ** 8]
# supervise the edge encoder with only 2**8 samples
self.train_edge = len(self.train_paths) - 2**8
# supervise the edge encoder with only 2**8 samples
self.train_edge = len(self.train_paths) - 2 ** 8
self.lcn_in = networks.LCN(self.lcn_radius, 0.05)
self.disparity_loss = networks.DisparityLoss()
self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device))
self.lcn_in = networks.LCN(self.lcn_radius, 0.05)
self.disparity_loss = networks.DisparityLoss()
self.edge_loss = torch.nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([0.1]).to(self.train_device))
# evaluate in the region where opencv Block Matching has valid values
self.eval_mask = np.zeros(self.imsizes[0])
self.eval_mask[13:self.imsizes[0][0]-13, 140:self.imsizes[0][1]-13]=1
self.eval_mask = self.eval_mask.astype(np.bool)
self.eval_h = self.imsizes[0][0]-2*13
self.eval_w = self.imsizes[0][1]-13-140
# evaluate in the region where opencv Block Matching has valid values
self.eval_mask = np.zeros(self.imsizes[0])
self.eval_mask[13:self.imsizes[0][0] - 13, 140:self.imsizes[0][1] - 13] = 1
self.eval_mask = self.eval_mask.astype(np.bool)
self.eval_h = self.imsizes[0][0] - 2 * 13
self.eval_w = self.imsizes[0][1] - 13 - 140
def get_train_set(self):
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True,
track_length=self.track_length)
return train_set
def get_train_set(self):
train_set = dataset.TrackSynDataset(self.settings_path, self.train_paths, train=True, data_aug=True, track_length=self.track_length)
return train_set
def get_test_sets(self):
test_sets = torchext.TestSets()
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True,
track_length=1)
test_sets.append('simple', test_set, test_frequency=1)
def get_test_sets(self):
test_sets = torchext.TestSets()
test_set = dataset.TrackSynDataset(self.settings_path, self.test_paths, train=False, data_aug=True, track_length=1)
test_sets.append('simple', test_set, test_frequency=1)
self.ph_losses = []
self.ge_losses = []
self.d2ds = []
self.ph_losses = []
self.ge_losses = []
self.d2ds = []
self.lcn_in = self.lcn_in.to('cuda')
for sidx in range(len(test_set.imsizes)):
imsize = test_set.imsizes[sidx]
pat = test_set.patterns[sidx]
pat = pat.mean(axis=2)
pat = torch.from_numpy(pat[None][None].astype(np.float32)).to('cuda')
pat, _ = self.lcn_in(pat)
pat = torch.cat([pat for idx in range(3)], dim=1)
ph_loss = networks.RectifiedPatternSimilarityLoss(imsize[0], imsize[1], pattern=pat)
self.lcn_in = self.lcn_in.to('cuda')
for sidx in range(len(test_set.imsizes)):
imsize = test_set.imsizes[sidx]
pat = test_set.patterns[sidx]
pat = pat.mean(axis=2)
pat = torch.from_numpy(pat[None][None].astype(np.float32)).to('cuda')
pat,_ = self.lcn_in(pat)
pat = torch.cat([pat for idx in range(3)], dim=1)
ph_loss = networks.RectifiedPatternSimilarityLoss(imsize[0],imsize[1], pattern=pat)
K = test_set.getK(sidx)
Ki = np.linalg.inv(K)
K = torch.from_numpy(K)
Ki = torch.from_numpy(Ki)
ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.1)
K = test_set.getK(sidx)
Ki = np.linalg.inv(K)
K = torch.from_numpy(K)
Ki = torch.from_numpy(Ki)
ge_loss = networks.ProjectionDepthSimilarityLoss(K, Ki, imsize[0], imsize[1], clamp=0.1)
self.ph_losses.append(ph_loss)
self.ge_losses.append(ge_loss)
self.ph_losses.append( ph_loss )
self.ge_losses.append( ge_loss )
d2d = networks.DispToDepth(float(test_set.focal_lengths[sidx]), float(test_set.baseline))
self.d2ds.append( d2d )
d2d = networks.DispToDepth(float(test_set.focal_lengths[sidx]), float(test_set.baseline))
self.d2ds.append(d2d)
return test_sets
return test_sets
def copy_data(self, data, device, requires_grad, train):
self.data = {}
def copy_data(self, data, device, requires_grad, train):
self.data = {}
self.lcn_in = self.lcn_in.to(device)
for key, val in data.items():
# from
# batch_size x track_length x ...
# to
# track_length x batch_size x ...
if len(val.shape)>2:
if train:
val = val.transpose(0,1)
self.lcn_in = self.lcn_in.to(device)
for key, val in data.items():
# from
# batch_size x track_length x ...
# to
# track_length x batch_size x ...
if len(val.shape) > 2:
if train:
val = val.transpose(0, 1)
else:
val = val.unsqueeze(0)
grad = 'im' in key and requires_grad
self.data[key] = val.to(device).requires_grad_(requires_grad=grad)
if 'im' in key and 'blend' not in key:
im = self.data[key]
tl = im.shape[0]
bs = im.shape[1]
im_lcn, im_std = self.lcn_in(im.contiguous().view(-1, *im.shape[2:]))
key_std = key.replace('im', 'std')
self.data[key_std] = im_std.view(tl, bs, *im.shape[2:]).to(device)
im_cat = torch.cat((im_lcn.view(tl, bs, *im.shape[2:]), im), dim=2)
self.data[key] = im_cat
def net_forward(self, net, train):
im0 = self.data['im0']
tl = im0.shape[0]
bs = im0.shape[1]
im0 = im0.view(-1, *im0.shape[2:])
out, edge = net(im0)
if not (isinstance(out, tuple) or isinstance(out, list)):
out = out.view(tl, bs, *out.shape[1:])
edge = edge.view(tl, bs, *out.shape[1:])
else:
val = val.unsqueeze(0)
grad = 'im' in key and requires_grad
self.data[key] = val.to(device).requires_grad_(requires_grad=grad)
if 'im' in key and 'blend' not in key:
im = self.data[key]
tl = im.shape[0]
bs = im.shape[1]
im_lcn,im_std = self.lcn_in(im.contiguous().view(-1, *im.shape[2:]))
key_std = key.replace('im','std')
self.data[key_std] = im_std.view(tl, bs, *im.shape[2:]).to(device)
im_cat = torch.cat((im_lcn.view(tl, bs, *im.shape[2:]), im), dim=2)
self.data[key] = im_cat
out = [o.view(tl, bs, *o.shape[1:]) for o in out]
edge = [e.view(tl, bs, *e.shape[1:]) for e in edge]
return out, edge
def net_forward(self, net, train):
im0 = self.data['im0']
tl = im0.shape[0]
bs = im0.shape[1]
im0 = im0.view(-1, *im0.shape[2:])
out, edge = net(im0)
if not(isinstance(out, tuple) or isinstance(out, list)):
out = out.view(tl, bs, *out.shape[1:])
edge = edge.view(tl, bs, *out.shape[1:])
else:
out = [o.view(tl, bs, *o.shape[1:]) for o in out]
edge = [e.view(tl, bs, *e.shape[1:]) for e in edge]
return out, edge
def loss_forward(self, out, train):
out, edge = out
if not (isinstance(out, tuple) or isinstance(out, list)):
out = [out]
vals = []
diffs = []
def loss_forward(self, out, train):
out, edge = out
if not(isinstance(out, tuple) or isinstance(out, list)):
out = [out]
vals = []
diffs = []
# apply photometric loss
for s, l, o in zip(itertools.count(), self.ph_losses, out):
im = self.data[f'im{s}']
im = im.view(-1, *im.shape[2:])
o = o.view(-1, *o.shape[2:])
std = self.data[f'std{s}']
std = std.view(-1, *std.shape[2:])
val, pattern_proj = l(o, im[:, 0:1, ...], std)
vals.append(val)
if s == 0:
self.pattern_proj = pattern_proj.detach()
# apply photometric loss
for s,l,o in zip(itertools.count(), self.ph_losses, out):
im = self.data[f'im{s}']
im = im.view(-1, *im.shape[2:])
o = o.view(-1, *o.shape[2:])
std = self.data[f'std{s}']
std = std.view(-1, *std.shape[2:])
val, pattern_proj = l(o, im[:,0:1,...], std)
vals.append(val)
if s == 0:
self.pattern_proj = pattern_proj.detach()
# apply disparity loss
# 1-edge as ground truth edge if inversed
edge0 = 1 - torch.sigmoid(edge[0])
edge0 = edge0.view(-1, *edge0.shape[2:])
out0 = out[0].view(-1, *out[0].shape[2:])
val = self.disparity_loss(out0, edge0)
if self.dp_weight > 0:
vals.append(val * self.dp_weight)
# apply disparity loss
# 1-edge as ground truth edge if inversed
edge0 = 1-torch.sigmoid(edge[0])
edge0 = edge0.view(-1, *edge0.shape[2:])
out0 = out[0].view(-1, *out[0].shape[2:])
val = self.disparity_loss(out0, edge0)
if self.dp_weight>0:
vals.append(val * self.dp_weight)
# apply edge loss on a subset of training samples
for s, e in zip(itertools.count(), edge):
# inversed ground truth edge where 0 means edge
grad = self.data[f'grad{s}'] < 0.2
grad = grad.to(torch.float32)
ids = self.data['id']
mask = ids > self.train_edge
if mask.sum() > 0:
e = e[:, mask, :]
grad = grad[:, mask, :]
e = e.view(-1, *e.shape[2:])
grad = grad.view(-1, *grad.shape[2:])
val = self.edge_loss(e, grad)
else:
val = torch.zeros_like(vals[0])
vals.append(val)
# apply edge loss on a subset of training samples
for s,e in zip(itertools.count(), edge):
# inversed ground truth edge where 0 means edge
grad = self.data[f'grad{s}']<0.2
grad = grad.to(torch.float32)
ids = self.data['id']
mask = ids>self.train_edge
if mask.sum()>0:
e = e[:,mask,:]
grad = grad[:,mask,:]
e = e.view(-1, *e.shape[2:])
grad = grad.view(-1, *grad.shape[2:])
val = self.edge_loss(e, grad)
else:
val = torch.zeros_like(vals[0])
vals.append(val)
if train is False:
return vals
if train is False:
return vals
# apply geometric loss
R = self.data['R']
t = self.data['t']
ge_num = self.track_length * (self.track_length - 1) / 2
for sidx in range(len(out)):
d2d = self.d2ds[sidx]
depth = d2d(out[sidx])
ge_loss = self.ge_losses[sidx]
imsize = self.imsizes[sidx]
for tidx0 in range(depth.shape[0]):
for tidx1 in range(tidx0 + 1, depth.shape[0]):
depth0 = depth[tidx0]
R0 = R[tidx0]
t0 = t[tidx0]
depth1 = depth[tidx1]
R1 = R[tidx1]
t1 = t[tidx1]
# apply geometric loss
R = self.data['R']
t = self.data['t']
ge_num = self.track_length * (self.track_length-1) / 2
for sidx in range(len(out)):
d2d = self.d2ds[sidx]
depth = d2d(out[sidx])
ge_loss = self.ge_losses[sidx]
imsize = self.imsizes[sidx]
for tidx0 in range(depth.shape[0]):
for tidx1 in range(tidx0+1, depth.shape[0]):
depth0 = depth[tidx0]
R0 = R[tidx0]
t0 = t[tidx0]
depth1 = depth[tidx1]
R1 = R[tidx1]
t1 = t[tidx1]
val = ge_loss(depth0, depth1, R0, t0, R1, t1)
vals.append(val * self.ge_weight / ge_num)
val = ge_loss(depth0, depth1, R0, t0, R1, t1)
vals.append(val * self.ge_weight / ge_num)
return vals
return vals
def numpy_in_out(self, output):
output, edge = output
if not (isinstance(output, tuple) or isinstance(output, list)):
output = [output]
es = output[0].detach().to('cpu').numpy()
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
im = self.data['im0'][:, :, 0:1, ...].detach().to('cpu').numpy()
ma = gt > 0
return es, gt, im, ma
def numpy_in_out(self, output):
output, edge = output
if not(isinstance(output, tuple) or isinstance(output, list)):
output = [output]
es = output[0].detach().to('cpu').numpy()
gt = self.data['disp0'].to('cpu').numpy().astype(np.float32)
im = self.data['im0'][:,:,0:1,...].detach().to('cpu').numpy()
ma = gt>0
return es, gt, im, ma
def write_img(self, out_path, es, gt, im, ma):
logging.info(f'write img {out_path}')
u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0]))
def write_img(self, out_path, es, gt, im, ma):
logging.info(f'write img {out_path}')
u_pos, _ = np.meshgrid(range(es.shape[1]), range(es.shape[0]))
diff = np.abs(es - gt)
diff = np.abs(es - gt)
vmin, vmax = np.nanmin(gt), np.nanmax(gt)
vmin = vmin - 0.2 * (vmax - vmin)
vmax = vmax + 0.2 * (vmax - vmin)
vmin, vmax = np.nanmin(gt), np.nanmax(gt)
vmin = vmin - 0.2*(vmax-vmin)
vmax = vmax + 0.2*(vmax-vmin)
pattern_proj = self.pattern_proj.to('cpu').numpy()[0, 0]
im_orig = self.data['im0'].detach().to('cpu').numpy()[0, 0, 0]
pattern_diff = np.abs(im_orig - pattern_proj)
pattern_proj = self.pattern_proj.to('cpu').numpy()[0,0]
im_orig = self.data['im0'].detach().to('cpu').numpy()[0,0,0]
pattern_diff = np.abs(im_orig - pattern_proj)
fig = plt.figure(figsize=(16, 16))
es0 = co.cmap.color_depth_map(es[0], scale=vmax)
gt0 = co.cmap.color_depth_map(gt[0], scale=vmax)
diff0 = co.cmap.color_error_image(diff[0], BGR=True)
fig = plt.figure(figsize=(16,16))
es0 = co.cmap.color_depth_map(es[0], scale=vmax)
gt0 = co.cmap.color_depth_map(gt[0], scale=vmax)
diff0 = co.cmap.color_error_image(diff[0], BGR=True)
# plot disparities, ground truth disparity is shown only for reference
ax = plt.subplot(3, 3, 1);
plt.imshow(es0[..., [2, 1, 0]]);
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F0 Disparity Est. {es0.min():.4f}/{es0.max():.4f}')
ax = plt.subplot(3, 3, 2);
plt.imshow(gt0[..., [2, 1, 0]]);
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F0 Disparity GT {np.nanmin(gt0):.4f}/{np.nanmax(gt0):.4f}')
ax = plt.subplot(3, 3, 3);
plt.imshow(diff0[..., [2, 1, 0]]);
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F0 Disparity Err. {diff0.mean():.5f}')
# plot disparities, ground truth disparity is shown only for reference
ax = plt.subplot(3,3,1); plt.imshow(es0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Est. {es0.min():.4f}/{es0.max():.4f}')
ax = plt.subplot(3,3,2); plt.imshow(gt0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity GT {np.nanmin(gt0):.4f}/{np.nanmax(gt0):.4f}')
ax = plt.subplot(3,3,3); plt.imshow(diff0[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 Disparity Err. {diff0.mean():.5f}')
# plot disparities of the second frame in the track if exists
if es.shape[0] >= 2:
es1 = co.cmap.color_depth_map(es[1], scale=vmax)
gt1 = co.cmap.color_depth_map(gt[1], scale=vmax)
diff1 = co.cmap.color_error_image(diff[1], BGR=True)
ax = plt.subplot(3, 3, 4);
plt.imshow(es1[..., [2, 1, 0]]);
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F1 Disparity Est. {es1.min():.4f}/{es1.max():.4f}')
ax = plt.subplot(3, 3, 5);
plt.imshow(gt1[..., [2, 1, 0]]);
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F1 Disparity GT {np.nanmin(gt1):.4f}/{np.nanmax(gt1):.4f}')
ax = plt.subplot(3, 3, 6);
plt.imshow(diff1[..., [2, 1, 0]]);
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F1 Disparity Err. {diff1.mean():.5f}')
# plot disparities of the second frame in the track if exists
if es.shape[0]>=2:
es1 = co.cmap.color_depth_map(es[1], scale=vmax)
gt1 = co.cmap.color_depth_map(gt[1], scale=vmax)
diff1 = co.cmap.color_error_image(diff[1], BGR=True)
ax = plt.subplot(3,3,4); plt.imshow(es1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity Est. {es1.min():.4f}/{es1.max():.4f}')
ax = plt.subplot(3,3,5); plt.imshow(gt1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity GT {np.nanmin(gt1):.4f}/{np.nanmax(gt1):.4f}')
ax = plt.subplot(3,3,6); plt.imshow(diff1[...,[2,1,0]]); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 Disparity Err. {diff1.mean():.5f}')
# plot normalized IR inputs
ax = plt.subplot(3, 3, 7);
plt.imshow(im[0], vmin=im.min(), vmax=im.max(), cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F0 IR input {im[0].mean():.5f}/{im[0].std():.5f}')
if es.shape[0] >= 2:
ax = plt.subplot(3, 3, 8);
plt.imshow(im[1], vmin=im.min(), vmax=im.max(), cmap='gray');
plt.xticks([]);
plt.yticks([]);
ax.set_title(f'F1 IR input {im[1].mean():.5f}/{im[1].std():.5f}')
# plot normalized IR inputs
ax = plt.subplot(3,3,7); plt.imshow(im[0], vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F0 IR input {im[0].mean():.5f}/{im[0].std():.5f}')
if es.shape[0]>=2:
ax = plt.subplot(3,3,8); plt.imshow(im[1], vmin=im.min(), vmax=im.max(), cmap='gray'); plt.xticks([]); plt.yticks([]); ax.set_title(f'F1 IR input {im[1].mean():.5f}/{im[1].std():.5f}')
plt.tight_layout()
plt.savefig(str(out_path))
plt.close(fig)
plt.tight_layout()
plt.savefig(str(out_path))
plt.close(fig)
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
if batch_idx % 512 == 0:
out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
es, gt, im, ma = self.numpy_in_out(output)
masks = [ m.detach().to('cpu').numpy() for m in masks ]
self.write_img(out_path, es[:,0,0], gt[:,0,0], im[:,0,0], ma[:,0,0])
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
if batch_idx % 512 == 0:
out_path = self.exp_out_root / f'train_{epoch:03d}_{batch_idx:04d}.png'
es, gt, im, ma = self.numpy_in_out(output)
masks = [m.detach().to('cpu').numpy() for m in masks]
self.write_img(out_path, es[:, 0, 0], gt[:, 0, 0], im[:, 0, 0], ma[:, 0, 0])
def callback_test_start(self, epoch, set_idx):
self.metric = co.metric.MultipleMetric(
co.metric.DistanceMetric(vec_length=1),
co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5])
)
def callback_test_start(self, epoch, set_idx):
self.metric = co.metric.MultipleMetric(
co.metric.DistanceMetric(vec_length=1),
co.metric.OutlierFractionMetric(vec_length=1, thresholds=[0.1, 0.5, 1, 2, 5])
)
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
es, gt, im, ma = self.numpy_in_out(output)
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
es, gt, im, ma = self.numpy_in_out(output)
if batch_idx % 8 == 0:
out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png'
self.write_img(out_path, es[:,0,0], gt[:,0,0], im[:,0,0], ma[:,0,0])
if batch_idx % 8 == 0:
out_path = self.exp_out_root / f'test_{epoch:03d}_{batch_idx:04d}.png'
self.write_img(out_path, es[:, 0, 0], gt[:, 0, 0], im[:, 0, 0], ma[:, 0, 0])
es, gt, im, ma = self.crop_output(es, gt, im, ma)
es, gt, im, ma = self.crop_output(es, gt, im, ma)
es = es.reshape(-1,1)
gt = gt.reshape(-1,1)
ma = ma.ravel()
self.metric.add(es, gt, ma)
es = es.reshape(-1, 1)
gt = gt.reshape(-1, 1)
ma = ma.ravel()
self.metric.add(es, gt, ma)
def callback_test_stop(self, epoch, set_idx, loss):
logging.info(f'{self.metric}')
for k, v in self.metric.items():
self.metric_add_test(epoch, set_idx, k, v)
def callback_test_stop(self, epoch, set_idx, loss):
logging.info(f'{self.metric}')
for k, v in self.metric.items():
self.metric_add_test(epoch, set_idx, k, v)
def crop_output(self, es, gt, im, ma):
tl = es.shape[0]
bs = es.shape[1]
es = np.reshape(es[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
gt = np.reshape(gt[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
im = np.reshape(im[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
ma = np.reshape(ma[..., self.eval_mask], [tl * bs, 1, self.eval_h, self.eval_w])
return es, gt, im, ma
def crop_output(self, es, gt, im, ma):
tl = es.shape[0]
bs = es.shape[1]
es = np.reshape(es[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
gt = np.reshape(gt[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
im = np.reshape(im[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
ma = np.reshape(ma[...,self.eval_mask], [tl*bs, 1, self.eval_h, self.eval_w])
return es, gt, im, ma
if __name__ == '__main__':
pass
pass

View File

@ -8,559 +8,572 @@ import co
class TimedModule(torch.nn.Module):
def __init__(self, mod_name):
super().__init__()
self.mod_name = mod_name
def __init__(self, mod_name):
super().__init__()
self.mod_name = mod_name
def tforward(self, *args, **kwargs):
raise Exception('not implemented')
def tforward(self, *args, **kwargs):
raise Exception('not implemented')
def forward(self, *args, **kwargs):
torch.cuda.synchronize()
with co.gtimer.Ctx(self.mod_name):
x = self.tforward(*args, **kwargs)
torch.cuda.synchronize()
return x
def forward(self, *args, **kwargs):
torch.cuda.synchronize()
with co.gtimer.Ctx(self.mod_name):
x = self.tforward(*args, **kwargs)
torch.cuda.synchronize()
return x
class PosOutput(TimedModule):
def __init__(self, channels_in, type, im_height, im_width, alpha=1, beta=0, gamma=1, offset=0):
super().__init__(mod_name='PosOutput')
self.im_width = im_width
self.im_width = im_width
def __init__(self, channels_in, type, im_height, im_width, alpha=1, beta=0, gamma=1, offset=0):
super().__init__(mod_name='PosOutput')
self.im_width = im_width
self.im_width = im_width
if type == 'pos':
self.layer = torch.nn.Sequential(
torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1),
SigmoidAffine(alpha=alpha, beta=beta, gamma=gamma, offset=offset)
)
elif type == 'pos_row':
self.layer = torch.nn.Sequential(
MultiLinear(im_height, channels_in, 1),
SigmoidAffine(alpha=alpha, beta=beta, gamma=gamma, offset=offset)
)
if type == 'pos':
self.layer = torch.nn.Sequential(
torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1),
SigmoidAffine(alpha=alpha, beta=beta, gamma=gamma, offset=offset)
)
elif type == 'pos_row':
self.layer = torch.nn.Sequential(
MultiLinear(im_height, channels_in, 1),
SigmoidAffine(alpha=alpha, beta=beta, gamma=gamma, offset=offset)
)
self.u_pos = None
self.u_pos = None
def tforward(self, x):
if self.u_pos is None:
self.u_pos = torch.arange(x.shape[3], dtype=torch.float32).view(1,1,1,-1)
self.u_pos = self.u_pos.to(x.device)
pos = self.layer(x)
disp = self.u_pos - pos
return disp
def tforward(self, x):
if self.u_pos is None:
self.u_pos = torch.arange(x.shape[3], dtype=torch.float32).view(1, 1, 1, -1)
self.u_pos = self.u_pos.to(x.device)
pos = self.layer(x)
disp = self.u_pos - pos
return disp
class OutputLayerFactory(object):
'''
Define type of output
type options:
linear: apply only conv channel, used for the edge decoder
disp: estimate the disparity
disp_row: independently estimate the disparity per row
pos: estimate the absolute location
pos_row: independently estimate the absolute location per row
'''
def __init__(self, type='disp', params={}):
self.type = type
self.params = params
'''
Define type of output
type options:
linear: apply only conv channel, used for the edge decoder
disp: estimate the disparity
disp_row: independently estimate the disparity per row
pos: estimate the absolute location
pos_row: independently estimate the absolute location per row
'''
def __call__(self, channels_in, imsize):
def __init__(self, type='disp', params={}):
self.type = type
self.params = params
if self.type == 'linear':
return torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1)
def __call__(self, channels_in, imsize):
elif self.type == 'disp':
return torch.nn.Sequential(
torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1),
SigmoidAffine(**self.params)
)
if self.type == 'linear':
return torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1)
elif self.type == 'disp_row':
return torch.nn.Sequential(
MultiLinear(imsize[0], channels_in, 1),
SigmoidAffine(**self.params)
)
elif self.type == 'disp':
return torch.nn.Sequential(
torch.nn.Conv2d(channels_in, 1, kernel_size=3, padding=1),
SigmoidAffine(**self.params)
)
elif self.type == 'pos' or self.type == 'pos_row':
return PosOutput(channels_in, **self.params)
elif self.type == 'disp_row':
return torch.nn.Sequential(
MultiLinear(imsize[0], channels_in, 1),
SigmoidAffine(**self.params)
)
else:
raise Exception('unknown output layer type')
elif self.type == 'pos' or self.type == 'pos_row':
return PosOutput(channels_in, **self.params)
else:
raise Exception('unknown output layer type')
class SigmoidAffine(TimedModule):
def __init__(self, alpha=1, beta=0, gamma=1, offset=0):
super().__init__(mod_name='SigmoidAffine')
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.offset = offset
def __init__(self, alpha=1, beta=0, gamma=1, offset=0):
super().__init__(mod_name='SigmoidAffine')
self.alpha = alpha
self.beta = beta
self.gamma = gamma
self.offset = offset
def tforward(self, x):
return torch.sigmoid(x/self.gamma - self.offset) * self.alpha + self.beta
def tforward(self, x):
return torch.sigmoid(x / self.gamma - self.offset) * self.alpha + self.beta
class MultiLinear(TimedModule):
def __init__(self, n, channels_in, channels_out):
super().__init__(mod_name='MultiLinear')
self.channels_out = channels_out
self.mods = torch.nn.ModuleList()
for idx in range(n):
self.mods.append(torch.nn.Linear(channels_in, channels_out))
def tforward(self, x):
x = x.permute(2,0,3,1) # BxCxHxW => HxBxWxC
y = x.new_empty(*x.shape[:-1], self.channels_out)
for hidx in range(x.shape[0]):
y[hidx] = self.mods[hidx](x[hidx])
y = y.permute(1,3,0,2) # HxBxWxC => BxCxHxW
return y
def __init__(self, n, channels_in, channels_out):
super().__init__(mod_name='MultiLinear')
self.channels_out = channels_out
self.mods = torch.nn.ModuleList()
for idx in range(n):
self.mods.append(torch.nn.Linear(channels_in, channels_out))
def tforward(self, x):
x = x.permute(2, 0, 3, 1) # BxCxHxW => HxBxWxC
y = x.new_empty(*x.shape[:-1], self.channels_out)
for hidx in range(x.shape[0]):
y[hidx] = self.mods[hidx](x[hidx])
y = y.permute(1, 3, 0, 2) # HxBxWxC => BxCxHxW
return y
class DispNetS(TimedModule):
'''
Disparity Decoder based on DispNetS
'''
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False, channel_multiplier=1):
super(DispNetS, self).__init__(mod_name='DispNetS')
'''
Disparity Decoder based on DispNetS
'''
self.output_ms = output_ms
self.coordconv = coordconv
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False,
channel_multiplier=1):
super(DispNetS, self).__init__(mod_name='DispNetS')
conv_planes = channel_multiplier * np.array( [32, 64, 128, 256, 512, 512, 512] )
self.conv1 = self.downsample_conv(channels_in, conv_planes[0], kernel_size=7)
self.conv2 = self.downsample_conv(conv_planes[0], conv_planes[1], kernel_size=5)
self.conv3 = self.downsample_conv(conv_planes[1], conv_planes[2])
self.conv4 = self.downsample_conv(conv_planes[2], conv_planes[3])
self.conv5 = self.downsample_conv(conv_planes[3], conv_planes[4])
self.conv6 = self.downsample_conv(conv_planes[4], conv_planes[5])
self.conv7 = self.downsample_conv(conv_planes[5], conv_planes[6])
self.output_ms = output_ms
self.coordconv = coordconv
upconv_planes = channel_multiplier * np.array( [512, 512, 256, 128, 64, 32, 16] )
self.upconv7 = self.upconv(conv_planes[6], upconv_planes[0])
self.upconv6 = self.upconv(upconv_planes[0], upconv_planes[1])
self.upconv5 = self.upconv(upconv_planes[1], upconv_planes[2])
self.upconv4 = self.upconv(upconv_planes[2], upconv_planes[3])
self.upconv3 = self.upconv(upconv_planes[3], upconv_planes[4])
self.upconv2 = self.upconv(upconv_planes[4], upconv_planes[5])
self.upconv1 = self.upconv(upconv_planes[5], upconv_planes[6])
conv_planes = channel_multiplier * np.array([32, 64, 128, 256, 512, 512, 512])
self.conv1 = self.downsample_conv(channels_in, conv_planes[0], kernel_size=7)
self.conv2 = self.downsample_conv(conv_planes[0], conv_planes[1], kernel_size=5)
self.conv3 = self.downsample_conv(conv_planes[1], conv_planes[2])
self.conv4 = self.downsample_conv(conv_planes[2], conv_planes[3])
self.conv5 = self.downsample_conv(conv_planes[3], conv_planes[4])
self.conv6 = self.downsample_conv(conv_planes[4], conv_planes[5])
self.conv7 = self.downsample_conv(conv_planes[5], conv_planes[6])
self.iconv7 = self.conv(upconv_planes[0] + conv_planes[5], upconv_planes[0])
self.iconv6 = self.conv(upconv_planes[1] + conv_planes[4], upconv_planes[1])
self.iconv5 = self.conv(upconv_planes[2] + conv_planes[3], upconv_planes[2])
self.iconv4 = self.conv(upconv_planes[3] + conv_planes[2], upconv_planes[3])
self.iconv3 = self.conv(1 + upconv_planes[4] + conv_planes[1], upconv_planes[4])
self.iconv2 = self.conv(1 + upconv_planes[5] + conv_planes[0], upconv_planes[5])
self.iconv1 = self.conv(1 + upconv_planes[6], upconv_planes[6])
upconv_planes = channel_multiplier * np.array([512, 512, 256, 128, 64, 32, 16])
self.upconv7 = self.upconv(conv_planes[6], upconv_planes[0])
self.upconv6 = self.upconv(upconv_planes[0], upconv_planes[1])
self.upconv5 = self.upconv(upconv_planes[1], upconv_planes[2])
self.upconv4 = self.upconv(upconv_planes[2], upconv_planes[3])
self.upconv3 = self.upconv(upconv_planes[3], upconv_planes[4])
self.upconv2 = self.upconv(upconv_planes[4], upconv_planes[5])
self.upconv1 = self.upconv(upconv_planes[5], upconv_planes[6])
if isinstance(output_facs, list):
self.predict_disp4 = output_facs[3](upconv_planes[3], imsizes[3])
self.predict_disp3 = output_facs[2](upconv_planes[4], imsizes[2])
self.predict_disp2 = output_facs[1](upconv_planes[5], imsizes[1])
self.predict_disp1 = output_facs[0](upconv_planes[6], imsizes[0])
else:
self.predict_disp4 = output_facs(upconv_planes[3], imsizes[3])
self.predict_disp3 = output_facs(upconv_planes[4], imsizes[2])
self.predict_disp2 = output_facs(upconv_planes[5], imsizes[1])
self.predict_disp1 = output_facs(upconv_planes[6], imsizes[0])
self.iconv7 = self.conv(upconv_planes[0] + conv_planes[5], upconv_planes[0])
self.iconv6 = self.conv(upconv_planes[1] + conv_planes[4], upconv_planes[1])
self.iconv5 = self.conv(upconv_planes[2] + conv_planes[3], upconv_planes[2])
self.iconv4 = self.conv(upconv_planes[3] + conv_planes[2], upconv_planes[3])
self.iconv3 = self.conv(1 + upconv_planes[4] + conv_planes[1], upconv_planes[4])
self.iconv2 = self.conv(1 + upconv_planes[5] + conv_planes[0], upconv_planes[5])
self.iconv1 = self.conv(1 + upconv_planes[6], upconv_planes[6])
if isinstance(output_facs, list):
self.predict_disp4 = output_facs[3](upconv_planes[3], imsizes[3])
self.predict_disp3 = output_facs[2](upconv_planes[4], imsizes[2])
self.predict_disp2 = output_facs[1](upconv_planes[5], imsizes[1])
self.predict_disp1 = output_facs[0](upconv_planes[6], imsizes[0])
else:
self.predict_disp4 = output_facs(upconv_planes[3], imsizes[3])
self.predict_disp3 = output_facs(upconv_planes[4], imsizes[2])
self.predict_disp2 = output_facs(upconv_planes[5], imsizes[1])
self.predict_disp1 = output_facs(upconv_planes[6], imsizes[0])
def init_weights(self):
for m in self.modules():
if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.ConvTranspose2d):
torch.nn.init.xavier_uniform_(m.weight, gain=0.1)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
def init_weights(self):
for m in self.modules():
if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.ConvTranspose2d):
torch.nn.init.xavier_uniform_(m.weight, gain=0.1)
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
def downsample_conv(self, in_planes, out_planes, kernel_size=3):
if self.coordconv:
conv = torchext.CoordConv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2)
else:
conv = torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2, padding=(kernel_size-1)//2)
return torch.nn.Sequential(
conv,
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(out_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size-1)//2),
torch.nn.ReLU(inplace=True)
)
def downsample_conv(self, in_planes, out_planes, kernel_size=3):
if self.coordconv:
conv = torchext.CoordConv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2,
padding=(kernel_size - 1) // 2)
else:
conv = torch.nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=2,
padding=(kernel_size - 1) // 2)
return torch.nn.Sequential(
conv,
torch.nn.ReLU(inplace=True),
torch.nn.Conv2d(out_planes, out_planes, kernel_size=kernel_size, padding=(kernel_size - 1) // 2),
torch.nn.ReLU(inplace=True)
)
def conv(self, in_planes, out_planes):
return torch.nn.Sequential(
torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1),
torch.nn.ReLU(inplace=True)
)
def conv(self, in_planes, out_planes):
return torch.nn.Sequential(
torch.nn.Conv2d(in_planes, out_planes, kernel_size=3, padding=1),
torch.nn.ReLU(inplace=True)
)
def upconv(self, in_planes, out_planes):
return torch.nn.Sequential(
torch.nn.ConvTranspose2d(in_planes, out_planes, kernel_size=3, stride=2, padding=1, output_padding=1),
torch.nn.ReLU(inplace=True)
)
def upconv(self, in_planes, out_planes):
return torch.nn.Sequential(
torch.nn.ConvTranspose2d(in_planes, out_planes, kernel_size=3, stride=2, padding=1, output_padding=1),
torch.nn.ReLU(inplace=True)
)
def crop_like(self, input, ref):
assert(input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3))
return input[:, :, :ref.size(2), :ref.size(3)]
def crop_like(self, input, ref):
assert (input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3))
return input[:, :, :ref.size(2), :ref.size(3)]
def tforward(self, x):
out_conv1 = self.conv1(x)
out_conv2 = self.conv2(out_conv1)
out_conv3 = self.conv3(out_conv2)
out_conv4 = self.conv4(out_conv3)
out_conv5 = self.conv5(out_conv4)
out_conv6 = self.conv6(out_conv5)
out_conv7 = self.conv7(out_conv6)
def tforward(self, x):
out_conv1 = self.conv1(x)
out_conv2 = self.conv2(out_conv1)
out_conv3 = self.conv3(out_conv2)
out_conv4 = self.conv4(out_conv3)
out_conv5 = self.conv5(out_conv4)
out_conv6 = self.conv6(out_conv5)
out_conv7 = self.conv7(out_conv6)
out_upconv7 = self.crop_like(self.upconv7(out_conv7), out_conv6)
concat7 = torch.cat((out_upconv7, out_conv6), 1)
out_iconv7 = self.iconv7(concat7)
out_upconv7 = self.crop_like(self.upconv7(out_conv7), out_conv6)
concat7 = torch.cat((out_upconv7, out_conv6), 1)
out_iconv7 = self.iconv7(concat7)
out_upconv6 = self.crop_like(self.upconv6(out_iconv7), out_conv5)
concat6 = torch.cat((out_upconv6, out_conv5), 1)
out_iconv6 = self.iconv6(concat6)
out_upconv6 = self.crop_like(self.upconv6(out_iconv7), out_conv5)
concat6 = torch.cat((out_upconv6, out_conv5), 1)
out_iconv6 = self.iconv6(concat6)
out_upconv5 = self.crop_like(self.upconv5(out_iconv6), out_conv4)
concat5 = torch.cat((out_upconv5, out_conv4), 1)
out_iconv5 = self.iconv5(concat5)
out_upconv5 = self.crop_like(self.upconv5(out_iconv6), out_conv4)
concat5 = torch.cat((out_upconv5, out_conv4), 1)
out_iconv5 = self.iconv5(concat5)
out_upconv4 = self.crop_like(self.upconv4(out_iconv5), out_conv3)
concat4 = torch.cat((out_upconv4, out_conv3), 1)
out_iconv4 = self.iconv4(concat4)
disp4 = self.predict_disp4(out_iconv4)
out_upconv4 = self.crop_like(self.upconv4(out_iconv5), out_conv3)
concat4 = torch.cat((out_upconv4, out_conv3), 1)
out_iconv4 = self.iconv4(concat4)
disp4 = self.predict_disp4(out_iconv4)
out_upconv3 = self.crop_like(self.upconv3(out_iconv4), out_conv2)
disp4_up = self.crop_like(torch.nn.functional.interpolate(disp4, scale_factor=2, mode='bilinear', align_corners=False), out_conv2)
concat3 = torch.cat((out_upconv3, out_conv2, disp4_up), 1)
out_iconv3 = self.iconv3(concat3)
disp3 = self.predict_disp3(out_iconv3)
out_upconv3 = self.crop_like(self.upconv3(out_iconv4), out_conv2)
disp4_up = self.crop_like(
torch.nn.functional.interpolate(disp4, scale_factor=2, mode='bilinear', align_corners=False), out_conv2)
concat3 = torch.cat((out_upconv3, out_conv2, disp4_up), 1)
out_iconv3 = self.iconv3(concat3)
disp3 = self.predict_disp3(out_iconv3)
out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1)
disp3_up = self.crop_like(torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
out_iconv2 = self.iconv2(concat2)
disp2 = self.predict_disp2(out_iconv2)
out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1)
disp3_up = self.crop_like(
torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
out_iconv2 = self.iconv2(concat2)
disp2 = self.predict_disp2(out_iconv2)
out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x)
disp2_up = self.crop_like(torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
concat1 = torch.cat((out_upconv1, disp2_up), 1)
out_iconv1 = self.iconv1(concat1)
disp1 = self.predict_disp1(out_iconv1)
out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x)
disp2_up = self.crop_like(
torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
concat1 = torch.cat((out_upconv1, disp2_up), 1)
out_iconv1 = self.iconv1(concat1)
disp1 = self.predict_disp1(out_iconv1)
if self.output_ms:
return disp1, disp2, disp3, disp4
else:
return disp1
if self.output_ms:
return disp1, disp2, disp3, disp4
else:
return disp1
class DispNetShallow(DispNetS):
'''
Edge Decoder based on DispNetS with fewer layers
'''
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False):
super(DispNetShallow, self).__init__(channels_in, imsizes, output_facs, output_ms, coordconv, weight_init)
self.mod_name = 'DispNetShallow'
conv_planes = [32, 64, 128, 256, 512, 512, 512]
upconv_planes = [512, 512, 256, 128, 64, 32, 16]
self.iconv3 = self.conv(upconv_planes[4] + conv_planes[1], upconv_planes[4])
'''
Edge Decoder based on DispNetS with fewer layers
'''
def tforward(self, x):
out_conv1 = self.conv1(x)
out_conv2 = self.conv2(out_conv1)
out_conv3 = self.conv3(out_conv2)
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False):
super(DispNetShallow, self).__init__(channels_in, imsizes, output_facs, output_ms, coordconv, weight_init)
self.mod_name = 'DispNetShallow'
conv_planes = [32, 64, 128, 256, 512, 512, 512]
upconv_planes = [512, 512, 256, 128, 64, 32, 16]
self.iconv3 = self.conv(upconv_planes[4] + conv_planes[1], upconv_planes[4])
out_upconv3 = self.crop_like(self.upconv3(out_conv3), out_conv2)
concat3 = torch.cat((out_upconv3, out_conv2), 1)
out_iconv3 = self.iconv3(concat3)
disp3 = self.predict_disp3(out_iconv3)
def tforward(self, x):
out_conv1 = self.conv1(x)
out_conv2 = self.conv2(out_conv1)
out_conv3 = self.conv3(out_conv2)
out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1)
disp3_up = self.crop_like(torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
out_iconv2 = self.iconv2(concat2)
disp2 = self.predict_disp2(out_iconv2)
out_upconv3 = self.crop_like(self.upconv3(out_conv3), out_conv2)
concat3 = torch.cat((out_upconv3, out_conv2), 1)
out_iconv3 = self.iconv3(concat3)
disp3 = self.predict_disp3(out_iconv3)
out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x)
disp2_up = self.crop_like(torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
concat1 = torch.cat((out_upconv1, disp2_up), 1)
out_iconv1 = self.iconv1(concat1)
disp1 = self.predict_disp1(out_iconv1)
out_upconv2 = self.crop_like(self.upconv2(out_iconv3), out_conv1)
disp3_up = self.crop_like(
torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
concat2 = torch.cat((out_upconv2, out_conv1, disp3_up), 1)
out_iconv2 = self.iconv2(concat2)
disp2 = self.predict_disp2(out_iconv2)
if self.output_ms:
return disp1, disp2, disp3
else:
return disp1
out_upconv1 = self.crop_like(self.upconv1(out_iconv2), x)
disp2_up = self.crop_like(
torch.nn.functional.interpolate(disp2, scale_factor=2, mode='bilinear', align_corners=False), x)
concat1 = torch.cat((out_upconv1, disp2_up), 1)
out_iconv1 = self.iconv1(concat1)
disp1 = self.predict_disp1(out_iconv1)
if self.output_ms:
return disp1, disp2, disp3
else:
return disp1
class DispEdgeDecoders(TimedModule):
'''
Disparity Decoder and Edge Decoder
'''
def __init__(self, *args, max_disp=128, **kwargs):
super(DispEdgeDecoders, self).__init__(mod_name='DispEdgeDecoders')
'''
Disparity Decoder and Edge Decoder
'''
output_facs = [OutputLayerFactory( type='disp', params={ 'alpha': max_disp/(2**s), 'beta': 0, 'gamma': 1, 'offset': 3}) for s in range(4)]
self.disp_decoder = DispNetS(*args, output_facs=output_facs, **kwargs)
def __init__(self, *args, max_disp=128, **kwargs):
super(DispEdgeDecoders, self).__init__(mod_name='DispEdgeDecoders')
output_facs = [OutputLayerFactory( type='linear' ) for s in range(4)]
self.edge_decoder = DispNetShallow(*args, output_facs=output_facs, **kwargs)
output_facs = [
OutputLayerFactory(type='disp', params={'alpha': max_disp / (2 ** s), 'beta': 0, 'gamma': 1, 'offset': 3})
for s in range(4)]
self.disp_decoder = DispNetS(*args, output_facs=output_facs, **kwargs)
def tforward(self, x):
disp = self.disp_decoder(x)
edge = self.edge_decoder(x)
return disp, edge
output_facs = [OutputLayerFactory(type='linear') for s in range(4)]
self.edge_decoder = DispNetShallow(*args, output_facs=output_facs, **kwargs)
def tforward(self, x):
disp = self.disp_decoder(x)
edge = self.edge_decoder(x)
return disp, edge
class DispToDepth(TimedModule):
def __init__(self, focal_length, baseline):
super().__init__(mod_name='DispToDepth')
self.baseline_focal_length = baseline * focal_length
def __init__(self, focal_length, baseline):
super().__init__(mod_name='DispToDepth')
self.baseline_focal_length = baseline * focal_length
def tforward(self, disp):
disp = torch.nn.functional.relu(disp) + 1e-12
depth = self.baseline_focal_length / disp
return depth
def tforward(self, disp):
disp = torch.nn.functional.relu(disp) + 1e-12
depth = self.baseline_focal_length / disp
return depth
class PosToDepth(DispToDepth):
def __init__(self, focal_length, baseline, im_height, im_width):
super().__init__(focal_length, baseline)
self.mod_name = 'PosToDepth'
def __init__(self, focal_length, baseline, im_height, im_width):
super().__init__(focal_length, baseline)
self.mod_name = 'PosToDepth'
self.im_height = im_height
self.im_width = im_width
self.u_pos = torch.arange(im_width, dtype=torch.float32).view(1,1,1,-1)
def tforward(self, pos):
self.u_pos = self.u_pos.to(pos.device)
disp = self.u_pos - pos
return super().forward(disp)
self.im_height = im_height
self.im_width = im_width
self.u_pos = torch.arange(im_width, dtype=torch.float32).view(1, 1, 1, -1)
def tforward(self, pos):
self.u_pos = self.u_pos.to(pos.device)
disp = self.u_pos - pos
return super().forward(disp)
class RectifiedPatternSimilarityLoss(TimedModule):
'''
Photometric Loss
'''
def __init__(self, im_height, im_width, pattern, loss_type='census_sad', loss_eps=0.5):
super().__init__(mod_name='RectifiedPatternSimilarityLoss')
self.im_height = im_height
self.im_width = im_width
self.pattern = pattern.mean(dim=1, keepdim=True).contiguous()
'''
Photometric Loss
'''
u, v = np.meshgrid(range(im_width), range(im_height))
uv0 = np.stack((u,v), axis=2).reshape(-1,1)
uv0 = uv0.astype(np.float32).reshape(1,-1,2)
self.uv0 = torch.from_numpy(uv0)
def __init__(self, im_height, im_width, pattern, loss_type='census_sad', loss_eps=0.5):
super().__init__(mod_name='RectifiedPatternSimilarityLoss')
self.im_height = im_height
self.im_width = im_width
self.pattern = pattern.mean(dim=1, keepdim=True).contiguous()
self.loss_type = loss_type
self.loss_eps = loss_eps
u, v = np.meshgrid(range(im_width), range(im_height))
uv0 = np.stack((u, v), axis=2).reshape(-1, 1)
uv0 = uv0.astype(np.float32).reshape(1, -1, 2)
self.uv0 = torch.from_numpy(uv0)
def tforward(self, disp0, im, std=None):
self.pattern = self.pattern.to(disp0.device)
self.uv0 = self.uv0.to(disp0.device)
self.loss_type = loss_type
self.loss_eps = loss_eps
uv0 = self.uv0.expand(disp0.shape[0], *self.uv0.shape[1:])
uv1 = torch.empty_like(uv0)
uv1[...,0] = uv0[...,0] - disp0.contiguous().view(disp0.shape[0],-1)
uv1[...,1] = uv0[...,1]
def tforward(self, disp0, im, std=None):
self.pattern = self.pattern.to(disp0.device)
self.uv0 = self.uv0.to(disp0.device)
uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width-1) - 0.5)
uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height-1) - 0.5)
uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone()
pattern = self.pattern.expand(disp0.shape[0], *self.pattern.shape[1:])
pattern_proj = torch.nn.functional.grid_sample(pattern, uv1, padding_mode='border')
mask = torch.ones_like(im)
if std is not None:
mask = mask*std
uv0 = self.uv0.expand(disp0.shape[0], *self.uv0.shape[1:])
uv1 = torch.empty_like(uv0)
uv1[..., 0] = uv0[..., 0] - disp0.contiguous().view(disp0.shape[0], -1)
uv1[..., 1] = uv0[..., 1]
uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width - 1) - 0.5)
uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height - 1) - 0.5)
uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone()
pattern = self.pattern.expand(disp0.shape[0], *self.pattern.shape[1:])
pattern_proj = torch.nn.functional.grid_sample(pattern, uv1, padding_mode='border')
mask = torch.ones_like(im)
if std is not None:
mask = mask * std
diff = torchext.photometric_loss(pattern_proj.contiguous(), im.contiguous(), 9, self.loss_type, self.loss_eps)
val = (mask * diff).sum() / mask.sum()
return val, pattern_proj
diff = torchext.photometric_loss(pattern_proj.contiguous(), im.contiguous(), 9, self.loss_type, self.loss_eps)
val = (mask*diff).sum() / mask.sum()
return val, pattern_proj
class DisparityLoss(TimedModule):
'''
Disparity Loss
'''
def __init__(self):
super().__init__(mod_name='DisparityLoss')
self.sobel = SobelFilter(norm=False)
'''
Disparity Loss
'''
#if not edge_gt:
self.b0=0.0503428816795
self.b1=1.07274045944
#else:
# self.b0=0.0587115108967
# self.b1=1.51931190491
def __init__(self):
super().__init__(mod_name='DisparityLoss')
self.sobel = SobelFilter(norm=False)
def tforward(self, disp, edge=None):
self.sobel=self.sobel.to(disp.device)
# if not edge_gt:
self.b0 = 0.0503428816795
self.b1 = 1.07274045944
# else:
# self.b0=0.0587115108967
# self.b1=1.51931190491
if edge is not None:
grad = self.sobel(disp)
grad = torch.sqrt(grad[:,0:1,...]**2 + grad[:,1:2,...]**2 + 1e-8)
pdf = (1-edge)/self.b0 * torch.exp(-torch.abs(grad)/self.b0) + \
edge/self.b1 * torch.exp(-torch.abs(grad)/self.b1)
val = torch.mean(-torch.log(pdf.clamp(min=1e-4)))
else:
# on qifeng's data we don't have ambient info
# therefore we supress edge everywhere
grad = self.sobel(disp)
grad = torch.sqrt(grad[:,0:1,...]**2 + grad[:,1:2,...]**2 + 1e-8)
grad= torch.clamp(grad, 0, 1.0)
val = torch.mean(grad)
def tforward(self, disp, edge=None):
self.sobel = self.sobel.to(disp.device)
return val
if edge is not None:
grad = self.sobel(disp)
grad = torch.sqrt(grad[:, 0:1, ...] ** 2 + grad[:, 1:2, ...] ** 2 + 1e-8)
pdf = (1 - edge) / self.b0 * torch.exp(-torch.abs(grad) / self.b0) + \
edge / self.b1 * torch.exp(-torch.abs(grad) / self.b1)
val = torch.mean(-torch.log(pdf.clamp(min=1e-4)))
else:
# on qifeng's data we don't have ambient info
# therefore we supress edge everywhere
grad = self.sobel(disp)
grad = torch.sqrt(grad[:, 0:1, ...] ** 2 + grad[:, 1:2, ...] ** 2 + 1e-8)
grad = torch.clamp(grad, 0, 1.0)
val = torch.mean(grad)
return val
class ProjectionBaseLoss(TimedModule):
'''
Base module of the Geometric Loss
'''
def __init__(self, K, Ki, im_height, im_width):
super().__init__(mod_name='ProjectionBaseLoss')
'''
Base module of the Geometric Loss
'''
self.K = K.view(-1,3,3)
def __init__(self, K, Ki, im_height, im_width):
super().__init__(mod_name='ProjectionBaseLoss')
self.im_height = im_height
self.im_width = im_width
self.K = K.view(-1, 3, 3)
u, v = np.meshgrid(range(im_width), range(im_height))
uv = np.stack((u,v,np.ones_like(u)), axis=2).reshape(-1,3)
self.im_height = im_height
self.im_width = im_width
ray = uv @ Ki.numpy().T
u, v = np.meshgrid(range(im_width), range(im_height))
uv = np.stack((u, v, np.ones_like(u)), axis=2).reshape(-1, 3)
ray = ray.reshape(1,-1,3).astype(np.float32)
self.ray = torch.from_numpy(ray)
ray = uv @ Ki.numpy().T
def transform(self, xyz, R=None, t=None):
if t is not None:
bs = xyz.shape[0]
xyz = xyz - t.reshape(bs,1,3)
if R is not None:
xyz = torch.bmm(xyz, R)
return xyz
ray = ray.reshape(1, -1, 3).astype(np.float32)
self.ray = torch.from_numpy(ray)
def unproject(self, depth, R=None, t=None):
self.ray = self.ray.to(depth.device)
bs = depth.shape[0]
def transform(self, xyz, R=None, t=None):
if t is not None:
bs = xyz.shape[0]
xyz = xyz - t.reshape(bs, 1, 3)
if R is not None:
xyz = torch.bmm(xyz, R)
return xyz
xyz = depth.reshape(bs,-1,1) * self.ray
xyz = self.transform(xyz, R, t)
return xyz
def unproject(self, depth, R=None, t=None):
self.ray = self.ray.to(depth.device)
bs = depth.shape[0]
def project(self, xyz, R, t):
self.K = self.K.to(xyz.device)
bs = xyz.shape[0]
xyz = depth.reshape(bs, -1, 1) * self.ray
xyz = self.transform(xyz, R, t)
return xyz
xyz = torch.bmm(xyz, R.transpose(1,2))
xyz = xyz + t.reshape(bs,1,3)
def project(self, xyz, R, t):
self.K = self.K.to(xyz.device)
bs = xyz.shape[0]
Kt = self.K.transpose(1,2).expand(bs,-1,-1)
uv = torch.bmm(xyz, Kt)
xyz = torch.bmm(xyz, R.transpose(1, 2))
xyz = xyz + t.reshape(bs, 1, 3)
d = uv[:,:,2:3]
Kt = self.K.transpose(1, 2).expand(bs, -1, -1)
uv = torch.bmm(xyz, Kt)
# avoid division by zero
uv = uv[:,:,:2] / (torch.nn.functional.relu(d) + 1e-12)
return uv, d
d = uv[:, :, 2:3]
# avoid division by zero
uv = uv[:, :, :2] / (torch.nn.functional.relu(d) + 1e-12)
return uv, d
def tforward(self, depth0, R0, t0, R1, t1):
xyz = self.unproject(depth0, R0, t0)
return self.project(xyz, R1, t1)
def tforward(self, depth0, R0, t0, R1, t1):
xyz = self.unproject(depth0, R0, t0)
return self.project(xyz, R1, t1)
class ProjectionDepthSimilarityLoss(ProjectionBaseLoss):
'''
Geometric Loss
'''
def __init__(self, *args, clamp=-1):
super().__init__(*args)
self.mod_name = 'ProjectionDepthSimilarityLoss'
self.clamp = clamp
'''
Geometric Loss
'''
def fwd(self, depth0, depth1, R0, t0, R1, t1):
uv1, d1 = super().tforward(depth0, R0, t0, R1, t1)
def __init__(self, *args, clamp=-1):
super().__init__(*args)
self.mod_name = 'ProjectionDepthSimilarityLoss'
self.clamp = clamp
uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width-1) - 0.5)
uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height-1) - 0.5)
uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone()
def fwd(self, depth0, depth1, R0, t0, R1, t1):
uv1, d1 = super().tforward(depth0, R0, t0, R1, t1)
depth10 = torch.nn.functional.grid_sample(depth1, uv1, padding_mode='border')
uv1[..., 0] = 2 * (uv1[..., 0] / (self.im_width - 1) - 0.5)
uv1[..., 1] = 2 * (uv1[..., 1] / (self.im_height - 1) - 0.5)
uv1 = uv1.view(-1, self.im_height, self.im_width, 2).clone()
diff = torch.abs(d1.view(-1) - depth10.view(-1))
depth10 = torch.nn.functional.grid_sample(depth1, uv1, padding_mode='border')
if self.clamp > 0:
diff = torch.clamp(diff, 0, self.clamp)
diff = torch.abs(d1.view(-1) - depth10.view(-1))
# return diff without clamping for debugging
return diff.mean()
if self.clamp > 0:
diff = torch.clamp(diff, 0, self.clamp)
def tforward(self, depth0, depth1, R0, t0, R1, t1):
l0 = self.fwd(depth0, depth1, R0, t0, R1, t1)
l1 = self.fwd(depth1, depth0, R1, t1, R0, t0)
return l0+l1
# return diff without clamping for debugging
return diff.mean()
def tforward(self, depth0, depth1, R0, t0, R1, t1):
l0 = self.fwd(depth0, depth1, R0, t0, R1, t1)
l1 = self.fwd(depth1, depth0, R1, t1, R0, t0)
return l0 + l1
class LCN(TimedModule):
'''
Local Contract Normalization
'''
def __init__(self, radius, epsilon):
super().__init__(mod_name='LCN')
self.box_conv = torch.nn.Sequential(
torch.nn.ReflectionPad2d(radius),
torch.nn.Conv2d(1, 1, kernel_size=2*radius+1, bias=False)
)
self.box_conv[1].weight.requires_grad=False
self.box_conv[1].weight.fill_(1.)
'''
Local Contract Normalization
'''
self.epsilon = epsilon
self.radius = radius
def __init__(self, radius, epsilon):
super().__init__(mod_name='LCN')
self.box_conv = torch.nn.Sequential(
torch.nn.ReflectionPad2d(radius),
torch.nn.Conv2d(1, 1, kernel_size=2 * radius + 1, bias=False)
)
self.box_conv[1].weight.requires_grad = False
self.box_conv[1].weight.fill_(1.)
def tforward(self, data):
boxs = self.box_conv(data)
self.epsilon = epsilon
self.radius = radius
avgs = boxs / (2*self.radius+1)**2
boxs_n2 = boxs**2
boxs_2n = self.box_conv(data**2)
def tforward(self, data):
boxs = self.box_conv(data)
stds = torch.sqrt(boxs_2n / (2*self.radius+1)**2 - avgs**2 + 1e-6)
stds = stds + self.epsilon
avgs = boxs / (2 * self.radius + 1) ** 2
boxs_n2 = boxs ** 2
boxs_2n = self.box_conv(data ** 2)
return (data - avgs) / stds, stds
stds = torch.sqrt(boxs_2n / (2 * self.radius + 1) ** 2 - avgs ** 2 + 1e-6)
stds = stds + self.epsilon
return (data - avgs) / stds, stds
class SobelFilter(TimedModule):
'''
Sobel Filter
'''
def __init__(self, norm=False):
super(SobelFilter, self).__init__(mod_name='SobelFilter')
kx = np.array([[-5, -4, 0, 4, 5],
[-8, -10, 0, 10, 8],
[-10, -20, 0, 20, 10],
[-8, -10, 0, 10, 8],
[-5, -4, 0, 4, 5]])/240.0
ky = kx.copy().transpose(1,0)
'''
Sobel Filter
'''
self.conv_x=torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False)
self.conv_x.weight=torch.nn.Parameter(torch.from_numpy(kx).float().unsqueeze(0).unsqueeze(0))
def __init__(self, norm=False):
super(SobelFilter, self).__init__(mod_name='SobelFilter')
kx = np.array([[-5, -4, 0, 4, 5],
[-8, -10, 0, 10, 8],
[-10, -20, 0, 20, 10],
[-8, -10, 0, 10, 8],
[-5, -4, 0, 4, 5]]) / 240.0
ky = kx.copy().transpose(1, 0)
self.conv_y=torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False)
self.conv_y.weight=torch.nn.Parameter(torch.from_numpy(ky).float().unsqueeze(0).unsqueeze(0))
self.conv_x = torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False)
self.conv_x.weight = torch.nn.Parameter(torch.from_numpy(kx).float().unsqueeze(0).unsqueeze(0))
self.norm=norm
self.conv_y = torch.nn.Conv2d(1, 1, kernel_size=5, stride=1, padding=0, bias=False)
self.conv_y.weight = torch.nn.Parameter(torch.from_numpy(ky).float().unsqueeze(0).unsqueeze(0))
def tforward(self,x):
x = F.pad(x, (2,2,2,2), "replicate")
gx = self.conv_x(x)
gy = self.conv_y(x)
if self.norm:
return torch.sqrt(gx**2 + gy**2 + 1e-8)
else:
return torch.cat((gx, gy), dim=1)
self.norm = norm
def tforward(self, x):
x = F.pad(x, (2, 2, 2, 2), "replicate")
gx = self.conv_x(x)
gy = self.conv_y(x)
if self.norm:
return torch.sqrt(gx ** 2 + gy ** 2 + 1e-8)
else:
return torch.cat((gx, gy), dim=1)

View File

@ -6,7 +6,9 @@ This repository contains the code for the paper
**[Connecting the Dots: Learning Representations for Active Monocular Depth Estimation](http://www.cvlibs.net/publications/Riegler2019CVPR.pdf)**
<br>
[Gernot Riegler](https://griegler.github.io/), [Yiyi Liao](https://yiyiliao.github.io/), [Simon Donne](https://avg.is.tuebingen.mpg.de/person/sdonne), [Vladlen Koltun](http://vladlen.info/), and [Andreas Geiger](http://www.cvlibs.net/)
[Gernot Riegler](https://griegler.github.io/), [Yiyi Liao](https://yiyiliao.github.io/)
, [Simon Donne](https://avg.is.tuebingen.mpg.de/person/sdonne), [Vladlen Koltun](http://vladlen.info/),
and [Andreas Geiger](http://www.cvlibs.net/)
<br>
[CVPR 2019](http://cvpr2019.thecvf.com/)
@ -24,40 +26,45 @@ If you find this code useful for your research, please cite
}
```
## Dependencies
The network training/evaluation code is based on `Pytorch`.
```
PyTorch>=1.1
Cuda>=10.0
```
Updated on 07.06.2021: The code is now compatible with the latest Pytorch version (1.8).
The other python packages can be installed with `anaconda`:
```
conda install --file requirements.txt
```
### Structured Light Renderer
To train and evaluate our method in a controlled setting, we implemented an structured light renderer.
It can be used to render a virtual scene (arbitrary triangle mesh) with the structured light pattern projected from a customizable projector location.
To build it, first make sure the correct `CUDA_LIBRARY_PATH` is set in `config.json`.
Afterwards, the renderer can be build by running `make` within the `renderer` directory.
To train and evaluate our method in a controlled setting, we implemented an structured light renderer. It can be used to
render a virtual scene (arbitrary triangle mesh) with the structured light pattern projected from a customizable
projector location. To build it, first make sure the correct `CUDA_LIBRARY_PATH` is set in `config.json`. Afterwards,
the renderer can be build by running `make` within the `renderer` directory.
### PyTorch Extensions
The network training/evaluation code is based on `PyTorch`.
We implemented some custom layers that need to be built in the `torchext` directory.
Simply change into this directory and run
The network training/evaluation code is based on `PyTorch`. We implemented some custom layers that need to be built in
the `torchext` directory. Simply change into this directory and run
```
python setup.py build_ext --inplace
```
### Baseline HyperDepth
As baseline we partially re-implemented the random forest based method [HyperDepth](http://openaccess.thecvf.com/content_cvpr_2016/papers/Fanello_HyperDepth_Learning_Depth_CVPR_2016_paper.pdf).
The code resided in the `hyperdepth` directory and is implemented in `C++11` with a Python wrapper written in `Cython`.
To build it change into the directory and run
As baseline we partially re-implemented the random forest based
method [HyperDepth](http://openaccess.thecvf.com/content_cvpr_2016/papers/Fanello_HyperDepth_Learning_Depth_CVPR_2016_paper.pdf)
. The code resided in the `hyperdepth` directory and is implemented in `C++11` with a Python wrapper written in `Cython`
. To build it change into the directory and run
```
python setup.py build_ext --inplace
@ -65,42 +72,59 @@ python setup.py build_ext --inplace
## Running
### Creating Synthetic Data
To create synthetic data and save it locally, download [ShapeNet V2](https://www.shapenet.org/) and correct `SHAPENET_ROOT` in `config.json`. Then the data can be generated and saved to `DATA_ROOT` in `config.json` by running
To create synthetic data and save it locally, download [ShapeNet V2](https://www.shapenet.org/) and
correct `SHAPENET_ROOT` in `config.json`. Then the data can be generated and saved to `DATA_ROOT` in `config.json` by
running
```
./create_syn_data.sh
```
If you are only interested in evaluating our pre-trained model, [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip) is a validation set that contains a small amount of images.
If you are only interested in evaluating our pre-trained
model, [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip) is a
validation set that contains a small amount of images.
### Training Network
As a first stage, it is recommended to train the disparity decoder and edge decoder without the geometric loss. To train the network on synthetic data for the first stage run
As a first stage, it is recommended to train the disparity decoder and edge decoder without the geometric loss. To train
the network on synthetic data for the first stage run
```
python train_val.py
```
After the model is pretrained without the geometric loss, the full model can be trained from the initialized weights by running
After the model is pretrained without the geometric loss, the full model can be trained from the initialized weights by
running
```
python train_val.py --loss phge
```
### Evaluating Network
To evaluate a specific checkpoint, e.g. the 50th epoch, one can run
To evaluate a specific checkpoint, e.g. the 50th epoch, one can run
```
python train_val.py --cmd retest --epoch 50
```
### Evaluating a Pre-trained Model
We provide a model pre-trained using the photometric loss. Once you have prepared the synthetic dataset and changed `DATA_ROOT` in `config.json`, the pre-trained model can be evaluated on the validation set by running:
We provide a model pre-trained using the photometric loss. Once you have prepared the synthetic dataset and
changed `DATA_ROOT` in `config.json`, the pre-trained model can be evaluated on the validation set by running:
```
mkdir -p output
mkdir -p output/exp_syn
wget -O output/exp_syn/net_0099.params https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/net_0099.params
python train_val.py --cmd retest --epoch 99
```
You can also download our validation set from [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip).
## Acknowledgement
You can also download our validation set
from [here (3.7G)](https://s3.eu-central-1.amazonaws.com/avg-projects/connecting_the_dots/val_data.zip).
## Acknowledgement
This work was supported by the Intel Network on Intelligent Systems.

View File

@ -10,7 +10,7 @@ import json
this_dir = os.path.dirname(__file__)
with open('../config.json') as fp:
config = json.load(fp)
config = json.load(fp)
extra_compile_args = ['-O3', '-std=c++11']
@ -20,7 +20,7 @@ cuda_lib = 'cudart'
sources = ['cyrender.pyx']
extra_objects = [
os.path.join(this_dir, 'render/render_cpu.cpp.o'),
os.path.join(this_dir, 'render/render_cpu.cpp.o'),
]
library_dirs = []
libraries = ['m']
@ -30,20 +30,20 @@ library_dirs.append(cuda_lib_dir)
libraries.append(cuda_lib)
setup(
name="cyrender",
cmdclass= {'build_ext': build_ext},
ext_modules=[
Extension('cyrender',
sources,
extra_objects=extra_objects,
language='c++',
library_dirs=library_dirs,
libraries=libraries,
include_dirs=[
np.get_include(),
],
extra_compile_args=extra_compile_args,
# extra_link_args=extra_link_args
)
]
name="cyrender",
cmdclass={'build_ext': build_ext},
ext_modules=[
Extension('cyrender',
sources,
extra_objects=extra_objects,
language='c++',
library_dirs=library_dirs,
libraries=libraries,
include_dirs=[
np.get_include(),
],
extra_compile_args=extra_compile_args,
# extra_link_args=extra_link_args
)
]
)

View File

@ -2,65 +2,65 @@ import torch
import torch.utils.data
import numpy as np
class TestSet(object):
def __init__(self, name, dset, test_frequency=1):
self.name = name
self.dset = dset
self.test_frequency = test_frequency
def __init__(self, name, dset, test_frequency=1):
self.name = name
self.dset = dset
self.test_frequency = test_frequency
class TestSets(list):
def append(self, name, dset, test_frequency=1):
super().append(TestSet(name, dset, test_frequency))
def append(self, name, dset, test_frequency=1):
super().append(TestSet(name, dset, test_frequency))
class MultiDataset(torch.utils.data.Dataset):
def __init__(self, *datasets):
self.current_epoch = 0
def __init__(self, *datasets):
self.current_epoch = 0
self.datasets = []
self.cum_n_samples = [0]
self.datasets = []
self.cum_n_samples = [0]
for dataset in datasets:
self.append(dataset)
for dataset in datasets:
self.append(dataset)
def append(self, dataset):
self.datasets.append(dataset)
self.__update_cum_n_samples(dataset)
def append(self, dataset):
self.datasets.append(dataset)
self.__update_cum_n_samples(dataset)
def __update_cum_n_samples(self, dataset):
n_samples = self.cum_n_samples[-1] + len(dataset)
self.cum_n_samples.append(n_samples)
def __update_cum_n_samples(self, dataset):
n_samples = self.cum_n_samples[-1] + len(dataset)
self.cum_n_samples.append(n_samples)
def dataset_updated(self):
self.cum_n_samples = [0]
for dset in self.datasets:
self.__update_cum_n_samples(dset)
def dataset_updated(self):
self.cum_n_samples = [0]
for dset in self.datasets:
self.__update_cum_n_samples(dset)
def __len__(self):
return self.cum_n_samples[-1]
def __getitem__(self, idx):
didx = np.searchsorted(self.cum_n_samples, idx, side='right') - 1
sidx = idx - self.cum_n_samples[didx]
return self.datasets[didx][sidx]
def __len__(self):
return self.cum_n_samples[-1]
def __getitem__(self, idx):
didx = np.searchsorted(self.cum_n_samples, idx, side='right') - 1
sidx = idx - self.cum_n_samples[didx]
return self.datasets[didx][sidx]
class BaseDataset(torch.utils.data.Dataset):
def __init__(self, train=True, fix_seed_per_epoch=False):
self.current_epoch = 0
self.train = train
self.fix_seed_per_epoch = fix_seed_per_epoch
def __init__(self, train=True, fix_seed_per_epoch=False):
self.current_epoch = 0
self.train = train
self.fix_seed_per_epoch = fix_seed_per_epoch
def get_rng(self, idx):
rng = np.random.RandomState()
if self.train:
if self.fix_seed_per_epoch:
seed = 1 * len(self) + idx
else:
seed = (self.current_epoch + 1) * len(self) + idx
rng.seed(seed)
else:
rng.seed(idx)
return rng
def get_rng(self, idx):
rng = np.random.RandomState()
if self.train:
if self.fix_seed_per_epoch:
seed = 1 * len(self) + idx
else:
seed = (self.current_epoch + 1) * len(self) + idx
rng.seed(seed)
else:
rng.seed(idx)
return rng

View File

@ -2,146 +2,151 @@ import torch
from . import ext_cpu
from . import ext_cuda
class NNFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, in0, in1):
args = (in0, in1)
if in0.is_cuda:
out = ext_cuda.nn_cuda(*args)
else:
out = ext_cpu.nn_cpu(*args)
return out
@staticmethod
def backward(ctx, grad_out):
return None, None
class NNFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, in0, in1):
args = (in0, in1)
if in0.is_cuda:
out = ext_cuda.nn_cuda(*args)
else:
out = ext_cpu.nn_cpu(*args)
return out
@staticmethod
def backward(ctx, grad_out):
return None, None
def nn(in0, in1):
return NNFunction.apply(in0, in1)
return NNFunction.apply(in0, in1)
class CrossCheckFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, in0, in1):
args = (in0, in1)
if in0.is_cuda:
out = ext_cuda.crosscheck_cuda(*args)
else:
out = ext_cpu.crosscheck_cpu(*args)
return out
@staticmethod
def forward(ctx, in0, in1):
args = (in0, in1)
if in0.is_cuda:
out = ext_cuda.crosscheck_cuda(*args)
else:
out = ext_cpu.crosscheck_cpu(*args)
return out
@staticmethod
def backward(ctx, grad_out):
return None, None
@staticmethod
def backward(ctx, grad_out):
return None, None
def crosscheck(in0, in1):
return CrossCheckFunction.apply(in0, in1)
return CrossCheckFunction.apply(in0, in1)
class ProjNNFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, xyz0, xyz1, K, patch_size):
args = (xyz0, xyz1, K, patch_size)
if xyz0.is_cuda:
out = ext_cuda.proj_nn_cuda(*args)
else:
out = ext_cpu.proj_nn_cpu(*args)
return out
@staticmethod
def forward(ctx, xyz0, xyz1, K, patch_size):
args = (xyz0, xyz1, K, patch_size)
if xyz0.is_cuda:
out = ext_cuda.proj_nn_cuda(*args)
else:
out = ext_cpu.proj_nn_cpu(*args)
return out
@staticmethod
def backward(ctx, grad_out):
return None, None, None, None
@staticmethod
def backward(ctx, grad_out):
return None, None, None, None
def proj_nn(xyz0, xyz1, K, patch_size):
return ProjNNFunction.apply(xyz0, xyz1, K, patch_size)
return ProjNNFunction.apply(xyz0, xyz1, K, patch_size)
class XCorrVolFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, in0, in1, n_disps, block_size):
args = (in0, in1, n_disps, block_size)
if in0.is_cuda:
out = ext_cuda.xcorrvol_cuda(*args)
else:
out = ext_cpu.xcorrvol_cpu(*args)
return out
@staticmethod
def forward(ctx, in0, in1, n_disps, block_size):
args = (in0, in1, n_disps, block_size)
if in0.is_cuda:
out = ext_cuda.xcorrvol_cuda(*args)
else:
out = ext_cpu.xcorrvol_cpu(*args)
return out
@staticmethod
def backward(ctx, grad_out):
return None, None, None, None
@staticmethod
def backward(ctx, grad_out):
return None, None, None, None
def xcorrvol(in0, in1, n_disps, block_size):
return XCorrVolFunction.apply(in0, in1, n_disps, block_size)
return XCorrVolFunction.apply(in0, in1, n_disps, block_size)
class PhotometricLossFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, es, ta, block_size, type, eps):
args = (es, ta, block_size, type, eps)
ctx.save_for_backward(es, ta)
ctx.block_size = block_size
ctx.type = type
ctx.eps = eps
if es.is_cuda:
out = ext_cuda.photometric_loss_forward(*args)
else:
out = ext_cpu.photometric_loss_forward(*args)
return out
@staticmethod
def forward(ctx, es, ta, block_size, type, eps):
args = (es, ta, block_size, type, eps)
ctx.save_for_backward(es, ta)
ctx.block_size = block_size
ctx.type = type
ctx.eps = eps
if es.is_cuda:
out = ext_cuda.photometric_loss_forward(*args)
else:
out = ext_cpu.photometric_loss_forward(*args)
return out
@staticmethod
def backward(ctx, grad_out):
es, ta = ctx.saved_tensors
block_size = ctx.block_size
type = ctx.type
eps = ctx.eps
args = (es, ta, grad_out.contiguous(), block_size, type, eps)
if grad_out.is_cuda:
grad_es = ext_cuda.photometric_loss_backward(*args)
else:
grad_es = ext_cpu.photometric_loss_backward(*args)
return grad_es, None, None, None, None
@staticmethod
def backward(ctx, grad_out):
es, ta = ctx.saved_tensors
block_size = ctx.block_size
type = ctx.type
eps = ctx.eps
args = (es, ta, grad_out.contiguous(), block_size, type, eps)
if grad_out.is_cuda:
grad_es = ext_cuda.photometric_loss_backward(*args)
else:
grad_es = ext_cpu.photometric_loss_backward(*args)
return grad_es, None, None, None, None
def photometric_loss(es, ta, block_size, type='mse', eps=0.1):
type = type.lower()
if type == 'mse':
type = 0
elif type == 'sad':
type = 1
elif type == 'census_mse':
type = 2
elif type == 'census_sad':
type = 3
else:
raise Exception('invalid loss type')
return PhotometricLossFunction.apply(es, ta, block_size, type, eps)
type = type.lower()
if type == 'mse':
type = 0
elif type == 'sad':
type = 1
elif type == 'census_mse':
type = 2
elif type == 'census_sad':
type = 3
else:
raise Exception('invalid loss type')
return PhotometricLossFunction.apply(es, ta, block_size, type, eps)
def photometric_loss_pytorch(es, ta, block_size, type='mse', eps=0.1):
type = type.lower()
p = block_size // 2
es_pad = torch.nn.functional.pad(es, (p,p,p,p), mode='replicate')
ta_pad = torch.nn.functional.pad(ta, (p,p,p,p), mode='replicate')
es_uf = torch.nn.functional.unfold(es_pad, kernel_size=block_size)
ta_uf = torch.nn.functional.unfold(ta_pad, kernel_size=block_size)
es_uf = es_uf.view(es.shape[0], es.shape[1], -1, es.shape[2], es.shape[3])
ta_uf = ta_uf.view(ta.shape[0], ta.shape[1], -1, ta.shape[2], ta.shape[3])
if type == 'mse':
ref = (es_uf - ta_uf)**2
elif type == 'sad':
ref = torch.abs(es_uf - ta_uf)
elif type == 'census_mse' or type == 'census_sad':
des = es_uf - es.unsqueeze(2)
dta = ta_uf - ta.unsqueeze(2)
h_des = 0.5 * (1 + des / torch.sqrt(des * des + eps))
h_dta = 0.5 * (1 + dta / torch.sqrt(dta * dta + eps))
diff = h_des - h_dta
if type == 'census_mse':
ref = diff * diff
elif type == 'census_sad':
ref = torch.abs(diff)
else:
raise Exception('invalid loss type')
ref = ref.view(es.shape[0], -1, es.shape[2], es.shape[3])
ref = torch.sum(ref, dim=1, keepdim=True) / block_size**2
return ref
type = type.lower()
p = block_size // 2
es_pad = torch.nn.functional.pad(es, (p, p, p, p), mode='replicate')
ta_pad = torch.nn.functional.pad(ta, (p, p, p, p), mode='replicate')
es_uf = torch.nn.functional.unfold(es_pad, kernel_size=block_size)
ta_uf = torch.nn.functional.unfold(ta_pad, kernel_size=block_size)
es_uf = es_uf.view(es.shape[0], es.shape[1], -1, es.shape[2], es.shape[3])
ta_uf = ta_uf.view(ta.shape[0], ta.shape[1], -1, ta.shape[2], ta.shape[3])
if type == 'mse':
ref = (es_uf - ta_uf) ** 2
elif type == 'sad':
ref = torch.abs(es_uf - ta_uf)
elif type == 'census_mse' or type == 'census_sad':
des = es_uf - es.unsqueeze(2)
dta = ta_uf - ta.unsqueeze(2)
h_des = 0.5 * (1 + des / torch.sqrt(des * des + eps))
h_dta = 0.5 * (1 + dta / torch.sqrt(dta * dta + eps))
diff = h_des - h_dta
if type == 'census_mse':
ref = diff * diff
elif type == 'census_sad':
ref = torch.abs(diff)
else:
raise Exception('invalid loss type')
ref = ref.view(es.shape[0], -1, es.shape[2], es.shape[3])
ref = torch.sum(ref, dim=1, keepdim=True) / block_size ** 2
return ref

View File

@ -4,24 +4,26 @@ import numpy as np
from .functions import *
class CoordConv2d(torch.nn.Module):
def __init__(self, channels_in, channels_out, kernel_size, stride, padding):
super().__init__()
def __init__(self, channels_in, channels_out, kernel_size, stride, padding):
super().__init__()
self.conv = torch.nn.Conv2d(channels_in+2, channels_out, kernel_size=kernel_size, padding=padding, stride=stride)
self.conv = torch.nn.Conv2d(channels_in + 2, channels_out, kernel_size=kernel_size, padding=padding,
stride=stride)
self.uv = None
self.uv = None
def forward(self, x):
if self.uv is None:
height, width = x.shape[2], x.shape[3]
u, v = np.meshgrid(range(width), range(height))
u = 2 * u / (width - 1) - 1
v = 2 * v / (height - 1) - 1
uv = np.stack((u, v)).reshape(1, 2, height, width)
self.uv = torch.from_numpy( uv.astype(np.float32) )
self.uv = self.uv.to(x.device)
uv = self.uv.expand(x.shape[0], *self.uv.shape[1:])
xuv = torch.cat((x, uv), dim=1)
y = self.conv(xuv)
return y
def forward(self, x):
if self.uv is None:
height, width = x.shape[2], x.shape[3]
u, v = np.meshgrid(range(width), range(height))
u = 2 * u / (width - 1) - 1
v = 2 * v / (height - 1) - 1
uv = np.stack((u, v)).reshape(1, 2, height, width)
self.uv = torch.from_numpy(uv.astype(np.float32))
self.uv = self.uv.to(x.device)
uv = self.uv.expand(x.shape[0], *self.uv.shape[1:])
xuv = torch.cat((x, uv), dim=1)
y = self.conv(xuv)
return y

View File

@ -11,11 +11,12 @@ nvcc_args = [
]
setup(
name='ext',
ext_modules=[
CppExtension('ext_cpu', ['ext/ext_cpu.cpp']),
CUDAExtension('ext_cuda', ['ext/ext_cuda.cpp', 'ext/ext_kernel.cu'], extra_compile_args={'cxx': [], 'nvcc': nvcc_args}),
],
cmdclass={'build_ext': BuildExtension},
include_dirs=include_dirs
name='ext',
ext_modules=[
CppExtension('ext_cpu', ['ext/ext_cpu.cpp']),
CUDAExtension('ext_cuda', ['ext/ext_cuda.cpp', 'ext/ext_kernel.cu'],
extra_compile_args={'cxx': [], 'nvcc': nvcc_args}),
],
cmdclass={'build_ext': BuildExtension},
include_dirs=include_dirs
)

View File

@ -17,512 +17,516 @@ from collections import OrderedDict
class StopWatch(object):
def __init__(self):
self.timings = OrderedDict()
self.starts = {}
def __init__(self):
self.timings = OrderedDict()
self.starts = {}
def start(self, name):
self.starts[name] = time.time()
def start(self, name):
self.starts[name] = time.time()
def stop(self, name):
if name not in self.timings:
self.timings[name] = []
self.timings[name].append(time.time() - self.starts[name])
def stop(self, name):
if name not in self.timings:
self.timings[name] = []
self.timings[name].append(time.time() - self.starts[name])
def get(self, name=None, reduce=np.sum):
if name is not None:
return reduce(self.timings[name])
else:
ret = {}
for k in self.timings:
ret[k] = reduce(self.timings[k])
return ret
def get(self, name=None, reduce=np.sum):
if name is not None:
return reduce(self.timings[name])
else:
ret = {}
for k in self.timings:
ret[k] = reduce(self.timings[k])
return ret
def __repr__(self):
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
def __str__(self):
return ', '.join(['%s: %f[s]' % (k,v) for k,v in self.get().items()])
def __repr__(self):
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
def __str__(self):
return ', '.join(['%s: %f[s]' % (k, v) for k, v in self.get().items()])
class ETA(object):
def __init__(self, length):
self.length = length
self.start_time = time.time()
self.current_idx = 0
self.current_time = time.time()
def __init__(self, length):
self.length = length
self.start_time = time.time()
self.current_idx = 0
self.current_time = time.time()
def update(self, idx):
self.current_idx = idx
self.current_time = time.time()
def update(self, idx):
self.current_idx = idx
self.current_time = time.time()
def get_elapsed_time(self):
return self.current_time - self.start_time
def get_elapsed_time(self):
return self.current_time - self.start_time
def get_item_time(self):
return self.get_elapsed_time() / (self.current_idx + 1)
def get_item_time(self):
return self.get_elapsed_time() / (self.current_idx + 1)
def get_remaining_time(self):
return self.get_item_time() * (self.length - self.current_idx + 1)
def get_remaining_time(self):
return self.get_item_time() * (self.length - self.current_idx + 1)
def format_time(self, seconds):
minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(minutes, 60)
hours = int(hours)
minutes = int(minutes)
return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}'
def format_time(self, seconds):
minutes, seconds = divmod(seconds, 60)
hours, minutes = divmod(minutes, 60)
hours = int(hours)
minutes = int(minutes)
return f'{hours:02d}:{minutes:02d}:{seconds:05.2f}'
def get_elapsed_time_str(self):
return self.format_time(self.get_elapsed_time())
def get_elapsed_time_str(self):
return self.format_time(self.get_elapsed_time())
def get_remaining_time_str(self):
return self.format_time(self.get_remaining_time())
def get_remaining_time_str(self):
return self.format_time(self.get_remaining_time())
class Worker(object):
def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16, num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1):
self.out_root = Path(out_root)
self.experiment_name = experiment_name
self.epochs = epochs
self.seed = seed
self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
self.num_workers = num_workers
self.save_frequency = save_frequency
self.train_device = train_device
self.test_device = test_device
self.max_train_iter = max_train_iter
def __init__(self, out_root, experiment_name, epochs=10, seed=42, train_batch_size=8, test_batch_size=16,
num_workers=16, save_frequency=1, train_device='cuda:0', test_device='cuda:0', max_train_iter=-1):
self.out_root = Path(out_root)
self.experiment_name = experiment_name
self.epochs = epochs
self.seed = seed
self.train_batch_size = train_batch_size
self.test_batch_size = test_batch_size
self.num_workers = num_workers
self.save_frequency = save_frequency
self.train_device = train_device
self.test_device = test_device
self.max_train_iter = max_train_iter
self.errs_list=[]
self.errs_list = []
self.setup_experiment()
self.setup_experiment()
def setup_experiment(self):
self.exp_out_root = self.out_root / self.experiment_name
self.exp_out_root.mkdir(parents=True, exist_ok=True)
def setup_experiment(self):
self.exp_out_root = self.out_root / self.experiment_name
self.exp_out_root.mkdir(parents=True, exist_ok=True)
if logging.root: del logging.root.handlers[:]
logging.basicConfig(
level=logging.INFO,
handlers=[
logging.FileHandler( str(self.exp_out_root / 'train.log') ),
logging.StreamHandler()
],
format='%(relativeCreated)d:%(levelname)s:%(process)d-%(processName)s: %(message)s'
)
if logging.root: del logging.root.handlers[:]
logging.basicConfig(
level=logging.INFO,
handlers=[
logging.FileHandler(str(self.exp_out_root / 'train.log')),
logging.StreamHandler()
],
format='%(relativeCreated)d:%(levelname)s:%(process)d-%(processName)s: %(message)s'
)
logging.info('='*80)
logging.info(f'Start of experiment: {self.experiment_name}')
logging.info(socket.gethostname())
self.log_datetime()
logging.info('='*80)
logging.info('=' * 80)
logging.info(f'Start of experiment: {self.experiment_name}')
logging.info(socket.gethostname())
self.log_datetime()
logging.info('=' * 80)
self.metric_path = self.exp_out_root / 'metrics.json'
if self.metric_path.exists():
with open(str(self.metric_path), 'r') as fp:
self.metric_data = json.load(fp)
else:
self.metric_data = {}
self.metric_path = self.exp_out_root / 'metrics.json'
if self.metric_path.exists():
with open(str(self.metric_path), 'r') as fp:
self.metric_data = json.load(fp)
else:
self.metric_data = {}
self.init_seed()
self.init_seed()
def metric_add_train(self, epoch, key, val):
epoch = str(epoch)
key = str(key)
if epoch not in self.metric_data:
self.metric_data[epoch] = {}
if 'train' not in self.metric_data[epoch]:
self.metric_data[epoch]['train'] = {}
self.metric_data[epoch]['train'][key] = val
def metric_add_train(self, epoch, key, val):
epoch = str(epoch)
key = str(key)
if epoch not in self.metric_data:
self.metric_data[epoch] = {}
if 'train' not in self.metric_data[epoch]:
self.metric_data[epoch]['train'] = {}
self.metric_data[epoch]['train'][key] = val
def metric_add_test(self, epoch, set_idx, key, val):
epoch = str(epoch)
set_idx = str(set_idx)
key = str(key)
if epoch not in self.metric_data:
self.metric_data[epoch] = {}
if 'test' not in self.metric_data[epoch]:
self.metric_data[epoch]['test'] = {}
if set_idx not in self.metric_data[epoch]['test']:
self.metric_data[epoch]['test'][set_idx] = {}
self.metric_data[epoch]['test'][set_idx][key] = val
def metric_add_test(self, epoch, set_idx, key, val):
epoch = str(epoch)
set_idx = str(set_idx)
key = str(key)
if epoch not in self.metric_data:
self.metric_data[epoch] = {}
if 'test' not in self.metric_data[epoch]:
self.metric_data[epoch]['test'] = {}
if set_idx not in self.metric_data[epoch]['test']:
self.metric_data[epoch]['test'][set_idx] = {}
self.metric_data[epoch]['test'][set_idx][key] = val
def metric_save(self):
with open(str(self.metric_path), 'w') as fp:
json.dump(self.metric_data, fp, indent=2)
def metric_save(self):
with open(str(self.metric_path), 'w') as fp:
json.dump(self.metric_data, fp, indent=2)
def init_seed(self, seed=None):
if seed is not None:
self.seed = seed
logging.info(f'Set seed to {self.seed}')
np.random.seed(self.seed)
random.seed(self.seed)
torch.manual_seed(self.seed)
torch.cuda.manual_seed(self.seed)
def init_seed(self, seed=None):
if seed is not None:
self.seed = seed
logging.info(f'Set seed to {self.seed}')
np.random.seed(self.seed)
random.seed(self.seed)
torch.manual_seed(self.seed)
torch.cuda.manual_seed(self.seed)
def log_datetime(self):
logging.info(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
def log_datetime(self):
logging.info(datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
def mem_report(self):
for obj in gc.get_objects():
if torch.is_tensor(obj):
print(type(obj), obj.shape)
def mem_report(self):
for obj in gc.get_objects():
if torch.is_tensor(obj):
print(type(obj), obj.shape)
def get_net_path(self, epoch, root=None):
if root is None:
root = self.exp_out_root
return root / f'net_{epoch:04d}.params'
def get_net_path(self, epoch, root=None):
if root is None:
root = self.exp_out_root
return root / f'net_{epoch:04d}.params'
def get_do_parser_cmds(self):
return ['retrain', 'resume', 'retest', 'test_init']
def get_do_parser_cmds(self):
return ['retrain', 'resume', 'retest', 'test_init']
def get_do_parser(self):
parser = argparse.ArgumentParser()
parser.add_argument('--cmd', type=str, default='resume', choices=self.get_do_parser_cmds())
parser.add_argument('--epoch', type=int, default=-1)
return parser
def get_do_parser(self):
parser = argparse.ArgumentParser()
parser.add_argument('--cmd', type=str, default='resume', choices=self.get_do_parser_cmds())
parser.add_argument('--epoch', type=int, default=-1)
return parser
def do_cmd(self, args, net, optimizer, scheduler=None):
if args.cmd == 'retrain':
self.train(net, optimizer, resume=False, scheduler=scheduler)
elif args.cmd == 'resume':
self.train(net, optimizer, resume=True, scheduler=scheduler)
elif args.cmd == 'retest':
self.retest(net, epoch=args.epoch)
elif args.cmd == 'test_init':
test_sets = self.get_test_sets()
self.test(-1, net, test_sets)
else:
raise Exception('invalid cmd')
def do_cmd(self, args, net, optimizer, scheduler=None):
if args.cmd == 'retrain':
self.train(net, optimizer, resume=False, scheduler=scheduler)
elif args.cmd == 'resume':
self.train(net, optimizer, resume=True, scheduler=scheduler)
elif args.cmd == 'retest':
self.retest(net, epoch=args.epoch)
elif args.cmd == 'test_init':
test_sets = self.get_test_sets()
self.test(-1, net, test_sets)
else:
raise Exception('invalid cmd')
def do(self, net, optimizer, load_net_optimizer=None, scheduler=None):
parser = self.get_do_parser()
args, _ = parser.parse_known_args()
def do(self, net, optimizer, load_net_optimizer=None, scheduler=None):
parser = self.get_do_parser()
args, _ = parser.parse_known_args()
if load_net_optimizer is not None and args.cmd not in ['schedule']:
net, optimizer = load_net_optimizer()
if load_net_optimizer is not None and args.cmd not in ['schedule']:
net, optimizer = load_net_optimizer()
self.do_cmd(args, net, optimizer, scheduler=scheduler)
self.do_cmd(args, net, optimizer, scheduler=scheduler)
def retest(self, net, epoch=-1):
if epoch < 0:
epochs = range(self.epochs)
else:
epochs = [epoch]
def retest(self, net, epoch=-1):
if epoch < 0:
epochs = range(self.epochs)
else:
epochs = [epoch]
test_sets = self.get_test_sets()
test_sets = self.get_test_sets()
for epoch in epochs:
net_path = self.get_net_path(epoch)
if net_path.exists():
state_dict = torch.load(str(net_path))
net.load_state_dict(state_dict)
self.test(epoch, net, test_sets)
for epoch in epochs:
net_path = self.get_net_path(epoch)
if net_path.exists():
state_dict = torch.load(str(net_path))
net.load_state_dict(state_dict)
self.test(epoch, net, test_sets)
def format_err_str(self, errs, div=1):
err = sum(errs)
if len(errs) > 1:
err_str = f'{err/div:0.4f}=' + '+'.join([f'{e/div:0.4f}' for e in errs])
else:
err_str = f'{err/div:0.4f}'
return err_str
def format_err_str(self, errs, div=1):
err = sum(errs)
if len(errs) > 1:
err_str = f'{err / div:0.4f}=' + '+'.join([f'{e / div:0.4f}' for e in errs])
else:
err_str = f'{err / div:0.4f}'
return err_str
def write_err_img(self):
err_img_path = self.exp_out_root / 'errs.png'
fig = plt.figure(figsize=(16,16))
lines=[]
for idx,errs in enumerate(self.errs_list):
line,=plt.plot(range(len(errs)), errs, label=f'error{idx}')
lines.append(line)
plt.tight_layout()
plt.legend(handles=lines)
plt.savefig(str(err_img_path))
plt.close(fig)
def write_err_img(self):
err_img_path = self.exp_out_root / 'errs.png'
fig = plt.figure(figsize=(16, 16))
lines = []
for idx, errs in enumerate(self.errs_list):
line, = plt.plot(range(len(errs)), errs, label=f'error{idx}')
lines.append(line)
plt.tight_layout()
plt.legend(handles=lines)
plt.savefig(str(err_img_path))
plt.close(fig)
def callback_train_new_epoch(self, epoch, net, optimizer):
pass
def train(self, net, optimizer, resume=False, scheduler=None):
logging.info('='*80)
logging.info('Start training')
self.log_datetime()
logging.info('='*80)
train_set = self.get_train_set()
test_sets = self.get_test_sets()
net = net.to(self.train_device)
epoch = 0
min_err = {ts.name: 1e9 for ts in test_sets}
state_path = self.exp_out_root / 'state.dict'
if resume and state_path.exists():
logging.info('='*80)
logging.info(f'Loading state from {state_path}')
logging.info('='*80)
state = torch.load(str(state_path))
epoch = state['epoch'] + 1
if 'min_err' in state:
min_err = state['min_err']
curr_state = net.state_dict()
curr_state.update(state['state_dict'])
net.load_state_dict(curr_state)
try:
optimizer.load_state_dict(state['optimizer'])
except:
logging.info('Warning: cannot load optimizer from state_dict')
def callback_train_new_epoch(self, epoch, net, optimizer):
pass
if 'cpu_rng_state' in state:
torch.set_rng_state(state['cpu_rng_state'])
if 'gpu_rng_state' in state:
torch.cuda.set_rng_state(state['gpu_rng_state'])
for epoch in range(epoch, self.epochs):
self.callback_train_new_epoch(epoch, net, optimizer)
def train(self, net, optimizer, resume=False, scheduler=None):
logging.info('=' * 80)
logging.info('Start training')
self.log_datetime()
logging.info('=' * 80)
# train epoch
self.train_epoch(epoch, net, optimizer, train_set)
train_set = self.get_train_set()
test_sets = self.get_test_sets()
# test epoch
errs = self.test(epoch, net, test_sets)
if (epoch + 1) % self.save_frequency == 0:
net = net.to(self.train_device)
# store state
state_dict = {
'epoch': epoch,
'min_err': min_err,
'state_dict': net.state_dict(),
'optimizer': optimizer.state_dict(),
'cpu_rng_state': torch.get_rng_state(),
'gpu_rng_state': torch.cuda.get_rng_state(),
}
logging.info(f'save state to {state_path}')
epoch = 0
min_err = {ts.name: 1e9 for ts in test_sets}
state_path = self.exp_out_root / 'state.dict'
torch.save(state_dict, str(state_path))
if resume and state_path.exists():
logging.info('=' * 80)
logging.info(f'Loading state from {state_path}')
logging.info('=' * 80)
state = torch.load(str(state_path))
epoch = state['epoch'] + 1
if 'min_err' in state:
min_err = state['min_err']
for test_set_name in errs:
err = sum(errs[test_set_name])
if err < min_err[test_set_name]:
min_err[test_set_name] = err
state_path = self.exp_out_root / f'state_set{test_set_name}_best.dict'
logging.info(f'save state to {state_path}')
torch.save(state_dict, str(state_path))
curr_state = net.state_dict()
curr_state.update(state['state_dict'])
net.load_state_dict(curr_state)
# store network
net_path = self.get_net_path(epoch)
logging.info(f'save network to {net_path}')
torch.save(net.state_dict(), str(net_path))
try:
optimizer.load_state_dict(state['optimizer'])
except:
logging.info('Warning: cannot load optimizer from state_dict')
pass
if 'cpu_rng_state' in state:
torch.set_rng_state(state['cpu_rng_state'])
if 'gpu_rng_state' in state:
torch.cuda.set_rng_state(state['gpu_rng_state'])
if scheduler is not None:
scheduler.step()
for epoch in range(epoch, self.epochs):
self.callback_train_new_epoch(epoch, net, optimizer)
logging.info('='*80)
logging.info('Finished training')
self.log_datetime()
logging.info('='*80)
# train epoch
self.train_epoch(epoch, net, optimizer, train_set)
def get_train_set(self):
# returns train_set
raise NotImplementedError()
# test epoch
errs = self.test(epoch, net, test_sets)
def get_test_sets(self):
# returns test_sets
raise NotImplementedError()
if (epoch + 1) % self.save_frequency == 0:
net = net.to(self.train_device)
def copy_data(self, data, device, requires_grad, train):
raise NotImplementedError()
# store state
state_dict = {
'epoch': epoch,
'min_err': min_err,
'state_dict': net.state_dict(),
'optimizer': optimizer.state_dict(),
'cpu_rng_state': torch.get_rng_state(),
'gpu_rng_state': torch.cuda.get_rng_state(),
}
logging.info(f'save state to {state_path}')
state_path = self.exp_out_root / 'state.dict'
torch.save(state_dict, str(state_path))
def net_forward(self, net, train):
raise NotImplementedError()
for test_set_name in errs:
err = sum(errs[test_set_name])
if err < min_err[test_set_name]:
min_err[test_set_name] = err
state_path = self.exp_out_root / f'state_set{test_set_name}_best.dict'
logging.info(f'save state to {state_path}')
torch.save(state_dict, str(state_path))
def loss_forward(self, output, train):
raise NotImplementedError()
# store network
net_path = self.get_net_path(epoch)
logging.info(f'save network to {net_path}')
torch.save(net.state_dict(), str(net_path))
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
# err = False
# for name, param in net.named_parameters():
# if not torch.isfinite(param.grad).all():
# print(name)
# err = True
# if err:
# import ipdb; ipdb.set_trace()
pass
if scheduler is not None:
scheduler.step()
def callback_train_start(self, epoch):
pass
logging.info('=' * 80)
logging.info('Finished training')
self.log_datetime()
logging.info('=' * 80)
def callback_train_stop(self, epoch, loss):
pass
def get_train_set(self):
# returns train_set
raise NotImplementedError()
def train_epoch(self, epoch, net, optimizer, dset):
self.callback_train_start(epoch)
stopwatch = StopWatch()
def get_test_sets(self):
# returns test_sets
raise NotImplementedError()
logging.info('='*80)
logging.info('Train epoch %d' % epoch)
def copy_data(self, data, device, requires_grad, train):
raise NotImplementedError()
dset.current_epoch = epoch
train_loader = torch.utils.data.DataLoader(dset, batch_size=self.train_batch_size, shuffle=True, num_workers=self.num_workers, drop_last=True, pin_memory=False)
def net_forward(self, net, train):
raise NotImplementedError()
net = net.to(self.train_device)
net.train()
def loss_forward(self, output, train):
raise NotImplementedError()
mean_loss = None
def callback_train_post_backward(self, net, errs, output, epoch, batch_idx, masks):
# err = False
# for name, param in net.named_parameters():
# if not torch.isfinite(param.grad).all():
# print(name)
# err = True
# if err:
# import ipdb; ipdb.set_trace()
pass
n_batches = self.max_train_iter if self.max_train_iter > 0 else len(train_loader)
bar = ETA(length=n_batches)
def callback_train_start(self, epoch):
pass
stopwatch.start('total')
stopwatch.start('data')
for batch_idx, data in enumerate(train_loader):
if self.max_train_iter > 0 and batch_idx > self.max_train_iter: break
self.copy_data(data, device=self.train_device, requires_grad=True, train=True)
stopwatch.stop('data')
def callback_train_stop(self, epoch, loss):
pass
optimizer.zero_grad()
def train_epoch(self, epoch, net, optimizer, dset):
self.callback_train_start(epoch)
stopwatch = StopWatch()
stopwatch.start('forward')
output = self.net_forward(net, train=True)
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('forward')
logging.info('=' * 80)
logging.info('Train epoch %d' % epoch)
stopwatch.start('loss')
errs = self.loss_forward(output, train=True)
if isinstance(errs, dict):
masks = errs['masks']
errs = errs['errs']
else:
masks = []
if not isinstance(errs, list) and not isinstance(errs, tuple):
errs = [errs]
err = sum(errs)
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('loss')
dset.current_epoch = epoch
train_loader = torch.utils.data.DataLoader(dset, batch_size=self.train_batch_size, shuffle=True,
num_workers=self.num_workers, drop_last=True, pin_memory=False)
stopwatch.start('backward')
err.backward()
self.callback_train_post_backward(net, errs, output, epoch, batch_idx, masks)
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('backward')
net = net.to(self.train_device)
net.train()
stopwatch.start('optimizer')
optimizer.step()
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('optimizer')
mean_loss = None
bar.update(batch_idx)
if (epoch <= 1 and batch_idx < 128) or batch_idx % 16 == 0:
err_str = self.format_err_str(errs)
logging.info(f'train e{epoch}: {batch_idx+1}/{len(train_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
#self.write_err_img()
if mean_loss is None:
mean_loss = [0 for e in errs]
for erridx, err in enumerate(errs):
mean_loss[erridx] += err.item()
stopwatch.start('data')
stopwatch.stop('total')
logging.info('timings: %s' % stopwatch)
mean_loss = [l / len(train_loader) for l in mean_loss]
self.callback_train_stop(epoch, mean_loss)
self.metric_add_train(epoch, 'loss', mean_loss)
# save metrics
self.metric_save()
err_str = self.format_err_str(mean_loss)
logging.info(f'avg train_loss={err_str}')
return mean_loss
def callback_test_start(self, epoch, set_idx):
pass
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
pass
def callback_test_stop(self, epoch, set_idx, loss):
pass
def test(self, epoch, net, test_sets):
errs = {}
for test_set_idx, test_set in enumerate(test_sets):
if (epoch + 1) % test_set.test_frequency == 0:
logging.info('='*80)
logging.info(f'testing set {test_set.name}')
err = self.test_epoch(epoch, test_set_idx, net, test_set.dset)
errs[test_set.name] = err
return errs
def test_epoch(self, epoch, set_idx, net, dset):
logging.info('-'*80)
logging.info('Test epoch %d' % epoch)
dset.current_epoch = epoch
test_loader = torch.utils.data.DataLoader(dset, batch_size=self.test_batch_size, shuffle=False, num_workers=self.num_workers, drop_last=False, pin_memory=False)
net = net.to(self.test_device)
net.eval()
with torch.no_grad():
mean_loss = None
self.callback_test_start(epoch, set_idx)
bar = ETA(length=len(test_loader))
stopwatch = StopWatch()
stopwatch.start('total')
stopwatch.start('data')
for batch_idx, data in enumerate(test_loader):
# if batch_idx == 10: break
self.copy_data(data, device=self.test_device, requires_grad=False, train=False)
stopwatch.stop('data')
stopwatch.start('forward')
output = self.net_forward(net, train=False)
if 'cuda' in self.test_device: torch.cuda.synchronize()
stopwatch.stop('forward')
stopwatch.start('loss')
errs = self.loss_forward(output, train=False)
if isinstance(errs, dict):
masks = errs['masks']
errs = errs['errs']
else:
masks = []
if not isinstance(errs, list) and not isinstance(errs, tuple):
errs = [errs]
bar.update(batch_idx)
if batch_idx % 25 == 0:
err_str = self.format_err_str(errs)
logging.info(f'test e{epoch}: {batch_idx+1}/{len(test_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
if mean_loss is None:
mean_loss = [0 for e in errs]
for erridx, err in enumerate(errs):
mean_loss[erridx] += err.item()
stopwatch.stop('loss')
self.callback_test_add(epoch, set_idx, batch_idx, len(test_loader), output, masks)
n_batches = self.max_train_iter if self.max_train_iter > 0 else len(train_loader)
bar = ETA(length=n_batches)
stopwatch.start('total')
stopwatch.start('data')
stopwatch.stop('total')
logging.info('timings: %s' % stopwatch)
for batch_idx, data in enumerate(train_loader):
if self.max_train_iter > 0 and batch_idx > self.max_train_iter: break
self.copy_data(data, device=self.train_device, requires_grad=True, train=True)
stopwatch.stop('data')
mean_loss = [l / len(test_loader) for l in mean_loss]
self.callback_test_stop(epoch, set_idx, mean_loss)
self.metric_add_test(epoch, set_idx, 'loss', mean_loss)
optimizer.zero_grad()
# save metrics
self.metric_save()
stopwatch.start('forward')
output = self.net_forward(net, train=True)
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('forward')
err_str = self.format_err_str(mean_loss)
logging.info(f'test epoch {epoch}: avg test_loss={err_str}')
return mean_loss
stopwatch.start('loss')
errs = self.loss_forward(output, train=True)
if isinstance(errs, dict):
masks = errs['masks']
errs = errs['errs']
else:
masks = []
if not isinstance(errs, list) and not isinstance(errs, tuple):
errs = [errs]
err = sum(errs)
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('loss')
stopwatch.start('backward')
err.backward()
self.callback_train_post_backward(net, errs, output, epoch, batch_idx, masks)
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('backward')
stopwatch.start('optimizer')
optimizer.step()
if 'cuda' in self.train_device: torch.cuda.synchronize()
stopwatch.stop('optimizer')
bar.update(batch_idx)
if (epoch <= 1 and batch_idx < 128) or batch_idx % 16 == 0:
err_str = self.format_err_str(errs)
logging.info(
f'train e{epoch}: {batch_idx + 1}/{len(train_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
# self.write_err_img()
if mean_loss is None:
mean_loss = [0 for e in errs]
for erridx, err in enumerate(errs):
mean_loss[erridx] += err.item()
stopwatch.start('data')
stopwatch.stop('total')
logging.info('timings: %s' % stopwatch)
mean_loss = [l / len(train_loader) for l in mean_loss]
self.callback_train_stop(epoch, mean_loss)
self.metric_add_train(epoch, 'loss', mean_loss)
# save metrics
self.metric_save()
err_str = self.format_err_str(mean_loss)
logging.info(f'avg train_loss={err_str}')
return mean_loss
def callback_test_start(self, epoch, set_idx):
pass
def callback_test_add(self, epoch, set_idx, batch_idx, n_batches, output, masks):
pass
def callback_test_stop(self, epoch, set_idx, loss):
pass
def test(self, epoch, net, test_sets):
errs = {}
for test_set_idx, test_set in enumerate(test_sets):
if (epoch + 1) % test_set.test_frequency == 0:
logging.info('=' * 80)
logging.info(f'testing set {test_set.name}')
err = self.test_epoch(epoch, test_set_idx, net, test_set.dset)
errs[test_set.name] = err
return errs
def test_epoch(self, epoch, set_idx, net, dset):
logging.info('-' * 80)
logging.info('Test epoch %d' % epoch)
dset.current_epoch = epoch
test_loader = torch.utils.data.DataLoader(dset, batch_size=self.test_batch_size, shuffle=False,
num_workers=self.num_workers, drop_last=False, pin_memory=False)
net = net.to(self.test_device)
net.eval()
with torch.no_grad():
mean_loss = None
self.callback_test_start(epoch, set_idx)
bar = ETA(length=len(test_loader))
stopwatch = StopWatch()
stopwatch.start('total')
stopwatch.start('data')
for batch_idx, data in enumerate(test_loader):
# if batch_idx == 10: break
self.copy_data(data, device=self.test_device, requires_grad=False, train=False)
stopwatch.stop('data')
stopwatch.start('forward')
output = self.net_forward(net, train=False)
if 'cuda' in self.test_device: torch.cuda.synchronize()
stopwatch.stop('forward')
stopwatch.start('loss')
errs = self.loss_forward(output, train=False)
if isinstance(errs, dict):
masks = errs['masks']
errs = errs['errs']
else:
masks = []
if not isinstance(errs, list) and not isinstance(errs, tuple):
errs = [errs]
bar.update(batch_idx)
if batch_idx % 25 == 0:
err_str = self.format_err_str(errs)
logging.info(
f'test e{epoch}: {batch_idx + 1}/{len(test_loader)}: loss={err_str} | {bar.get_elapsed_time_str()} / {bar.get_remaining_time_str()}')
if mean_loss is None:
mean_loss = [0 for e in errs]
for erridx, err in enumerate(errs):
mean_loss[erridx] += err.item()
stopwatch.stop('loss')
self.callback_test_add(epoch, set_idx, batch_idx, len(test_loader), output, masks)
stopwatch.start('data')
stopwatch.stop('total')
logging.info('timings: %s' % stopwatch)
mean_loss = [l / len(test_loader) for l in mean_loss]
self.callback_test_stop(epoch, set_idx, mean_loss)
self.metric_add_test(epoch, set_idx, 'loss', mean_loss)
# save metrics
self.metric_save()
err_str = self.format_err_str(mean_loss)
logging.info(f'test epoch {epoch}: avg test_loss={err_str}')
return mean_loss

View File

@ -5,25 +5,24 @@ from model import exp_synphge
from model import networks
from co.args import parse_args
# parse args
args = parse_args()
# loss types
if args.loss=='ph':
worker = exp_synph.Worker(args)
elif args.loss=='phge':
worker = exp_synphge.Worker(args)
if args.loss == 'ph':
worker = exp_synph.Worker(args)
elif args.loss == 'phge':
worker = exp_synphge.Worker(args)
# concatenation of original image and lcn image
channels_in=2
channels_in = 2
# set up network
net = networks.DispEdgeDecoders(channels_in=channels_in, max_disp=args.max_disp, imsizes=worker.imsizes, output_ms=worker.ms)
net = networks.DispEdgeDecoders(channels_in=channels_in, max_disp=args.max_disp, imsizes=worker.imsizes,
output_ms=worker.ms)
# optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4)
# start the work
worker.do(net, optimizer)