general improvements
This commit is contained in:
parent
7633990c81
commit
168516924e
@ -224,7 +224,7 @@ class DispNetS(TimedModule):
|
||||
)
|
||||
|
||||
def crop_like(self, input, ref):
|
||||
assert (input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3))
|
||||
assert (input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3)), f'Assertion ({input.size(2)} >= {ref.size(2)} and {input.size(3)} >= {ref.size(3)}) failed'
|
||||
return input[:, :, :ref.size(2), :ref.size(3)]
|
||||
|
||||
def tforward(self, x):
|
||||
@ -291,7 +291,8 @@ class DispNetS(TimedModule):
|
||||
|
||||
if self.output_ms:
|
||||
if self.double_head:
|
||||
return (disp1, disp1_d), (disp2, disp2_d), disp3, disp4
|
||||
# NOTE return all tuples for easier handling
|
||||
return (disp1, disp1_d), (disp2, disp2_d), (disp3, disp3), (disp4, disp4)
|
||||
return disp1, disp2, disp3, disp4
|
||||
else:
|
||||
if self.double_head:
|
||||
@ -304,8 +305,8 @@ class DispNetShallow(DispNetS):
|
||||
Edge Decoder based on DispNetS with fewer layers
|
||||
'''
|
||||
|
||||
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False):
|
||||
super(DispNetShallow, self).__init__(channels_in, imsizes, output_facs, output_ms, coordconv, weight_init)
|
||||
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False, double_head=False):
|
||||
super(DispNetShallow, self).__init__(channels_in, imsizes, output_facs, output_ms, coordconv, weight_init, double_head=False)
|
||||
self.mod_name = 'DispNetShallow'
|
||||
conv_planes = [32, 64, 128, 256, 512, 512, 512]
|
||||
upconv_planes = [512, 512, 256, 128, 64, 32, 16]
|
||||
@ -335,6 +336,21 @@ class DispNetShallow(DispNetS):
|
||||
out_iconv1 = self.iconv1(concat1)
|
||||
disp1 = self.predict_disp1(out_iconv1)
|
||||
|
||||
if self.double_head:
|
||||
out_upconv2_d = self.crop_like(self.upconv2(out_iconv3), out_conv1)
|
||||
disp3_up_d = self.crop_like(
|
||||
torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
|
||||
concat2_d = torch.cat((out_upconv2_d, out_conv1, disp3_up_d), 1)
|
||||
out_iconv2_d = self.iconv2(concat2_d)
|
||||
disp2_d = self.predict_disp2_double(out_iconv2_d)
|
||||
|
||||
out_upconv1_d = self.crop_like(self.upconv1(out_iconv2), x)
|
||||
disp2_up_d = self.crop_like(
|
||||
torch.nn.functional.interpolate(disp2_d, scale_factor=2, mode='bilinear', align_corners=False), x)
|
||||
concat1_d = torch.cat((out_upconv1_d, disp2_up_d), 1)
|
||||
out_iconv1_d = self.iconv1(concat1_d)
|
||||
disp1_d = self.predict_disp1_double(out_iconv1_d)
|
||||
|
||||
if self.output_ms:
|
||||
return disp1, disp2, disp3
|
||||
else:
|
||||
@ -411,7 +427,6 @@ class RectifiedPatternSimilarityLoss(TimedModule):
|
||||
def tforward(self, disp0, im, std=None):
|
||||
self.pattern = self.pattern.to(disp0.device)
|
||||
self.uv0 = self.uv0.to(disp0.device)
|
||||
|
||||
uv0 = self.uv0.expand(disp0.shape[0], *self.uv0.shape[1:])
|
||||
uv1 = torch.empty_like(uv0)
|
||||
uv1[..., 0] = uv0[..., 0] - disp0.contiguous().view(disp0.shape[0], -1)
|
||||
@ -512,7 +527,7 @@ class ProjectionBaseLoss(TimedModule):
|
||||
xyz = xyz + t.reshape(bs, 1, 3)
|
||||
|
||||
Kt = self.K.transpose(1, 2).expand(bs, -1, -1)
|
||||
uv = torch.bmm(xyz, Kt)
|
||||
uv = torch.bmm(xyz, Kt.float())
|
||||
|
||||
d = uv[:, :, 2:3]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user