CREStereo Repository for the 'Towards accurate and robust depth estimation' project
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
CREStereo-pytorch-nxt/api_server.py

123 lines
3.5 KiB

import json
from datetime import datetime
from typing import Union, Literal
from io import BytesIO
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from cv2 import cv2
from fastapi import FastAPI, File, UploadFile
from PIL import Image
from nets import Model
app = FastAPI()
reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
reference_pattern = cv2.imread(reference_pattern_path)
device = torch.device('cuda:0')
model = None
def load_model(epoch):
global model
epoch = f'epoch-{epoch}.pth' if epoch != 'latest' else 'latest.pth'
model_path = f"train_log/models/{epoch}"
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
model = nn.DataParallel(model, device_ids=[device])
# model.load_state_dict(torch.load(model_path), strict=False)
state_dict = torch.load(model_path)['state_dict']
model.load_state_dict(state_dict, strict=True)
model.to(device)
model.eval()
print(f'loaded model {epoch}')
return model
model = load_model('latest')
class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)
def inference(left, right, model, n_iter=20):
print("Model Forwarding...")
imgL = np.ascontiguousarray(left[None, :, :, :])
imgR = np.ascontiguousarray(right[None, :, :, :])
device = torch.device('cuda')
imgL = torch.tensor(imgL.astype("float32")).to(device)
imgR = torch.tensor(imgR.astype("float32")).to(device)
imgR = imgR.transpose(1,2)
imgL = imgL.transpose(1,2)
imgL_dw2 = F.interpolate(
imgL,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
)
imgR_dw2 = F.interpolate(
imgR,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
)
with torch.inference_mode():
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
return pred_disp
@app.post("/model/update/{epoch}")
async def change_model(epoch: Union[int, Literal['latest']]):
global model
print(epoch)
print('updating model')
model = load_model(epoch)
return {'status': 'success'}
@app.put("/ir")
async def read_ir_input(file: UploadFile = File(...)):
try:
img = np.array(Image.open(BytesIO(await file.read())))
except Exception as e:
return {'error': 'couldn\'t read file', 'exception': e}
# img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
if len(img.shape) == 2:
img = cv2.merge([img for _ in range(3)])
if img.shape == (1024, 1280, 3):
diff = (512 - 480) // 2
downsampled = cv2.pyrDown(img)
img = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
img = img.transpose((1,2,0))
ref_pat = reference_pattern.transpose((1,2,0))
start = datetime.now()
pred_disp = inference(img, ref_pat, model, 20)
duration = (datetime.now() - start).total_seconds()
return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img, 'duration': duration}, cls=NumpyEncoder)
# return json.dumps({'disp': pred_disp, 'duration': duration}, cls=NumpyEncoder)
@app.get('/')
def main():
return {'test': 'abc'}