Initial commit
This commit is contained in:
commit
fc524edd76
2
.gitattributes
vendored
Normal file
2
.gitattributes
vendored
Normal file
@ -0,0 +1,2 @@
|
||||
# Auto detect text files and perform LF normalization
|
||||
* text=auto
|
152
.gitignore
vendored
Normal file
152
.gitignore
vendored
Normal file
@ -0,0 +1,152 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintainted in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
2
README.md
Normal file
2
README.md
Normal file
@ -0,0 +1,2 @@
|
||||
# CREStereo-Pytorch
|
||||
Non-official Pytorch implementation of the CREStereo(CVPR 2022 Oral).
|
44
function_convertion_tests/test_bilinear_sampler.py
Normal file
44
function_convertion_tests/test_bilinear_sampler.py
Normal file
@ -0,0 +1,44 @@
|
||||
import pickle
|
||||
import numpy as np
|
||||
import megengine as mge
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
|
||||
|
||||
""" Wrapper for grid_sample, uses pixel coordinates """
|
||||
H, W = img.shape[-2:]
|
||||
xgrid, ygrid = coords.split([1,1], dim=-1)
|
||||
xgrid = 2*xgrid/(W-1) - 1
|
||||
ygrid = 2*ygrid/(H-1) - 1
|
||||
|
||||
grid = torch.cat([xgrid, ygrid], dim=-1)
|
||||
img = F.grid_sample(img, grid, align_corners=True)
|
||||
|
||||
if mask:
|
||||
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
||||
return img, mask.float()
|
||||
|
||||
return img
|
||||
|
||||
def test_bilinear_sampler():
|
||||
# Getting back the megengine objects:
|
||||
with open('test_data/bilinear_sampler_test.pickle', 'rb') as f:
|
||||
right_feature_prev, coords, right_feature = pickle.load(f)
|
||||
|
||||
right_feature_prev = torch.tensor(right_feature_prev.numpy())
|
||||
coords = torch.tensor(coords.numpy())
|
||||
right_feature = right_feature.numpy()
|
||||
|
||||
# Test Pytorch
|
||||
right_feature_pytorch = bilinear_sampler(right_feature_prev, coords).numpy()
|
||||
|
||||
error = np.mean(right_feature_pytorch-right_feature)
|
||||
print(f"test_coords_grid - Avg. Error: {error}, \n \
|
||||
Original shape: {coords.numpy().shape},\n \
|
||||
Obtained shape: {right_feature_pytorch.shape}, Expected shape: {right_feature.shape}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
test_bilinear_sampler()
|
29
function_convertion_tests/test_coords_grid.py
Normal file
29
function_convertion_tests/test_coords_grid.py
Normal file
@ -0,0 +1,29 @@
|
||||
import pickle
|
||||
import numpy as np
|
||||
import megengine as mge
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
def coords_grid(batch, ht, wd, device):
|
||||
coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device), indexing='ij')
|
||||
coords = torch.stack(coords[::-1], dim=0).float()
|
||||
return coords[None].repeat(batch, 1, 1, 1)
|
||||
|
||||
def test_coords_grid():
|
||||
# Getting back the megengine objects:
|
||||
with open('test_data/coords_grid_test.pickle', 'rb') as f:
|
||||
batch, ht, wd, coords = pickle.load(f)
|
||||
|
||||
coords = coords.numpy()
|
||||
|
||||
# Test Pytorch
|
||||
coords_pytorch = coords_grid(batch, ht, wd, 'cpu').numpy()
|
||||
|
||||
error = np.mean(coords_pytorch-coords)
|
||||
print(f"test_coords_grid - Avg. Error: {error}, \n \
|
||||
Obtained shape: {coords_pytorch.shape}, Expected shape: {coords.shape}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
test_coords_grid()
|
BIN
function_convertion_tests/test_data/bilinear_sampler_test.pickle
Normal file
BIN
function_convertion_tests/test_data/bilinear_sampler_test.pickle
Normal file
Binary file not shown.
BIN
function_convertion_tests/test_data/coords_grid_test.pickle
Normal file
BIN
function_convertion_tests/test_data/coords_grid_test.pickle
Normal file
Binary file not shown.
BIN
function_convertion_tests/test_data/manual_pad_test0_4.pickle
Normal file
BIN
function_convertion_tests/test_data/manual_pad_test0_4.pickle
Normal file
Binary file not shown.
BIN
function_convertion_tests/test_data/manual_pad_test1_1.pickle
Normal file
BIN
function_convertion_tests/test_data/manual_pad_test1_1.pickle
Normal file
Binary file not shown.
BIN
function_convertion_tests/test_data/meshgrid_np_test.pkl
Normal file
BIN
function_convertion_tests/test_data/meshgrid_np_test.pkl
Normal file
Binary file not shown.
BIN
function_convertion_tests/test_data/offset_test.pkl
Normal file
BIN
function_convertion_tests/test_data/offset_test.pkl
Normal file
Binary file not shown.
BIN
function_convertion_tests/test_data/split_test.pkl
Normal file
BIN
function_convertion_tests/test_data/split_test.pkl
Normal file
Binary file not shown.
BIN
function_convertion_tests/test_data/split_test_list.pkl
Normal file
BIN
function_convertion_tests/test_data/split_test_list.pkl
Normal file
Binary file not shown.
51
function_convertion_tests/test_manual_pad.py
Normal file
51
function_convertion_tests/test_manual_pad.py
Normal file
@ -0,0 +1,51 @@
|
||||
import pickle
|
||||
import numpy as np
|
||||
import megengine as mge
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
def manual_pad(x, pady, padx):
|
||||
|
||||
pad = (padx, padx, pady, pady)
|
||||
return F.pad(torch.tensor(x), pad, "replicate")
|
||||
|
||||
|
||||
def test_pad_1_1():
|
||||
# Getting back the megengine objects:
|
||||
with open('test_data/manual_pad_test1_1.pickle', 'rb') as f:
|
||||
right_feature, pady, padx, right_pad = pickle.load(f)
|
||||
|
||||
right_feature = right_feature.numpy()
|
||||
right_pad = right_pad.numpy()
|
||||
|
||||
# Test Pytorch
|
||||
right_pad_pytorch = manual_pad(right_feature, pady, padx).numpy()
|
||||
|
||||
error = np.mean(right_pad_pytorch-right_pad)
|
||||
print(f"test_pad_1_1 - Avg. Error: {error}, \n \
|
||||
Orig. shape: {right_feature.shape}, \n \
|
||||
Padded shape: {right_pad_pytorch.shape}, Expected shape: {right_pad.shape}")
|
||||
|
||||
def test_pad_0_4():
|
||||
# Getting back the megengine objects:
|
||||
with open('test_data/manual_pad_test0_4.pickle', 'rb') as f:
|
||||
right_feature, pady, padx, right_pad = pickle.load(f)
|
||||
|
||||
right_feature = right_feature.numpy()
|
||||
right_pad = right_pad.numpy()
|
||||
|
||||
# Test Pytorch
|
||||
right_pad_pytorch = manual_pad(right_feature, pady, padx).numpy()
|
||||
|
||||
error = np.mean(right_pad_pytorch-right_pad)
|
||||
print(f"test_pad_0_4 - Avg. Error: {error}, \n \
|
||||
Orig. shape: {right_feature.shape}, \n \
|
||||
Padded shape: {right_pad_pytorch.shape}, Expected shape: {right_pad.shape}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
test_pad_1_1()
|
||||
|
||||
test_pad_0_4()
|
30
function_convertion_tests/test_meshgrid.py
Normal file
30
function_convertion_tests/test_meshgrid.py
Normal file
@ -0,0 +1,30 @@
|
||||
import pickle
|
||||
import numpy as np
|
||||
import megengine as mge
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
def test_meshgrid():
|
||||
# Getting back the megengine objects:
|
||||
with open('test_data/meshgrid_np_test.pkl', 'rb') as f:
|
||||
rx, dilatex, ry, dilatey, x_grid, y_grid = pickle.load(f)
|
||||
|
||||
x_grid = x_grid.numpy()
|
||||
y_grid = y_grid.numpy()
|
||||
|
||||
# Test Pytorch
|
||||
x_grid_pytorch, y_grid_pytorch = torch.meshgrid(torch.arange(-rx, rx + 1, dilatex, device='cpu'),
|
||||
torch.arange(-ry, ry + 1, dilatey, device='cpu'), indexing='xy')
|
||||
|
||||
|
||||
error_x = np.mean(x_grid_pytorch.numpy()-x_grid)
|
||||
error_y = np.mean(y_grid_pytorch.numpy()-y_grid)
|
||||
print(f"test_meshgrid (X) - Avg. Error: {error_x}, \n \
|
||||
Obtained shape: {x_grid_pytorch.numpy().shape}, Expected shape: {x_grid.shape}")
|
||||
print(f"test_meshgrid (Y) - Avg. Error: {error_y}, \n \
|
||||
Obtained shape: {y_grid_pytorch.numpy().shape}, Expected shape: {y_grid.shape}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
test_meshgrid()
|
31
function_convertion_tests/test_offset.py
Normal file
31
function_convertion_tests/test_offset.py
Normal file
@ -0,0 +1,31 @@
|
||||
import pickle
|
||||
import numpy as np
|
||||
import megengine as mge
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
def test_offset():
|
||||
# Getting back the megengine objects:
|
||||
with open('test_data/offset_test.pkl', 'rb') as f:
|
||||
x_grid, y_grid, reshape_shape, transpose_order, expand_size, repeat_size, repeat_axis, offsets = pickle.load(f)
|
||||
|
||||
x_grid = torch.tensor(x_grid.numpy())
|
||||
y_grid = torch.tensor(y_grid.numpy())
|
||||
offsets_mge = offsets.numpy()
|
||||
N = repeat_size
|
||||
|
||||
# Test Pytorch
|
||||
offsets = torch.stack((x_grid, y_grid))
|
||||
offsets = offsets.reshape(2, -1).permute(1, 0)
|
||||
for d in sorted((0, 2, 3)):
|
||||
offsets = offsets.unsqueeze(d)
|
||||
offsets = offsets.repeat_interleave(N, dim=0)
|
||||
|
||||
error = np.mean(offsets.numpy()-offsets_mge)
|
||||
print(f"test_offset - Avg. Error: {error}, \n \
|
||||
Obtained shape: {offsets.numpy().shape}, Expected shape: {offsets_mge.shape}")
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
test_offset()
|
47
function_convertion_tests/test_split.py
Normal file
47
function_convertion_tests/test_split.py
Normal file
@ -0,0 +1,47 @@
|
||||
import pickle
|
||||
import numpy as np
|
||||
import megengine as mge
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
def test_split():
|
||||
# Getting back the megengine objects:
|
||||
with open('test_data/split_test.pkl', 'rb') as f:
|
||||
left_feature, size, axis, lefts = pickle.load(f)
|
||||
|
||||
left_feature = torch.tensor(left_feature.numpy())
|
||||
|
||||
# Test Pytorch
|
||||
lefts_pytorch = torch.split(left_feature, left_feature.shape[axis]//size, dim=axis)
|
||||
|
||||
for i, (left_pytorch, left) in enumerate(zip(lefts_pytorch, lefts)):
|
||||
|
||||
error = np.mean(left_pytorch.numpy()-left.numpy())
|
||||
print(f"test_split {i} - Avg. Error: {error}, \n \
|
||||
Obtained shape: {left_pytorch.numpy().shape}, Expected shape: {left.numpy().shape}\n")
|
||||
|
||||
def test_split_list():
|
||||
# Getting back the megengine objects:
|
||||
with open('test_data/split_test_list.pkl', 'rb') as f:
|
||||
fmap1, size, axis, net, inp = pickle.load(f)
|
||||
|
||||
fmap1 = torch.tensor(fmap1.numpy())
|
||||
net = net.numpy()
|
||||
inp = inp.numpy()
|
||||
|
||||
# Test Pytorch
|
||||
net_pytorch, inp_pytorch = torch.split(fmap1, [size[0],size[0]], dim=axis)
|
||||
|
||||
error_net = np.mean(net_pytorch.numpy()-net)
|
||||
error_inp = np.mean(inp_pytorch.numpy()-inp)
|
||||
print(f"test_split_list (net) - Avg. Error: {error_net}, \n \
|
||||
Obtained shape: {net_pytorch.numpy().shape}, Expected shape: {net.shape}\n")
|
||||
print(f"test_split_list (inp) - Avg. Error: {error_inp}, \n \
|
||||
Obtained shape: {inp_pytorch.numpy().shape}, Expected shape: {inp.shape}\n")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
test_split()
|
||||
test_split_list()
|
1
nets/__init__.py
Normal file
1
nets/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .crestereo import CREStereo as Model
|
2
nets/attention/__init__.py
Normal file
2
nets/attention/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
from .transformer import LocalFeatureTransformer
|
||||
from .position_encoding import PositionEncodingSine
|
81
nets/attention/linear_attention.py
Normal file
81
nets/attention/linear_attention.py
Normal file
@ -0,0 +1,81 @@
|
||||
"""
|
||||
Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention"
|
||||
Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch.nn import Module, Dropout
|
||||
|
||||
|
||||
def elu_feature_map(x):
|
||||
return torch.nn.functional.elu(x) + 1
|
||||
|
||||
|
||||
class LinearAttention(Module):
|
||||
def __init__(self, eps=1e-6):
|
||||
super().__init__()
|
||||
self.feature_map = elu_feature_map
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
|
||||
""" Multi-Head linear attention proposed in "Transformers are RNNs"
|
||||
Args:
|
||||
queries: [N, L, H, D]
|
||||
keys: [N, S, H, D]
|
||||
values: [N, S, H, D]
|
||||
q_mask: [N, L]
|
||||
kv_mask: [N, S]
|
||||
Returns:
|
||||
queried_values: (N, L, H, D)
|
||||
"""
|
||||
Q = self.feature_map(queries)
|
||||
K = self.feature_map(keys)
|
||||
|
||||
# set padded position to zero
|
||||
if q_mask is not None:
|
||||
Q = Q * q_mask[:, :, None, None]
|
||||
if kv_mask is not None:
|
||||
K = K * kv_mask[:, :, None, None]
|
||||
values = values * kv_mask[:, :, None, None]
|
||||
|
||||
v_length = values.size(1)
|
||||
values = values / v_length # prevent fp16 overflow
|
||||
KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V
|
||||
Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps)
|
||||
queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length
|
||||
|
||||
return queried_values.contiguous()
|
||||
|
||||
|
||||
class FullAttention(Module):
|
||||
def __init__(self, use_dropout=False, attention_dropout=0.1):
|
||||
super().__init__()
|
||||
self.use_dropout = use_dropout
|
||||
self.dropout = Dropout(attention_dropout)
|
||||
|
||||
def forward(self, queries, keys, values, q_mask=None, kv_mask=None):
|
||||
""" Multi-head scaled dot-product attention, a.k.a full attention.
|
||||
Args:
|
||||
queries: [N, L, H, D]
|
||||
keys: [N, S, H, D]
|
||||
values: [N, S, H, D]
|
||||
q_mask: [N, L]
|
||||
kv_mask: [N, S]
|
||||
Returns:
|
||||
queried_values: (N, L, H, D)
|
||||
"""
|
||||
|
||||
# Compute the unnormalized attention and apply the masks
|
||||
QK = torch.einsum("nlhd,nshd->nlsh", queries, keys)
|
||||
if kv_mask is not None:
|
||||
QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf'))
|
||||
|
||||
# Compute the attention and the weighted average
|
||||
softmax_temp = 1. / queries.size(3)**.5 # sqrt(D)
|
||||
A = torch.softmax(softmax_temp * QK, dim=2)
|
||||
if self.use_dropout:
|
||||
A = self.dropout(A)
|
||||
|
||||
queried_values = torch.einsum("nlsh,nshd->nlhd", A, values)
|
||||
|
||||
return queried_values.contiguous()
|
42
nets/attention/position_encoding.py
Normal file
42
nets/attention/position_encoding.py
Normal file
@ -0,0 +1,42 @@
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class PositionEncodingSine(nn.Module):
|
||||
"""
|
||||
This is a sinusoidal position encoding that generalized to 2-dimensional images
|
||||
"""
|
||||
|
||||
def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True):
|
||||
"""
|
||||
Args:
|
||||
max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels
|
||||
temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41),
|
||||
the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact
|
||||
on the final performance. For now, we keep both impls for backward compatability.
|
||||
We will remove the buggy impl after re-training all variants of our released models.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
pe = torch.zeros((d_model, *max_shape))
|
||||
y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0)
|
||||
x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0)
|
||||
if temp_bug_fix:
|
||||
div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2)))
|
||||
else: # a buggy implementation (for backward compatability only)
|
||||
div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2))
|
||||
div_term = div_term[:, None, None] # [C//4, 1, 1]
|
||||
pe[0::4, :, :] = torch.sin(x_position * div_term)
|
||||
pe[1::4, :, :] = torch.cos(x_position * div_term)
|
||||
pe[2::4, :, :] = torch.sin(y_position * div_term)
|
||||
pe[3::4, :, :] = torch.cos(y_position * div_term)
|
||||
|
||||
self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W]
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x: [N, C, H, W]
|
||||
"""
|
||||
return x + self.pe[:, :, :x.size(2), :x.size(3)]
|
100
nets/attention/transformer.py
Normal file
100
nets/attention/transformer.py
Normal file
@ -0,0 +1,100 @@
|
||||
import copy
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from .linear_attention import LinearAttention, FullAttention
|
||||
|
||||
#Ref: https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py
|
||||
class LoFTREncoderLayer(nn.Module):
|
||||
def __init__(self,
|
||||
d_model,
|
||||
nhead,
|
||||
attention='linear'):
|
||||
super(LoFTREncoderLayer, self).__init__()
|
||||
|
||||
self.dim = d_model // nhead
|
||||
self.nhead = nhead
|
||||
|
||||
# multi-head attention
|
||||
self.q_proj = nn.Linear(d_model, d_model, bias=False)
|
||||
self.k_proj = nn.Linear(d_model, d_model, bias=False)
|
||||
self.v_proj = nn.Linear(d_model, d_model, bias=False)
|
||||
self.attention = LinearAttention() if attention == 'linear' else FullAttention()
|
||||
self.merge = nn.Linear(d_model, d_model, bias=False)
|
||||
|
||||
# feed-forward network
|
||||
self.mlp = nn.Sequential(
|
||||
nn.Linear(d_model*2, d_model*2, bias=False),
|
||||
nn.ReLU(True),
|
||||
nn.Linear(d_model*2, d_model, bias=False),
|
||||
)
|
||||
|
||||
# norm and dropout
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2 = nn.LayerNorm(d_model)
|
||||
|
||||
def forward(self, x, source, x_mask=None, source_mask=None):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): [N, L, C]
|
||||
source (torch.Tensor): [N, S, C]
|
||||
x_mask (torch.Tensor): [N, L] (optional)
|
||||
source_mask (torch.Tensor): [N, S] (optional)
|
||||
"""
|
||||
bs = x.size(0)
|
||||
query, key, value = x, source, source
|
||||
|
||||
# multi-head attention
|
||||
query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)]
|
||||
key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)]
|
||||
value = self.v_proj(value).view(bs, -1, self.nhead, self.dim)
|
||||
message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)]
|
||||
message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C]
|
||||
message = self.norm1(message)
|
||||
|
||||
# feed-forward network
|
||||
message = self.mlp(torch.cat([x, message], dim=2))
|
||||
message = self.norm2(message)
|
||||
|
||||
return x + message
|
||||
|
||||
|
||||
class LocalFeatureTransformer(nn.Module):
|
||||
"""A Local Feature Transformer (LoFTR) module."""
|
||||
|
||||
def __init__(self, d_model, nhead, layer_names, attention):
|
||||
super(LocalFeatureTransformer, self).__init__()
|
||||
|
||||
self.d_model = d_model
|
||||
self.nhead = nhead
|
||||
self.layer_names = layer_names
|
||||
encoder_layer = LoFTREncoderLayer(d_model, nhead, attention)
|
||||
self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))])
|
||||
self._reset_parameters()
|
||||
|
||||
def _reset_parameters(self):
|
||||
for p in self.parameters():
|
||||
if p.dim() > 1:
|
||||
nn.init.xavier_uniform_(p)
|
||||
|
||||
def forward(self, feat0, feat1, mask0=None, mask1=None):
|
||||
"""
|
||||
Args:
|
||||
feat0 (torch.Tensor): [N, L, C]
|
||||
feat1 (torch.Tensor): [N, S, C]
|
||||
mask0 (torch.Tensor): [N, L] (optional)
|
||||
mask1 (torch.Tensor): [N, S] (optional)
|
||||
"""
|
||||
|
||||
assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
|
||||
|
||||
for layer, name in zip(self.layers, self.layer_names):
|
||||
if name == 'self':
|
||||
feat0 = layer(feat0, feat0, mask0, mask0)
|
||||
feat1 = layer(feat1, feat1, mask1, mask1)
|
||||
elif name == 'cross':
|
||||
feat0 = layer(feat0, feat1, mask0, mask1)
|
||||
feat1 = layer(feat1, feat0, mask1, mask0)
|
||||
else:
|
||||
raise KeyError
|
||||
|
||||
return feat0, feat1
|
146
nets/corr.py
Normal file
146
nets/corr.py
Normal file
@ -0,0 +1,146 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .utils import bilinear_sampler, coords_grid, manual_pad
|
||||
|
||||
class AGCL:
|
||||
"""
|
||||
Implementation of Adaptive Group Correlation Layer (AGCL).
|
||||
"""
|
||||
|
||||
def __init__(self, fmap1, fmap2, att=None):
|
||||
self.fmap1 = fmap1
|
||||
self.fmap2 = fmap2
|
||||
|
||||
self.att = att
|
||||
|
||||
self.coords = coords_grid(fmap1.shape[0], fmap1.shape[2], fmap1.shape[3], fmap1.device)
|
||||
|
||||
def __call__(self, flow, extra_offset, small_patch=False, iter_mode=False):
|
||||
if iter_mode:
|
||||
corr = self.corr_iter(self.fmap1, self.fmap2, flow, small_patch)
|
||||
else:
|
||||
corr = self.corr_att_offset(
|
||||
self.fmap1, self.fmap2, flow, extra_offset, small_patch
|
||||
)
|
||||
return corr
|
||||
|
||||
def get_correlation(self, left_feature, right_feature, psize=(3, 3), dilate=(1, 1)):
|
||||
|
||||
N, C, H, W = left_feature.shape
|
||||
|
||||
di_y, di_x = dilate[0], dilate[1]
|
||||
pady, padx = psize[0] // 2 * di_y, psize[1] // 2 * di_x
|
||||
|
||||
right_pad = manual_pad(right_feature, pady, padx)
|
||||
|
||||
corr_list = []
|
||||
for h in range(0, pady * 2 + 1, di_y):
|
||||
for w in range(0, padx * 2 + 1, di_x):
|
||||
right_crop = right_pad[:, :, h : h + H, w : w + W]
|
||||
assert right_crop.shape == left_feature.shape
|
||||
corr = torch.mean(left_feature * right_crop, dim=1, keepdims=True)
|
||||
corr_list.append(corr)
|
||||
|
||||
corr_final = torch.cat(corr_list, dim=1)
|
||||
|
||||
return corr_final
|
||||
|
||||
def corr_iter(self, left_feature, right_feature, flow, small_patch):
|
||||
|
||||
coords = self.coords + flow
|
||||
coords = coords.permute(0, 2, 3, 1)
|
||||
right_feature = bilinear_sampler(right_feature, coords)
|
||||
|
||||
if small_patch:
|
||||
psize_list = [(3, 3), (3, 3), (3, 3), (3, 3)]
|
||||
dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
|
||||
else:
|
||||
psize_list = [(1, 9), (1, 9), (1, 9), (1, 9)]
|
||||
dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
|
||||
|
||||
N, C, H, W = left_feature.shape
|
||||
lefts = torch.split(left_feature, left_feature.shape[1]//4, dim=1)
|
||||
rights = torch.split(right_feature, right_feature.shape[1]//4, dim=1)
|
||||
|
||||
corrs = []
|
||||
for i in range(len(psize_list)):
|
||||
corr = self.get_correlation(
|
||||
lefts[i], rights[i], psize_list[i], dilate_list[i]
|
||||
)
|
||||
corrs.append(corr)
|
||||
|
||||
final_corr = torch.cat(corrs, dim=1)
|
||||
|
||||
return final_corr
|
||||
|
||||
def corr_att_offset(
|
||||
self, left_feature, right_feature, flow, extra_offset, small_patch
|
||||
):
|
||||
|
||||
N, C, H, W = left_feature.shape
|
||||
|
||||
if self.att is not None:
|
||||
left_feature = left_feature.permute(0, 2, 3, 1).reshape(N, H * W, C) # 'n c h w -> n (h w) c'
|
||||
right_feature = right_feature.permute(0, 2, 3, 1).reshape(N, H * W, C) # 'n c h w -> n (h w) c'
|
||||
# 'n (h w) c -> n c h w'
|
||||
left_feature, right_feature = [
|
||||
x.reshape(N, H, W, C).permute(0, 3, 1, 2)
|
||||
for x in [left_feature, right_feature]
|
||||
]
|
||||
|
||||
lefts = torch.split(left_feature, left_feature.shape[1]//4, dim=1)
|
||||
rights = torch.split(right_feature, right_feature.shape[1]//4, dim=1)
|
||||
|
||||
C = C // 4
|
||||
|
||||
if small_patch:
|
||||
psize_list = [(3, 3), (3, 3), (3, 3), (3, 3)]
|
||||
dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
|
||||
else:
|
||||
psize_list = [(1, 9), (1, 9), (1, 9), (1, 9)]
|
||||
dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)]
|
||||
|
||||
search_num = 9
|
||||
extra_offset = extra_offset.reshape(N, search_num, 2, H, W).permute(0, 1, 3, 4, 2) # [N, search_num, 1, 1, 2]
|
||||
|
||||
corrs = []
|
||||
for i in range(len(psize_list)):
|
||||
left_feature, right_feature = lefts[i], rights[i]
|
||||
psize, dilate = psize_list[i], dilate_list[i]
|
||||
|
||||
psizey, psizex = psize[0], psize[1]
|
||||
dilatey, dilatex = dilate[0], dilate[1]
|
||||
|
||||
ry = psizey // 2 * dilatey
|
||||
rx = psizex // 2 * dilatex
|
||||
x_grid, y_grid = torch.meshgrid(torch.arange(-rx, rx + 1, dilatex, device=self.fmap1.device),
|
||||
torch.arange(-ry, ry + 1, dilatey, device=self.fmap1.device), indexing='xy')
|
||||
|
||||
offsets = torch.stack((x_grid, y_grid))
|
||||
offsets = offsets.reshape(2, -1).permute(1, 0)
|
||||
for d in sorted((0, 2, 3)):
|
||||
offsets = offsets.unsqueeze(d)
|
||||
offsets = offsets.repeat_interleave(N, dim=0)
|
||||
offsets = offsets + extra_offset
|
||||
|
||||
coords = self.coords + flow # [N, 2, H, W]
|
||||
coords = coords.permute(0, 2, 3, 1) # [N, H, W, 2]
|
||||
coords = torch.unsqueeze(coords, 1) + offsets
|
||||
coords = coords.reshape(N, -1, W, 2) # [N, search_num*H, W, 2]
|
||||
|
||||
right_feature = bilinear_sampler(
|
||||
right_feature, coords
|
||||
) # [N, C, search_num*H, W]
|
||||
right_feature = right_feature.reshape(N, C, -1, H, W) # [N, C, search_num, H, W]
|
||||
left_feature = left_feature.unsqueeze(2).repeat_interleave(right_feature.shape[2], dim=2)
|
||||
|
||||
corr = torch.mean(left_feature * right_feature, dim=1)
|
||||
|
||||
corrs.append(corr)
|
||||
|
||||
final_corr = torch.cat(corrs, dim=1)
|
||||
|
||||
return final_corr
|
258
nets/crestereo.py
Normal file
258
nets/crestereo.py
Normal file
@ -0,0 +1,258 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .update import BasicUpdateBlock
|
||||
from .extractor import BasicEncoder
|
||||
from .corr import AGCL
|
||||
|
||||
from .attention import PositionEncodingSine, LocalFeatureTransformer
|
||||
|
||||
try:
|
||||
autocast = torch.cuda.amp.autocast
|
||||
except:
|
||||
# dummy autocast for PyTorch < 1.6
|
||||
class autocast:
|
||||
def __init__(self, enabled):
|
||||
pass
|
||||
def __enter__(self):
|
||||
pass
|
||||
def __exit__(self, *args):
|
||||
pass
|
||||
|
||||
#Ref: https://github.com/princeton-vl/RAFT/blob/master/core/raft.py
|
||||
class CREStereo(nn.Module):
|
||||
def __init__(self, max_disp=192, mixed_precision=False, test_mode=False):
|
||||
super(CREStereo, self).__init__()
|
||||
|
||||
self.max_flow = max_disp
|
||||
self.mixed_precision = mixed_precision
|
||||
self.test_mode = test_mode
|
||||
|
||||
self.hidden_dim = 128
|
||||
self.context_dim = 128
|
||||
self.dropout = 0
|
||||
|
||||
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=self.dropout)
|
||||
self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4)
|
||||
|
||||
# loftr
|
||||
self.self_att_fn = LocalFeatureTransformer(
|
||||
d_model=256, nhead=8, layer_names=["self"] * 1, attention="linear"
|
||||
)
|
||||
self.cross_att_fn = LocalFeatureTransformer(
|
||||
d_model=256, nhead=8, layer_names=["cross"] * 1, attention="linear"
|
||||
)
|
||||
|
||||
# adaptive search
|
||||
self.search_num = 9
|
||||
self.conv_offset_16 = nn.Conv2d(
|
||||
256, self.search_num * 2, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
self.conv_offset_8 = nn.Conv2d(
|
||||
256, self.search_num * 2, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
self.range_16 = 1
|
||||
self.range_8 = 1
|
||||
|
||||
def freeze_bn(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.BatchNorm2d):
|
||||
m.eval()
|
||||
|
||||
def convex_upsample(self, flow, mask, rate=4):
|
||||
""" Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
|
||||
N, _, H, W = flow.shape
|
||||
# print(flow.shape, mask.shape, rate)
|
||||
mask = mask.view(N, 1, 9, rate, rate, H, W)
|
||||
mask = torch.softmax(mask, dim=2)
|
||||
|
||||
up_flow = F.unfold(rate * flow, [3,3], padding=1)
|
||||
up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
|
||||
|
||||
up_flow = torch.sum(mask * up_flow, dim=2)
|
||||
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
|
||||
return up_flow.reshape(N, 2, rate*H, rate*W)
|
||||
|
||||
def zero_init(self, fmap):
|
||||
N, C, H, W = fmap.shape
|
||||
_x = torch.zeros([N, 1, H, W], dtype=torch.float32)
|
||||
_y = torch.zeros([N, 1, H, W], dtype=torch.float32)
|
||||
zero_flow = torch.cat((_x, _y), dim=1).to(fmap.device)
|
||||
return zero_flow
|
||||
|
||||
def forward(self, image1, image2, iters=10, flow_init=None, upsample=True, test_mode=False):
|
||||
""" Estimate optical flow between pair of frames """
|
||||
|
||||
image1 = 2 * (image1 / 255.0) - 1.0
|
||||
image2 = 2 * (image2 / 255.0) - 1.0
|
||||
|
||||
image1 = image1.contiguous()
|
||||
image2 = image2.contiguous()
|
||||
|
||||
hdim = self.hidden_dim
|
||||
cdim = self.context_dim
|
||||
|
||||
# run the feature network
|
||||
with autocast(enabled=self.mixed_precision):
|
||||
fmap1, fmap2 = self.fnet([image1, image2])
|
||||
|
||||
fmap1 = fmap1.float()
|
||||
fmap2 = fmap2.float()
|
||||
|
||||
with autocast(enabled=self.mixed_precision):
|
||||
|
||||
# 1/4 -> 1/8
|
||||
# feature
|
||||
fmap1_dw8 = F.avg_pool2d(fmap1, 2, stride=2)
|
||||
fmap2_dw8 = F.avg_pool2d(fmap2, 2, stride=2)
|
||||
|
||||
# offset
|
||||
offset_dw8 = self.conv_offset_8(fmap1_dw8)
|
||||
offset_dw8 = self.range_8 * (torch.sigmoid(offset_dw8) - 0.5) * 2.0
|
||||
|
||||
# context
|
||||
net, inp = torch.split(fmap1, [hdim,hdim], dim=1)
|
||||
net = torch.tanh(net)
|
||||
inp = F.relu(inp)
|
||||
net_dw8 = F.avg_pool2d(net, 2, stride=2)
|
||||
inp_dw8 = F.avg_pool2d(inp, 2, stride=2)
|
||||
|
||||
# 1/4 -> 1/16
|
||||
# feature
|
||||
fmap1_dw16 = F.avg_pool2d(fmap1, 4, stride=4)
|
||||
fmap2_dw16 = F.avg_pool2d(fmap2, 4, stride=4)
|
||||
offset_dw16 = self.conv_offset_16(fmap1_dw16)
|
||||
offset_dw16 = self.range_16 * (torch.sigmoid(offset_dw16) - 0.5) * 2.0
|
||||
|
||||
# context
|
||||
net_dw16 = F.avg_pool2d(net, 4, stride=4)
|
||||
inp_dw16 = F.avg_pool2d(inp, 4, stride=4)
|
||||
|
||||
# positional encoding and self-attention
|
||||
pos_encoding_fn_small = PositionEncodingSine(
|
||||
d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
|
||||
)
|
||||
# 'n c h w -> n (h w) c'
|
||||
x_tmp = pos_encoding_fn_small(fmap1_dw16)
|
||||
fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
|
||||
# 'n c h w -> n (h w) c'
|
||||
x_tmp = pos_encoding_fn_small(fmap2_dw16)
|
||||
fmap2_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
|
||||
|
||||
fmap1_dw16, fmap2_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16)
|
||||
fmap1_dw16, fmap2_dw16 = [
|
||||
x.reshape(x.shape[0], image1.shape[2] // 16, -1, x.shape[2]).permute(0, 3, 1, 2)
|
||||
for x in [fmap1_dw16, fmap2_dw16]
|
||||
]
|
||||
|
||||
corr_fn = AGCL(fmap1, fmap2)
|
||||
corr_fn_dw8 = AGCL(fmap1_dw8, fmap2_dw8)
|
||||
corr_fn_att_dw16 = AGCL(fmap1_dw16, fmap2_dw16, att=self.cross_att_fn)
|
||||
|
||||
# Cascaded refinement (1/16 + 1/8 + 1/4)
|
||||
predictions = []
|
||||
flow = None
|
||||
flow_up = None
|
||||
if flow_init is not None:
|
||||
scale = fmap1.shape[2] / flow_init.shape[2]
|
||||
flow = -scale * F.interpolate(
|
||||
flow_init,
|
||||
size=(fmap1.shape[2], fmap1.shape[3]),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
else:
|
||||
# zero initialization
|
||||
flow_dw16 = self.zero_init(fmap1_dw16)
|
||||
|
||||
# Recurrent Update Module
|
||||
# RUM: 1/16
|
||||
for itr in range(iters // 2):
|
||||
if itr % 2 == 0:
|
||||
small_patch = False
|
||||
else:
|
||||
small_patch = True
|
||||
|
||||
flow_dw16 = flow_dw16.detach()
|
||||
out_corrs = corr_fn_att_dw16(
|
||||
flow_dw16, offset_dw16, small_patch=small_patch
|
||||
)
|
||||
|
||||
with autocast(enabled=self.mixed_precision):
|
||||
net_dw16, up_mask, delta_flow = self.update_block(
|
||||
net_dw16, inp_dw16, out_corrs, flow_dw16
|
||||
)
|
||||
|
||||
flow_dw16 = flow_dw16 + delta_flow
|
||||
flow = self.convex_upsample(flow_dw16, up_mask, rate=4)
|
||||
flow_up = -4 * F.interpolate(
|
||||
flow,
|
||||
size=(4 * flow.shape[2], 4 * flow.shape[3]),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
predictions.append(flow_up)
|
||||
|
||||
scale = fmap1_dw8.shape[2] / flow.shape[2]
|
||||
flow_dw8 = -scale * F.interpolate(
|
||||
flow,
|
||||
size=(fmap1_dw8.shape[2], fmap1_dw8.shape[3]),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
|
||||
# RUM: 1/8
|
||||
for itr in range(iters // 2):
|
||||
if itr % 2 == 0:
|
||||
small_patch = False
|
||||
else:
|
||||
small_patch = True
|
||||
|
||||
flow_dw8 = flow_dw8.detach()
|
||||
out_corrs = corr_fn_dw8(flow_dw8, offset_dw8, small_patch=small_patch)
|
||||
|
||||
with autocast(enabled=self.mixed_precision):
|
||||
net_dw8, up_mask, delta_flow = self.update_block(
|
||||
net_dw8, inp_dw8, out_corrs, flow_dw8
|
||||
)
|
||||
|
||||
flow_dw8 = flow_dw8 + delta_flow
|
||||
flow = self.convex_upsample(flow_dw8, up_mask, rate=4)
|
||||
flow_up = -2 * F.interpolate(
|
||||
flow,
|
||||
size=(2 * flow.shape[2], 2 * flow.shape[3]),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
predictions.append(flow_up)
|
||||
|
||||
scale = fmap1.shape[2] / flow.shape[2]
|
||||
flow = -scale * F.interpolate(
|
||||
flow,
|
||||
size=(fmap1.shape[2], fmap1.shape[3]),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
|
||||
# RUM: 1/4
|
||||
for itr in range(iters):
|
||||
if itr % 2 == 0:
|
||||
small_patch = False
|
||||
else:
|
||||
small_patch = True
|
||||
|
||||
flow = flow.detach()
|
||||
out_corrs = corr_fn(flow, None, small_patch=small_patch, iter_mode=True)
|
||||
|
||||
with autocast(enabled=self.mixed_precision):
|
||||
net, up_mask, delta_flow = self.update_block(net, inp, out_corrs, flow)
|
||||
|
||||
flow = flow + delta_flow
|
||||
flow_up = -self.convex_upsample(flow, up_mask, rate=4)
|
||||
predictions.append(flow_up)
|
||||
|
||||
if self.test_mode:
|
||||
return flow_up
|
||||
|
||||
return predictions
|
123
nets/extractor.py
Normal file
123
nets/extractor.py
Normal file
@ -0,0 +1,123 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Ref: https://github.com/princeton-vl/RAFT/blob/master/core/extractor.py
|
||||
class ResidualBlock(nn.Module):
|
||||
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
num_groups = planes // 8
|
||||
|
||||
if norm_fn == 'group':
|
||||
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
||||
|
||||
elif norm_fn == 'batch':
|
||||
self.norm1 = nn.BatchNorm2d(planes)
|
||||
self.norm2 = nn.BatchNorm2d(planes)
|
||||
self.norm3 = nn.BatchNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(planes)
|
||||
self.norm2 = nn.InstanceNorm2d(planes)
|
||||
self.norm3 = nn.InstanceNorm2d(planes)
|
||||
|
||||
elif norm_fn == 'none':
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm2 = nn.Sequential()
|
||||
self.norm3 = nn.Sequential()
|
||||
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.relu(self.norm1(self.conv1(y)))
|
||||
y = self.relu(self.norm2(self.conv2(y)))
|
||||
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x+y)
|
||||
|
||||
|
||||
class BasicEncoder(nn.Module):
|
||||
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
|
||||
super(BasicEncoder, self).__init__()
|
||||
self.norm_fn = norm_fn
|
||||
|
||||
if self.norm_fn == 'group':
|
||||
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
|
||||
|
||||
elif self.norm_fn == 'batch':
|
||||
self.norm1 = nn.BatchNorm2d(64)
|
||||
|
||||
elif self.norm_fn == 'instance':
|
||||
self.norm1 = nn.InstanceNorm2d(64)
|
||||
|
||||
elif self.norm_fn == 'none':
|
||||
self.norm1 = nn.Sequential()
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.in_planes = 64
|
||||
self.layer1 = self._make_layer(64, stride=1)
|
||||
self.layer2 = self._make_layer(96, stride=2)
|
||||
self.layer3 = self._make_layer(128, stride=1)
|
||||
|
||||
# output convolution
|
||||
self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
|
||||
|
||||
self.dropout = None
|
||||
if dropout > 0:
|
||||
self.dropout = nn.Dropout2d(p=dropout)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
||||
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
||||
if m.weight is not None:
|
||||
nn.init.constant_(m.weight, 1)
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def _make_layer(self, dim, stride=1):
|
||||
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||
layers = (layer1, layer2)
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
# if input is list, combine batch dimension
|
||||
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||
if is_list:
|
||||
batch_dim = x[0].shape[0]
|
||||
x = torch.cat(x, dim=0)
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
|
||||
if self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
|
||||
if is_list:
|
||||
x = torch.split(x, x.shape[0]//2, dim=0)
|
||||
|
||||
return x
|
91
nets/update.py
Normal file
91
nets/update.py
Normal file
@ -0,0 +1,91 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
#Ref: https://github.com/princeton-vl/RAFT/blob/master/core/update.py
|
||||
class FlowHead(nn.Module):
|
||||
def __init__(self, input_dim=128, hidden_dim=256):
|
||||
super(FlowHead, self).__init__()
|
||||
self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
|
||||
self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv2(self.relu(self.conv1(x)))
|
||||
|
||||
|
||||
class SepConvGRU(nn.Module):
|
||||
def __init__(self, hidden_dim=128, input_dim=192+128):
|
||||
super(SepConvGRU, self).__init__()
|
||||
self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
||||
self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
||||
self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
|
||||
|
||||
self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
||||
self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
||||
self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
|
||||
|
||||
def forward(self, h, x):
|
||||
# horizontal
|
||||
hx = torch.cat([h, x], dim=1)
|
||||
z = torch.sigmoid(self.convz1(hx))
|
||||
r = torch.sigmoid(self.convr1(hx))
|
||||
q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
|
||||
h = (1-z) * h + z * q
|
||||
|
||||
# vertical
|
||||
hx = torch.cat([h, x], dim=1)
|
||||
z = torch.sigmoid(self.convz2(hx))
|
||||
r = torch.sigmoid(self.convr2(hx))
|
||||
q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
|
||||
h = (1-z) * h + z * q
|
||||
|
||||
return h
|
||||
|
||||
|
||||
class BasicMotionEncoder(nn.Module):
|
||||
def __init__(self, cor_planes):
|
||||
super(BasicMotionEncoder, self).__init__()
|
||||
|
||||
self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
|
||||
self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
|
||||
self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
|
||||
self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
|
||||
self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
|
||||
|
||||
def forward(self, flow, corr):
|
||||
cor = F.relu(self.convc1(corr))
|
||||
cor = F.relu(self.convc2(cor))
|
||||
flo = F.relu(self.convf1(flow))
|
||||
flo = F.relu(self.convf2(flo))
|
||||
|
||||
cor_flo = torch.cat([cor, flo], dim=1)
|
||||
out = F.relu(self.conv(cor_flo))
|
||||
return torch.cat([out, flow], dim=1)
|
||||
|
||||
|
||||
class BasicUpdateBlock(nn.Module):
|
||||
def __init__(self, hidden_dim, cor_planes, mask_size=8):
|
||||
super(BasicUpdateBlock, self).__init__()
|
||||
|
||||
self.encoder = BasicMotionEncoder(cor_planes)
|
||||
self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
|
||||
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
||||
|
||||
self.mask = nn.Sequential(
|
||||
nn.Conv2d(128, 256, 3, padding=1),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(256, mask_size**2 *9, 1, padding=0))
|
||||
|
||||
def forward(self, net, inp, corr, flow, upsample=True):
|
||||
# print(inp.shape, corr.shape, flow.shape)
|
||||
motion_features = self.encoder(flow, corr)
|
||||
# print(motion_features.shape, inp.shape)
|
||||
inp = torch.cat((inp, motion_features), dim=1)
|
||||
|
||||
net = self.gru(net, inp)
|
||||
delta_flow = self.flow_head(net)
|
||||
|
||||
# scale mask to balence gradients
|
||||
mask = .25 * self.mask(net)
|
||||
return net, mask, delta_flow
|
1
nets/utils/__init__.py
Normal file
1
nets/utils/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
from .utils import bilinear_sampler, coords_grid, manual_pad
|
31
nets/utils/utils.py
Normal file
31
nets/utils/utils.py
Normal file
@ -0,0 +1,31 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
#Ref: https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py
|
||||
|
||||
def bilinear_sampler(img, coords, mode='bilinear', mask=False):
|
||||
""" Wrapper for grid_sample, uses pixel coordinates """
|
||||
H, W = img.shape[-2:]
|
||||
xgrid, ygrid = coords.split([1,1], dim=-1)
|
||||
xgrid = 2*xgrid/(W-1) - 1
|
||||
ygrid = 2*ygrid/(H-1) - 1
|
||||
|
||||
grid = torch.cat([xgrid, ygrid], dim=-1)
|
||||
img = F.grid_sample(img, grid, align_corners=True)
|
||||
|
||||
if mask:
|
||||
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
||||
return img, mask.float()
|
||||
|
||||
return img
|
||||
|
||||
def coords_grid(batch, ht, wd, device):
|
||||
coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device), indexing='ij')
|
||||
coords = torch.stack(coords[::-1], dim=0).float()
|
||||
return coords[None].repeat(batch, 1, 1, 1)
|
||||
|
||||
def manual_pad(x, pady, padx):
|
||||
|
||||
pad = (padx, padx, pady, pady)
|
||||
return F.pad(x.clone().detach(), pad, "replicate")
|
15
test_model.py
Normal file
15
test_model.py
Normal file
@ -0,0 +1,15 @@
|
||||
import torch
|
||||
from torchsummary import summary
|
||||
import numpy as np
|
||||
|
||||
from nets import Model
|
||||
|
||||
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
|
||||
model.eval()
|
||||
|
||||
t1 = torch.rand(1, 3, 480, 640)
|
||||
t2 = torch.rand(1, 3, 480, 640)
|
||||
|
||||
output = model(t1,t2)
|
||||
print(output.shape)
|
||||
|
Loading…
Reference in New Issue
Block a user