You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
111 lines
3.2 KiB
111 lines
3.2 KiB
from typing import Optional
|
|
|
|
from fastapi import FastAPI
|
|
from pydantic import BaseModel
|
|
|
|
import numpy as np
|
|
from cv2 import cv2
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from ..nets import Model
|
|
|
|
|
|
app = FastAPI()
|
|
|
|
# TODO
|
|
# beide modelle laden, jeweils eine gpu zuweisen
|
|
# routen bauen, gegen die man bilder werfen kann, die dann jeweils von einem modell interpretiert werden
|
|
# ergebnisse zurueck geben
|
|
#
|
|
# input validierung nicht vergessen
|
|
# paramter (bildgroesse etc.) konfigurierbar machen oder automatisch rausfinden?
|
|
# kommt ctd überhaupt mit was anderem klar?
|
|
|
|
|
|
class Item(BaseModel):
|
|
name: str
|
|
price: float
|
|
is_offer: Optional[bool] = None
|
|
|
|
|
|
class IrImage(BaseModel):
|
|
image: np.array
|
|
|
|
|
|
reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
|
|
reference_pattern = cv2.imread(reference_pattern_path)
|
|
model_path = "../train_log/models/latest.pth"
|
|
device = torch.device('cuda:0')
|
|
|
|
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
|
|
model = nn.DataParallel(model, device_ids=[])
|
|
# model.load_state_dict(torch.load(model_path), strict=False)
|
|
state_dict = torch.load(model_path)['state_dict']
|
|
model.load_state_dict(state_dict, strict=True)
|
|
model.to(device)
|
|
model.eval()
|
|
|
|
|
|
def normalize_and_colormap(img):
|
|
ret = (img - img.min()) / (img.max() - img.min()) * 255.0
|
|
if isinstance(ret, torch.Tensor):
|
|
ret = ret.cpu().detach().numpy()
|
|
ret = ret.astype("uint8")
|
|
ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO)
|
|
return ret
|
|
|
|
|
|
def inference_ctd(left, right, model, n_iter=20):
|
|
print("Model Forwarding...")
|
|
# print(left.shape)
|
|
# left = left.cpu().detach().numpy()
|
|
# imgL = left
|
|
# imgR = right.cpu().detach().numpy()
|
|
imgL = np.ascontiguousarray(left[None, :, :, :])
|
|
imgR = np.ascontiguousarray(right[None, :, :, :])
|
|
|
|
# chosen for convenience
|
|
device = torch.device('cuda:0')
|
|
|
|
imgL = torch.tensor(imgL.astype("float32")).to(device)
|
|
imgR = torch.tensor(imgR.astype("float32")).to(device)
|
|
imgL = imgL.transpose(2, 3).transpose(1, 2)
|
|
|
|
imgL_dw2 = F.interpolate(
|
|
imgL,
|
|
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
|
mode="bilinear",
|
|
align_corners=True,
|
|
)
|
|
imgR_dw2 = F.interpolate(
|
|
imgR,
|
|
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
|
mode="bilinear",
|
|
align_corners=True,
|
|
)
|
|
with torch.inference_mode():
|
|
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
|
|
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
|
|
|
|
for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)):
|
|
pred_disp = torch.squeeze(pf[:, 0, :, :]).cpu().detach().numpy()
|
|
pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
|
|
|
|
return pred_disp_norm
|
|
|
|
|
|
@app.put("/ir")
|
|
def read_ir_input(ir_image: IrImage):
|
|
pred_disp = inference_ctd(ir_image.image, reference_pattern)
|
|
return {"pred_disp": pred_disp}
|
|
|
|
|
|
@app.get("/items/{item_id}")
|
|
def read_item(item_id: int, q: Optional[str] = None):
|
|
return {"item_id": item_id, "q": q}
|
|
|
|
|
|
@app.put("/items/{item_id}")
|
|
def update_item(item_id: int, item: Item):
|
|
return {"item_price": item.price, "item_id": item_id}
|
|
|