import json from io import BytesIO import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from cv2 import cv2 from fastapi import FastAPI, File, UploadFile from PIL import Image from nets import Model app = FastAPI() reference_pattern_path = '/home/nils/kinect_reference_cropped.png' reference_pattern = cv2.imread(reference_pattern_path) model_path = "train_log/models/latest.pth" # model_path = "train_log/models/epoch-100.pth" device = torch.device('cuda') model = Model(max_disp=256, mixed_precision=False, test_mode=True) model = nn.DataParallel(model, device_ids=[device]) # model.load_state_dict(torch.load(model_path), strict=False) state_dict = torch.load(model_path)['state_dict'] model.load_state_dict(state_dict, strict=True) model.to(device) model.eval() class NumpyEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, np.ndarray): return obj.tolist() return json.JSONEncoder.default(self, obj) def inference(left, right, model, n_iter=20): print("Model Forwarding...") imgL = np.ascontiguousarray(left[None, :, :, :]) imgR = np.ascontiguousarray(right[None, :, :, :]) device = torch.device('cuda') imgL = torch.tensor(imgL.astype("float32")).to(device) imgR = torch.tensor(imgR.astype("float32")).to(device) # Funzt grob imgR = imgR.transpose(1,2) imgL = imgL.transpose(1,2) imgL_dw2 = F.interpolate( imgL, size=(imgL.shape[2] // 2, imgL.shape[3] // 2), mode="bilinear", align_corners=True, ) imgR_dw2 = F.interpolate( imgR, size=(imgL.shape[2] // 2, imgL.shape[3] // 2), mode="bilinear", align_corners=True, ) with torch.inference_mode(): pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None) pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2) pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy() return pred_disp @app.put("/ir") async def read_ir_input(file: UploadFile = File(...)): try: img = np.array(Image.open(BytesIO(await file.read()))) except Exception as e: return {'error': 'couldn\'t read file', 'exception': e} # img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) if len(img.shape) == 2: img = cv2.merge([img for _ in range(3)]) if img.shape == (1024, 1280, 3): diff = (512 - 480) // 2 downsampled = cv2.pyrDown(img) img = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]] img = img.transpose((1,2,0)) ref_pat = reference_pattern.transpose((1,2,0)) pred_disp = inference(img, ref_pat, model) return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img}, cls=NumpyEncoder) @app.get('/') def main(): return {'test': 'abc'}