CREStereo Repository for the 'Towards accurate and robust depth estimation' project
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

124 lines
4.5 KiB

3 years ago
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
3 years ago
from nets import Model
import wandb
import random
from import DataLoader
from dataset import CTDDataset
from train import normalize_and_colormap, parse_yaml, inference as ctd_inference
device = 'cuda'
wandb.init(project="crestereo", entity="cpt-captain")
def do_infer(left_img, right_img, gt_disp, model):
disp = ctd_inference(left_img, right_img, gt_disp, None, model, None, n_iter=20, wandb_log=False)
disp_vis = normalize_and_colormap(disp)
gt_disp_vis = normalize_and_colormap(gt_disp)
if gt_disp.shape != disp.shape:
gt_disp = gt_disp.reshape(disp.shape)
disp_err = gt_disp - disp
disp_err = normalize_and_colormap(disp_err.abs())
'disp': wandb.Image(
caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}",
'disp_vis': wandb.Image(
caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}",
'gt_disp_vis': wandb.Image(
caption=f"GT Disparity \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}",
'disp_err': wandb.Image(
caption=f"Disparity Error\n{disp_err.min():.{2}f}/{disp_err.max():.{2}f}",
'input_left': wandb.Image(
caption=f"Input left",
'input_right': wandb.Image(
caption=f"Input right",
if __name__ == '__main__':
# model_path = "models/crestereo_eth3d.pth"
model_path = "train_log/models/latest.pth"
# reference_pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png'
reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
# reference_pattern_path = '/home/nils/new_reference.png'
# reference_pattern_path = '/home/nils/kinect_reference_high_res.png'
# reference_pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png'
data_type = 'kinect'
augment = False
args = parse_yaml("cfgs/train.yaml")
wandb.config.update({'model_path': model_path, 'reference_pattern': reference_pattern_path, 'augment': augment})
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
model = nn.DataParallel(model, device_ids=[device])
state_dict = torch.load(model_path)['state_dict']
model.load_state_dict(state_dict, strict=True)
CTD = True
if not CTD:
left_img = cv2.imread("../test_imgs/left.png")
right_img = cv2.imread("../test_imgs/right.png")
in_h, in_w = left_img.shape[:2]
# Resize image in case the GPU memory overflows
eval_h, eval_w = (in_h, in_w)
# FIXME borked for some reason, hopefully not very important
imgL = cv2.resize(left_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
imgR = cv2.resize(right_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
pred = inference(imgL, imgR, model, n_iter=20)
t = float(in_w) / float(eval_w)
disp = cv2.resize(pred, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t
disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0
disp_vis = disp_vis.astype("uint8")
disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO)
combined_img = np.hstack((left_img, disp_vis))
# cv2.namedWindow("output", cv2.WINDOW_NORMAL)
# cv2.imshow("output", combined_img)
cv2.imwrite("output.jpg", disp_vis)
# cv2.waitKey(0)
dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type,
pattern_path=reference_pattern_path, augment=augment)
dataloader = DataLoader(dataset, args.batch_size, shuffle=True,
num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True)
for batch in dataloader:
for left, right, disparity in zip(batch['left'], batch['right'], batch['disparity']):
right = right.transpose(0, 2).transpose(0, 1)
left_img = left
imgL = left.cpu().detach().numpy()
right_img = right
imgR = right.cpu().detach().numpy()
gt_disp = disparity
do_infer(left_img, right_img, gt_disp, model)