CREStereo Repository for the 'Towards accurate and robust depth estimation' project
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

113 lines
3.3 KiB

import requests
from cv2 import cv2
import numpy as np
import json
import os
from datetime import datetime
API_URL = 'http://127.0.0.1:8000'
img_dir = '../../usable_imgs/'
cv2.namedWindow('Input Image')
cv2.namedWindow('Predicted Disparity')
# epoch 75 ist weird
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
def normalize_and_colormap(img):
ret = (img - img.min()) / (img.max() - img.min()) * 255.0
ret = ret.astype("uint8")
ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO)
return ret
def change_epoch():
epoch = input('Enter epoch number or "latest"\n')
r = requests.post(f'{API_URL}/model/update/{epoch}')
print(r.text)
def extract_data(data):
# FIXME yuck, don't json the json
duration = data['duration']
# get result and rotate 90 deg
pred_disp = cv2.transpose(np.asarray(data['disp'], dtype='uint8'))
if 'input' not in data:
return pred_disp, duration
in_img = np.asarray(data['input'], dtype='uint8').transpose((2, 0, 1))
ref_pat = np.asarray(data['reference'], dtype='uint8').transpose((2, 0, 1))
return pred_disp, in_img, ref_pat, duration
def downsize_input_img():
input_img = cv2.imread(img.path)
if input_img.shape == (1024, 1280, 3):
diff = (512 - 480) // 2
downsampled = cv2.pyrDown(input_img)
input_img = downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]]
cv2.imwrite('buffer.png', input_img)
def put_image(img_path):
openBin = {'file': ('file', open(img_path, 'rb'), 'image/png')}
print('sending image')
r = requests.put(f'{API_URL}/ir', files=openBin)
print('received response')
r.raise_for_status()
data = json.loads(json.loads(r.text))
return data
def change_minimal_data(enabled):
r = requests.post(f'{API_URL}/params/minimal_data/{not enabled}')
cv2.destroyWindow('Input Image')
cv2.destroyWindow('Reference Image')
if __name__ == '__main__':
while True:
for img in os.scandir(img_dir):
start = datetime.now()
if 'ir' not in img.path:
continue
# alternatively: use img.path for native size
downsize_input_img()
data = put_image('buffer.png')
if 'input' in data:
pred_disp, in_img, ref_pat, duration = extract_data(data)
else:
pred_disp, duration = extract_data(data)
print(f'inference took {duration:1.4f}s')
print(f'pipeline and transfer took another {(datetime.now() - start).total_seconds() - float(duration):1.4f}s')
print(f"Pred. Disparity: \n\t{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}\n")
if 'input' in data:
cv2.imshow('Input Image', in_img)
cv2.imshow('Reference Image', ref_pat)
cv2.imshow('Normalized Predicted Disparity', normalize_and_colormap(pred_disp))
cv2.imshow('Predicted Disparity', pred_disp)
key = cv2.waitKey()
if key == 113:
quit()
elif key == 101:
change_epoch()
elif key == 109:
change_minimal_data('input' not in data)