api_server.py: rework so it actually kinda works
results are weirdly poor, but idk
This commit is contained in:
parent
46a6ae44af
commit
1bdf2e7776
@ -1,40 +1,24 @@
|
||||
import json
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI, File, UploadFile
|
||||
from pydantic import BaseModel
|
||||
|
||||
import numpy as np
|
||||
from cv2 import cv2
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from nets import Model
|
||||
from cv2 import cv2
|
||||
from fastapi import FastAPI, File, UploadFile
|
||||
from PIL import Image
|
||||
|
||||
from nets import Model
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# TODO
|
||||
# beide modelle laden, jeweils eine gpu zuweisen
|
||||
# routen bauen, gegen die man bilder werfen kann, die dann jeweils von einem modell interpretiert werden
|
||||
# ergebnisse zurueck geben
|
||||
#
|
||||
# input validierung nicht vergessen
|
||||
# paramter (bildgroesse etc.) konfigurierbar machen oder automatisch rausfinden?
|
||||
# kommt ctd überhaupt mit was anderem klar?
|
||||
|
||||
|
||||
|
||||
class IrImage(BaseModel):
|
||||
image: np.array
|
||||
|
||||
|
||||
reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
|
||||
reference_pattern = cv2.imread(reference_pattern_path)
|
||||
model_path = "train_log/models/latest.pth"
|
||||
device = torch.device('cuda:0')
|
||||
# model_path = "train_log/models/epoch-100.pth"
|
||||
device = torch.device('cuda')
|
||||
|
||||
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
|
||||
model = nn.DataParallel(model, device_ids=[device])
|
||||
@ -44,32 +28,27 @@ model.load_state_dict(state_dict, strict=True)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
|
||||
def normalize_and_colormap(img):
|
||||
ret = (img - img.min()) / (img.max() - img.min()) * 255.0
|
||||
if isinstance(ret, torch.Tensor):
|
||||
ret = ret.cpu().detach().numpy()
|
||||
ret = ret.astype("uint8")
|
||||
ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO)
|
||||
return ret
|
||||
class NumpyEncoder(json.JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
return json.JSONEncoder.default(self, obj)
|
||||
|
||||
|
||||
def inference_ctd(left, right, model, n_iter=20):
|
||||
def inference(left, right, model, n_iter=20):
|
||||
print("Model Forwarding...")
|
||||
# print(left.shape)
|
||||
# left = left.cpu().detach().numpy()
|
||||
# imgL = left
|
||||
# imgR = right.cpu().detach().numpy()
|
||||
imgL = np.ascontiguousarray(left[None, :, :, :])
|
||||
imgR = np.ascontiguousarray(right[None, :, :, :])
|
||||
|
||||
# chosen for convenience
|
||||
device = torch.device('cuda:0')
|
||||
device = torch.device('cuda')
|
||||
|
||||
imgL = torch.tensor(imgL.astype("float32")).to(device)
|
||||
imgR = torch.tensor(imgR.astype("float32")).to(device)
|
||||
imgL = imgL.transpose(2, 3).transpose(1, 2)
|
||||
|
||||
# Funzt grob
|
||||
imgR = imgR.transpose(1,2)
|
||||
imgL = imgL.transpose(1,2)
|
||||
|
||||
imgL_dw2 = F.interpolate(
|
||||
imgL,
|
||||
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
||||
@ -82,15 +61,14 @@ def inference_ctd(left, right, model, n_iter=20):
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
|
||||
with torch.inference_mode():
|
||||
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
|
||||
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
|
||||
|
||||
for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)):
|
||||
pred_disp = torch.squeeze(pf[:, 0, :, :]).cpu().detach().numpy()
|
||||
pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
|
||||
pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
|
||||
|
||||
return pred_disp_norm
|
||||
return pred_disp
|
||||
|
||||
|
||||
@app.put("/ir")
|
||||
@ -99,11 +77,21 @@ async def read_ir_input(file: UploadFile = File(...)):
|
||||
img = np.array(Image.open(BytesIO(await file.read())))
|
||||
except Exception as e:
|
||||
return {'error': 'couldn\'t read file', 'exception': e}
|
||||
print(img.shape)
|
||||
|
||||
# img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
|
||||
if len(img.shape) == 2:
|
||||
img = np.stack((img for _ in range(3)))
|
||||
pred_disp = inference_ctd(np.array(img), reference_pattern, None)
|
||||
return {"pred_disp": pred_disp}
|
||||
img = cv2.merge([img for _ in range(3)])
|
||||
if img.shape == (1024, 1280, 3):
|
||||
diff = (512 - 480) // 2
|
||||
downsampled = cv2.pyrDown(img)
|
||||
img = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
|
||||
|
||||
img = img.transpose((1,2,0))
|
||||
ref_pat = reference_pattern.transpose((1,2,0))
|
||||
|
||||
pred_disp = inference(img, ref_pat, model)
|
||||
|
||||
return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img}, cls=NumpyEncoder)
|
||||
|
||||
|
||||
@app.get('/')
|
||||
|
Loading…
Reference in New Issue
Block a user