Adjusted mesgh_grid for ONNX conversion
This commit is contained in:
parent
f0a57bb444
commit
d199d62b8d
@ -81,7 +81,7 @@ class CREStereo(nn.Module):
|
||||
zero_flow = torch.cat((_x, _y), dim=1).to(fmap.device)
|
||||
return zero_flow
|
||||
|
||||
def forward(self, image1, image2, iters=10, flow_init=None, upsample=True, test_mode=False):
|
||||
def forward(self, image1, image2, flow_init, iters=10, upsample=True, test_mode=False):
|
||||
""" Estimate optical flow between pair of frames """
|
||||
|
||||
image1 = 2 * (image1 / 255.0) - 1.0
|
||||
|
@ -12,7 +12,8 @@ def bilinear_sampler(img, coords, mode='bilinear', mask=False):
|
||||
ygrid = 2*ygrid/(H-1) - 1
|
||||
|
||||
grid = torch.cat([xgrid, ygrid], dim=-1)
|
||||
img = F.grid_sample(img, grid, align_corners=True)
|
||||
# img = F.grid_sample(img, grid, align_corners=True)
|
||||
img = bilinear_grid_sample(img, grid, align_corners=True)
|
||||
|
||||
if mask:
|
||||
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
||||
@ -29,3 +30,79 @@ def manual_pad(x, pady, padx):
|
||||
|
||||
pad = (padx, padx, pady, pady)
|
||||
return F.pad(x.clone().detach(), pad, "replicate")
|
||||
|
||||
# Ref: https://zenn.dev/pinto0309/scraps/7d4032067d0160
|
||||
def bilinear_grid_sample(im, grid, align_corners=False):
|
||||
"""Given an input and a flow-field grid, computes the output using input
|
||||
values and pixel locations from grid. Supported only bilinear interpolation
|
||||
method to sample the input pixels.
|
||||
|
||||
Args:
|
||||
im (torch.Tensor): Input feature map, shape (N, C, H, W)
|
||||
grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2)
|
||||
align_corners {bool}: If set to True, the extrema (-1 and 1) are
|
||||
considered as referring to the center points of the input’s
|
||||
corner pixels. If set to False, they are instead considered as
|
||||
referring to the corner points of the input’s corner pixels,
|
||||
making the sampling more resolution agnostic.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg)
|
||||
"""
|
||||
n, c, h, w = im.shape
|
||||
gn, gh, gw, _ = grid.shape
|
||||
assert n == gn
|
||||
|
||||
x = grid[:, :, :, 0]
|
||||
y = grid[:, :, :, 1]
|
||||
|
||||
if align_corners:
|
||||
x = ((x + 1) / 2) * (w - 1)
|
||||
y = ((y + 1) / 2) * (h - 1)
|
||||
else:
|
||||
x = ((x + 1) * w - 1) / 2
|
||||
y = ((y + 1) * h - 1) / 2
|
||||
|
||||
x = x.view(n, -1)
|
||||
y = y.view(n, -1)
|
||||
|
||||
x0 = torch.floor(x).long()
|
||||
y0 = torch.floor(y).long()
|
||||
x1 = x0 + 1
|
||||
y1 = y0 + 1
|
||||
|
||||
wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
|
||||
wb = ((x1 - x) * (y - y0)).unsqueeze(1)
|
||||
wc = ((x - x0) * (y1 - y)).unsqueeze(1)
|
||||
wd = ((x - x0) * (y - y0)).unsqueeze(1)
|
||||
|
||||
# Apply default for grid_sample function zero padding
|
||||
im_padded = torch.nn.functional.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0)
|
||||
padded_h = h + 2
|
||||
padded_w = w + 2
|
||||
# save points positions after padding
|
||||
x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1
|
||||
|
||||
# Clip coordinates to padded image size
|
||||
x0 = torch.where(x0 < 0, torch.tensor(0, device=im.device), x0)
|
||||
x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1, device=im.device), x0)
|
||||
x1 = torch.where(x1 < 0, torch.tensor(0, device=im.device), x1)
|
||||
x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1, device=im.device), x1)
|
||||
y0 = torch.where(y0 < 0, torch.tensor(0, device=im.device), y0)
|
||||
y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1, device=im.device), y0)
|
||||
y1 = torch.where(y1 < 0, torch.tensor(0, device=im.device), y1)
|
||||
y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1, device=im.device), y1)
|
||||
|
||||
im_padded = im_padded.view(n, c, -1)
|
||||
|
||||
x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
|
||||
x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
|
||||
x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
|
||||
x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
|
||||
|
||||
Ia = torch.gather(im_padded, 2, x0_y0)
|
||||
Ib = torch.gather(im_padded, 2, x0_y1)
|
||||
Ic = torch.gather(im_padded, 2, x1_y0)
|
||||
Id = torch.gather(im_padded, 2, x1_y1)
|
||||
|
||||
return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)
|
||||
|
Loading…
Reference in New Issue
Block a user