general improvements

This commit is contained in:
Cpt.Captain 2022-02-22 13:35:24 +01:00
parent 7633990c81
commit 168516924e

View File

@ -224,7 +224,7 @@ class DispNetS(TimedModule):
)
def crop_like(self, input, ref):
assert (input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3))
assert (input.size(2) >= ref.size(2) and input.size(3) >= ref.size(3)), f'Assertion ({input.size(2)} >= {ref.size(2)} and {input.size(3)} >= {ref.size(3)}) failed'
return input[:, :, :ref.size(2), :ref.size(3)]
def tforward(self, x):
@ -291,7 +291,8 @@ class DispNetS(TimedModule):
if self.output_ms:
if self.double_head:
return (disp1, disp1_d), (disp2, disp2_d), disp3, disp4
# NOTE return all tuples for easier handling
return (disp1, disp1_d), (disp2, disp2_d), (disp3, disp3), (disp4, disp4)
return disp1, disp2, disp3, disp4
else:
if self.double_head:
@ -304,8 +305,8 @@ class DispNetShallow(DispNetS):
Edge Decoder based on DispNetS with fewer layers
'''
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False):
super(DispNetShallow, self).__init__(channels_in, imsizes, output_facs, output_ms, coordconv, weight_init)
def __init__(self, channels_in, imsizes, output_facs, output_ms=True, coordconv=False, weight_init=False, double_head=False):
super(DispNetShallow, self).__init__(channels_in, imsizes, output_facs, output_ms, coordconv, weight_init, double_head=False)
self.mod_name = 'DispNetShallow'
conv_planes = [32, 64, 128, 256, 512, 512, 512]
upconv_planes = [512, 512, 256, 128, 64, 32, 16]
@ -335,6 +336,21 @@ class DispNetShallow(DispNetS):
out_iconv1 = self.iconv1(concat1)
disp1 = self.predict_disp1(out_iconv1)
if self.double_head:
out_upconv2_d = self.crop_like(self.upconv2(out_iconv3), out_conv1)
disp3_up_d = self.crop_like(
torch.nn.functional.interpolate(disp3, scale_factor=2, mode='bilinear', align_corners=False), out_conv1)
concat2_d = torch.cat((out_upconv2_d, out_conv1, disp3_up_d), 1)
out_iconv2_d = self.iconv2(concat2_d)
disp2_d = self.predict_disp2_double(out_iconv2_d)
out_upconv1_d = self.crop_like(self.upconv1(out_iconv2), x)
disp2_up_d = self.crop_like(
torch.nn.functional.interpolate(disp2_d, scale_factor=2, mode='bilinear', align_corners=False), x)
concat1_d = torch.cat((out_upconv1_d, disp2_up_d), 1)
out_iconv1_d = self.iconv1(concat1_d)
disp1_d = self.predict_disp1_double(out_iconv1_d)
if self.output_ms:
return disp1, disp2, disp3
else:
@ -411,7 +427,6 @@ class RectifiedPatternSimilarityLoss(TimedModule):
def tforward(self, disp0, im, std=None):
self.pattern = self.pattern.to(disp0.device)
self.uv0 = self.uv0.to(disp0.device)
uv0 = self.uv0.expand(disp0.shape[0], *self.uv0.shape[1:])
uv1 = torch.empty_like(uv0)
uv1[..., 0] = uv0[..., 0] - disp0.contiguous().view(disp0.shape[0], -1)
@ -512,7 +527,7 @@ class ProjectionBaseLoss(TimedModule):
xyz = xyz + t.reshape(bs, 1, 3)
Kt = self.K.transpose(1, 2).expand(bs, -1, -1)
uv = torch.bmm(xyz, Kt)
uv = torch.bmm(xyz, Kt.float())
d = uv[:, :, 2:3]