commit fc524edd767ba3b47794bdbf5483f4a777dfca38 Author: Ibai Date: Fri Apr 8 00:18:27 2022 +0900 Initial commit diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..dfe0770 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# Auto detect text files and perform LF normalization +* text=auto diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..d9005f2 --- /dev/null +++ b/.gitignore @@ -0,0 +1,152 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintainted in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ diff --git a/README.md b/README.md new file mode 100644 index 0000000..e583792 --- /dev/null +++ b/README.md @@ -0,0 +1,2 @@ +# CREStereo-Pytorch + Non-official Pytorch implementation of the CREStereo(CVPR 2022 Oral). diff --git a/function_convertion_tests/test_bilinear_sampler.py b/function_convertion_tests/test_bilinear_sampler.py new file mode 100644 index 0000000..2e12267 --- /dev/null +++ b/function_convertion_tests/test_bilinear_sampler.py @@ -0,0 +1,44 @@ +import pickle +import numpy as np +import megengine as mge + +import torch +import torch.nn.functional as F + +def bilinear_sampler(img, coords, mode='bilinear', mask=False): + + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + +def test_bilinear_sampler(): + # Getting back the megengine objects: + with open('test_data/bilinear_sampler_test.pickle', 'rb') as f: + right_feature_prev, coords, right_feature = pickle.load(f) + + right_feature_prev = torch.tensor(right_feature_prev.numpy()) + coords = torch.tensor(coords.numpy()) + right_feature = right_feature.numpy() + + # Test Pytorch + right_feature_pytorch = bilinear_sampler(right_feature_prev, coords).numpy() + + error = np.mean(right_feature_pytorch-right_feature) + print(f"test_coords_grid - Avg. Error: {error}, \n \ + Original shape: {coords.numpy().shape},\n \ + Obtained shape: {right_feature_pytorch.shape}, Expected shape: {right_feature.shape}") + +if __name__ == '__main__': + + test_bilinear_sampler() \ No newline at end of file diff --git a/function_convertion_tests/test_coords_grid.py b/function_convertion_tests/test_coords_grid.py new file mode 100644 index 0000000..0f597f4 --- /dev/null +++ b/function_convertion_tests/test_coords_grid.py @@ -0,0 +1,29 @@ +import pickle +import numpy as np +import megengine as mge + +import torch +import torch.nn.functional as F + +def coords_grid(batch, ht, wd, device): + coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device), indexing='ij') + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + +def test_coords_grid(): + # Getting back the megengine objects: + with open('test_data/coords_grid_test.pickle', 'rb') as f: + batch, ht, wd, coords = pickle.load(f) + + coords = coords.numpy() + + # Test Pytorch + coords_pytorch = coords_grid(batch, ht, wd, 'cpu').numpy() + + error = np.mean(coords_pytorch-coords) + print(f"test_coords_grid - Avg. Error: {error}, \n \ + Obtained shape: {coords_pytorch.shape}, Expected shape: {coords.shape}") + +if __name__ == '__main__': + + test_coords_grid() \ No newline at end of file diff --git a/function_convertion_tests/test_data/bilinear_sampler_test.pickle b/function_convertion_tests/test_data/bilinear_sampler_test.pickle new file mode 100644 index 0000000..7156e08 Binary files /dev/null and b/function_convertion_tests/test_data/bilinear_sampler_test.pickle differ diff --git a/function_convertion_tests/test_data/coords_grid_test.pickle b/function_convertion_tests/test_data/coords_grid_test.pickle new file mode 100644 index 0000000..18b3799 Binary files /dev/null and b/function_convertion_tests/test_data/coords_grid_test.pickle differ diff --git a/function_convertion_tests/test_data/manual_pad_test0_4.pickle b/function_convertion_tests/test_data/manual_pad_test0_4.pickle new file mode 100644 index 0000000..37dddcb Binary files /dev/null and b/function_convertion_tests/test_data/manual_pad_test0_4.pickle differ diff --git a/function_convertion_tests/test_data/manual_pad_test1_1.pickle b/function_convertion_tests/test_data/manual_pad_test1_1.pickle new file mode 100644 index 0000000..8fb52a8 Binary files /dev/null and b/function_convertion_tests/test_data/manual_pad_test1_1.pickle differ diff --git a/function_convertion_tests/test_data/meshgrid_np_test.pkl b/function_convertion_tests/test_data/meshgrid_np_test.pkl new file mode 100644 index 0000000..c3f9cb3 Binary files /dev/null and b/function_convertion_tests/test_data/meshgrid_np_test.pkl differ diff --git a/function_convertion_tests/test_data/offset_test.pkl b/function_convertion_tests/test_data/offset_test.pkl new file mode 100644 index 0000000..b92d22e Binary files /dev/null and b/function_convertion_tests/test_data/offset_test.pkl differ diff --git a/function_convertion_tests/test_data/split_test.pkl b/function_convertion_tests/test_data/split_test.pkl new file mode 100644 index 0000000..e5c96a1 Binary files /dev/null and b/function_convertion_tests/test_data/split_test.pkl differ diff --git a/function_convertion_tests/test_data/split_test_list.pkl b/function_convertion_tests/test_data/split_test_list.pkl new file mode 100644 index 0000000..68692b2 Binary files /dev/null and b/function_convertion_tests/test_data/split_test_list.pkl differ diff --git a/function_convertion_tests/test_manual_pad.py b/function_convertion_tests/test_manual_pad.py new file mode 100644 index 0000000..c5a1bbe --- /dev/null +++ b/function_convertion_tests/test_manual_pad.py @@ -0,0 +1,51 @@ +import pickle +import numpy as np +import megengine as mge + +import torch +import torch.nn.functional as F + +def manual_pad(x, pady, padx): + + pad = (padx, padx, pady, pady) + return F.pad(torch.tensor(x), pad, "replicate") + + +def test_pad_1_1(): + # Getting back the megengine objects: + with open('test_data/manual_pad_test1_1.pickle', 'rb') as f: + right_feature, pady, padx, right_pad = pickle.load(f) + + right_feature = right_feature.numpy() + right_pad = right_pad.numpy() + + # Test Pytorch + right_pad_pytorch = manual_pad(right_feature, pady, padx).numpy() + + error = np.mean(right_pad_pytorch-right_pad) + print(f"test_pad_1_1 - Avg. Error: {error}, \n \ + Orig. shape: {right_feature.shape}, \n \ + Padded shape: {right_pad_pytorch.shape}, Expected shape: {right_pad.shape}") + +def test_pad_0_4(): + # Getting back the megengine objects: + with open('test_data/manual_pad_test0_4.pickle', 'rb') as f: + right_feature, pady, padx, right_pad = pickle.load(f) + + right_feature = right_feature.numpy() + right_pad = right_pad.numpy() + + # Test Pytorch + right_pad_pytorch = manual_pad(right_feature, pady, padx).numpy() + + error = np.mean(right_pad_pytorch-right_pad) + print(f"test_pad_0_4 - Avg. Error: {error}, \n \ + Orig. shape: {right_feature.shape}, \n \ + Padded shape: {right_pad_pytorch.shape}, Expected shape: {right_pad.shape}") + + +if __name__ == '__main__': + + test_pad_1_1() + + test_pad_0_4() \ No newline at end of file diff --git a/function_convertion_tests/test_meshgrid.py b/function_convertion_tests/test_meshgrid.py new file mode 100644 index 0000000..9883641 --- /dev/null +++ b/function_convertion_tests/test_meshgrid.py @@ -0,0 +1,30 @@ +import pickle +import numpy as np +import megengine as mge + +import torch +import torch.nn.functional as F + +def test_meshgrid(): + # Getting back the megengine objects: + with open('test_data/meshgrid_np_test.pkl', 'rb') as f: + rx, dilatex, ry, dilatey, x_grid, y_grid = pickle.load(f) + + x_grid = x_grid.numpy() + y_grid = y_grid.numpy() + + # Test Pytorch + x_grid_pytorch, y_grid_pytorch = torch.meshgrid(torch.arange(-rx, rx + 1, dilatex, device='cpu'), + torch.arange(-ry, ry + 1, dilatey, device='cpu'), indexing='xy') + + + error_x = np.mean(x_grid_pytorch.numpy()-x_grid) + error_y = np.mean(y_grid_pytorch.numpy()-y_grid) + print(f"test_meshgrid (X) - Avg. Error: {error_x}, \n \ + Obtained shape: {x_grid_pytorch.numpy().shape}, Expected shape: {x_grid.shape}") + print(f"test_meshgrid (Y) - Avg. Error: {error_y}, \n \ + Obtained shape: {y_grid_pytorch.numpy().shape}, Expected shape: {y_grid.shape}") + +if __name__ == '__main__': + + test_meshgrid() \ No newline at end of file diff --git a/function_convertion_tests/test_offset.py b/function_convertion_tests/test_offset.py new file mode 100644 index 0000000..fb8bf57 --- /dev/null +++ b/function_convertion_tests/test_offset.py @@ -0,0 +1,31 @@ +import pickle +import numpy as np +import megengine as mge + +import torch +import torch.nn.functional as F + +def test_offset(): + # Getting back the megengine objects: + with open('test_data/offset_test.pkl', 'rb') as f: + x_grid, y_grid, reshape_shape, transpose_order, expand_size, repeat_size, repeat_axis, offsets = pickle.load(f) + + x_grid = torch.tensor(x_grid.numpy()) + y_grid = torch.tensor(y_grid.numpy()) + offsets_mge = offsets.numpy() + N = repeat_size + + # Test Pytorch + offsets = torch.stack((x_grid, y_grid)) + offsets = offsets.reshape(2, -1).permute(1, 0) + for d in sorted((0, 2, 3)): + offsets = offsets.unsqueeze(d) + offsets = offsets.repeat_interleave(N, dim=0) + + error = np.mean(offsets.numpy()-offsets_mge) + print(f"test_offset - Avg. Error: {error}, \n \ + Obtained shape: {offsets.numpy().shape}, Expected shape: {offsets_mge.shape}") + +if __name__ == '__main__': + + test_offset() \ No newline at end of file diff --git a/function_convertion_tests/test_split.py b/function_convertion_tests/test_split.py new file mode 100644 index 0000000..7022eb8 --- /dev/null +++ b/function_convertion_tests/test_split.py @@ -0,0 +1,47 @@ +import pickle +import numpy as np +import megengine as mge + +import torch +import torch.nn.functional as F + +def test_split(): + # Getting back the megengine objects: + with open('test_data/split_test.pkl', 'rb') as f: + left_feature, size, axis, lefts = pickle.load(f) + + left_feature = torch.tensor(left_feature.numpy()) + + # Test Pytorch + lefts_pytorch = torch.split(left_feature, left_feature.shape[axis]//size, dim=axis) + + for i, (left_pytorch, left) in enumerate(zip(lefts_pytorch, lefts)): + + error = np.mean(left_pytorch.numpy()-left.numpy()) + print(f"test_split {i} - Avg. Error: {error}, \n \ + Obtained shape: {left_pytorch.numpy().shape}, Expected shape: {left.numpy().shape}\n") + +def test_split_list(): + # Getting back the megengine objects: + with open('test_data/split_test_list.pkl', 'rb') as f: + fmap1, size, axis, net, inp = pickle.load(f) + + fmap1 = torch.tensor(fmap1.numpy()) + net = net.numpy() + inp = inp.numpy() + + # Test Pytorch + net_pytorch, inp_pytorch = torch.split(fmap1, [size[0],size[0]], dim=axis) + + error_net = np.mean(net_pytorch.numpy()-net) + error_inp = np.mean(inp_pytorch.numpy()-inp) + print(f"test_split_list (net) - Avg. Error: {error_net}, \n \ + Obtained shape: {net_pytorch.numpy().shape}, Expected shape: {net.shape}\n") + print(f"test_split_list (inp) - Avg. Error: {error_inp}, \n \ + Obtained shape: {inp_pytorch.numpy().shape}, Expected shape: {inp.shape}\n") + + +if __name__ == '__main__': + + test_split() + test_split_list() \ No newline at end of file diff --git a/nets/__init__.py b/nets/__init__.py new file mode 100644 index 0000000..0d0730f --- /dev/null +++ b/nets/__init__.py @@ -0,0 +1 @@ +from .crestereo import CREStereo as Model diff --git a/nets/attention/__init__.py b/nets/attention/__init__.py new file mode 100644 index 0000000..a7f763c --- /dev/null +++ b/nets/attention/__init__.py @@ -0,0 +1,2 @@ +from .transformer import LocalFeatureTransformer +from .position_encoding import PositionEncodingSine diff --git a/nets/attention/linear_attention.py b/nets/attention/linear_attention.py new file mode 100644 index 0000000..61b1b85 --- /dev/null +++ b/nets/attention/linear_attention.py @@ -0,0 +1,81 @@ +""" +Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" +Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py +""" + +import torch +from torch.nn import Module, Dropout + + +def elu_feature_map(x): + return torch.nn.functional.elu(x) + 1 + + +class LinearAttention(Module): + def __init__(self, eps=1e-6): + super().__init__() + self.feature_map = elu_feature_map + self.eps = eps + + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-Head linear attention proposed in "Transformers are RNNs" + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + Q = self.feature_map(queries) + K = self.feature_map(keys) + + # set padded position to zero + if q_mask is not None: + Q = Q * q_mask[:, :, None, None] + if kv_mask is not None: + K = K * kv_mask[:, :, None, None] + values = values * kv_mask[:, :, None, None] + + v_length = values.size(1) + values = values / v_length # prevent fp16 overflow + KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V + Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) + queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length + + return queried_values.contiguous() + + +class FullAttention(Module): + def __init__(self, use_dropout=False, attention_dropout=0.1): + super().__init__() + self.use_dropout = use_dropout + self.dropout = Dropout(attention_dropout) + + def forward(self, queries, keys, values, q_mask=None, kv_mask=None): + """ Multi-head scaled dot-product attention, a.k.a full attention. + Args: + queries: [N, L, H, D] + keys: [N, S, H, D] + values: [N, S, H, D] + q_mask: [N, L] + kv_mask: [N, S] + Returns: + queried_values: (N, L, H, D) + """ + + # Compute the unnormalized attention and apply the masks + QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) + if kv_mask is not None: + QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf')) + + # Compute the attention and the weighted average + softmax_temp = 1. / queries.size(3)**.5 # sqrt(D) + A = torch.softmax(softmax_temp * QK, dim=2) + if self.use_dropout: + A = self.dropout(A) + + queried_values = torch.einsum("nlsh,nshd->nlhd", A, values) + + return queried_values.contiguous() \ No newline at end of file diff --git a/nets/attention/position_encoding.py b/nets/attention/position_encoding.py new file mode 100644 index 0000000..3ad49f1 --- /dev/null +++ b/nets/attention/position_encoding.py @@ -0,0 +1,42 @@ +import math +import torch +from torch import nn + + +class PositionEncodingSine(nn.Module): + """ + This is a sinusoidal position encoding that generalized to 2-dimensional images + """ + + def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True): + """ + Args: + max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels + temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41), + the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact + on the final performance. For now, we keep both impls for backward compatability. + We will remove the buggy impl after re-training all variants of our released models. + """ + super().__init__() + + pe = torch.zeros((d_model, *max_shape)) + y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0) + x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0) + if temp_bug_fix: + div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2))) + else: # a buggy implementation (for backward compatability only) + div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2)) + div_term = div_term[:, None, None] # [C//4, 1, 1] + pe[0::4, :, :] = torch.sin(x_position * div_term) + pe[1::4, :, :] = torch.cos(x_position * div_term) + pe[2::4, :, :] = torch.sin(y_position * div_term) + pe[3::4, :, :] = torch.cos(y_position * div_term) + + self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W] + + def forward(self, x): + """ + Args: + x: [N, C, H, W] + """ + return x + self.pe[:, :, :x.size(2), :x.size(3)] \ No newline at end of file diff --git a/nets/attention/transformer.py b/nets/attention/transformer.py new file mode 100644 index 0000000..f47c36a --- /dev/null +++ b/nets/attention/transformer.py @@ -0,0 +1,100 @@ +import copy +import torch +import torch.nn as nn +from .linear_attention import LinearAttention, FullAttention + +#Ref: https://github.com/zju3dv/LoFTR/blob/master/src/loftr/loftr_module/transformer.py +class LoFTREncoderLayer(nn.Module): + def __init__(self, + d_model, + nhead, + attention='linear'): + super(LoFTREncoderLayer, self).__init__() + + self.dim = d_model // nhead + self.nhead = nhead + + # multi-head attention + self.q_proj = nn.Linear(d_model, d_model, bias=False) + self.k_proj = nn.Linear(d_model, d_model, bias=False) + self.v_proj = nn.Linear(d_model, d_model, bias=False) + self.attention = LinearAttention() if attention == 'linear' else FullAttention() + self.merge = nn.Linear(d_model, d_model, bias=False) + + # feed-forward network + self.mlp = nn.Sequential( + nn.Linear(d_model*2, d_model*2, bias=False), + nn.ReLU(True), + nn.Linear(d_model*2, d_model, bias=False), + ) + + # norm and dropout + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + + def forward(self, x, source, x_mask=None, source_mask=None): + """ + Args: + x (torch.Tensor): [N, L, C] + source (torch.Tensor): [N, S, C] + x_mask (torch.Tensor): [N, L] (optional) + source_mask (torch.Tensor): [N, S] (optional) + """ + bs = x.size(0) + query, key, value = x, source, source + + # multi-head attention + query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] + key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)] + value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) + message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)] + message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C] + message = self.norm1(message) + + # feed-forward network + message = self.mlp(torch.cat([x, message], dim=2)) + message = self.norm2(message) + + return x + message + + +class LocalFeatureTransformer(nn.Module): + """A Local Feature Transformer (LoFTR) module.""" + + def __init__(self, d_model, nhead, layer_names, attention): + super(LocalFeatureTransformer, self).__init__() + + self.d_model = d_model + self.nhead = nhead + self.layer_names = layer_names + encoder_layer = LoFTREncoderLayer(d_model, nhead, attention) + self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]) + self._reset_parameters() + + def _reset_parameters(self): + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + + def forward(self, feat0, feat1, mask0=None, mask1=None): + """ + Args: + feat0 (torch.Tensor): [N, L, C] + feat1 (torch.Tensor): [N, S, C] + mask0 (torch.Tensor): [N, L] (optional) + mask1 (torch.Tensor): [N, S] (optional) + """ + + assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal" + + for layer, name in zip(self.layers, self.layer_names): + if name == 'self': + feat0 = layer(feat0, feat0, mask0, mask0) + feat1 = layer(feat1, feat1, mask1, mask1) + elif name == 'cross': + feat0 = layer(feat0, feat1, mask0, mask1) + feat1 = layer(feat1, feat0, mask1, mask0) + else: + raise KeyError + + return feat0, feat1 \ No newline at end of file diff --git a/nets/corr.py b/nets/corr.py new file mode 100644 index 0000000..aaa2af9 --- /dev/null +++ b/nets/corr.py @@ -0,0 +1,146 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .utils import bilinear_sampler, coords_grid, manual_pad + +class AGCL: + """ + Implementation of Adaptive Group Correlation Layer (AGCL). + """ + + def __init__(self, fmap1, fmap2, att=None): + self.fmap1 = fmap1 + self.fmap2 = fmap2 + + self.att = att + + self.coords = coords_grid(fmap1.shape[0], fmap1.shape[2], fmap1.shape[3], fmap1.device) + + def __call__(self, flow, extra_offset, small_patch=False, iter_mode=False): + if iter_mode: + corr = self.corr_iter(self.fmap1, self.fmap2, flow, small_patch) + else: + corr = self.corr_att_offset( + self.fmap1, self.fmap2, flow, extra_offset, small_patch + ) + return corr + + def get_correlation(self, left_feature, right_feature, psize=(3, 3), dilate=(1, 1)): + + N, C, H, W = left_feature.shape + + di_y, di_x = dilate[0], dilate[1] + pady, padx = psize[0] // 2 * di_y, psize[1] // 2 * di_x + + right_pad = manual_pad(right_feature, pady, padx) + + corr_list = [] + for h in range(0, pady * 2 + 1, di_y): + for w in range(0, padx * 2 + 1, di_x): + right_crop = right_pad[:, :, h : h + H, w : w + W] + assert right_crop.shape == left_feature.shape + corr = torch.mean(left_feature * right_crop, dim=1, keepdims=True) + corr_list.append(corr) + + corr_final = torch.cat(corr_list, dim=1) + + return corr_final + + def corr_iter(self, left_feature, right_feature, flow, small_patch): + + coords = self.coords + flow + coords = coords.permute(0, 2, 3, 1) + right_feature = bilinear_sampler(right_feature, coords) + + if small_patch: + psize_list = [(3, 3), (3, 3), (3, 3), (3, 3)] + dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)] + else: + psize_list = [(1, 9), (1, 9), (1, 9), (1, 9)] + dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)] + + N, C, H, W = left_feature.shape + lefts = torch.split(left_feature, left_feature.shape[1]//4, dim=1) + rights = torch.split(right_feature, right_feature.shape[1]//4, dim=1) + + corrs = [] + for i in range(len(psize_list)): + corr = self.get_correlation( + lefts[i], rights[i], psize_list[i], dilate_list[i] + ) + corrs.append(corr) + + final_corr = torch.cat(corrs, dim=1) + + return final_corr + + def corr_att_offset( + self, left_feature, right_feature, flow, extra_offset, small_patch + ): + + N, C, H, W = left_feature.shape + + if self.att is not None: + left_feature = left_feature.permute(0, 2, 3, 1).reshape(N, H * W, C) # 'n c h w -> n (h w) c' + right_feature = right_feature.permute(0, 2, 3, 1).reshape(N, H * W, C) # 'n c h w -> n (h w) c' + # 'n (h w) c -> n c h w' + left_feature, right_feature = [ + x.reshape(N, H, W, C).permute(0, 3, 1, 2) + for x in [left_feature, right_feature] + ] + + lefts = torch.split(left_feature, left_feature.shape[1]//4, dim=1) + rights = torch.split(right_feature, right_feature.shape[1]//4, dim=1) + + C = C // 4 + + if small_patch: + psize_list = [(3, 3), (3, 3), (3, 3), (3, 3)] + dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)] + else: + psize_list = [(1, 9), (1, 9), (1, 9), (1, 9)] + dilate_list = [(1, 1), (1, 1), (1, 1), (1, 1)] + + search_num = 9 + extra_offset = extra_offset.reshape(N, search_num, 2, H, W).permute(0, 1, 3, 4, 2) # [N, search_num, 1, 1, 2] + + corrs = [] + for i in range(len(psize_list)): + left_feature, right_feature = lefts[i], rights[i] + psize, dilate = psize_list[i], dilate_list[i] + + psizey, psizex = psize[0], psize[1] + dilatey, dilatex = dilate[0], dilate[1] + + ry = psizey // 2 * dilatey + rx = psizex // 2 * dilatex + x_grid, y_grid = torch.meshgrid(torch.arange(-rx, rx + 1, dilatex, device=self.fmap1.device), + torch.arange(-ry, ry + 1, dilatey, device=self.fmap1.device), indexing='xy') + + offsets = torch.stack((x_grid, y_grid)) + offsets = offsets.reshape(2, -1).permute(1, 0) + for d in sorted((0, 2, 3)): + offsets = offsets.unsqueeze(d) + offsets = offsets.repeat_interleave(N, dim=0) + offsets = offsets + extra_offset + + coords = self.coords + flow # [N, 2, H, W] + coords = coords.permute(0, 2, 3, 1) # [N, H, W, 2] + coords = torch.unsqueeze(coords, 1) + offsets + coords = coords.reshape(N, -1, W, 2) # [N, search_num*H, W, 2] + + right_feature = bilinear_sampler( + right_feature, coords + ) # [N, C, search_num*H, W] + right_feature = right_feature.reshape(N, C, -1, H, W) # [N, C, search_num, H, W] + left_feature = left_feature.unsqueeze(2).repeat_interleave(right_feature.shape[2], dim=2) + + corr = torch.mean(left_feature * right_feature, dim=1) + + corrs.append(corr) + + final_corr = torch.cat(corrs, dim=1) + + return final_corr diff --git a/nets/crestereo.py b/nets/crestereo.py new file mode 100644 index 0000000..d685d4c --- /dev/null +++ b/nets/crestereo.py @@ -0,0 +1,258 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .update import BasicUpdateBlock +from .extractor import BasicEncoder +from .corr import AGCL + +from .attention import PositionEncodingSine, LocalFeatureTransformer + +try: + autocast = torch.cuda.amp.autocast +except: + # dummy autocast for PyTorch < 1.6 + class autocast: + def __init__(self, enabled): + pass + def __enter__(self): + pass + def __exit__(self, *args): + pass + +#Ref: https://github.com/princeton-vl/RAFT/blob/master/core/raft.py +class CREStereo(nn.Module): + def __init__(self, max_disp=192, mixed_precision=False, test_mode=False): + super(CREStereo, self).__init__() + + self.max_flow = max_disp + self.mixed_precision = mixed_precision + self.test_mode = test_mode + + self.hidden_dim = 128 + self.context_dim = 128 + self.dropout = 0 + + self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=self.dropout) + self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4) + + # loftr + self.self_att_fn = LocalFeatureTransformer( + d_model=256, nhead=8, layer_names=["self"] * 1, attention="linear" + ) + self.cross_att_fn = LocalFeatureTransformer( + d_model=256, nhead=8, layer_names=["cross"] * 1, attention="linear" + ) + + # adaptive search + self.search_num = 9 + self.conv_offset_16 = nn.Conv2d( + 256, self.search_num * 2, kernel_size=3, stride=1, padding=1 + ) + self.conv_offset_8 = nn.Conv2d( + 256, self.search_num * 2, kernel_size=3, stride=1, padding=1 + ) + self.range_16 = 1 + self.range_8 = 1 + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def convex_upsample(self, flow, mask, rate=4): + """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ + N, _, H, W = flow.shape + # print(flow.shape, mask.shape, rate) + mask = mask.view(N, 1, 9, rate, rate, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(rate * flow, [3,3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, rate*H, rate*W) + + def zero_init(self, fmap): + N, C, H, W = fmap.shape + _x = torch.zeros([N, 1, H, W], dtype=torch.float32) + _y = torch.zeros([N, 1, H, W], dtype=torch.float32) + zero_flow = torch.cat((_x, _y), dim=1).to(fmap.device) + return zero_flow + + def forward(self, image1, image2, iters=10, flow_init=None, upsample=True, test_mode=False): + """ Estimate optical flow between pair of frames """ + + image1 = 2 * (image1 / 255.0) - 1.0 + image2 = 2 * (image2 / 255.0) - 1.0 + + image1 = image1.contiguous() + image2 = image2.contiguous() + + hdim = self.hidden_dim + cdim = self.context_dim + + # run the feature network + with autocast(enabled=self.mixed_precision): + fmap1, fmap2 = self.fnet([image1, image2]) + + fmap1 = fmap1.float() + fmap2 = fmap2.float() + + with autocast(enabled=self.mixed_precision): + + # 1/4 -> 1/8 + # feature + fmap1_dw8 = F.avg_pool2d(fmap1, 2, stride=2) + fmap2_dw8 = F.avg_pool2d(fmap2, 2, stride=2) + + # offset + offset_dw8 = self.conv_offset_8(fmap1_dw8) + offset_dw8 = self.range_8 * (torch.sigmoid(offset_dw8) - 0.5) * 2.0 + + # context + net, inp = torch.split(fmap1, [hdim,hdim], dim=1) + net = torch.tanh(net) + inp = F.relu(inp) + net_dw8 = F.avg_pool2d(net, 2, stride=2) + inp_dw8 = F.avg_pool2d(inp, 2, stride=2) + + # 1/4 -> 1/16 + # feature + fmap1_dw16 = F.avg_pool2d(fmap1, 4, stride=4) + fmap2_dw16 = F.avg_pool2d(fmap2, 4, stride=4) + offset_dw16 = self.conv_offset_16(fmap1_dw16) + offset_dw16 = self.range_16 * (torch.sigmoid(offset_dw16) - 0.5) * 2.0 + + # context + net_dw16 = F.avg_pool2d(net, 4, stride=4) + inp_dw16 = F.avg_pool2d(inp, 4, stride=4) + + # positional encoding and self-attention + pos_encoding_fn_small = PositionEncodingSine( + d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16) + ) + # 'n c h w -> n (h w) c' + x_tmp = pos_encoding_fn_small(fmap1_dw16) + fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1]) + # 'n c h w -> n (h w) c' + x_tmp = pos_encoding_fn_small(fmap2_dw16) + fmap2_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1]) + + fmap1_dw16, fmap2_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16) + fmap1_dw16, fmap2_dw16 = [ + x.reshape(x.shape[0], image1.shape[2] // 16, -1, x.shape[2]).permute(0, 3, 1, 2) + for x in [fmap1_dw16, fmap2_dw16] + ] + + corr_fn = AGCL(fmap1, fmap2) + corr_fn_dw8 = AGCL(fmap1_dw8, fmap2_dw8) + corr_fn_att_dw16 = AGCL(fmap1_dw16, fmap2_dw16, att=self.cross_att_fn) + + # Cascaded refinement (1/16 + 1/8 + 1/4) + predictions = [] + flow = None + flow_up = None + if flow_init is not None: + scale = fmap1.shape[2] / flow_init.shape[2] + flow = -scale * F.interpolate( + flow_init, + size=(fmap1.shape[2], fmap1.shape[3]), + mode="bilinear", + align_corners=True, + ) + else: + # zero initialization + flow_dw16 = self.zero_init(fmap1_dw16) + + # Recurrent Update Module + # RUM: 1/16 + for itr in range(iters // 2): + if itr % 2 == 0: + small_patch = False + else: + small_patch = True + + flow_dw16 = flow_dw16.detach() + out_corrs = corr_fn_att_dw16( + flow_dw16, offset_dw16, small_patch=small_patch + ) + + with autocast(enabled=self.mixed_precision): + net_dw16, up_mask, delta_flow = self.update_block( + net_dw16, inp_dw16, out_corrs, flow_dw16 + ) + + flow_dw16 = flow_dw16 + delta_flow + flow = self.convex_upsample(flow_dw16, up_mask, rate=4) + flow_up = -4 * F.interpolate( + flow, + size=(4 * flow.shape[2], 4 * flow.shape[3]), + mode="bilinear", + align_corners=True, + ) + predictions.append(flow_up) + + scale = fmap1_dw8.shape[2] / flow.shape[2] + flow_dw8 = -scale * F.interpolate( + flow, + size=(fmap1_dw8.shape[2], fmap1_dw8.shape[3]), + mode="bilinear", + align_corners=True, + ) + + # RUM: 1/8 + for itr in range(iters // 2): + if itr % 2 == 0: + small_patch = False + else: + small_patch = True + + flow_dw8 = flow_dw8.detach() + out_corrs = corr_fn_dw8(flow_dw8, offset_dw8, small_patch=small_patch) + + with autocast(enabled=self.mixed_precision): + net_dw8, up_mask, delta_flow = self.update_block( + net_dw8, inp_dw8, out_corrs, flow_dw8 + ) + + flow_dw8 = flow_dw8 + delta_flow + flow = self.convex_upsample(flow_dw8, up_mask, rate=4) + flow_up = -2 * F.interpolate( + flow, + size=(2 * flow.shape[2], 2 * flow.shape[3]), + mode="bilinear", + align_corners=True, + ) + predictions.append(flow_up) + + scale = fmap1.shape[2] / flow.shape[2] + flow = -scale * F.interpolate( + flow, + size=(fmap1.shape[2], fmap1.shape[3]), + mode="bilinear", + align_corners=True, + ) + + # RUM: 1/4 + for itr in range(iters): + if itr % 2 == 0: + small_patch = False + else: + small_patch = True + + flow = flow.detach() + out_corrs = corr_fn(flow, None, small_patch=small_patch, iter_mode=True) + + with autocast(enabled=self.mixed_precision): + net, up_mask, delta_flow = self.update_block(net, inp, out_corrs, flow) + + flow = flow + delta_flow + flow_up = -self.convex_upsample(flow, up_mask, rate=4) + predictions.append(flow_up) + + if self.test_mode: + return flow_up + + return predictions diff --git a/nets/extractor.py b/nets/extractor.py new file mode 100644 index 0000000..0faf510 --- /dev/null +++ b/nets/extractor.py @@ -0,0 +1,123 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +# Ref: https://github.com/princeton-vl/RAFT/blob/master/core/extractor.py +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + x = self.downsample(x) + + return self.relu(x+y) + + +class BasicEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=2) + self.layer3 = self._make_layer(128, stride=1) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, x.shape[0]//2, dim=0) + + return x \ No newline at end of file diff --git a/nets/update.py b/nets/update.py new file mode 100644 index 0000000..401d504 --- /dev/null +++ b/nets/update.py @@ -0,0 +1,91 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +#Ref: https://github.com/princeton-vl/RAFT/blob/master/core/update.py +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + + self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + return h + + +class BasicMotionEncoder(nn.Module): + def __init__(self, cor_planes): + super(BasicMotionEncoder, self).__init__() + + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + + +class BasicUpdateBlock(nn.Module): + def __init__(self, hidden_dim, cor_planes, mask_size=8): + super(BasicUpdateBlock, self).__init__() + + self.encoder = BasicMotionEncoder(cor_planes) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, mask_size**2 *9, 1, padding=0)) + + def forward(self, net, inp, corr, flow, upsample=True): + # print(inp.shape, corr.shape, flow.shape) + motion_features = self.encoder(flow, corr) + # print(motion_features.shape, inp.shape) + inp = torch.cat((inp, motion_features), dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = .25 * self.mask(net) + return net, mask, delta_flow diff --git a/nets/utils/__init__.py b/nets/utils/__init__.py new file mode 100644 index 0000000..8cc08e5 --- /dev/null +++ b/nets/utils/__init__.py @@ -0,0 +1 @@ +from .utils import bilinear_sampler, coords_grid, manual_pad \ No newline at end of file diff --git a/nets/utils/utils.py b/nets/utils/utils.py new file mode 100644 index 0000000..96075ca --- /dev/null +++ b/nets/utils/utils.py @@ -0,0 +1,31 @@ +import torch +import torch.nn.functional as F +import numpy as np + +#Ref: https://github.com/princeton-vl/RAFT/blob/master/core/utils/utils.py + +def bilinear_sampler(img, coords, mode='bilinear', mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + +def coords_grid(batch, ht, wd, device): + coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device), indexing='ij') + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + +def manual_pad(x, pady, padx): + + pad = (padx, padx, pady, pady) + return F.pad(x.clone().detach(), pad, "replicate") diff --git a/test_model.py b/test_model.py new file mode 100644 index 0000000..6a3d971 --- /dev/null +++ b/test_model.py @@ -0,0 +1,15 @@ +import torch +from torchsummary import summary +import numpy as np + +from nets import Model + +model = Model(max_disp=256, mixed_precision=False, test_mode=True) +model.eval() + +t1 = torch.rand(1, 3, 480, 640) +t2 = torch.rand(1, 3, 480, 640) + +output = model(t1,t2) +print(output.shape) +