import requests from cv2 import cv2 import numpy as np import json import os from datetime import datetime API_URL = 'http://127.0.0.1:8000' img_dir = '../../usable_imgs/' cv2.namedWindow('Input Image') cv2.namedWindow('Predicted Disparity') # epoch 75 ist weird class NumpyEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, np.ndarray): return obj.tolist() return json.JSONEncoder.default(self, obj) def normalize_and_colormap(img): ret = (img - img.min()) / (img.max() - img.min()) * 255.0 ret = ret.astype("uint8") ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO) return ret def change_epoch(): epoch = input('Enter epoch number or "latest"\n') r = requests.post(f'{API_URL}/model/update/{epoch}') print(r.text) def extract_data(data): # FIXME yuck, don't json the json duration = data['duration'] # get result and rotate 90 deg pred_disp = cv2.transpose(np.asarray(data['disp'], dtype='uint8')) if 'input' not in data: return pred_disp, duration in_img = np.asarray(data['input'], dtype='uint8').transpose((2, 0, 1)) ref_pat = np.asarray(data['reference'], dtype='uint8').transpose((2, 0, 1)) return pred_disp, in_img, ref_pat, duration def downsize_input_img(): input_img = cv2.imread(img.path) if input_img.shape == (1024, 1280, 3): diff = (512 - 480) // 2 downsampled = cv2.pyrDown(input_img) input_img = downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]] cv2.imwrite('buffer.png', input_img) def put_image(img_path): openBin = {'file': ('file', open(img_path, 'rb'), 'image/png')} print('sending image') r = requests.put(f'{API_URL}/ir', files=openBin) print('received response') r.raise_for_status() data = json.loads(json.loads(r.text)) return data def change_minimal_data(enabled): r = requests.post(f'{API_URL}/params/minimal_data/{not enabled}') cv2.destroyWindow('Input Image') cv2.destroyWindow('Reference Image') if __name__ == '__main__': while True: for img in os.scandir(img_dir): start = datetime.now() if 'ir' not in img.path: continue # alternatively: use img.path for native size downsize_input_img() data = put_image('buffer.png') if 'input' in data: pred_disp, in_img, ref_pat, duration = extract_data(data) else: pred_disp, duration = extract_data(data) print(f'inference took {duration:1.4f}s') print(f'pipeline and transfer took another {(datetime.now() - start).total_seconds() - float(duration):1.4f}s') print(f"Pred. Disparity: \n\t{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}\n") if 'input' in data: cv2.imshow('Input Image', in_img) cv2.imshow('Reference Image', ref_pat) cv2.imshow('Normalized Predicted Disparity', normalize_and_colormap(pred_disp)) cv2.imshow('Predicted Disparity', pred_disp) key = cv2.waitKey() if key == 113: quit() elif key == 101: change_epoch() elif key == 109: change_minimal_data('input' not in data)