api_server.py: reformat, more customizability

This commit is contained in:
Nils Koch 2022-06-02 15:03:15 +02:00
parent 3bc0e7d575
commit 50581efa01

View File

@ -13,13 +13,13 @@ from PIL import Image
from nets import Model
app = FastAPI()
reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
reference_pattern = cv2.imread(reference_pattern_path)
iters = 20
minimal_data = False
device = torch.device('cuda:0')
model = None
def load_model(epoch):
@ -57,8 +57,8 @@ def inference(left, right, model, n_iter=20):
imgL = torch.tensor(imgL.astype("float32")).to(device)
imgR = torch.tensor(imgR.astype("float32")).to(device)
imgR = imgR.transpose(1,2)
imgL = imgL.transpose(1,2)
imgR = imgR.transpose(1, 2)
imgL = imgL.transpose(1, 2)
imgL_dw2 = F.interpolate(
imgL,
@ -91,6 +91,18 @@ async def change_model(epoch: Union[int, Literal['latest']]):
return {'status': 'success'}
@app.post("/params/iterations/{iterations}")
async def set_iterations(iterations: int):
global iters
iters = iterations
@app.post("/params/minimal_data/{enable}")
async def set_minimal_data(enable: bool):
global minimal_data
minimal_data = enable
@app.put("/ir")
async def read_ir_input(file: UploadFile = File(...)):
try:
@ -104,20 +116,22 @@ async def read_ir_input(file: UploadFile = File(...)):
if img.shape == (1024, 1280, 3):
diff = (512 - 480) // 2
downsampled = cv2.pyrDown(img)
img = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
img = downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]]
img = img.transpose((1,2,0))
ref_pat = reference_pattern.transpose((1,2,0))
img = img.transpose((1, 2, 0))
ref_pat = reference_pattern.transpose((1, 2, 0))
start = datetime.now()
pred_disp = inference(img, ref_pat, model, 20)
pred_disp = inference(img, ref_pat, model, iters)
duration = (datetime.now() - start).total_seconds()
return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img, 'duration': duration}, cls=NumpyEncoder)
# return json.dumps({'disp': pred_disp, 'duration': duration}, cls=NumpyEncoder)
if minimal_data:
return json.dumps({'disp': pred_disp, 'duration': duration}, cls=NumpyEncoder)
else:
return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img, 'duration': duration},
cls=NumpyEncoder)
@app.get('/')
def main():
return {'test': 'abc'}