api_server.py: reformat, more customizability
This commit is contained in:
parent
3bc0e7d575
commit
50581efa01
@ -13,13 +13,13 @@ from PIL import Image
|
||||
|
||||
from nets import Model
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
|
||||
reference_pattern = cv2.imread(reference_pattern_path)
|
||||
iters = 20
|
||||
minimal_data = False
|
||||
device = torch.device('cuda:0')
|
||||
model = None
|
||||
|
||||
|
||||
def load_model(epoch):
|
||||
@ -57,8 +57,8 @@ def inference(left, right, model, n_iter=20):
|
||||
imgL = torch.tensor(imgL.astype("float32")).to(device)
|
||||
imgR = torch.tensor(imgR.astype("float32")).to(device)
|
||||
|
||||
imgR = imgR.transpose(1,2)
|
||||
imgL = imgL.transpose(1,2)
|
||||
imgR = imgR.transpose(1, 2)
|
||||
imgL = imgL.transpose(1, 2)
|
||||
|
||||
imgL_dw2 = F.interpolate(
|
||||
imgL,
|
||||
@ -91,6 +91,18 @@ async def change_model(epoch: Union[int, Literal['latest']]):
|
||||
return {'status': 'success'}
|
||||
|
||||
|
||||
@app.post("/params/iterations/{iterations}")
|
||||
async def set_iterations(iterations: int):
|
||||
global iters
|
||||
iters = iterations
|
||||
|
||||
|
||||
@app.post("/params/minimal_data/{enable}")
|
||||
async def set_minimal_data(enable: bool):
|
||||
global minimal_data
|
||||
minimal_data = enable
|
||||
|
||||
|
||||
@app.put("/ir")
|
||||
async def read_ir_input(file: UploadFile = File(...)):
|
||||
try:
|
||||
@ -104,20 +116,22 @@ async def read_ir_input(file: UploadFile = File(...)):
|
||||
if img.shape == (1024, 1280, 3):
|
||||
diff = (512 - 480) // 2
|
||||
downsampled = cv2.pyrDown(img)
|
||||
img = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
|
||||
img = downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]]
|
||||
|
||||
img = img.transpose((1,2,0))
|
||||
ref_pat = reference_pattern.transpose((1,2,0))
|
||||
img = img.transpose((1, 2, 0))
|
||||
ref_pat = reference_pattern.transpose((1, 2, 0))
|
||||
|
||||
start = datetime.now()
|
||||
pred_disp = inference(img, ref_pat, model, 20)
|
||||
pred_disp = inference(img, ref_pat, model, iters)
|
||||
duration = (datetime.now() - start).total_seconds()
|
||||
|
||||
return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img, 'duration': duration}, cls=NumpyEncoder)
|
||||
# return json.dumps({'disp': pred_disp, 'duration': duration}, cls=NumpyEncoder)
|
||||
if minimal_data:
|
||||
return json.dumps({'disp': pred_disp, 'duration': duration}, cls=NumpyEncoder)
|
||||
else:
|
||||
return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img, 'duration': duration},
|
||||
cls=NumpyEncoder)
|
||||
|
||||
|
||||
@app.get('/')
|
||||
def main():
|
||||
return {'test': 'abc'}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user