CREStereo Repository for the 'Towards accurate and robust depth estimation' project
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

352 lines
10 KiB

import signal
from time import sleep
# import requests
import httpx as requests
import asyncio
import open3d as o3d
from cv2 import cv2
import numpy as np
import json
import os
from datetime import datetime
API_URL = 'http://127.0.0.1:8000'
img_dir = '../../usable_imgs/'
cv2.namedWindow('Input Image')
cv2.namedWindow('Predicted Disparity')
vis = o3d.visualization.VisualizerWithKeyCallback()
viscont = o3d.visualization.ViewControl()
# vis.register_key_callback(99)
vis.create_window()
K = np.array([[567.6, 0, 324.7], [0, 570.2, 250.1], [0, 0, 1]], dtype=np.float32)
# temporal_init = requests.get(f'{API_URL}/temporal_init')
good_models = [260, 183]
interesting = [214, ]
# new ganz gut bei ca 175
verbose = False
running_tasks = set()
minimal_data = False
with open('frontend.pid', 'w+') as f:
print('writing pid')
f.write(str(os.getpid()))
# epoch 75 ist weird
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
def update_vis(*args):
vis.poll_events()
vis.update_renderer()
# signal.signal(signal.SIGALRM, update_vis)
# signal.setitimer(signal.ITIMER_REAL, 0.1, 0.1)
def ghetto_lcn(img):
# gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
gray = img
float_gray = gray.astype(np.float32) / 255.0
blur = cv2.GaussianBlur(float_gray, (0, 0), sigmaX=2, sigmaY=2)
num = float_gray - blur
blur = cv2.GaussianBlur(num * num, (0, 0), sigmaX=20, sigmaY=20)
den = cv2.pow(blur, 0.5)
gray = num / den
# cv2.normalize(gray, dst=gray, alpha=0.0, beta=1.0, norm_type=cv2.NORM_MINMAX)
cv2.normalize(gray, dst=gray, alpha=0.0, beta=255.0, norm_type=cv2.NORM_MINMAX)
return gray
def normalize_and_colormap(img):
ret = (img - img.min()) / (img.max() - img.min()) * 255.0
ret = ret.astype("uint8")
ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO)
return ret
def reproject(disparity_img):
print('reprojecting')
baseline = 0.075
depth_img = baseline * K[0][0] / disparity_img
pointcloud = o3d.geometry.PointCloud()
intrinsics = o3d.pybind.camera.PinholeCameraIntrinsic()
print('setting intrinsics')
intrinsics.set_intrinsics(width=640, height=480, fx=K[0][0], fy=K[1][1], cx=0., cy=0.)
# depth = open3d.geometry.Image(depth_img.astype('float32'))
rgb = normalize_and_colormap(disparity_img)
rgb = o3d.geometry.Image(rgb * 255)
print(depth_img.max(), depth_img.min())
depth_img = np.log(depth_img + (1 - depth_img.min()) + 1)
print(depth_img.max(), depth_img.min())
depth = o3d.geometry.Image(depth_img.astype('float32'))
rgb_depth = o3d.geometry.RGBDImage().create_from_color_and_depth(
color=rgb,
depth=depth,
depth_scale=1,
convert_rgb_to_intensity=False,
)
print('creating pointcloud')
# depth = open3d.cpu.pybind.t.geometry.Image(depth_img.astype('float32'))
# depth.colorize_depth(1.0, 0., 1.)
# print('now really creating pointcloud')
# dpcd = pointcloud.create_from_depth_image(
# depth=depth,
# intrinsic=intrinsics,
# )
# print(type(depth))
pcd = pointcloud.create_from_rgbd_image(
image=rgb_depth,
intrinsic=intrinsics,
# project_valid_depth_only=False,
)
flip_transform = [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
# dpcd.paint_uniform_color(np.asarray([0.5, 0.4, 0.25]))
pcd.transform(flip_transform)
# dpcd.transform(flip_transform)
pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0)
print('drawing pointcloud')
# vis.clear_geometries()
# vis.add_geometry(pcd, reset_bounding_box=True)
# viscont.rotate(250, 500)
vis.update_geometry(pcd)
vis.poll_events()
vis.update_renderer()
# vis.run()
# o3d.visualization.draw(geometry=[rgb_depth])
# o3d.visualization.draw([dpcd])
def change_epoch(epoch: int = None):
if epoch is None:
epoch = input('Enter epoch number or "latest"\n')
r = requests.post(f'{API_URL}/model/update/{epoch}')
# print(r.text)
def change_reference():
r = requests.post(f'{API_URL}/params/update_reference')
print(r.json()['status'])
if r.json()['status'] == 'finished':
change_reference()
def extract_data(data):
# FIXME yuck, don't json the json
duration = data['duration']
# get result and rotate 90 deg
# pred_disp = cv2.transpose(np.asarray(data['disp'], dtype='uint8'))
raw_disp = np.asarray(data['disp'])
# print(raw_disp.min(), raw_disp.max())
if raw_disp.min() < 0:
# print('Negative disparity detected. shifting...')
raw_disp = raw_disp - raw_disp.min()
if raw_disp.max() > 255:
# print('Excessive disparity detected. scaling...')
raw_disp = raw_disp / (raw_disp.max() / 255)
pred_disp = np.asarray(raw_disp, dtype='uint8')
# if 'input' not in data:
if len(data) == 2:
return pred_disp, duration
ref_pat = data.get('reference', None)
in_img = np.asarray(data['input'], dtype='uint8') # .transpose((2, 0, 1))
if ref_pat:
ref_pat = np.asarray(ref_pat, dtype='uint8') # .transpose((2, 0, 1))
return pred_disp, in_img, ref_pat, duration
def downsize_input_img(path):
input_img = None
while input_img is None:
input_img = cv2.imread(path)
if input_img.shape == (1024, 1280, 3):
diff = (512 - 480) // 2
downsampled = cv2.pyrDown(input_img)
input_img = downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]]
# print(input_img.shape)
input_img = cv2.normalize(input_img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
# input_img = ghetto_lcn(input_img)
cv2.imwrite('buffer.png', input_img)
async def put_image(img_path):
openBin = {'file': ('file', open(img_path, 'rb'), 'image/png')}
if verbose:
print('sending image')
async with requests.AsyncClient() as client:
r = await client.put(f'{API_URL}/ir', files=openBin)
if verbose:
print('received response')
r.raise_for_status()
data = json.loads(json.loads(r.text))
return data
def change_minimal_data(current: bool = None):
global minimal_data
if current is None:
current = minimal_data
minimal_data = not current
r = requests.post(f'{API_URL}/params/minimal_data/{minimal_data}')
cv2.destroyWindow('Input Image')
cv2.destroyWindow('Reference Image')
def change_temporal_init(enabled):
global temporal_init
r = requests.post(f'{API_URL}/params/temporal_init/{not enabled}')
temporal_init = not temporal_init
def handle_keypress(key):
if key == 113:
quit()
elif key == 101:
change_epoch()
elif key == 109:
change_minimal_data()
elif key == 116:
change_temporal_init(temporal_init)
elif key == 99:
change_reference()
async def do_inference():
start = datetime.now()
data = await put_image('buffer.png')
in_img = None
ref_pat = None
if len(data) == 4:
pred_disp, in_img, ref_pat, duration = extract_data(data)
elif len(data) == 2:
pred_disp, duration = extract_data(data)
reproject(pred_disp)
show_results(duration, in_img, pred_disp, ref_pat, start)
# reproject(pred_disp)
def show_results(duration, in_img, pred_disp, ref_pat, start):
print(f"Pred. Disparity: \n\t{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}")
if verbose:
print(f'inference took {duration:1.4f}s')
print(f'pipeline and transfer took another {(datetime.now() - start).total_seconds() - float(duration):1.4f}s')
print(f'total {(datetime.now() - start).total_seconds():1.4f}s')
if in_img is not None:
cv2.imshow('Input Image', in_img)
else:
cv2.imshow('Input Image', cv2.imread('buffer.png'))
if ref_pat is not None:
cv2.imshow('Reference Image', ref_pat)
cv2.imshow('Normalized Predicted Disparity', normalize_and_colormap(pred_disp))
cv2.imshow('Predicted Disparity', pred_disp)
key = cv2.waitKey(1000)
handle_keypress(key)
async def fresh_img():
# print('running task')
start = datetime.now()
print(f'started at {start}')
downsize_input_img('kinect_ir.png')
await do_inference()
print(f'task took {(datetime.now() - start).total_seconds()}')
print()
def create_task(*args):
global running_tasks
# print('received signal')
print(f'currently running: {len(running_tasks)} tasks')
task = asyncio.create_task(fresh_img())
# print(f'created task {task.get_name()}')
running_tasks.add(task)
task.add_done_callback(running_tasks.discard)
# await task
# return task
signal.signal(signal.SIGUSR1, create_task)
async def run_test(img_dir, iterate_checkpoints):
img_dir = list(os.scandir(img_dir))
for epoch in range(175, 270):
if iterate_checkpoints:
change_epoch(epoch)
print()
print(f'loaded epoch {epoch}')
for img in img_dir:
if 'ir' not in img.path:
continue
# alternatively: use img.path for native size
downsize_input_img(img.path)
# asyncio.run(do_inference())
await do_inference()
await asyncio.sleep(10)
async def main():
use_live_data = True
iterate_checkpoints = False
# change_epoch(good_models[1])
# change_epoch('latest')
o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Info)
change_epoch(150)
change_minimal_data(False)
await asyncio.sleep(50000)
# signal.signal(signal.SIGBUS, lambda x: print('received sigbus'))
# loop = asyncio.get_running_loop()
# loop.run_forever()
# loop = asyncio.get_event_loop()
while True:
# create_task()
# await asyncio.sleep(0.1)
# await run_test(img_dir, iterate_checkpoints)
await asyncio.sleep(50000)
# print('[main] slept')
# if use_live_data:
# signal.pause()
# else:
# await run_test(img_dir, iterate_checkpoints)
#
if __name__ == '__main__':
asyncio.run(main())