test_model.py: reformat
This commit is contained in:
parent
17bf30fa2a
commit
9740e5d647
@ -17,10 +17,8 @@ device = 'cuda'
|
|||||||
wandb.init(project="crestereo", entity="cpt-captain")
|
wandb.init(project="crestereo", entity="cpt-captain")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Ref: https://github.com/megvii-research/CREStereo/blob/master/test.py
|
# Ref: https://github.com/megvii-research/CREStereo/blob/master/test.py
|
||||||
def inference(left, right, model, n_iter=20):
|
def inference(left, right, model, n_iter=20):
|
||||||
|
|
||||||
print("Model Forwarding...")
|
print("Model Forwarding...")
|
||||||
imgL = left.transpose(2, 0, 1)
|
imgL = left.transpose(2, 0, 1)
|
||||||
imgR = right.transpose(2, 0, 1)
|
imgR = right.transpose(2, 0, 1)
|
||||||
@ -53,7 +51,6 @@ def inference(left, right, model, n_iter=20):
|
|||||||
|
|
||||||
|
|
||||||
def inference_ctd(left, right, gt_disp, mask, model, epoch, n_iter=20):
|
def inference_ctd(left, right, gt_disp, mask, model, epoch, n_iter=20):
|
||||||
|
|
||||||
print("Model Forwarding...")
|
print("Model Forwarding...")
|
||||||
# print(left.shape)
|
# print(left.shape)
|
||||||
left = left.cpu().detach().numpy()
|
left = left.cpu().detach().numpy()
|
||||||
@ -111,13 +108,12 @@ def inference_ctd(left, right, gt_disp, mask, model, epoch, n_iter=20):
|
|||||||
caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
|
caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left")
|
log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left")
|
||||||
log['input_right'] = wandb.Image(right.cpu().detach().numpy().transpose(1,2,0).astype('uint8'), caption="Input Right")
|
log['input_right'] = wandb.Image(right.cpu().detach().numpy().transpose(1, 2, 0).astype('uint8'),
|
||||||
|
caption="Input Right")
|
||||||
|
|
||||||
log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}")
|
log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}")
|
||||||
|
|
||||||
|
|
||||||
disp_error = gt_disp - disp
|
disp_error = gt_disp - disp
|
||||||
log['disp_error'] = wandb.Image(
|
log['disp_error'] = wandb.Image(
|
||||||
normalize_and_colormap(disp_error),
|
normalize_and_colormap(disp_error),
|
||||||
@ -178,7 +174,6 @@ def do_infer(left_img, right_img, gt_disp, model):
|
|||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
# model_path = "models/crestereo_eth3d.pth"
|
# model_path = "models/crestereo_eth3d.pth"
|
||||||
model_path = "train_log/models/latest.pth"
|
model_path = "train_log/models/latest.pth"
|
||||||
@ -233,7 +228,8 @@ if __name__ == '__main__':
|
|||||||
# cv2.waitKey(0)
|
# cv2.waitKey(0)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type, pattern_path=reference_pattern_path, augment=augment)
|
dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type,
|
||||||
|
pattern_path=reference_pattern_path, augment=augment)
|
||||||
dataloader = DataLoader(dataset, args.batch_size, shuffle=True,
|
dataloader = DataLoader(dataset, args.batch_size, shuffle=True,
|
||||||
num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True)
|
num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True)
|
||||||
for batch in dataloader:
|
for batch in dataloader:
|
||||||
@ -245,4 +241,3 @@ if __name__ == '__main__':
|
|||||||
imgR = right.cpu().detach().numpy()
|
imgR = right.cpu().detach().numpy()
|
||||||
gt_disp = disparity
|
gt_disp = disparity
|
||||||
do_infer(left_img, right_img, gt_disp, model)
|
do_infer(left_img, right_img, gt_disp, model)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user