CREStereo Repository for the 'Towards accurate and robust depth estimation' project
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

47 lines
1.5 KiB

3 years ago
import pickle
import numpy as np
import megengine as mge
import torch
import torch.nn.functional as F
def test_split():
# Getting back the megengine objects:
with open('test_data/split_test.pkl', 'rb') as f:
left_feature, size, axis, lefts = pickle.load(f)
left_feature = torch.tensor(left_feature.numpy())
# Test Pytorch
lefts_pytorch = torch.split(left_feature, left_feature.shape[axis]//size, dim=axis)
for i, (left_pytorch, left) in enumerate(zip(lefts_pytorch, lefts)):
error = np.mean(left_pytorch.numpy()-left.numpy())
print(f"test_split {i} - Avg. Error: {error}, \n \
Obtained shape: {left_pytorch.numpy().shape}, Expected shape: {left.numpy().shape}\n")
def test_split_list():
# Getting back the megengine objects:
with open('test_data/split_test_list.pkl', 'rb') as f:
fmap1, size, axis, net, inp = pickle.load(f)
fmap1 = torch.tensor(fmap1.numpy())
net = net.numpy()
inp = inp.numpy()
# Test Pytorch
net_pytorch, inp_pytorch = torch.split(fmap1, [size[0],size[0]], dim=axis)
error_net = np.mean(net_pytorch.numpy()-net)
error_inp = np.mean(inp_pytorch.numpy()-inp)
print(f"test_split_list (net) - Avg. Error: {error_net}, \n \
Obtained shape: {net_pytorch.numpy().shape}, Expected shape: {net.shape}\n")
print(f"test_split_list (inp) - Avg. Error: {error_inp}, \n \
Obtained shape: {inp_pytorch.numpy().shape}, Expected shape: {inp.shape}\n")
if __name__ == '__main__':
test_split()
test_split_list()