You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
47 lines
1.5 KiB
47 lines
1.5 KiB
import pickle
|
|
import numpy as np
|
|
import megengine as mge
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
def test_split():
|
|
# Getting back the megengine objects:
|
|
with open('test_data/split_test.pkl', 'rb') as f:
|
|
left_feature, size, axis, lefts = pickle.load(f)
|
|
|
|
left_feature = torch.tensor(left_feature.numpy())
|
|
|
|
# Test Pytorch
|
|
lefts_pytorch = torch.split(left_feature, left_feature.shape[axis]//size, dim=axis)
|
|
|
|
for i, (left_pytorch, left) in enumerate(zip(lefts_pytorch, lefts)):
|
|
|
|
error = np.mean(left_pytorch.numpy()-left.numpy())
|
|
print(f"test_split {i} - Avg. Error: {error}, \n \
|
|
Obtained shape: {left_pytorch.numpy().shape}, Expected shape: {left.numpy().shape}\n")
|
|
|
|
def test_split_list():
|
|
# Getting back the megengine objects:
|
|
with open('test_data/split_test_list.pkl', 'rb') as f:
|
|
fmap1, size, axis, net, inp = pickle.load(f)
|
|
|
|
fmap1 = torch.tensor(fmap1.numpy())
|
|
net = net.numpy()
|
|
inp = inp.numpy()
|
|
|
|
# Test Pytorch
|
|
net_pytorch, inp_pytorch = torch.split(fmap1, [size[0],size[0]], dim=axis)
|
|
|
|
error_net = np.mean(net_pytorch.numpy()-net)
|
|
error_inp = np.mean(inp_pytorch.numpy()-inp)
|
|
print(f"test_split_list (net) - Avg. Error: {error_net}, \n \
|
|
Obtained shape: {net_pytorch.numpy().shape}, Expected shape: {net.shape}\n")
|
|
print(f"test_split_list (inp) - Avg. Error: {error_inp}, \n \
|
|
Obtained shape: {inp_pytorch.numpy().shape}, Expected shape: {inp.shape}\n")
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
test_split()
|
|
test_split_list() |