frontend/__init__.py: loop endlessly, show more info, allow model change
This commit is contained in:
parent
bb15dcd0a1
commit
b614fcfd74
@ -14,6 +14,8 @@ img_dir = '../../usable_imgs/'
|
||||
cv2.namedWindow('Input Image')
|
||||
cv2.namedWindow('Predicted Disparity')
|
||||
|
||||
# epoch 75 ist weird
|
||||
|
||||
|
||||
def normalize_and_colormap(img):
|
||||
ret = (img - img.min()) / (img.max() - img.min()) * 255.0
|
||||
@ -22,31 +24,42 @@ def normalize_and_colormap(img):
|
||||
return ret
|
||||
|
||||
|
||||
for img in os.scandir(img_dir):
|
||||
if 'ir' not in img.path:
|
||||
continue
|
||||
input_img = cv2.imread(img.path)
|
||||
if input_img.shape == (1024, 1280, 3):
|
||||
diff = (512 - 480) // 2
|
||||
downsampled = cv2.pyrDown(input_img)
|
||||
input_img = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
|
||||
while True:
|
||||
for img in os.scandir(img_dir):
|
||||
start = datetime.now()
|
||||
if 'ir' not in img.path:
|
||||
continue
|
||||
input_img = cv2.imread(img.path)
|
||||
if input_img.shape == (1024, 1280, 3):
|
||||
diff = (512 - 480) // 2
|
||||
downsampled = cv2.pyrDown(input_img)
|
||||
input_img = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
|
||||
|
||||
openBin = {'file': ('file', open(img.path, 'rb'), 'image/png')}
|
||||
openBin = {'file': ('file', open(img.path, 'rb'), 'image/png')}
|
||||
|
||||
print('sending image')
|
||||
start = datetime.now()
|
||||
r = requests.put(f'{API_URL}/ir', files=openBin)
|
||||
end = datetime.now()
|
||||
print('received response')
|
||||
print(f'processing took {end - start}')
|
||||
r.raise_for_status()
|
||||
print('sending image')
|
||||
r = requests.put(f'{API_URL}/ir', files=openBin)
|
||||
print('received response')
|
||||
r.raise_for_status()
|
||||
data = json.loads(json.loads(r.text))
|
||||
|
||||
# FIXME yuck, don't json the json
|
||||
pred_disp = np.asarray(json.loads(json.loads(r.text))['disp'], dtype='uint8')
|
||||
ref_pat = np.asarray(json.loads(json.loads(r.text))['reference'], dtype='uint8').transpose((2,0,1)).astype('uint8')
|
||||
pred_disp = cv2.transpose(pred_disp)
|
||||
# FIXME yuck, don't json the json
|
||||
pred_disp = np.asarray(data['disp'], dtype='uint8')
|
||||
in_img = np.asarray(data['input'], dtype='uint8').transpose((2,0,1))
|
||||
ref_pat = np.asarray(data['reference'], dtype='uint8').transpose((2,0,1)).astype('uint8')
|
||||
duration = data['duration']
|
||||
pred_disp = cv2.transpose(pred_disp)
|
||||
print(f'inference took {duration}s')
|
||||
print(f'pipeline and transfer took another {(datetime.now() - start).total_seconds() - float(duration)}s\n')
|
||||
|
||||
cv2.imshow('Input Image', input_img)
|
||||
# cv2.imshow('Reference Image', ref_pat)
|
||||
cv2.imshow('Predicted Disparity', normalize_and_colormap(pred_disp))
|
||||
cv2.waitKey()
|
||||
cv2.imshow('Input Image', in_img)
|
||||
cv2.imshow('Reference Image', ref_pat)
|
||||
cv2.imshow('Normalized Predicted Disparity', normalize_and_colormap(pred_disp))
|
||||
cv2.imshow('Predicted Disparity', pred_disp)
|
||||
key = cv2.waitKey()
|
||||
if key == 113:
|
||||
quit()
|
||||
elif key == 101:
|
||||
epoch = input('Enter epoch number or "latest"\n')
|
||||
r = requests.post(f'{API_URL}/model/update/{epoch}')
|
||||
print(r.text)
|
||||
|
Loading…
Reference in New Issue
Block a user