You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
152 lines
4.4 KiB
152 lines
4.4 KiB
import torch
|
|
from . import ext_cpu
|
|
from . import ext_cuda
|
|
|
|
|
|
class NNFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, in0, in1):
|
|
args = (in0, in1)
|
|
if in0.is_cuda:
|
|
out = ext_cuda.nn_cuda(*args)
|
|
else:
|
|
out = ext_cpu.nn_cpu(*args)
|
|
return out
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
return None, None
|
|
|
|
|
|
def nn(in0, in1):
|
|
return NNFunction.apply(in0, in1)
|
|
|
|
|
|
class CrossCheckFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, in0, in1):
|
|
args = (in0, in1)
|
|
if in0.is_cuda:
|
|
out = ext_cuda.crosscheck_cuda(*args)
|
|
else:
|
|
out = ext_cpu.crosscheck_cpu(*args)
|
|
return out
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
return None, None
|
|
|
|
|
|
def crosscheck(in0, in1):
|
|
return CrossCheckFunction.apply(in0, in1)
|
|
|
|
|
|
class ProjNNFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, xyz0, xyz1, K, patch_size):
|
|
args = (xyz0, xyz1, K, patch_size)
|
|
if xyz0.is_cuda:
|
|
out = ext_cuda.proj_nn_cuda(*args)
|
|
else:
|
|
out = ext_cpu.proj_nn_cpu(*args)
|
|
return out
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
return None, None, None, None
|
|
|
|
|
|
def proj_nn(xyz0, xyz1, K, patch_size):
|
|
return ProjNNFunction.apply(xyz0, xyz1, K, patch_size)
|
|
|
|
|
|
class XCorrVolFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, in0, in1, n_disps, block_size):
|
|
args = (in0, in1, n_disps, block_size)
|
|
if in0.is_cuda:
|
|
out = ext_cuda.xcorrvol_cuda(*args)
|
|
else:
|
|
out = ext_cpu.xcorrvol_cpu(*args)
|
|
return out
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
return None, None, None, None
|
|
|
|
|
|
def xcorrvol(in0, in1, n_disps, block_size):
|
|
return XCorrVolFunction.apply(in0, in1, n_disps, block_size)
|
|
|
|
|
|
class PhotometricLossFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, es, ta, block_size, type, eps):
|
|
args = (es, ta, block_size, type, eps)
|
|
ctx.save_for_backward(es, ta)
|
|
ctx.block_size = block_size
|
|
ctx.type = type
|
|
ctx.eps = eps
|
|
if es.is_cuda:
|
|
out = ext_cuda.photometric_loss_forward(*args)
|
|
else:
|
|
out = ext_cpu.photometric_loss_forward(*args)
|
|
return out
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
es, ta = ctx.saved_tensors
|
|
block_size = ctx.block_size
|
|
type = ctx.type
|
|
eps = ctx.eps
|
|
args = (es, ta, grad_out.contiguous(), block_size, type, eps)
|
|
if grad_out.is_cuda:
|
|
grad_es = ext_cuda.photometric_loss_backward(*args)
|
|
else:
|
|
grad_es = ext_cpu.photometric_loss_backward(*args)
|
|
return grad_es, None, None, None, None
|
|
|
|
|
|
def photometric_loss(es, ta, block_size, type='mse', eps=0.1):
|
|
type = type.lower()
|
|
if type == 'mse':
|
|
type = 0
|
|
elif type == 'sad':
|
|
type = 1
|
|
elif type == 'census_mse':
|
|
type = 2
|
|
elif type == 'census_sad':
|
|
type = 3
|
|
else:
|
|
raise Exception('invalid loss type')
|
|
return PhotometricLossFunction.apply(es, ta, block_size, type, eps)
|
|
|
|
|
|
def photometric_loss_pytorch(es, ta, block_size, type='mse', eps=0.1):
|
|
type = type.lower()
|
|
p = block_size // 2
|
|
es_pad = torch.nn.functional.pad(es, (p, p, p, p), mode='replicate')
|
|
ta_pad = torch.nn.functional.pad(ta, (p, p, p, p), mode='replicate')
|
|
es_uf = torch.nn.functional.unfold(es_pad, kernel_size=block_size)
|
|
ta_uf = torch.nn.functional.unfold(ta_pad, kernel_size=block_size)
|
|
es_uf = es_uf.view(es.shape[0], es.shape[1], -1, es.shape[2], es.shape[3])
|
|
ta_uf = ta_uf.view(ta.shape[0], ta.shape[1], -1, ta.shape[2], ta.shape[3])
|
|
if type == 'mse':
|
|
ref = (es_uf - ta_uf) ** 2
|
|
elif type == 'sad':
|
|
ref = torch.abs(es_uf - ta_uf)
|
|
elif type == 'census_mse' or type == 'census_sad':
|
|
des = es_uf - es.unsqueeze(2)
|
|
dta = ta_uf - ta.unsqueeze(2)
|
|
h_des = 0.5 * (1 + des / torch.sqrt(des * des + eps))
|
|
h_dta = 0.5 * (1 + dta / torch.sqrt(dta * dta + eps))
|
|
diff = h_des - h_dta
|
|
if type == 'census_mse':
|
|
ref = diff * diff
|
|
elif type == 'census_sad':
|
|
ref = torch.abs(diff)
|
|
else:
|
|
raise Exception('invalid loss type')
|
|
ref = ref.view(es.shape[0], -1, es.shape[2], es.shape[3])
|
|
ref = torch.sum(ref, dim=1, keepdim=True) / block_size ** 2
|
|
return ref
|
|
|