Initial commit

This commit is contained in:
Ibai 2022-04-08 00:18:27 +09:00
commit fc524edd76
29 changed files with 1279 additions and 0 deletions

2
.gitattributes vendored Normal file
View File

@ -0,0 +1,2 @@
# Auto detect text files and perform LF normalization
* text=auto

152
.gitignore vendored Normal file
View File

@ -0,0 +1,152 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintainted in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/

2
README.md Normal file
View File

@ -0,0 +1,2 @@
# CREStereo-Pytorch
Non-official Pytorch implementation of the CREStereo(CVPR 2022 Oral).

View File

@ -0,0 +1,44 @@
import pickle
import numpy as np
import megengine as mge
import torch
import torch.nn.functional as F
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
""" Wrapper for grid_sample, uses pixel coordinates """
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1,1], dim=-1)
xgrid = 2*xgrid/(W-1) - 1
ygrid = 2*ygrid/(H-1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
def test_bilinear_sampler():
# Getting back the megengine objects:
with open('test_data/bilinear_sampler_test.pickle', 'rb') as f:
right_feature_prev, coords, right_feature = pickle.load(f)
right_feature_prev = torch.tensor(right_feature_prev.numpy())
coords = torch.tensor(coords.numpy())
right_feature = right_feature.numpy()
# Test Pytorch
right_feature_pytorch = bilinear_sampler(right_feature_prev, coords).numpy()
error = np.mean(right_feature_pytorch-right_feature)
print(f"test_coords_grid - Avg. Error: {error}, \n \
Original shape: {coords.numpy().shape},\n \
Obtained shape: {right_feature_pytorch.shape}, Expected shape: {right_feature.shape}")
if __name__ == '__main__':
test_bilinear_sampler()

View File

@ -0,0 +1,29 @@
import pickle
import numpy as np
import megengine as mge
import torch
import torch.nn.functional as F
def coords_grid(batch, ht, wd, device):
coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device), indexing='ij')
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)
def test_coords_grid():
# Getting back the megengine objects:
with open('test_data/coords_grid_test.pickle', 'rb') as f:
batch, ht, wd, coords = pickle.load(f)
coords = coords.numpy()
# Test Pytorch
coords_pytorch = coords_grid(batch, ht, wd, 'cpu').numpy()
error = np.mean(coords_pytorch-coords)
print(f"test_coords_grid - Avg. Error: {error}, \n \
Obtained shape: {coords_pytorch.shape}, Expected shape: {coords.shape}")
if __name__ == '__main__':
test_coords_grid()

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,51 @@
import pickle
import numpy as np
import megengine as mge
import torch
import torch.nn.functional as F
def manual_pad(x, pady, padx):
pad = (padx, padx, pady, pady)
return F.pad(torch.tensor(x), pad, "replicate")
def test_pad_1_1():
# Getting back the megengine objects:
with open('test_data/manual_pad_test1_1.pickle', 'rb') as f:
right_feature, pady, padx, right_pad = pickle.load(f)
right_feature = right_feature.numpy()
right_pad = right_pad.numpy()
# Test Pytorch
right_pad_pytorch = manual_pad(right_feature, pady, padx).numpy()
error = np.mean(right_pad_pytorch-right_pad)
print(f"test_pad_1_1 - Avg. Error: {error}, \n \
Orig. shape: {right_feature.shape}, \n \
Padded shape: {right_pad_pytorch.shape}, Expected shape: {right_pad.shape}")
def test_pad_0_4():
# Getting back the megengine objects:
with open('test_data/manual_pad_test0_4.pickle', 'rb') as f:
right_feature, pady, padx, right_pad = pickle.load(f)
right_feature = right_feature.numpy()
right_pad = right_pad.numpy()
# Test Pytorch
right_pad_pytorch = manual_pad(right_feature, pady, padx).numpy()
error = np.mean(right_pad_pytorch-right_pad)
print(f"test_pad_0_4 - Avg. Error: {error}, \n \
Orig. shape: {right_feature.shape}, \n \
Padded shape: {right_pad_pytorch.shape}, Expected shape: {right_pad.shape}")
if __name__ == '__main__':
test_pad_1_1()
test_pad_0_4()

View File

@ -0,0 +1,30 @@
import pickle
import numpy as np
import megengine as mge
import torch
import torch.nn.functional as F
def test_meshgrid():
# Getting back the megengine objects:
with open('test_data/meshgrid_np_test.pkl', 'rb') as f:
rx, dilatex, ry, dilatey, x_grid, y_grid = pickle.load(f)
x_grid = x_grid.numpy()
y_grid = y_grid.numpy()
# Test Pytorch
x_grid_pytorch, y_grid_pytorch = torch.meshgrid(torch.arange(-rx, rx + 1, dilatex, device='cpu'),
torch.arange(-ry, ry + 1, dilatey, device='cpu'), indexing='xy')
error_x = np.mean(x_grid_pytorch.numpy()-x_grid)
error_y = np.mean(y_grid_pytorch.numpy()-y_grid)
print(f"test_meshgrid (X) - Avg. Error: {error_x}, \n \
Obtained shape: {x_grid_pytorch.numpy().shape}, Expected shape: {x_grid.shape}")
print(f"test_meshgrid (Y) - Avg. Error: {error_y}, \n \
Obtained shape: {y_grid_pytorch.numpy().shape}, Expected shape: {y_grid.shape}")
if __name__ == '__main__':
test_meshgrid()

View File

@ -0,0 +1,31 @@
import pickle
import numpy as np
import megengine as mge
import torch
import torch.nn.functional as F
def test_offset():
# Getting back the megengine objects:
with open('test_data/offset_test.pkl', 'rb') as f:
x_grid, y_grid, reshape_shape, transpose_order, expand_size, repeat_size, repeat_axis, offsets = pickle.load(f)
x_grid = torch.tensor(x_grid.numpy())
y_grid = torch.tensor(y_grid.numpy())
offsets_mge = offsets.numpy()
N = repeat_size
# Test Pytorch
offsets = torch.stack((x_grid, y_grid))
offsets = offsets.reshape(2, -1).permute(1, 0)
for d in sorted((0, 2, 3)):
offsets = offsets.unsqueeze(d)
offsets = offsets.repeat_interleave(N, dim=0)
error = np.mean(offsets.numpy()-offsets_mge)
print(f"test_offset - Avg. Error: {error}, \n \
Obtained shape: {offsets.numpy().shape}, Expected shape: {offsets_mge.shape}")
if __name__ == '__main__':
test_offset()

View File

@ -0,0 +1,47 @@
import pickle
import numpy as np
import megengine as mge
import torch
import torch.nn.functional as F
def test_split():
# Getting back the megengine objects:
with open('test_data/split_test.pkl', 'rb') as f:
left_feature, size, axis, lefts = pickle.load(f)
left_feature = torch.tensor(left_feature.numpy())
# Test Pytorch
lefts_pytorch = torch.split(left_feature, left_feature.shape[axis]//size, dim=axis)
for i, (left_pytorch, left) in enumerate(zip(lefts_pytorch, lefts)):
error = np.mean(left_pytorch.numpy()-left.numpy())
print(f"test_split {i} - Avg. Error: {error}, \n \
Obtained shape: {left_pytorch.numpy().shape}, Expected shape: {left.numpy().shape}\n")
def test_split_list():
# Getting back the megengine objects:
with open('test_data/split_test_list.pkl', 'rb') as f:
fmap1, size, axis, net, inp = pickle.load(f)
fmap1 = torch.tensor(fmap1.numpy())
net = net.numpy()
inp = inp.numpy()
# Test Pytorch
net_pytorch, inp_pytorch = torch.split(fmap1, [size[0],size[0]], dim=axis)
error_net = np.mean(net_pytorch.numpy()-net)
error_inp = np.mean(inp_pytorch.numpy()-inp)
print(f"test_split_list (net) - Avg. Error: {error_net}, \n \
Obtained shape: {net_pytorch.numpy().shape}, Expected shape: {net.shape}\n")
print(f"test_split_list (inp) - Avg. Error: {error_inp}, \n \
Obtained shape: {inp_pytorch.numpy().shape}, Expected shape: {inp.shape}\n")
if __name__ == '__main__':
test_split()
test_split_list()

1
nets/__init__.py Normal file
View File

@ -0,0 +1 @@
from .crestereo import CREStereo as Model

View File

@ -0,0 +1,2 @@
from .transformer import LocalFeatureTransformer
from .position_encoding import PositionEncodingSine

View File

@ -0,0 +1,81 @@
"""
Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
"""
import torch
from torch.nn import Module, Dropout
def elu_feature_map(x):
return torch.nn.functional.elu(x) + 1
class LinearAttention(Module):
def __init__(self, eps=1e-6):
super().__init__()
self.feature_map = elu_feature_map
self.eps = eps
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
""" Multi-Head linear attention proposed in "Transformers are RNNs"
Args:
queries: [N, L, H, D]
keys: [N, S, H, D]
values: [N, S, H, D]
q_mask: [N, L]
kv_mask: [N, S]
Returns:
queried_values: (N, L, H, D)
"""
Q = self.feature_map(queries)
K = self.feature_map(keys)
# set padded position to zero
if q_mask is not None:
Q = Q * q_mask[:, :, None, None]
if kv_mask is not None:
K = K * kv_mask[:, :, None, None]
values = values * kv_mask[:, :, None, None]
v_length = values.size(1)
values = values / v_length # prevent fp16 overflow
KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
return queried_values.contiguous()
class FullAttention(Module):
def __init__(self, use_dropout=False, attention_dropout=0.1):
super().__init__()
self.use_dropout = use_dropout
self.dropout = Dropout(attention_dropout)
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
""" Multi-head scaled dot-product attention, a.k.a full attention.
Args:
queries: [N, L, H, D]
keys: [N, S, H, D]
values: [N, S, H, D]
q_mask: [N, L]
kv_mask: [N, S]
Returns:
queried_values: (N, L, H, D)
"""
# Compute the unnormalized attention and apply the masks
QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
if kv_mask is not None:
QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf'))
# Compute the attention and the weighted average
softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
A = torch.softmax(softmax_temp * QK, dim=2)
if self.use_dropout:
A = self.dropout(A)
queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
return queried_values.contiguous()

View File

@ -0,0 +1,42 @@
import math
import torch
from torch import nn
class PositionEncodingSine(nn.Module):
"""
This is a sinusoidal position encoding that generalized to 2-dimensional images
"""
def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True):
"""
Args:
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41),
the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact
on the final performance. For now, we keep both impls for backward compatability.
We will remove the buggy impl after re-training all variants of our released models.
"""
super().__init__()
pe = torch.zeros((d_model, *max_shape))
y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
if temp_bug_fix:
div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2)))
else: # a buggy implementation (for backward compatability only)
div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2))
div_term = div_term[:, None, None] # [C//4, 1, 1]
pe[0::4, :, :] = torch.sin(x_position * div_term)
pe[1::4, :, :] = torch.cos(x_position * div_term)
pe[2::4, :, :] = torch.sin(y_position * div_term)
pe[3::4, :, :] = torch.cos(y_position * div_term)
self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W]
def forward(self, x):
"""
Args:
x: [N, C, H, W]
"""
return x + self.pe[:, :, :x.size(2), :x.size(3)]

View File

@ -0,0 +1,100 @@
import copy
import torch
import torch.nn as nn
from .linear_attention import LinearAttention, FullAttention
#Ref: https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py
class LoFTREncoderLayer(nn.Module):
def __init__(self,
d_model,
nhead,
attention='linear'):
super(LoFTREncoderLayer, self).__init__()
self.dim = d_model // nhead
self.nhead = nhead
# multi-head attention
self.q_proj = nn.Linear(d_model, d_model, bias=False)
self.k_proj = nn.Linear(d_model, d_model, bias=False)
self.v_proj = nn.Linear(d_model, d_model, bias=False)
self.attention = LinearAttention() if attention == 'linear' else FullAttention()
self.merge = nn.Linear(d_model, d_model, bias=False)
# feed-forward network
self.mlp = nn.Sequential(
nn.Linear(d_model*2, d_model*2, bias=False),
nn.ReLU(True),
nn.Linear(d_model*2, d_model, bias=False),
)
# norm and dropout
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
def forward(self, x, source, x_mask=None, source_mask=None):
"""
Args:
x (torch.Tensor): [N, L, C]
source (torch.Tensor): [N, S, C]
x_mask (torch.Tensor): [N, L] (optional)
source_mask (torch.Tensor): [N, S] (optional)
"""
bs = x.size(0)
query, key, value = x, source, source
# multi-head attention
query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
message = self.norm1(message)
# feed-forward network
message = self.mlp(torch.cat([x, message], dim=2))
message = self.norm2(message)
return x + message
class LocalFeatureTransformer(nn.Module):
"""A Local Feature Transformer (LoFTR) module."""
def __init__(self, d_model, nhead, layer_names, attention):
super(LocalFeatureTransformer, self).__init__()
self.d_model = d_model
self.nhead = nhead
self.layer_names = layer_names
encoder_layer = LoFTREncoderLayer(d_model, nhead, attention)
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
self._reset_parameters()
def _reset_parameters(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, feat0, feat1, mask0=None, mask1=None):
"""
Args:
feat0 (torch.Tensor): [N, L, C]
feat1 (torch.Tensor): [N, S, C]
mask0 (torch.Tensor): [N, L] (optional)
mask1 (torch.Tensor): [N, S] (optional)
"""
assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
for layer, name in zip(self.layers, self.layer_names):
if name == 'self':
feat0 = layer(feat0, feat0, mask0, mask0)
feat1 = layer(feat1, feat1, mask1, mask1)
elif name == 'cross':
feat0 = layer(feat0, feat1, mask0, mask1)
feat1 = layer(feat1, feat0, mask1, mask0)
else:
raise KeyError
return feat0, feat1

146
nets/corr.py Normal file
View File

@ -0,0 +1,146 @@
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from .utils import bilinear_sampler, coords_grid, manual_pad
class AGCL:
"""
Implementation of Adaptive Group Correlation Layer (AGCL).
"""
def __init__(self, fmap1, fmap2, att=None):
self.fmap1 = fmap1
self.fmap2 = fmap2
self.att = att
self.coords = coords_grid(fmap1.shape[0], fmap1.shape[2], fmap1.shape[3], fmap1.device)
def __call__(self, flow, extra_offset, small_patch=False, iter_mode=False):
if iter_mode:
corr = self.corr_iter(self.fmap1, self.fmap2, flow, small_patch)
else:
corr = self.corr_att_offset(
self.fmap1, self.fmap2, flow, extra_offset, small_patch
)
return corr
def get_correlation(self, left_feature, right_feature, psize=(3, 3), dilate=(1, 1)):
N, C, H, W = left_feature.shape
di_y, di_x = dilate[0], dilate[1]
pady, padx = psize[0] // 2 * di_y, psize[1] // 2 * di_x
right_pad = manual_pad(right_feature, pady, padx)
corr_list = []
for h in range(0, pady * 2 + 1, di_y):
for w in range(0, padx * 2 + 1, di_x):
right_crop = right_pad[:, :, h : h + H, w : w + W]
assert right_crop.shape == left_feature.shape
corr = torch.mean(left_feature * right_crop, dim=1, keepdims=True)
corr_list.append(corr)
corr_final = torch.cat(corr_list, dim=1)
return corr_final
def corr_iter(self, left_feature, right_feature, flow, small_patch):
coords = self.coords + flow
coords = coords.permute(0, 2, 3, 1)
right_feature = bilinear_sampler(right_feature, coords)
if small_patch:
psize_list = [(3, 3), (3, 3), (3, 3), (3, 3)]
dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
else:
psize_list = [(1, 9), (1, 9), (1, 9), (1, 9)]
dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
N, C, H, W = left_feature.shape
lefts = torch.split(left_feature, left_feature.shape[1]//4, dim=1)
rights = torch.split(right_feature, right_feature.shape[1]//4, dim=1)
corrs = []
for i in range(len(psize_list)):
corr = self.get_correlation(
lefts[i], rights[i], psize_list[i], dilate_list[i]
)
corrs.append(corr)
final_corr = torch.cat(corrs, dim=1)
return final_corr
def corr_att_offset(
self, left_feature, right_feature, flow, extra_offset, small_patch
):
N, C, H, W = left_feature.shape
if self.att is not None:
left_feature = left_feature.permute(0, 2, 3, 1).reshape(N, H * W, C) # 'n c h w -> n (h w) c'
right_feature = right_feature.permute(0, 2, 3, 1).reshape(N, H * W, C) # 'n c h w -> n (h w) c'
# 'n (h w) c -> n c h w'
left_feature, right_feature = [
x.reshape(N, H, W, C).permute(0, 3, 1, 2)
for x in [left_feature, right_feature]
]
lefts = torch.split(left_feature, left_feature.shape[1]//4, dim=1)
rights = torch.split(right_feature, right_feature.shape[1]//4, dim=1)
C = C // 4
if small_patch:
psize_list = [(3, 3), (3, 3), (3, 3), (3, 3)]
dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
else:
psize_list = [(1, 9), (1, 9), (1, 9), (1, 9)]
dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
search_num = 9
extra_offset = extra_offset.reshape(N, search_num, 2, H, W).permute(0, 1, 3, 4, 2) # [N, search_num, 1, 1, 2]
corrs = []
for i in range(len(psize_list)):
left_feature, right_feature = lefts[i], rights[i]
psize, dilate = psize_list[i], dilate_list[i]
psizey, psizex = psize[0], psize[1]
dilatey, dilatex = dilate[0], dilate[1]
ry = psizey // 2 * dilatey
rx = psizex // 2 * dilatex
x_grid, y_grid = torch.meshgrid(torch.arange(-rx, rx + 1, dilatex, device=self.fmap1.device),
torch.arange(-ry, ry + 1, dilatey, device=self.fmap1.device), indexing='xy')
offsets = torch.stack((x_grid, y_grid))
offsets = offsets.reshape(2, -1).permute(1, 0)
for d in sorted((0, 2, 3)):
offsets = offsets.unsqueeze(d)
offsets = offsets.repeat_interleave(N, dim=0)
offsets = offsets + extra_offset
coords = self.coords + flow # [N, 2, H, W]
coords = coords.permute(0, 2, 3, 1) # [N, H, W, 2]
coords = torch.unsqueeze(coords, 1) + offsets
coords = coords.reshape(N, -1, W, 2) # [N, search_num*H, W, 2]
right_feature = bilinear_sampler(
right_feature, coords
) # [N, C, search_num*H, W]
right_feature = right_feature.reshape(N, C, -1, H, W) # [N, C, search_num, H, W]
left_feature = left_feature.unsqueeze(2).repeat_interleave(right_feature.shape[2], dim=2)
corr = torch.mean(left_feature * right_feature, dim=1)
corrs.append(corr)
final_corr = torch.cat(corrs, dim=1)
return final_corr

258
nets/crestereo.py Normal file
View File

@ -0,0 +1,258 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from .update import BasicUpdateBlock
from .extractor import BasicEncoder
from .corr import AGCL
from .attention import PositionEncodingSine, LocalFeatureTransformer
try:
autocast = torch.cuda.amp.autocast
except:
# dummy autocast for PyTorch < 1.6
class autocast:
def __init__(self, enabled):
pass
def __enter__(self):
pass
def __exit__(self, *args):
pass
#Ref: https://github.com/princeton-vl/RAFT/blob/master/core/raft.py
class CREStereo(nn.Module):
def __init__(self, max_disp=192, mixed_precision=False, test_mode=False):
super(CREStereo, self).__init__()
self.max_flow = max_disp
self.mixed_precision = mixed_precision
self.test_mode = test_mode
self.hidden_dim = 128
self.context_dim = 128
self.dropout = 0
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=self.dropout)
self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4)
# loftr
self.self_att_fn = LocalFeatureTransformer(
d_model=256, nhead=8, layer_names=["self"] * 1, attention="linear"
)
self.cross_att_fn = LocalFeatureTransformer(
d_model=256, nhead=8, layer_names=["cross"] * 1, attention="linear"
)
# adaptive search
self.search_num = 9
self.conv_offset_16 = nn.Conv2d(
256, self.search_num * 2, kernel_size=3, stride=1, padding=1
)
self.conv_offset_8 = nn.Conv2d(
256, self.search_num * 2, kernel_size=3, stride=1, padding=1
)
self.range_16 = 1
self.range_8 = 1
def freeze_bn(self):
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
def convex_upsample(self, flow, mask, rate=4):
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
N, _, H, W = flow.shape
# print(flow.shape, mask.shape, rate)
mask = mask.view(N, 1, 9, rate, rate, H, W)
mask = torch.softmax(mask, dim=2)
up_flow = F.unfold(rate * flow, [3,3], padding=1)
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
up_flow = torch.sum(mask * up_flow, dim=2)
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
return up_flow.reshape(N, 2, rate*H, rate*W)
def zero_init(self, fmap):
N, C, H, W = fmap.shape
_x = torch.zeros([N, 1, H, W], dtype=torch.float32)
_y = torch.zeros([N, 1, H, W], dtype=torch.float32)
zero_flow = torch.cat((_x, _y), dim=1).to(fmap.device)
return zero_flow
def forward(self, image1, image2, iters=10, flow_init=None, upsample=True, test_mode=False):
""" Estimate optical flow between pair of frames """
image1 = 2 * (image1 / 255.0) - 1.0
image2 = 2 * (image2 / 255.0) - 1.0
image1 = image1.contiguous()
image2 = image2.contiguous()
hdim = self.hidden_dim
cdim = self.context_dim
# run the feature network
with autocast(enabled=self.mixed_precision):
fmap1, fmap2 = self.fnet([image1, image2])
fmap1 = fmap1.float()
fmap2 = fmap2.float()
with autocast(enabled=self.mixed_precision):
# 1/4 -> 1/8
# feature
fmap1_dw8 = F.avg_pool2d(fmap1, 2, stride=2)
fmap2_dw8 = F.avg_pool2d(fmap2, 2, stride=2)
# offset
offset_dw8 = self.conv_offset_8(fmap1_dw8)
offset_dw8 = self.range_8 * (torch.sigmoid(offset_dw8) - 0.5) * 2.0
# context
net, inp = torch.split(fmap1, [hdim,hdim], dim=1)
net = torch.tanh(net)
inp = F.relu(inp)
net_dw8 = F.avg_pool2d(net, 2, stride=2)
inp_dw8 = F.avg_pool2d(inp, 2, stride=2)
# 1/4 -> 1/16
# feature
fmap1_dw16 = F.avg_pool2d(fmap1, 4, stride=4)
fmap2_dw16 = F.avg_pool2d(fmap2, 4, stride=4)
offset_dw16 = self.conv_offset_16(fmap1_dw16)
offset_dw16 = self.range_16 * (torch.sigmoid(offset_dw16) - 0.5) * 2.0
# context
net_dw16 = F.avg_pool2d(net, 4, stride=4)
inp_dw16 = F.avg_pool2d(inp, 4, stride=4)
# positional encoding and self-attention
pos_encoding_fn_small = PositionEncodingSine(
d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
)
# 'n c h w -> n (h w) c'
x_tmp = pos_encoding_fn_small(fmap1_dw16)
fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
# 'n c h w -> n (h w) c'
x_tmp = pos_encoding_fn_small(fmap2_dw16)
fmap2_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
fmap1_dw16, fmap2_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16)
fmap1_dw16, fmap2_dw16 = [
x.reshape(x.shape[0], image1.shape[2] // 16, -1, x.shape[2]).permute(0, 3, 1, 2)
for x in [fmap1_dw16, fmap2_dw16]
]
corr_fn = AGCL(fmap1, fmap2)
corr_fn_dw8 = AGCL(fmap1_dw8, fmap2_dw8)
corr_fn_att_dw16 = AGCL(fmap1_dw16, fmap2_dw16, att=self.cross_att_fn)
# Cascaded refinement (1/16 + 1/8 + 1/4)
predictions = []
flow = None
flow_up = None
if flow_init is not None:
scale = fmap1.shape[2] / flow_init.shape[2]
flow = -scale * F.interpolate(
flow_init,
size=(fmap1.shape[2], fmap1.shape[3]),
mode="bilinear",
align_corners=True,
)
else:
# zero initialization
flow_dw16 = self.zero_init(fmap1_dw16)
# Recurrent Update Module
# RUM: 1/16
for itr in range(iters // 2):
if itr % 2 == 0:
small_patch = False
else:
small_patch = True
flow_dw16 = flow_dw16.detach()
out_corrs = corr_fn_att_dw16(
flow_dw16, offset_dw16, small_patch=small_patch
)
with autocast(enabled=self.mixed_precision):
net_dw16, up_mask, delta_flow = self.update_block(
net_dw16, inp_dw16, out_corrs, flow_dw16
)
flow_dw16 = flow_dw16 + delta_flow
flow = self.convex_upsample(flow_dw16, up_mask, rate=4)
flow_up = -4 * F.interpolate(
flow,
size=(4 * flow.shape[2], 4 * flow.shape[3]),
mode="bilinear",
align_corners=True,
)
predictions.append(flow_up)
scale = fmap1_dw8.shape[2] / flow.shape[2]
flow_dw8 = -scale * F.interpolate(
flow,
size=(fmap1_dw8.shape[2], fmap1_dw8.shape[3]),
mode="bilinear",
align_corners=True,
)
# RUM: 1/8
for itr in range(iters // 2):
if itr % 2 == 0:
small_patch = False
else:
small_patch = True
flow_dw8 = flow_dw8.detach()
out_corrs = corr_fn_dw8(flow_dw8, offset_dw8, small_patch=small_patch)
with autocast(enabled=self.mixed_precision):
net_dw8, up_mask, delta_flow = self.update_block(
net_dw8, inp_dw8, out_corrs, flow_dw8
)
flow_dw8 = flow_dw8 + delta_flow
flow = self.convex_upsample(flow_dw8, up_mask, rate=4)
flow_up = -2 * F.interpolate(
flow,
size=(2 * flow.shape[2], 2 * flow.shape[3]),
mode="bilinear",
align_corners=True,
)
predictions.append(flow_up)
scale = fmap1.shape[2] / flow.shape[2]
flow = -scale * F.interpolate(
flow,
size=(fmap1.shape[2], fmap1.shape[3]),
mode="bilinear",
align_corners=True,
)
# RUM: 1/4
for itr in range(iters):
if itr % 2 == 0:
small_patch = False
else:
small_patch = True
flow = flow.detach()
out_corrs = corr_fn(flow, None, small_patch=small_patch, iter_mode=True)
with autocast(enabled=self.mixed_precision):
net, up_mask, delta_flow = self.update_block(net, inp, out_corrs, flow)
flow = flow + delta_flow
flow_up = -self.convex_upsample(flow, up_mask, rate=4)
predictions.append(flow_up)
if self.test_mode:
return flow_up
return predictions

123
nets/extractor.py Normal file
View File

@ -0,0 +1,123 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
# Ref: https://github.com/princeton-vl/RAFT/blob/master/core/extractor.py
class ResidualBlock(nn.Module):
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
self.relu = nn.ReLU(inplace=True)
num_groups = planes // 8
if norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
elif norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(planes)
self.norm2 = nn.BatchNorm2d(planes)
self.norm3 = nn.BatchNorm2d(planes)
elif norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(planes)
self.norm2 = nn.InstanceNorm2d(planes)
self.norm3 = nn.InstanceNorm2d(planes)
elif norm_fn == 'none':
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
self.norm3 = nn.Sequential()
self.downsample = nn.Sequential(
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
x = self.downsample(x)
return self.relu(x+y)
class BasicEncoder(nn.Module):
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
super(BasicEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == 'group':
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
elif self.norm_fn == 'batch':
self.norm1 = nn.BatchNorm2d(64)
elif self.norm_fn == 'instance':
self.norm1 = nn.InstanceNorm2d(64)
elif self.norm_fn == 'none':
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
self.relu1 = nn.ReLU(inplace=True)
self.in_planes = 64
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(96, stride=2)
self.layer3 = self._make_layer(128, stride=1)
# output convolution
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
self.dropout = None
if dropout > 0:
self.dropout = nn.Dropout2d(p=dropout)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
if m.weight is not None:
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = (layer1, layer2)
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
if self.dropout is not None:
x = self.dropout(x)
if is_list:
x = torch.split(x, x.shape[0]//2, dim=0)
return x

91
nets/update.py Normal file
View File

@ -0,0 +1,91 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
#Ref: https://github.com/princeton-vl/RAFT/blob/master/core/update.py
class FlowHead(nn.Module):
def __init__(self, input_dim=128, hidden_dim=256):
super(FlowHead, self).__init__()
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class SepConvGRU(nn.Module):
def __init__(self, hidden_dim=128, input_dim=192+128):
super(SepConvGRU, self).__init__()
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
def forward(self, h, x):
# horizontal
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz1(hx))
r = torch.sigmoid(self.convr1(hx))
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
# vertical
hx = torch.cat([h, x], dim=1)
z = torch.sigmoid(self.convz2(hx))
r = torch.sigmoid(self.convr2(hx))
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
h = (1-z) * h + z * q
return h
class BasicMotionEncoder(nn.Module):
def __init__(self, cor_planes):
super(BasicMotionEncoder, self).__init__()
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
def forward(self, flow, corr):
cor = F.relu(self.convc1(corr))
cor = F.relu(self.convc2(cor))
flo = F.relu(self.convf1(flow))
flo = F.relu(self.convf2(flo))
cor_flo = torch.cat([cor, flo], dim=1)
out = F.relu(self.conv(cor_flo))
return torch.cat([out, flow], dim=1)
class BasicUpdateBlock(nn.Module):
def __init__(self, hidden_dim, cor_planes, mask_size=8):
super(BasicUpdateBlock, self).__init__()
self.encoder = BasicMotionEncoder(cor_planes)
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2d(128, 256, 3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, mask_size**2 *9, 1, padding=0))
def forward(self, net, inp, corr, flow, upsample=True):
# print(inp.shape, corr.shape, flow.shape)
motion_features = self.encoder(flow, corr)
# print(motion_features.shape, inp.shape)
inp = torch.cat((inp, motion_features), dim=1)
net = self.gru(net, inp)
delta_flow = self.flow_head(net)
# scale mask to balence gradients
mask = .25 * self.mask(net)
return net, mask, delta_flow

1
nets/utils/__init__.py Normal file
View File

@ -0,0 +1 @@
from .utils import bilinear_sampler, coords_grid, manual_pad

31
nets/utils/utils.py Normal file
View File

@ -0,0 +1,31 @@
import torch
import torch.nn.functional as F
import numpy as np
#Ref: https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
""" Wrapper for grid_sample, uses pixel coordinates """
H, W = img.shape[-2:]
xgrid, ygrid = coords.split([1,1], dim=-1)
xgrid = 2*xgrid/(W-1) - 1
ygrid = 2*ygrid/(H-1) - 1
grid = torch.cat([xgrid, ygrid], dim=-1)
img = F.grid_sample(img, grid, align_corners=True)
if mask:
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
return img, mask.float()
return img
def coords_grid(batch, ht, wd, device):
coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device), indexing='ij')
coords = torch.stack(coords[::-1], dim=0).float()
return coords[None].repeat(batch, 1, 1, 1)
def manual_pad(x, pady, padx):
pad = (padx, padx, pady, pady)
return F.pad(x.clone().detach(), pad, "replicate")

15
test_model.py Normal file
View File

@ -0,0 +1,15 @@
import torch
from torchsummary import summary
import numpy as np
from nets import Model
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
model.eval()
t1 = torch.rand(1, 3, 480, 640)
t2 = torch.rand(1, 3, 480, 640)
output = model(t1,t2)
print(output.shape)