api_server.py: report timing info, allow on the fly changing of model

main
Nils Koch 2 years ago
parent 7219fdef7c
commit bb15dcd0a1
  1. 54
      api_server.py

@ -1,4 +1,6 @@
import json import json
from datetime import datetime
from typing import Union, Literal
from io import BytesIO from io import BytesIO
import numpy as np import numpy as np
@ -16,17 +18,27 @@ app = FastAPI()
reference_pattern_path = '/home/nils/kinect_reference_cropped.png' reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
reference_pattern = cv2.imread(reference_pattern_path) reference_pattern = cv2.imread(reference_pattern_path)
model_path = "train_log/models/latest.pth" device = torch.device('cuda:0')
# model_path = "train_log/models/epoch-100.pth" model = None
device = torch.device('cuda')
model = Model(max_disp=256, mixed_precision=False, test_mode=True) def load_model(epoch):
model = nn.DataParallel(model, device_ids=[device]) global model
# model.load_state_dict(torch.load(model_path), strict=False) epoch = f'epoch-{epoch}.pth' if epoch != 'latest' else 'latest.pth'
state_dict = torch.load(model_path)['state_dict'] model_path = f"train_log/models/{epoch}"
model.load_state_dict(state_dict, strict=True) model = Model(max_disp=256, mixed_precision=False, test_mode=True)
model.to(device) model = nn.DataParallel(model, device_ids=[device])
model.eval() # model.load_state_dict(torch.load(model_path), strict=False)
state_dict = torch.load(model_path)['state_dict']
model.load_state_dict(state_dict, strict=True)
model.to(device)
model.eval()
print(f'loaded model {epoch}')
return model
model = load_model('latest')
class NumpyEncoder(json.JSONEncoder): class NumpyEncoder(json.JSONEncoder):
def default(self, obj): def default(self, obj):
@ -45,10 +57,9 @@ def inference(left, right, model, n_iter=20):
imgL = torch.tensor(imgL.astype("float32")).to(device) imgL = torch.tensor(imgL.astype("float32")).to(device)
imgR = torch.tensor(imgR.astype("float32")).to(device) imgR = torch.tensor(imgR.astype("float32")).to(device)
# Funzt grob
imgR = imgR.transpose(1,2) imgR = imgR.transpose(1,2)
imgL = imgL.transpose(1,2) imgL = imgL.transpose(1,2)
imgL_dw2 = F.interpolate( imgL_dw2 = F.interpolate(
imgL, imgL,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2), size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
@ -71,6 +82,15 @@ def inference(left, right, model, n_iter=20):
return pred_disp return pred_disp
@app.post("/model/update/{epoch}")
async def change_model(epoch: Union[int, Literal['latest']]):
global model
print(epoch)
print('updating model')
model = load_model(epoch)
return {'status': 'success'}
@app.put("/ir") @app.put("/ir")
async def read_ir_input(file: UploadFile = File(...)): async def read_ir_input(file: UploadFile = File(...)):
try: try:
@ -89,11 +109,15 @@ async def read_ir_input(file: UploadFile = File(...)):
img = img.transpose((1,2,0)) img = img.transpose((1,2,0))
ref_pat = reference_pattern.transpose((1,2,0)) ref_pat = reference_pattern.transpose((1,2,0))
pred_disp = inference(img, ref_pat, model) start = datetime.now()
pred_disp = inference(img, ref_pat, model, 20)
duration = (datetime.now() - start).total_seconds()
return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img}, cls=NumpyEncoder) return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img, 'duration': duration}, cls=NumpyEncoder)
# return json.dumps({'disp': pred_disp, 'duration': duration}, cls=NumpyEncoder)
@app.get('/') @app.get('/')
def main(): def main():
return {'test': 'abc'} return {'test': 'abc'}

Loading…
Cancel
Save