diff --git a/nets/attention/position_encoding.py b/nets/attention/position_encoding.py index 3ad49f1..6ead877 100644 --- a/nets/attention/position_encoding.py +++ b/nets/attention/position_encoding.py @@ -18,7 +18,6 @@ class PositionEncodingSine(nn.Module): We will remove the buggy impl after re-training all variants of our released models. """ super().__init__() - pe = torch.zeros((d_model, *max_shape)) y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0) x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0) @@ -39,4 +38,4 @@ class PositionEncodingSine(nn.Module): Args: x: [N, C, H, W] """ - return x + self.pe[:, :, :x.size(2), :x.size(3)] \ No newline at end of file + return x + self.pe[:, :, :x.size(2), :x.size(3)].to(x.device) \ No newline at end of file diff --git a/test_model.py b/test_model.py index 90e3b8a..7030a6d 100644 --- a/test_model.py +++ b/test_model.py @@ -6,17 +6,19 @@ from imread_from_url import imread_from_url from nets import Model +device = 'cuda' + #Ref: https://github.com/megvii-research/CREStereo/blob/master/test.py def inference(left, right, model, n_iter=20): - + print("Model Forwarding...") imgL = left.transpose(2, 0, 1) imgR = right.transpose(2, 0, 1) imgL = np.ascontiguousarray(imgL[None, :, :, :]) imgR = np.ascontiguousarray(imgR[None, :, :, :]) - imgL = torch.tensor(imgL.astype("float32")) - imgR = torch.tensor(imgR.astype("float32")) + imgL = torch.tensor(imgL.astype("float32")).to(device) + imgR = torch.tensor(imgR.astype("float32")).to(device) imgL_dw2 = F.interpolate( imgL, @@ -35,26 +37,32 @@ def inference(left, right, model, n_iter=20): pred_flow_dw2 = model(imgL_dw2, imgR_dw2, iters=n_iter, flow_init=None) pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2) - pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).detach().numpy() + pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy() return pred_disp if __name__ == '__main__': - left_img = imread_from_url("https://vision.middlebury.edu/stereo/data/scenes2003/newdata/cones/im2.png") - right_img = imread_from_url("https://vision.middlebury.edu/stereo/data/scenes2003/newdata/cones/im6.png") + left_img = imread_from_url("https://raw.githubusercontent.com/megvii-research/CREStereo/master/img/test/left.png") + right_img = imread_from_url("https://raw.githubusercontent.com/megvii-research/CREStereo/master/img/test/right.png") + + # Resize image in case the GPU memory overflows + eval_h, eval_w = (240,426) + imgL = cv2.resize(left, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) + imgR = cv2.resize(right, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR) model_path = "models/crestereo_eth3d.pth" model = Model(max_disp=256, mixed_precision=False, test_mode=True) model.load_state_dict(torch.load(model_path), strict=True) + model.to(device) model.eval() - disp = inference(left_img, right_img, model, n_iter=20) + disp = inference(imgL, imgR, model, n_iter=20) disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0 disp_vis = disp_vis.astype("uint8") disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO) - disp_vis = cv2.resize(disp_vis, left_img.shape[1::-1]) + left_img = cv2.resize(left_img, disp_vis.shape[1::-1]) combined_img = np.hstack((left_img, disp_vis)) cv2.imshow("output", combined_img)