api_server.py: rework so it actually kinda works

results are weirdly poor, but idk
main
Nils Koch 2 years ago
parent 46a6ae44af
commit 1bdf2e7776
  1. 80
      api_server.py

@ -1,40 +1,24 @@
import json
from io import BytesIO from io import BytesIO
from typing import Optional
from fastapi import FastAPI, File, UploadFile
from pydantic import BaseModel
import numpy as np import numpy as np
from cv2 import cv2
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from nets import Model from cv2 import cv2
from fastapi import FastAPI, File, UploadFile
from PIL import Image from PIL import Image
from nets import Model
app = FastAPI() app = FastAPI()
# TODO
# beide modelle laden, jeweils eine gpu zuweisen
# routen bauen, gegen die man bilder werfen kann, die dann jeweils von einem modell interpretiert werden
# ergebnisse zurueck geben
#
# input validierung nicht vergessen
# paramter (bildgroesse etc.) konfigurierbar machen oder automatisch rausfinden?
# kommt ctd überhaupt mit was anderem klar?
class IrImage(BaseModel):
image: np.array
reference_pattern_path = '/home/nils/kinect_reference_cropped.png' reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
reference_pattern = cv2.imread(reference_pattern_path) reference_pattern = cv2.imread(reference_pattern_path)
model_path = "train_log/models/latest.pth" model_path = "train_log/models/latest.pth"
device = torch.device('cuda:0') # model_path = "train_log/models/epoch-100.pth"
device = torch.device('cuda')
model = Model(max_disp=256, mixed_precision=False, test_mode=True) model = Model(max_disp=256, mixed_precision=False, test_mode=True)
model = nn.DataParallel(model, device_ids=[device]) model = nn.DataParallel(model, device_ids=[device])
@ -44,32 +28,27 @@ model.load_state_dict(state_dict, strict=True)
model.to(device) model.to(device)
model.eval() model.eval()
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
def normalize_and_colormap(img):
ret = (img - img.min()) / (img.max() - img.min()) * 255.0
if isinstance(ret, torch.Tensor):
ret = ret.cpu().detach().numpy()
ret = ret.astype("uint8")
ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO)
return ret
def inference(left, right, model, n_iter=20):
def inference_ctd(left, right, model, n_iter=20):
print("Model Forwarding...") print("Model Forwarding...")
# print(left.shape)
# left = left.cpu().detach().numpy()
# imgL = left
# imgR = right.cpu().detach().numpy()
imgL = np.ascontiguousarray(left[None, :, :, :]) imgL = np.ascontiguousarray(left[None, :, :, :])
imgR = np.ascontiguousarray(right[None, :, :, :]) imgR = np.ascontiguousarray(right[None, :, :, :])
# chosen for convenience device = torch.device('cuda')
device = torch.device('cuda:0')
imgL = torch.tensor(imgL.astype("float32")).to(device) imgL = torch.tensor(imgL.astype("float32")).to(device)
imgR = torch.tensor(imgR.astype("float32")).to(device) imgR = torch.tensor(imgR.astype("float32")).to(device)
imgL = imgL.transpose(2, 3).transpose(1, 2)
# Funzt grob
imgR = imgR.transpose(1,2)
imgL = imgL.transpose(1,2)
imgL_dw2 = F.interpolate( imgL_dw2 = F.interpolate(
imgL, imgL,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2), size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
@ -82,15 +61,14 @@ def inference_ctd(left, right, model, n_iter=20):
mode="bilinear", mode="bilinear",
align_corners=True, align_corners=True,
) )
with torch.inference_mode(): with torch.inference_mode():
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None) pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2) pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)): pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
pred_disp = torch.squeeze(pf[:, 0, :, :]).cpu().detach().numpy()
pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
return pred_disp_norm return pred_disp
@app.put("/ir") @app.put("/ir")
@ -99,11 +77,21 @@ async def read_ir_input(file: UploadFile = File(...)):
img = np.array(Image.open(BytesIO(await file.read()))) img = np.array(Image.open(BytesIO(await file.read())))
except Exception as e: except Exception as e:
return {'error': 'couldn\'t read file', 'exception': e} return {'error': 'couldn\'t read file', 'exception': e}
print(img.shape)
# img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
if len(img.shape) == 2: if len(img.shape) == 2:
img = np.stack((img for _ in range(3))) img = cv2.merge([img for _ in range(3)])
pred_disp = inference_ctd(np.array(img), reference_pattern, None) if img.shape == (1024, 1280, 3):
return {"pred_disp": pred_disp} diff = (512 - 480) // 2
downsampled = cv2.pyrDown(img)
img = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
img = img.transpose((1,2,0))
ref_pat = reference_pattern.transpose((1,2,0))
pred_disp = inference(img, ref_pat, model)
return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img}, cls=NumpyEncoder)
@app.get('/') @app.get('/')

Loading…
Cancel
Save