#1 Added weight conversion to Pytorch

main
Ibai 3 years ago
parent ecd3a40505
commit 1bdbb27240
  1. 26
      convert_weights.py
  2. BIN
      doc/img/output.jpg
  3. 0
      models/.gitkeep
  4. 58
      test_model.py

@ -0,0 +1,26 @@
import copy
import torch
import numpy as np
import megengine as mge
from nets import Model
# Read Megengine parameters
pretrained_dict = mge.load("models/crestereo_eth3d.mge")
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
model.eval()
state_dict = model.state_dict()
for key, value in pretrained_dict['state_dict'].items():
print(f"Converting {key}")
# Fix shape mismatch
if value.shape[0] == 1:
value = np.squeeze(value)
state_dict[key] = torch.tensor(value)
output_path = "models/crestereo_eth3d.pth"
torch.save(state_dict, output_path)
print(f"\nModel saved to: {output_path}")

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

@ -1,13 +1,63 @@
import torch
import torch.nn.functional as F
import numpy as np
import cv2
from imread_from_url import imread_from_url
from nets import Model
#Ref: https://github.com/megvii-research/CREStereo/blob/master/test.py
def inference(left, right, model, n_iter=20):
print("Model Forwarding...")
imgL = left.transpose(2, 0, 1)
imgR = right.transpose(2, 0, 1)
imgL = np.ascontiguousarray(imgL[None, :, :, :])
imgR = np.ascontiguousarray(imgR[None, :, :, :])
imgL = torch.tensor(imgL.astype("float32"))
imgR = torch.tensor(imgR.astype("float32"))
imgL_dw2 = F.interpolate(
imgL,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
)
imgR_dw2 = F.interpolate(
imgR,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
)
# print(imgR_dw2.shape)
pred_flow_dw2 = model(imgL_dw2, imgR_dw2, iters=n_iter, flow_init=None)
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).detach().numpy()
return pred_disp
if __name__ == '__main__':
left_img = imread_from_url("https://vision.middlebury.edu/stereo/data/scenes2003/newdata/cones/im2.png")
right_img = imread_from_url("https://vision.middlebury.edu/stereo/data/scenes2003/newdata/cones/im6.png")
model_path = "models/crestereo_eth3d.pth"
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
model.load_state_dict(torch.load(model_path), strict=True)
model.eval()
t1 = torch.rand(1, 3, 480, 640)
t2 = torch.rand(1, 3, 480, 640)
disp = inference(left_img, right_img, model, n_iter=20)
disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0
disp_vis = disp_vis.astype("uint8")
disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO)
cv2.imshow("output", disp_vis)
cv2.imwrite("output.jpg", disp_vis)
cv2.waitKey(0)
output = model(t1,t2)
print(output.shape)

Loading…
Cancel
Save