Adjusted mesgh_grid for ONNX conversion

main
Ibai 3 years ago
parent f0a57bb444
commit d199d62b8d
  1. 2
      nets/crestereo.py
  2. 79
      nets/utils/utils.py

@ -81,7 +81,7 @@ class CREStereo(nn.Module):
zero_flow = torch.cat((_x, _y), dim=1).to(fmap.device) zero_flow = torch.cat((_x, _y), dim=1).to(fmap.device)
return zero_flow return zero_flow
def forward(self, image1, image2, iters=10, flow_init=None, upsample=True, test_mode=False): def forward(self, image1, image2, flow_init, iters=10, upsample=True, test_mode=False):
""" Estimate optical flow between pair of frames """ """ Estimate optical flow between pair of frames """
image1 = 2 * (image1 / 255.0) - 1.0 image1 = 2 * (image1 / 255.0) - 1.0

@ -12,7 +12,8 @@ def bilinear_sampler(img, coords, mode='bilinear', mask=False):
ygrid = 2*ygrid/(H-1) - 1 ygrid = 2*ygrid/(H-1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1) grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True) # img = F.grid_sample(img, grid, align_corners=True)
img = bilinear_grid_sample(img, grid, align_corners=True)
if mask: if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
@ -29,3 +30,79 @@ def manual_pad(x, pady, padx):
pad = (padx, padx, pady, pady) pad = (padx, padx, pady, pady)
return F.pad(x.clone().detach(), pad, "replicate") return F.pad(x.clone().detach(), pad, "replicate")
# Ref: https://zenn.dev/pinto0309/scraps/7d4032067d0160
def bilinear_grid_sample(im, grid, align_corners=False):
"""Given an input and a flow-field grid, computes the output using input
values and pixel locations from grid. Supported only bilinear interpolation
method to sample the input pixels.
Args:
im (torch.Tensor): Input feature map, shape (N, C, H, W)
grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2)
align_corners {bool}: If set to True, the extrema (-1 and 1) are
considered as referring to the center points of the inputs
corner pixels. If set to False, they are instead considered as
referring to the corner points of the inputs corner pixels,
making the sampling more resolution agnostic.
Returns:
torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg)
"""
n, c, h, w = im.shape
gn, gh, gw, _ = grid.shape
assert n == gn
x = grid[:, :, :, 0]
y = grid[:, :, :, 1]
if align_corners:
x = ((x + 1) / 2) * (w - 1)
y = ((y + 1) / 2) * (h - 1)
else:
x = ((x + 1) * w - 1) / 2
y = ((y + 1) * h - 1) / 2
x = x.view(n, -1)
y = y.view(n, -1)
x0 = torch.floor(x).long()
y0 = torch.floor(y).long()
x1 = x0 + 1
y1 = y0 + 1
wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
wb = ((x1 - x) * (y - y0)).unsqueeze(1)
wc = ((x - x0) * (y1 - y)).unsqueeze(1)
wd = ((x - x0) * (y - y0)).unsqueeze(1)
# Apply default for grid_sample function zero padding
im_padded = torch.nn.functional.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0)
padded_h = h + 2
padded_w = w + 2
# save points positions after padding
x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1
# Clip coordinates to padded image size
x0 = torch.where(x0 < 0, torch.tensor(0, device=im.device), x0)
x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1, device=im.device), x0)
x1 = torch.where(x1 < 0, torch.tensor(0, device=im.device), x1)
x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1, device=im.device), x1)
y0 = torch.where(y0 < 0, torch.tensor(0, device=im.device), y0)
y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1, device=im.device), y0)
y1 = torch.where(y1 < 0, torch.tensor(0, device=im.device), y1)
y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1, device=im.device), y1)
im_padded = im_padded.view(n, c, -1)
x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
Ia = torch.gather(im_padded, 2, x0_y0)
Ib = torch.gather(im_padded, 2, x0_y1)
Ic = torch.gather(im_padded, 2, x1_y0)
Id = torch.gather(im_padded, 2, x1_y1)
return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)

Loading…
Cancel
Save