CREStereo Repository for the 'Towards accurate and robust depth estimation' project
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

109 lines
4.0 KiB

3 years ago
import torch
import torch.nn.functional as F
import numpy as np
#Ref: https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
""" Wrapper for grid_sample, uses pixel coordinates """
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1,1], dim=-1)
xgrid = 2*xgrid/(W-1) - 1
ygrid = 2*ygrid/(H-1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
# img = F.grid_sample(img, grid, align_corners=True)
img = bilinear_grid_sample(img, grid, align_corners=True)
3 years ago
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
def coords_grid(batch, ht, wd, device):
coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device), indexing='ij')
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)
def manual_pad(x, pady, padx):
pad = (padx, padx, pady, pady)
return F.pad(x.clone().detach(), pad, "replicate")
# Ref: https://zenn.dev/pinto0309/scraps/7d4032067d0160
def bilinear_grid_sample(im, grid, align_corners=False):
"""Given an input and a flow-field grid, computes the output using input
values and pixel locations from grid. Supported only bilinear interpolation
method to sample the input pixels.
Args:
im (torch.Tensor): Input feature map, shape (N, C, H, W)
grid (torch.Tensor): Point coordinates, shape (N, Hg, Wg, 2)
align_corners {bool}: If set to True, the extrema (-1 and 1) are
considered as referring to the center points of the inputs
corner pixels. If set to False, they are instead considered as
referring to the corner points of the inputs corner pixels,
making the sampling more resolution agnostic.
Returns:
torch.Tensor: A tensor with sampled points, shape (N, C, Hg, Wg)
"""
n, c, h, w = im.shape
gn, gh, gw, _ = grid.shape
assert n == gn
x = grid[:, :, :, 0]
y = grid[:, :, :, 1]
if align_corners:
x = ((x + 1) / 2) * (w - 1)
y = ((y + 1) / 2) * (h - 1)
else:
x = ((x + 1) * w - 1) / 2
y = ((y + 1) * h - 1) / 2
x = x.view(n, -1)
y = y.view(n, -1)
x0 = torch.floor(x).long()
y0 = torch.floor(y).long()
x1 = x0 + 1
y1 = y0 + 1
wa = ((x1 - x) * (y1 - y)).unsqueeze(1)
wb = ((x1 - x) * (y - y0)).unsqueeze(1)
wc = ((x - x0) * (y1 - y)).unsqueeze(1)
wd = ((x - x0) * (y - y0)).unsqueeze(1)
# Apply default for grid_sample function zero padding
im_padded = torch.nn.functional.pad(im, pad=[1, 1, 1, 1], mode='constant', value=0)
padded_h = h + 2
padded_w = w + 2
# save points positions after padding
x0, x1, y0, y1 = x0 + 1, x1 + 1, y0 + 1, y1 + 1
# Clip coordinates to padded image size
x0 = torch.where(x0 < 0, torch.tensor(0, device=im.device), x0)
x0 = torch.where(x0 > padded_w - 1, torch.tensor(padded_w - 1, device=im.device), x0)
x1 = torch.where(x1 < 0, torch.tensor(0, device=im.device), x1)
x1 = torch.where(x1 > padded_w - 1, torch.tensor(padded_w - 1, device=im.device), x1)
y0 = torch.where(y0 < 0, torch.tensor(0, device=im.device), y0)
y0 = torch.where(y0 > padded_h - 1, torch.tensor(padded_h - 1, device=im.device), y0)
y1 = torch.where(y1 < 0, torch.tensor(0, device=im.device), y1)
y1 = torch.where(y1 > padded_h - 1, torch.tensor(padded_h - 1, device=im.device), y1)
im_padded = im_padded.view(n, c, -1)
x0_y0 = (x0 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
x0_y1 = (x0 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
x1_y0 = (x1 + y0 * padded_w).unsqueeze(1).expand(-1, c, -1)
x1_y1 = (x1 + y1 * padded_w).unsqueeze(1).expand(-1, c, -1)
Ia = torch.gather(im_padded, 2, x0_y0)
Ib = torch.gather(im_padded, 2, x0_y1)
Ic = torch.gather(im_padded, 2, x1_y0)
Id = torch.gather(im_padded, 2, x1_y1)
return (Ia * wa + Ib * wb + Ic * wc + Id * wd).reshape(n, c, gh, gw)