frontend/__init__.py: major refactor, still wip
now uses async to pipeline images for inference, triggered by signals, generates pointclouds from depth and displays them
This commit is contained in:
parent
1eefa2847b
commit
ed80c3056f
@ -1,4 +1,11 @@
|
||||
import requests
|
||||
import signal
|
||||
from time import sleep
|
||||
|
||||
# import requests
|
||||
import httpx as requests
|
||||
import asyncio
|
||||
import open3d as o3d
|
||||
|
||||
from cv2 import cv2
|
||||
import numpy as np
|
||||
import json
|
||||
@ -13,6 +20,27 @@ img_dir = '../../usable_imgs/'
|
||||
cv2.namedWindow('Input Image')
|
||||
cv2.namedWindow('Predicted Disparity')
|
||||
|
||||
vis = o3d.visualization.VisualizerWithKeyCallback()
|
||||
viscont = o3d.visualization.ViewControl()
|
||||
# vis.register_key_callback(99)
|
||||
vis.create_window()
|
||||
|
||||
K = np.array([[567.6, 0, 324.7], [0, 570.2, 250.1], [0, 0, 1]], dtype=np.float32)
|
||||
|
||||
# temporal_init = requests.get(f'{API_URL}/temporal_init')
|
||||
|
||||
good_models = [260, 183]
|
||||
interesting = [214, ]
|
||||
# new ganz gut bei ca 175
|
||||
verbose = False
|
||||
|
||||
running_tasks = set()
|
||||
minimal_data = False
|
||||
|
||||
with open('frontend.pid', 'w+') as f:
|
||||
print('writing pid')
|
||||
f.write(str(os.getpid()))
|
||||
|
||||
|
||||
# epoch 75 ist weird
|
||||
|
||||
@ -24,6 +52,34 @@ class NumpyEncoder(json.JSONEncoder):
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
|
||||
|
||||
def update_vis(*args):
|
||||
vis.poll_events()
|
||||
vis.update_renderer()
|
||||
|
||||
|
||||
# signal.signal(signal.SIGALRM, update_vis)
|
||||
# signal.setitimer(signal.ITIMER_REAL, 0.1, 0.1)
|
||||
|
||||
|
||||
def ghetto_lcn(img):
|
||||
# gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
gray = img
|
||||
|
||||
float_gray = gray.astype(np.float32) / 255.0
|
||||
|
||||
blur = cv2.GaussianBlur(float_gray, (0, 0), sigmaX=2, sigmaY=2)
|
||||
num = float_gray - blur
|
||||
|
||||
blur = cv2.GaussianBlur(num * num, (0, 0), sigmaX=20, sigmaY=20)
|
||||
den = cv2.pow(blur, 0.5)
|
||||
|
||||
gray = num / den
|
||||
|
||||
# cv2.normalize(gray, dst=gray, alpha=0.0, beta=1.0, norm_type=cv2.NORM_MINMAX)
|
||||
cv2.normalize(gray, dst=gray, alpha=0.0, beta=255.0, norm_type=cv2.NORM_MINMAX)
|
||||
return gray
|
||||
|
||||
|
||||
def normalize_and_colormap(img):
|
||||
ret = (img - img.min()) / (img.max() - img.min()) * 255.0
|
||||
ret = ret.astype("uint8")
|
||||
@ -31,10 +87,74 @@ def normalize_and_colormap(img):
|
||||
return ret
|
||||
|
||||
|
||||
def change_epoch():
|
||||
epoch = input('Enter epoch number or "latest"\n')
|
||||
def reproject(disparity_img):
|
||||
print('reprojecting')
|
||||
baseline = 0.075
|
||||
depth_img = baseline * K[0][0] / disparity_img
|
||||
pointcloud = o3d.geometry.PointCloud()
|
||||
intrinsics = o3d.pybind.camera.PinholeCameraIntrinsic()
|
||||
print('setting intrinsics')
|
||||
intrinsics.set_intrinsics(width=640, height=480, fx=K[0][0], fy=K[1][1], cx=0., cy=0.)
|
||||
# depth = open3d.geometry.Image(depth_img.astype('float32'))
|
||||
rgb = normalize_and_colormap(disparity_img)
|
||||
rgb = o3d.geometry.Image(rgb * 255)
|
||||
print(depth_img.max(), depth_img.min())
|
||||
depth_img = np.log(depth_img + (1 - depth_img.min()) + 1)
|
||||
print(depth_img.max(), depth_img.min())
|
||||
depth = o3d.geometry.Image(depth_img.astype('float32'))
|
||||
|
||||
rgb_depth = o3d.geometry.RGBDImage().create_from_color_and_depth(
|
||||
color=rgb,
|
||||
depth=depth,
|
||||
depth_scale=1,
|
||||
convert_rgb_to_intensity=False,
|
||||
)
|
||||
print('creating pointcloud')
|
||||
# depth = open3d.cpu.pybind.t.geometry.Image(depth_img.astype('float32'))
|
||||
# depth.colorize_depth(1.0, 0., 1.)
|
||||
# print('now really creating pointcloud')
|
||||
# dpcd = pointcloud.create_from_depth_image(
|
||||
# depth=depth,
|
||||
# intrinsic=intrinsics,
|
||||
# )
|
||||
# print(type(depth))
|
||||
|
||||
pcd = pointcloud.create_from_rgbd_image(
|
||||
image=rgb_depth,
|
||||
intrinsic=intrinsics,
|
||||
# project_valid_depth_only=False,
|
||||
)
|
||||
|
||||
flip_transform = [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]
|
||||
# dpcd.paint_uniform_color(np.asarray([0.5, 0.4, 0.25]))
|
||||
pcd.transform(flip_transform)
|
||||
# dpcd.transform(flip_transform)
|
||||
pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0)
|
||||
print('drawing pointcloud')
|
||||
# vis.clear_geometries()
|
||||
# vis.add_geometry(pcd, reset_bounding_box=True)
|
||||
# viscont.rotate(250, 500)
|
||||
vis.update_geometry(pcd)
|
||||
vis.poll_events()
|
||||
vis.update_renderer()
|
||||
|
||||
# vis.run()
|
||||
# o3d.visualization.draw(geometry=[rgb_depth])
|
||||
# o3d.visualization.draw([dpcd])
|
||||
|
||||
|
||||
def change_epoch(epoch: int = None):
|
||||
if epoch is None:
|
||||
epoch = input('Enter epoch number or "latest"\n')
|
||||
r = requests.post(f'{API_URL}/model/update/{epoch}')
|
||||
print(r.text)
|
||||
# print(r.text)
|
||||
|
||||
|
||||
def change_reference():
|
||||
r = requests.post(f'{API_URL}/params/update_reference')
|
||||
print(r.json()['status'])
|
||||
if r.json()['status'] == 'finished':
|
||||
change_reference()
|
||||
|
||||
|
||||
def extract_data(data):
|
||||
@ -42,72 +162,190 @@ def extract_data(data):
|
||||
duration = data['duration']
|
||||
|
||||
# get result and rotate 90 deg
|
||||
pred_disp = cv2.transpose(np.asarray(data['disp'], dtype='uint8'))
|
||||
# pred_disp = cv2.transpose(np.asarray(data['disp'], dtype='uint8'))
|
||||
raw_disp = np.asarray(data['disp'])
|
||||
# print(raw_disp.min(), raw_disp.max())
|
||||
if raw_disp.min() < 0:
|
||||
# print('Negative disparity detected. shifting...')
|
||||
raw_disp = raw_disp - raw_disp.min()
|
||||
if raw_disp.max() > 255:
|
||||
# print('Excessive disparity detected. scaling...')
|
||||
raw_disp = raw_disp / (raw_disp.max() / 255)
|
||||
pred_disp = np.asarray(raw_disp, dtype='uint8')
|
||||
|
||||
if 'input' not in data:
|
||||
# if 'input' not in data:
|
||||
if len(data) == 2:
|
||||
return pred_disp, duration
|
||||
|
||||
in_img = np.asarray(data['input'], dtype='uint8').transpose((2, 0, 1))
|
||||
ref_pat = np.asarray(data['reference'], dtype='uint8').transpose((2, 0, 1))
|
||||
ref_pat = data.get('reference', None)
|
||||
in_img = np.asarray(data['input'], dtype='uint8') # .transpose((2, 0, 1))
|
||||
if ref_pat:
|
||||
ref_pat = np.asarray(ref_pat, dtype='uint8') # .transpose((2, 0, 1))
|
||||
return pred_disp, in_img, ref_pat, duration
|
||||
|
||||
|
||||
def downsize_input_img():
|
||||
input_img = cv2.imread(img.path)
|
||||
def downsize_input_img(path):
|
||||
input_img = None
|
||||
while input_img is None:
|
||||
input_img = cv2.imread(path)
|
||||
if input_img.shape == (1024, 1280, 3):
|
||||
diff = (512 - 480) // 2
|
||||
downsampled = cv2.pyrDown(input_img)
|
||||
input_img = downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]]
|
||||
# print(input_img.shape)
|
||||
input_img = cv2.normalize(input_img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
|
||||
# input_img = ghetto_lcn(input_img)
|
||||
|
||||
cv2.imwrite('buffer.png', input_img)
|
||||
|
||||
|
||||
def put_image(img_path):
|
||||
async def put_image(img_path):
|
||||
openBin = {'file': ('file', open(img_path, 'rb'), 'image/png')}
|
||||
print('sending image')
|
||||
r = requests.put(f'{API_URL}/ir', files=openBin)
|
||||
print('received response')
|
||||
if verbose:
|
||||
print('sending image')
|
||||
async with requests.AsyncClient() as client:
|
||||
r = await client.put(f'{API_URL}/ir', files=openBin)
|
||||
if verbose:
|
||||
print('received response')
|
||||
r.raise_for_status()
|
||||
data = json.loads(json.loads(r.text))
|
||||
return data
|
||||
|
||||
|
||||
def change_minimal_data(enabled):
|
||||
r = requests.post(f'{API_URL}/params/minimal_data/{not enabled}')
|
||||
def change_minimal_data(current: bool = None):
|
||||
global minimal_data
|
||||
if current is None:
|
||||
current = minimal_data
|
||||
minimal_data = not current
|
||||
r = requests.post(f'{API_URL}/params/minimal_data/{minimal_data}')
|
||||
cv2.destroyWindow('Input Image')
|
||||
cv2.destroyWindow('Reference Image')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
while True:
|
||||
for img in os.scandir(img_dir):
|
||||
start = datetime.now()
|
||||
def change_temporal_init(enabled):
|
||||
global temporal_init
|
||||
r = requests.post(f'{API_URL}/params/temporal_init/{not enabled}')
|
||||
temporal_init = not temporal_init
|
||||
|
||||
|
||||
def handle_keypress(key):
|
||||
if key == 113:
|
||||
quit()
|
||||
elif key == 101:
|
||||
change_epoch()
|
||||
elif key == 109:
|
||||
change_minimal_data()
|
||||
elif key == 116:
|
||||
change_temporal_init(temporal_init)
|
||||
elif key == 99:
|
||||
change_reference()
|
||||
|
||||
|
||||
async def do_inference():
|
||||
start = datetime.now()
|
||||
data = await put_image('buffer.png')
|
||||
in_img = None
|
||||
ref_pat = None
|
||||
if len(data) == 4:
|
||||
pred_disp, in_img, ref_pat, duration = extract_data(data)
|
||||
elif len(data) == 2:
|
||||
pred_disp, duration = extract_data(data)
|
||||
reproject(pred_disp)
|
||||
show_results(duration, in_img, pred_disp, ref_pat, start)
|
||||
# reproject(pred_disp)
|
||||
|
||||
|
||||
def show_results(duration, in_img, pred_disp, ref_pat, start):
|
||||
print(f"Pred. Disparity: \n\t{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}")
|
||||
if verbose:
|
||||
print(f'inference took {duration:1.4f}s')
|
||||
print(f'pipeline and transfer took another {(datetime.now() - start).total_seconds() - float(duration):1.4f}s')
|
||||
print(f'total {(datetime.now() - start).total_seconds():1.4f}s')
|
||||
if in_img is not None:
|
||||
cv2.imshow('Input Image', in_img)
|
||||
else:
|
||||
cv2.imshow('Input Image', cv2.imread('buffer.png'))
|
||||
if ref_pat is not None:
|
||||
cv2.imshow('Reference Image', ref_pat)
|
||||
cv2.imshow('Normalized Predicted Disparity', normalize_and_colormap(pred_disp))
|
||||
cv2.imshow('Predicted Disparity', pred_disp)
|
||||
key = cv2.waitKey(1000)
|
||||
handle_keypress(key)
|
||||
|
||||
|
||||
async def fresh_img():
|
||||
# print('running task')
|
||||
start = datetime.now()
|
||||
print(f'started at {start}')
|
||||
downsize_input_img('kinect_ir.png')
|
||||
await do_inference()
|
||||
print(f'task took {(datetime.now() - start).total_seconds()}')
|
||||
print()
|
||||
|
||||
|
||||
def create_task(*args):
|
||||
global running_tasks
|
||||
# print('received signal')
|
||||
print(f'currently running: {len(running_tasks)} tasks')
|
||||
task = asyncio.create_task(fresh_img())
|
||||
# print(f'created task {task.get_name()}')
|
||||
running_tasks.add(task)
|
||||
task.add_done_callback(running_tasks.discard)
|
||||
# await task
|
||||
# return task
|
||||
|
||||
|
||||
signal.signal(signal.SIGUSR1, create_task)
|
||||
|
||||
|
||||
async def run_test(img_dir, iterate_checkpoints):
|
||||
img_dir = list(os.scandir(img_dir))
|
||||
for epoch in range(175, 270):
|
||||
if iterate_checkpoints:
|
||||
change_epoch(epoch)
|
||||
print()
|
||||
print(f'loaded epoch {epoch}')
|
||||
for img in img_dir:
|
||||
if 'ir' not in img.path:
|
||||
continue
|
||||
|
||||
# alternatively: use img.path for native size
|
||||
downsize_input_img()
|
||||
downsize_input_img(img.path)
|
||||
|
||||
data = put_image('buffer.png')
|
||||
if 'input' in data:
|
||||
pred_disp, in_img, ref_pat, duration = extract_data(data)
|
||||
else:
|
||||
pred_disp, duration = extract_data(data)
|
||||
# asyncio.run(do_inference())
|
||||
await do_inference()
|
||||
await asyncio.sleep(10)
|
||||
|
||||
print(f'inference took {duration:1.4f}s')
|
||||
print(f'pipeline and transfer took another {(datetime.now() - start).total_seconds() - float(duration):1.4f}s')
|
||||
print(f"Pred. Disparity: \n\t{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}\n")
|
||||
|
||||
if 'input' in data:
|
||||
cv2.imshow('Input Image', in_img)
|
||||
cv2.imshow('Reference Image', ref_pat)
|
||||
cv2.imshow('Normalized Predicted Disparity', normalize_and_colormap(pred_disp))
|
||||
cv2.imshow('Predicted Disparity', pred_disp)
|
||||
key = cv2.waitKey()
|
||||
async def main():
|
||||
use_live_data = True
|
||||
iterate_checkpoints = False
|
||||
# change_epoch(good_models[1])
|
||||
# change_epoch('latest')
|
||||
o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Info)
|
||||
|
||||
if key == 113:
|
||||
quit()
|
||||
elif key == 101:
|
||||
change_epoch()
|
||||
elif key == 109:
|
||||
change_minimal_data('input' not in data)
|
||||
change_epoch(150)
|
||||
change_minimal_data(False)
|
||||
await asyncio.sleep(50000)
|
||||
|
||||
# signal.signal(signal.SIGBUS, lambda x: print('received sigbus'))
|
||||
# loop = asyncio.get_running_loop()
|
||||
# loop.run_forever()
|
||||
# loop = asyncio.get_event_loop()
|
||||
while True:
|
||||
# create_task()
|
||||
# await asyncio.sleep(0.1)
|
||||
# await run_test(img_dir, iterate_checkpoints)
|
||||
await asyncio.sleep(50000)
|
||||
# print('[main] slept')
|
||||
# if use_live_data:
|
||||
# signal.pause()
|
||||
|
||||
# else:
|
||||
# await run_test(img_dir, iterate_checkpoints)
|
||||
|
||||
|
||||
#
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(main())
|
||||
|
Loading…
Reference in New Issue
Block a user