You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
108 lines
4.0 KiB
108 lines
4.0 KiB
import torch
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
|
|
#Ref: https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py
|
|
|
|
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
|
|
""" Wrapper for grid_sample, uses pixel coordinates """
|
|
H, W = img.shape[-2:]
|
|
xgrid, ygrid = coords.split([1,1], dim=-1)
|
|
xgrid = 2*xgrid/(W-1) - 1
|
|
ygrid = 2*ygrid/(H-1) - 1
|
|
|
|
grid = torch.cat([xgrid, ygrid], dim=-1)
|
|
# img = F.grid_sample(img, grid, align_corners=True)
|
|
img = bilinear_grid_sample(img, grid, align_corners=True)
|
|
|
|
if mask:
|
|
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
|
return img, mask.float()
|
|
|
|
return img
|
|
|
|
def coords_grid(batch, ht, wd, device):
|
|
coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device), indexing='ij')
|
|
coords = torch.stack(coords[::-1], dim=0).float()
|
|
return coords[None].repeat(batch, 1, 1, 1)
|
|
|
|
def manual_pad(x, pady, padx):
|
|
|
|
pad = (padx, padx, pady, pady)
|
|
return F.pad(x.clone().detach(), pad, "replicate")
|
|
|
|
# Ref: https://zenn.dev/pinto0309/scraps/7d4032067d0160
|
|
def bilinear_grid_sample(im, grid, align_corners=False):
|
|
"""Given an input and a flow-field grid, computes the output using input
|
|
values and pixel locations from grid. Supported only bilinear interpolation
|
|
method to sample the input pixels.
|
|
|
|
Args:
|
|
im (torch.Tensor): Input feature map, shape (N, C, H, W)
|
|
grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2)
|
|
align_corners {bool}: If set to True, the extrema (-1 and 1) are
|
|
considered as referring to the center points of the input’s
|
|
corner pixels. If set to False, they are instead considered as
|
|
referring to the corner points of the input’s corner pixels,
|
|
making the sampling more resolution agnostic.
|
|
|
|
Returns:
|
|
torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg)
|
|
"""
|
|
n, c, h, w = im.shape
|
|
gn, gh, gw, _ = grid.shape
|
|
assert n == gn
|
|
|
|
x = grid[:, :, :, 0]
|
|
y = grid[:, :, :, 1]
|
|
|
|
if align_corners:
|
|
x = ((x + 1) / 2) * (w - 1)
|
|
y = ((y + 1) / 2) * (h - 1)
|
|
else:
|
|
x = ((x + 1) * w - 1) / 2
|
|
y = ((y + 1) * h - 1) / 2
|
|
|
|
x = x.view(n, -1)
|
|
y = y.view(n, -1)
|
|
|
|
x0 = torch.floor(x).long()
|
|
y0 = torch.floor(y).long()
|
|
x1 = x0 + 1
|
|
y1 = y0 + 1
|
|
|
|
wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
|
|
wb = ((x1 - x) * (y - y0)).unsqueeze(1)
|
|
wc = ((x - x0) * (y1 - y)).unsqueeze(1)
|
|
wd = ((x - x0) * (y - y0)).unsqueeze(1)
|
|
|
|
# Apply default for grid_sample function zero padding
|
|
im_padded = torch.nn.functional.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0)
|
|
padded_h = h + 2
|
|
padded_w = w + 2
|
|
# save points positions after padding
|
|
x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1
|
|
|
|
# Clip coordinates to padded image size
|
|
x0 = torch.where(x0 < 0, torch.tensor(0, device=im.device), x0)
|
|
x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1, device=im.device), x0)
|
|
x1 = torch.where(x1 < 0, torch.tensor(0, device=im.device), x1)
|
|
x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1, device=im.device), x1)
|
|
y0 = torch.where(y0 < 0, torch.tensor(0, device=im.device), y0)
|
|
y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1, device=im.device), y0)
|
|
y1 = torch.where(y1 < 0, torch.tensor(0, device=im.device), y1)
|
|
y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1, device=im.device), y1)
|
|
|
|
im_padded = im_padded.view(n, c, -1)
|
|
|
|
x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
|
|
x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
|
|
x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
|
|
x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
|
|
|
|
Ia = torch.gather(im_padded, 2, x0_y0)
|
|
Ib = torch.gather(im_padded, 2, x0_y1)
|
|
Ic = torch.gather(im_padded, 2, x1_y0)
|
|
Id = torch.gather(im_padded, 2, x1_y1)
|
|
|
|
return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)
|
|
|