CREStereo Repository for the 'Towards accurate and robust depth estimation' project
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
CREStereo-pytorch-nxt/api_server.py

112 lines
3.2 KiB

from io import BytesIO
from typing import Optional
from fastapi import FastAPI, File, UploadFile
from pydantic import BaseModel
import numpy as np
from cv2 import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from nets import Model
from PIL import Image
app = FastAPI()
# TODO
# beide modelle laden, jeweils eine gpu zuweisen
# routen bauen, gegen die man bilder werfen kann, die dann jeweils von einem modell interpretiert werden
# ergebnisse zurueck geben
#
# input validierung nicht vergessen
# paramter (bildgroesse etc.) konfigurierbar machen oder automatisch rausfinden?
# kommt ctd überhaupt mit was anderem klar?
class IrImage(BaseModel):
image: np.array
reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
reference_pattern = cv2.imread(reference_pattern_path)
model_path = "../train_log/models/latest.pth"
device = torch.device('cuda:0')
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
model = nn.DataParallel(model, device_ids=[])
# model.load_state_dict(torch.load(model_path), strict=False)
state_dict = torch.load(model_path)['state_dict']
model.load_state_dict(state_dict, strict=True)
model.to(device)
model.eval()
def normalize_and_colormap(img):
ret = (img - img.min()) / (img.max() - img.min()) * 255.0
if isinstance(ret, torch.Tensor):
ret = ret.cpu().detach().numpy()
ret = ret.astype("uint8")
ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO)
return ret
def inference_ctd(left, right, model, n_iter=20):
print("Model Forwarding...")
# print(left.shape)
# left = left.cpu().detach().numpy()
# imgL = left
# imgR = right.cpu().detach().numpy()
imgL = np.ascontiguousarray(left[None, :, :, :])
imgR = np.ascontiguousarray(right[None, :, :, :])
# chosen for convenience
device = torch.device('cuda:0')
imgL = torch.tensor(imgL.astype("float32")).to(device)
imgR = torch.tensor(imgR.astype("float32")).to(device)
imgL = imgL.transpose(2, 3).transpose(1, 2)
imgL_dw2 = F.interpolate(
imgL,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
)
imgR_dw2 = F.interpolate(
imgR,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
)
with torch.inference_mode():
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)):
pred_disp = torch.squeeze(pf[:, 0, :, :]).cpu().detach().numpy()
pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
return pred_disp_norm
@app.put("/ir")
async def read_ir_input(file: UploadFile = File(...)):
try:
img = np.array(Image.open(BytesIO(await file.read())))
except Exception as e:
return {'error': 'couldn\'t read file', 'exception': e}
print(img.shape)
if len(img.shape) == 2:
img = np.stack((img for _ in range(3)))
pred_disp = inference_ctd(np.array(img), reference_pattern, None)
return {"pred_disp": pred_disp}
@app.get('/')
def main():
return {'test': 'abc'}