CREStereo Repository for the 'Towards accurate and robust depth estimation' project
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
CREStereo-pytorch-nxt/api_server.py

296 lines
9.4 KiB

import os
import json
from datetime import datetime
from typing import Union, Literal
from io import BytesIO
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_tensorrt
from cv2 import cv2
from fastapi import FastAPI, File, UploadFile
from PIL import Image
from nets import Model
from train import inference as ctd_inference
app = FastAPI()
# reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
reference_pattern_path = '/home/nils/kinect_reference_far.png'
# reference_pattern_path = '/home/nils/kinect_diff_ref.png'
print(reference_pattern_path)
reference_pattern = cv2.imread(reference_pattern_path)
# shift reference pattern a few pixels to the left to simulate further backdrop
trans_mat = np.float32([[1, 0, 0], [0, 1, 0]])
reference_pattern = cv2.warpAffine(
reference_pattern, trans_mat, reference_pattern.shape[1::-1], flags=cv2.INTER_LINEAR
)
iters = 20
minimal_data = True
temporal_init = False
last_img = None
device = torch.device('cuda:0')
def downsize(img):
diff = (512 - 480) // 2
downsampled = cv2.pyrDown(img)
img = downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]]
return img
if 1024 in reference_pattern.shape:
reference_pattern = downsize(reference_pattern)
def ghetto_lcn(img):
# gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
gray = img
float_gray = gray.astype(np.float32) / 255.0
blur = cv2.GaussianBlur(float_gray, (0, 0), sigmaX=2, sigmaY=2)
num = float_gray - blur
blur = cv2.GaussianBlur(num * num, (0, 0), sigmaX=20, sigmaY=20)
den = cv2.pow(blur, 0.5)
gray = num / den
# cv2.normalize(gray, dst=gray, alpha=0.0, beta=1.0, norm_type=cv2.NORM_MINMAX)
cv2.normalize(gray, dst=gray, alpha=0.0, beta=255.0, norm_type=cv2.NORM_MINMAX)
return gray
# reference_pattern = ghetto_lcn(reference_pattern)
def load_model(epoch, use_tensorrt=False):
global model
epoch = f'epoch-{epoch}.pth' if epoch != 'latest' else 'latest.pth'
model_path = f"train_log/models/{epoch}"
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
# FIXME WIP Workaround Dataparallel TensorRT incompatibility
model = nn.DataParallel(model, device_ids=[device])
# model.load_state_dict(torch.load(model_path), strict=False)
state_dict = torch.load(model_path)['state_dict']
model.load_state_dict(state_dict, strict=True)
model.to(device)
model.eval()
if use_tensorrt:
np_model = Model(max_disp=256, mixed_precision=False, test_mode=True)
np_model.load_state_dict(model.module.state_dict(), strict=True)
np_model.to(device)
np_model.eval()
spec_dict = {
"inputs": [
torch_tensorrt.Input(
min_shape=[1, 2, 240, 320],
max_shape=[1, 2, 480, 640],
opt_shape=[1, 2, 480, 640],
dtype=torch.int32,
),
torch_tensorrt.Input(
min_shape=[1, 2, 240, 320],
max_shape=[1, 2, 480, 640],
opt_shape=[1, 2, 480, 640],
dtype=torch.int32,
),
],
"enabled_precisions": {torch.float, torch.half},
"refit": False,
"debug": False,
"device": {
"device_type": torch_tensorrt.DeviceType.GPU,
"gpu_id": 0,
"dla_core": 0,
"allow_gpu_fallback": True
},
"capability": torch_tensorrt.EngineCapability.default,
"num_min_timing_iters": 2,
"num_avg_timing_iters": 1,
}
spec = {
"forward":
torch_tensorrt.ts.TensorRTCompileSpec(**spec_dict)
}
# trt_model = torch_tensorrt.compile(np_model ,
# inputs=torch_tensorrt.Input(
# min_shape=[1, 2, 240, 320],
# max_shape=[1, 2, 480, 640],
# opt_shape=[1, 2, 480, 640],
# dtype=torch.int32,
# inputs = [torch_tensorrt.Input((1, 2, 480, 640)), torch_tensorrt.Input((1, 2, 480, 640))], # input shape
# enabled_precisions = {torch_tensorrt.dtype.half} # Run with FP16
# )
# trt_dw_model = torch_tensorrt.compile(np_model ,
# inputs = [torch_tensorrt.Input((1, 2, 240, 320)), torch_tensorrt.Input((1, 2, 240, 320))], # input shape
# enabled_precisions = {torch_tensorrt.dtype.half} # Run with FP16
# )
script_model = torch.jit.script(np_model.eval())
# script_dw_model = torch.jit.script(trt_dw_model.eval())
# save the TensorRT embedded Torchscript
# torch.jit.save(trt_model, 'trt_torchscript_module.ts')
# torch.jit.save(trt_dw_model, 'trt_torchscript_dw_module.ts')
print(script_model)
print(script_model.forward)
print(script_model.forward())
print(dir(script_model))
model = torch._C._jit_to_backend("tensorrt", script_model, spec)
print(f'loaded model {epoch}')
return model
model = load_model('latest')
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
def inference(left, right, model, n_iter=20):
print("Model Forwarding...")
imgL = np.ascontiguousarray(left[None, :, :, :])
imgR = np.ascontiguousarray(right[None, :, :, :])
device = torch.device('cuda')
imgL = torch.tensor(imgL.astype("float32")).to(device)
imgR = torch.tensor(imgR.astype("float32")).to(device)
imgR = imgR.transpose(1, 2)
imgL = imgL.transpose(1, 2)
imgL_dw2 = F.interpolate(
imgL,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
)
imgR_dw2 = F.interpolate(
imgR,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
)
with torch.inference_mode():
# pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
# pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
pred_flow_dw2 = model.forward(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
pred_flow = model.forward(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
return pred_disp
def get_reference():
refs = [ref.path for ref in os.scandir('/home/nils/references/')]
for ref in refs:
reference = cv2.imread(ref)
yield reference
references = get_reference()
@app.post('/params/update_reference')
async def update_reference():
global references, reference_pattern
try:
reference_pattern = downsize(next(references))
print(reference_pattern.shape)
return {'status': 'success'}
except StopIteration:
references = get_reference()
return {'status': 'finished'}
@app.post("/model/update/{epoch}")
async def change_model(epoch: Union[int, Literal['latest']]):
global model
print(epoch)
print('updating model')
model = load_model(epoch)
return {'status': 'success'}
@app.post("/params/iterations/{iterations}")
async def set_iterations(iterations: int):
global iters
iters = iterations
@app.post("/params/minimal_data/{enable}")
async def set_minimal_data(enable: bool):
global minimal_data
minimal_data = enable
@app.post("/params/temporal_init/{enable}")
async def set_temporal_init(enable: bool):
global temporal_init
temporal_init = enable
@app.put("/ir")
async def read_ir_input(file: UploadFile = File(...)):
global last_img, minimal_data
try:
img = np.array(Image.open(BytesIO(await file.read())))
except Exception as e:
return {'error': 'couldn\'t read file', 'exception': e}
# img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
if len(img.shape) == 2:
img = cv2.merge([img for _ in range(3)])
if img.shape == (1024, 1280, 3):
img = downsize(img)
# img = img.transpose((1, 2, 0))
# ref_pat = reference_pattern.transpose((1, 2, 0))
ref_pat = reference_pattern
start = datetime.now()
if temporal_init:
pred_disp = ctd_inference(img, ref_pat, None, None, model, None, iters, False, last_img)
last_img = pred_disp
else:
pred_disp = ctd_inference(img, ref_pat, None, None, model, None, iters, False)
# pred_disp = inference(img, ref_pat, model, iters)
duration = (datetime.now() - start).total_seconds()
if minimal_data:
return json.dumps({'disp': pred_disp, 'duration': duration}, cls=NumpyEncoder)
else:
# return json.dumps({'disp': pred_disp, 'input': img, 'duration': duration},
return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img, 'duration': duration},
cls=NumpyEncoder)
@app.get('/temporal_init')
def get_temporal_init():
return {'status': 'enabled' if temporal_init else 'disabled'}
@app.get('/')
def main():
return {'test': 'abc'}