change a bunch of stuff, add wip lightning implementation
This commit is contained in:
parent
11959eef61
commit
63da24f429
181
api_server.py
181
api_server.py
@ -1,3 +1,4 @@
|
||||
import os
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Union, Literal
|
||||
@ -7,32 +8,149 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch_tensorrt
|
||||
from cv2 import cv2
|
||||
from fastapi import FastAPI, File, UploadFile
|
||||
from PIL import Image
|
||||
|
||||
|
||||
from nets import Model
|
||||
|
||||
from train import inference as ctd_inference
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
|
||||
# reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
|
||||
reference_pattern_path = '/home/nils/kinect_reference_far.png'
|
||||
# reference_pattern_path = '/home/nils/kinect_diff_ref.png'
|
||||
print(reference_pattern_path)
|
||||
reference_pattern = cv2.imread(reference_pattern_path)
|
||||
|
||||
# shift reference pattern a few pixels to the left to simulate further backdrop
|
||||
trans_mat = np.float32([[1, 0, 0], [0, 1, 0]])
|
||||
reference_pattern = cv2.warpAffine(
|
||||
reference_pattern, trans_mat, reference_pattern.shape[1::-1], flags=cv2.INTER_LINEAR
|
||||
)
|
||||
|
||||
iters = 20
|
||||
minimal_data = False
|
||||
minimal_data = True
|
||||
temporal_init = False
|
||||
last_img = None
|
||||
device = torch.device('cuda:0')
|
||||
|
||||
|
||||
def load_model(epoch):
|
||||
def downsize(img):
|
||||
diff = (512 - 480) // 2
|
||||
downsampled = cv2.pyrDown(img)
|
||||
img = downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]]
|
||||
return img
|
||||
|
||||
|
||||
if 1024 in reference_pattern.shape:
|
||||
reference_pattern = downsize(reference_pattern)
|
||||
|
||||
|
||||
def ghetto_lcn(img):
|
||||
# gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
|
||||
gray = img
|
||||
|
||||
float_gray = gray.astype(np.float32) / 255.0
|
||||
|
||||
blur = cv2.GaussianBlur(float_gray, (0, 0), sigmaX=2, sigmaY=2)
|
||||
num = float_gray - blur
|
||||
|
||||
blur = cv2.GaussianBlur(num * num, (0, 0), sigmaX=20, sigmaY=20)
|
||||
den = cv2.pow(blur, 0.5)
|
||||
|
||||
gray = num / den
|
||||
|
||||
# cv2.normalize(gray, dst=gray, alpha=0.0, beta=1.0, norm_type=cv2.NORM_MINMAX)
|
||||
cv2.normalize(gray, dst=gray, alpha=0.0, beta=255.0, norm_type=cv2.NORM_MINMAX)
|
||||
return gray
|
||||
|
||||
|
||||
# reference_pattern = ghetto_lcn(reference_pattern)
|
||||
|
||||
|
||||
def load_model(epoch, use_tensorrt=False):
|
||||
global model
|
||||
epoch = f'epoch-{epoch}.pth' if epoch != 'latest' else 'latest.pth'
|
||||
model_path = f"train_log/models/{epoch}"
|
||||
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
|
||||
# FIXME WIP Workaround Dataparallel TensorRT incompatibility
|
||||
model = nn.DataParallel(model, device_ids=[device])
|
||||
# model.load_state_dict(torch.load(model_path), strict=False)
|
||||
state_dict = torch.load(model_path)['state_dict']
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
if use_tensorrt:
|
||||
np_model = Model(max_disp=256, mixed_precision=False, test_mode=True)
|
||||
np_model.load_state_dict(model.module.state_dict(), strict=True)
|
||||
np_model.to(device)
|
||||
np_model.eval()
|
||||
|
||||
spec_dict = {
|
||||
"inputs": [
|
||||
torch_tensorrt.Input(
|
||||
min_shape=[1, 2, 240, 320],
|
||||
max_shape=[1, 2, 480, 640],
|
||||
opt_shape=[1, 2, 480, 640],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
torch_tensorrt.Input(
|
||||
min_shape=[1, 2, 240, 320],
|
||||
max_shape=[1, 2, 480, 640],
|
||||
opt_shape=[1, 2, 480, 640],
|
||||
dtype=torch.int32,
|
||||
),
|
||||
],
|
||||
"enabled_precisions": {torch.float, torch.half},
|
||||
"refit": False,
|
||||
"debug": False,
|
||||
"device": {
|
||||
"device_type": torch_tensorrt.DeviceType.GPU,
|
||||
"gpu_id": 0,
|
||||
"dla_core": 0,
|
||||
"allow_gpu_fallback": True
|
||||
},
|
||||
"capability": torch_tensorrt.EngineCapability.default,
|
||||
"num_min_timing_iters": 2,
|
||||
"num_avg_timing_iters": 1,
|
||||
}
|
||||
spec = {
|
||||
"forward":
|
||||
torch_tensorrt.ts.TensorRTCompileSpec(**spec_dict)
|
||||
}
|
||||
|
||||
# trt_model = torch_tensorrt.compile(np_model ,
|
||||
# inputs=torch_tensorrt.Input(
|
||||
# min_shape=[1, 2, 240, 320],
|
||||
# max_shape=[1, 2, 480, 640],
|
||||
# opt_shape=[1, 2, 480, 640],
|
||||
# dtype=torch.int32,
|
||||
# inputs = [torch_tensorrt.Input((1, 2, 480, 640)), torch_tensorrt.Input((1, 2, 480, 640))], # input shape
|
||||
# enabled_precisions = {torch_tensorrt.dtype.half} # Run with FP16
|
||||
# )
|
||||
# trt_dw_model = torch_tensorrt.compile(np_model ,
|
||||
# inputs = [torch_tensorrt.Input((1, 2, 240, 320)), torch_tensorrt.Input((1, 2, 240, 320))], # input shape
|
||||
# enabled_precisions = {torch_tensorrt.dtype.half} # Run with FP16
|
||||
# )
|
||||
|
||||
script_model = torch.jit.script(np_model.eval())
|
||||
# script_dw_model = torch.jit.script(trt_dw_model.eval())
|
||||
|
||||
# save the TensorRT embedded Torchscript
|
||||
# torch.jit.save(trt_model, 'trt_torchscript_module.ts')
|
||||
# torch.jit.save(trt_dw_model, 'trt_torchscript_dw_module.ts')
|
||||
print(script_model)
|
||||
print(script_model.forward)
|
||||
print(script_model.forward())
|
||||
print(dir(script_model))
|
||||
|
||||
model = torch._C._jit_to_backend("tensorrt", script_model, spec)
|
||||
|
||||
print(f'loaded model {epoch}')
|
||||
return model
|
||||
|
||||
@ -74,14 +192,37 @@ def inference(left, right, model, n_iter=20):
|
||||
)
|
||||
|
||||
with torch.inference_mode():
|
||||
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
|
||||
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
|
||||
# pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
|
||||
# pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
|
||||
pred_flow_dw2 = model.forward(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
|
||||
pred_flow = model.forward(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
|
||||
|
||||
pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
|
||||
|
||||
return pred_disp
|
||||
|
||||
|
||||
def get_reference():
|
||||
refs = [ref.path for ref in os.scandir('/home/nils/references/')]
|
||||
for ref in refs:
|
||||
reference = cv2.imread(ref)
|
||||
yield reference
|
||||
|
||||
references = get_reference()
|
||||
|
||||
|
||||
@app.post('/params/update_reference')
|
||||
async def update_reference():
|
||||
global references, reference_pattern
|
||||
try:
|
||||
reference_pattern = downsize(next(references))
|
||||
print(reference_pattern.shape)
|
||||
return {'status': 'success'}
|
||||
except StopIteration:
|
||||
references = get_reference()
|
||||
return {'status': 'finished'}
|
||||
|
||||
|
||||
@app.post("/model/update/{epoch}")
|
||||
async def change_model(epoch: Union[int, Literal['latest']]):
|
||||
global model
|
||||
@ -103,8 +244,15 @@ async def set_minimal_data(enable: bool):
|
||||
minimal_data = enable
|
||||
|
||||
|
||||
@app.post("/params/temporal_init/{enable}")
|
||||
async def set_temporal_init(enable: bool):
|
||||
global temporal_init
|
||||
temporal_init = enable
|
||||
|
||||
|
||||
@app.put("/ir")
|
||||
async def read_ir_input(file: UploadFile = File(...)):
|
||||
global last_img, minimal_data
|
||||
try:
|
||||
img = np.array(Image.open(BytesIO(await file.read())))
|
||||
except Exception as e:
|
||||
@ -114,24 +262,35 @@ async def read_ir_input(file: UploadFile = File(...)):
|
||||
if len(img.shape) == 2:
|
||||
img = cv2.merge([img for _ in range(3)])
|
||||
if img.shape == (1024, 1280, 3):
|
||||
diff = (512 - 480) // 2
|
||||
downsampled = cv2.pyrDown(img)
|
||||
img = downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]]
|
||||
img = downsize(img)
|
||||
|
||||
img = img.transpose((1, 2, 0))
|
||||
ref_pat = reference_pattern.transpose((1, 2, 0))
|
||||
# img = img.transpose((1, 2, 0))
|
||||
# ref_pat = reference_pattern.transpose((1, 2, 0))
|
||||
ref_pat = reference_pattern
|
||||
|
||||
start = datetime.now()
|
||||
pred_disp = inference(img, ref_pat, model, iters)
|
||||
if temporal_init:
|
||||
pred_disp = ctd_inference(img, ref_pat, None, None, model, None, iters, False, last_img)
|
||||
last_img = pred_disp
|
||||
else:
|
||||
pred_disp = ctd_inference(img, ref_pat, None, None, model, None, iters, False)
|
||||
# pred_disp = inference(img, ref_pat, model, iters)
|
||||
duration = (datetime.now() - start).total_seconds()
|
||||
|
||||
if minimal_data:
|
||||
return json.dumps({'disp': pred_disp, 'duration': duration}, cls=NumpyEncoder)
|
||||
else:
|
||||
# return json.dumps({'disp': pred_disp, 'input': img, 'duration': duration},
|
||||
return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img, 'duration': duration},
|
||||
cls=NumpyEncoder)
|
||||
|
||||
|
||||
@app.get('/temporal_init')
|
||||
def get_temporal_init():
|
||||
return {'status': 'enabled' if temporal_init else 'disabled'}
|
||||
|
||||
|
||||
|
||||
@app.get('/')
|
||||
def main():
|
||||
return {'test': 'abc'}
|
||||
|
@ -4,18 +4,22 @@ base_lr: 4.0e-4
|
||||
|
||||
nr_gpus: 3
|
||||
batch_size: 4
|
||||
n_total_epoch: 600
|
||||
n_total_epoch: 300
|
||||
minibatch_per_epoch: 500
|
||||
|
||||
loadmodel: ~
|
||||
log_dir: "./train_log"
|
||||
log_dir_lightning: "./train_log_lightning"
|
||||
model_save_freq_epoch: 1
|
||||
|
||||
max_disp: 256
|
||||
image_width: 640
|
||||
image_height: 480
|
||||
# training_data_path: "./stereo_trainset/crestereo"
|
||||
training_data_path: "/media/Data1/connecting_the_dots_data/ctd_data/"
|
||||
pattern_attention: true
|
||||
dataset: "blender"
|
||||
# training_data_path: "/media/Data1/connecting_the_dots_data/ctd_data/"
|
||||
training_data_path: "/media/Data1/connecting_the_dots_data/blender_renders/data"
|
||||
|
||||
log_level: "logging.INFO"
|
||||
|
||||
|
111
dataset.py
111
dataset.py
@ -17,7 +17,7 @@ class Augmentor:
|
||||
scale_min=0.6,
|
||||
scale_max=1.0,
|
||||
seed=0,
|
||||
):
|
||||
):
|
||||
super().__init__()
|
||||
self.image_height = image_height
|
||||
self.image_width = image_width
|
||||
@ -234,12 +234,16 @@ class CREStereoDataset(Dataset):
|
||||
|
||||
|
||||
class CTDDataset(Dataset):
|
||||
def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False):
|
||||
def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False, split=0.9, test_set=False):
|
||||
super().__init__()
|
||||
self.rng = np.random.RandomState(0)
|
||||
self.augment = augment
|
||||
self.blur = blur
|
||||
self.imgs = glob.glob(os.path.join(root, f"{data_type}/*/im0_*.npy"), recursive=True)
|
||||
imgs = glob.glob(os.path.join(root, f"{data_type}/*/im0_*.npy"), recursive=True)
|
||||
if test_set:
|
||||
self.imgs = imgs[:int(split * len(imgs))]
|
||||
else:
|
||||
self.imgs = imgs[int(split * len(imgs)):]
|
||||
self.pattern = cv2.imread(pattern_path)#, cv2.IMREAD_GRAYSCALE)
|
||||
|
||||
if resize_pattern and self.pattern.shape != (480, 640, 3):
|
||||
@ -271,6 +275,10 @@ class CTDDataset(Dataset):
|
||||
|
||||
# read img, disp
|
||||
left_img = np.load(left_path)
|
||||
|
||||
if left_img.dtype == 'float32':
|
||||
left_img = (left_img * 255).astype('uint8')
|
||||
|
||||
left_img = cv2.merge([left_img, left_img, left_img]).reshape((480, 640, 3))
|
||||
|
||||
right_img = self.pattern
|
||||
@ -307,3 +315,100 @@ class CTDDataset(Dataset):
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imgs)
|
||||
|
||||
|
||||
class BlenderDataset(CTDDataset):
|
||||
def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False, split=0.9, test_set=False, use_lightning=False):
|
||||
super().__init__(root, pattern_path)
|
||||
self.use_lightning = use_lightning
|
||||
imgs = [f for f in glob.glob(f"{root}/im_*.png", recursive=True) if not 'depth0001' in f]
|
||||
if test_set:
|
||||
self.imgs = imgs[:int(split * len(imgs))]
|
||||
else:
|
||||
self.imgs = imgs[int(split * len(imgs)):]
|
||||
self.pattern = cv2.imread(pattern_path)#, cv2.IMREAD_GRAYSCALE)
|
||||
|
||||
if resize_pattern and self.pattern.shape != (480, 640, 3):
|
||||
downsampled = cv2.pyrDown(self.pattern)
|
||||
diff = (downsampled.shape[0] - 480) // 2
|
||||
self.pattern = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
|
||||
|
||||
self.augmentor = Augmentor(
|
||||
image_height=480,
|
||||
image_width=640,
|
||||
max_disp=256,
|
||||
scale_min=0.6,
|
||||
scale_max=1.0,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
def __getitem__(self, index):
|
||||
# find path
|
||||
left_path = self.imgs[index]
|
||||
left_disp_path = left_path.split('.')[0] + '_depth0001.png'
|
||||
|
||||
# read img, disp
|
||||
left_img = cv2.imread(left_path)
|
||||
|
||||
if left_img.dtype == 'float32':
|
||||
left_img = (left_img * 255).astype('uint8')
|
||||
|
||||
if left_img.shape != (480, 640, 3):
|
||||
downsampled = cv2.pyrDown(left_img)
|
||||
diff = (downsampled.shape[0] - 480) // 2
|
||||
left_img = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
|
||||
if left_img.shape[-1] != 3:
|
||||
left_img = cv2.merge([left_img, left_img, left_img]).reshape((480, 640, 3))
|
||||
|
||||
right_img = self.pattern
|
||||
left_disp = self.get_disp(left_disp_path)
|
||||
|
||||
if False: # self.rng.binomial(1, 0.5):
|
||||
left_img, right_img = np.fliplr(right_img), np.fliplr(left_img)
|
||||
left_disp, right_disp = np.fliplr(right_disp), np.fliplr(left_disp)
|
||||
left_disp[left_disp == np.inf] = 0
|
||||
|
||||
if self.blur:
|
||||
kernel_size = random.sample([1,3,5,7,9], 1)[0]
|
||||
kernel = (kernel_size, kernel_size)
|
||||
left_img = cv2.GaussianBlur(left_img, kernel, 0)
|
||||
|
||||
# augmentation
|
||||
if not self.augment:
|
||||
_left_img, _right_img, _left_disp, disp_mask = self.augmentor(
|
||||
left_img, right_img, left_disp
|
||||
)
|
||||
else:
|
||||
left_img, right_img, left_disp, disp_mask = self.augmentor(
|
||||
left_img, right_img, left_disp
|
||||
)
|
||||
|
||||
if not self.use_lightning:
|
||||
right_img = right_img.transpose((2, 0, 1)).astype("uint8")
|
||||
return {
|
||||
"left": left_img,
|
||||
"right": right_img,
|
||||
"disparity": left_disp,
|
||||
"mask": disp_mask,
|
||||
}
|
||||
|
||||
right_img = right_img.transpose((2, 0, 1)).astype("uint8")
|
||||
left_img = left_img.transpose((2, 0, 1)).astype("uint8")
|
||||
return left_img, right_img, left_disp, disp_mask
|
||||
|
||||
def get_disp(self, path):
|
||||
baseline = 0.075 # meters
|
||||
fl = 560. # as per CTD
|
||||
depth = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
downsampled = cv2.pyrDown(depth)
|
||||
diff = (downsampled.shape[0] - 480) // 2
|
||||
depth = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
|
||||
# disp = np.load(path).transpose(1,2,0)
|
||||
# disp = baseline * fl / depth
|
||||
# return disp.astype(np.float32) / 32
|
||||
# FIXME temporarily increase disparity until new data with better depth values is generated
|
||||
# higher values seem to speedup convergence, but introduce much stronger artifacting
|
||||
# mystery_factor = 150
|
||||
mystery_factor = 1
|
||||
disp = (baseline * fl * mystery_factor) / depth
|
||||
return disp.astype(np.float32)
|
||||
|
@ -86,8 +86,15 @@ class LocalFeatureTransformer(nn.Module):
|
||||
"""
|
||||
assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
|
||||
|
||||
for layer, name in zip(self.layers, self.layer_names):
|
||||
|
||||
# NOTE Workaround for non statically determinable zip
|
||||
# for layer, name in zip(self.layers, self.layer_names):
|
||||
# layer_zip = ((layer, self.layer_names[i]) for i, layer in enumerate(self.layers))
|
||||
# layer_zip = []
|
||||
# for i, layer in enumerate(self.layers):
|
||||
# layer_zip.append((layer, self.layer_names[i]))
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
name = self.layer_names[i]
|
||||
if name == 'self':
|
||||
feat0 = layer(feat0, feat0, mask0, mask0)
|
||||
feat1 = layer(feat1, feat1, mask1, mask1)
|
||||
@ -97,4 +104,4 @@ class LocalFeatureTransformer(nn.Module):
|
||||
else:
|
||||
raise KeyError
|
||||
|
||||
return feat0, feat1
|
||||
return feat0, feat1
|
||||
|
@ -36,6 +36,12 @@ class CREStereo(nn.Module):
|
||||
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=self.dropout)
|
||||
self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4)
|
||||
|
||||
# # NOTE Position_encoding as workaround for TensorRt
|
||||
# image1_shape = [1, 2, 480, 640]
|
||||
# self.pos_encoding_fn_small = PositionEncodingSine(
|
||||
# d_model=256, max_shape=(image1_shape[2] // 16, image1_shape[3] // 16)
|
||||
# )
|
||||
|
||||
# loftr
|
||||
self.self_att_fn = LocalFeatureTransformer(
|
||||
d_model=256, nhead=8, layer_names=["self"] * 1, attention="linear"
|
||||
@ -81,7 +87,7 @@ class CREStereo(nn.Module):
|
||||
zero_flow = torch.cat((_x, _y), dim=1).to(fmap.device)
|
||||
return zero_flow
|
||||
|
||||
def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False):
|
||||
def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False, self_attend_right=True):
|
||||
""" Estimate optical flow between pair of frames """
|
||||
|
||||
image1 = 2 * (image1 / 255.0) - 1.0
|
||||
@ -130,17 +136,22 @@ class CREStereo(nn.Module):
|
||||
inp_dw16 = F.avg_pool2d(inp, 4, stride=4)
|
||||
|
||||
# positional encoding and self-attention
|
||||
pos_encoding_fn_small = PositionEncodingSine(
|
||||
d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
|
||||
)
|
||||
# pos_encoding_fn_small = PositionEncodingSine(
|
||||
# d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
|
||||
# )
|
||||
# 'n c h w -> n (h w) c'
|
||||
x_tmp = pos_encoding_fn_small(fmap1_dw16)
|
||||
x_tmp = self.pos_encoding_fn_small(fmap1_dw16)
|
||||
fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
|
||||
# 'n c h w -> n (h w) c'
|
||||
x_tmp = pos_encoding_fn_small(fmap2_dw16)
|
||||
x_tmp = self.pos_encoding_fn_small(fmap2_dw16)
|
||||
fmap2_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
|
||||
|
||||
fmap1_dw16, fmap2_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16)
|
||||
# FIXME experimental ! no self-attention for pattern
|
||||
if not self_attend_right:
|
||||
fmap1_dw16, _ = self.self_att_fn(fmap1_dw16, fmap2_dw16)
|
||||
else:
|
||||
fmap1_dw16, fmap2_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16)
|
||||
|
||||
fmap1_dw16, fmap2_dw16 = [
|
||||
x.reshape(x.shape[0], image1.shape[2] // 16, -1, x.shape[2]).permute(0, 3, 1, 2)
|
||||
for x in [fmap1_dw16, fmap2_dw16]
|
||||
@ -258,3 +269,4 @@ class CREStereo(nn.Module):
|
||||
return flow_up
|
||||
|
||||
return predictions
|
||||
|
||||
|
@ -1,6 +1,8 @@
|
||||
from typing import List
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
|
||||
# Ref: https://github.com/princeton-vl/RAFT/blob/master/core/extractor.py
|
||||
class ResidualBlock(nn.Module):
|
||||
@ -96,28 +98,43 @@ class BasicEncoder(nn.Module):
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x: List[Tensor]):
|
||||
# NOTE always assume list, otherwise TensorRT is sad
|
||||
# batch_dim = x[0].shape[0]
|
||||
# x_tensor = torch.cat(list(x), dim=0)
|
||||
|
||||
# if input is list, combine batch dimension
|
||||
is_list = isinstance(x, tuple) or isinstance(x, list)
|
||||
if is_list:
|
||||
batch_dim = x[0].shape[0]
|
||||
x = torch.cat(x, dim=0)
|
||||
x_tensor = torch.cat(x, dim=0)
|
||||
else:
|
||||
x_tensor = x
|
||||
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
print()
|
||||
print()
|
||||
print(x_tensor.shape)
|
||||
print()
|
||||
print()
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x_tensor = self.conv1(x_tensor)
|
||||
x_tensor = self.norm1(x_tensor)
|
||||
x_tensor = self.relu1(x_tensor)
|
||||
|
||||
x = self.conv2(x)
|
||||
x_tensor = self.layer1(x_tensor)
|
||||
x_tensor = self.layer2(x_tensor)
|
||||
x_tensor = self.layer3(x_tensor)
|
||||
|
||||
x_tensor = self.conv2(x_tensor)
|
||||
|
||||
if self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
x_tensor = self.dropout(x_tensor)
|
||||
|
||||
if is_list:
|
||||
x = torch.split(x, x.shape[0]//2, dim=0)
|
||||
x_list = torch.split(x_tensor, x_tensor.shape[0]//2, dim=0)
|
||||
return x_list
|
||||
|
||||
return x
|
||||
x_list = torch.split(x_tensor, x_tensor.shape[0]//2, dim=0)
|
||||
return x_list
|
||||
|
||||
# return list(x)
|
||||
|
@ -77,7 +77,7 @@ class BasicUpdateBlock(nn.Module):
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(256, mask_size**2 *9, 1, padding=0))
|
||||
|
||||
def forward(self, net, inp, corr, flow, upsample=True):
|
||||
def forward(self, net, inp, corr, flow, upsample: bool=True):
|
||||
# print(inp.shape, corr.shape, flow.shape)
|
||||
motion_features = self.encoder(flow, corr)
|
||||
# print(motion_features.shape, inp.shape)
|
||||
|
109
test_model.py
109
test_model.py
@ -3,6 +3,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import cv2
|
||||
import os
|
||||
|
||||
from nets import Model
|
||||
|
||||
@ -16,17 +17,20 @@ device = 'cuda'
|
||||
wandb.init(project="crestereo", entity="cpt-captain")
|
||||
|
||||
|
||||
def do_infer(left_img, right_img, gt_disp, model):
|
||||
disp = ctd_inference(left_img, right_img, gt_disp, None, model, None, n_iter=20, wandb_log=False)
|
||||
def do_infer(left_img, right_img, gt_disp, model, attend_pattern=True):
|
||||
disp = ctd_inference(left_img, right_img, gt_disp, None, model, None, n_iter=20, wandb_log=False, attend_pattern=attend_pattern)
|
||||
|
||||
disp_vis = normalize_and_colormap(disp)
|
||||
gt_disp_vis = normalize_and_colormap(gt_disp)
|
||||
if gt_disp.shape != disp.shape:
|
||||
gt_disp = gt_disp.reshape(disp.shape)
|
||||
disp_err = gt_disp - disp
|
||||
disp_err = normalize_and_colormap(disp_err.abs())
|
||||
# gt_disp_vis = normalize_and_colormap(gt_disp)
|
||||
# if gt_disp.shape != disp.shape:
|
||||
# gt_disp = gt_disp.reshape(disp.shape)
|
||||
# disp_err = gt_disp - disp
|
||||
# disp_err = normalize_and_colormap(disp_err.abs())
|
||||
if isinstance(left_img, torch.Tensor):
|
||||
left_img = left_img.cpu().detach().numpy().astype('uint8')
|
||||
right_img = right_img.cpu().detach().numpy().astype('uint8')
|
||||
|
||||
wandb.log({
|
||||
results = {
|
||||
'disp': wandb.Image(
|
||||
disp,
|
||||
caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}",
|
||||
@ -35,41 +39,60 @@ def do_infer(left_img, right_img, gt_disp, model):
|
||||
disp_vis,
|
||||
caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}",
|
||||
),
|
||||
'gt_disp_vis': wandb.Image(
|
||||
gt_disp_vis,
|
||||
caption=f"GT Disparity \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}",
|
||||
),
|
||||
'disp_err': wandb.Image(
|
||||
disp_err,
|
||||
caption=f"Disparity Error\n{disp_err.min():.{2}f}/{disp_err.max():.{2}f}",
|
||||
),
|
||||
# 'disp_err': wandb.Image(
|
||||
# disp_err,
|
||||
# caption=f"Disparity Error\n{disp_err.min():.{2}f}/{disp_err.max():.{2}f}",
|
||||
# ),
|
||||
'input_left': wandb.Image(
|
||||
left_img.cpu().detach().numpy().astype('uint8'),
|
||||
left_img,
|
||||
caption=f"Input left",
|
||||
),
|
||||
'input_right': wandb.Image(
|
||||
right_img.cpu().detach().numpy().astype('uint8'),
|
||||
right_img,
|
||||
caption=f"Input right",
|
||||
),
|
||||
})
|
||||
}
|
||||
|
||||
if gt_disp is not None:
|
||||
print('logging gt')
|
||||
print(f'gt: {gt_disp.max()}/{gt_disp.min()}/{gt_disp.mean()}')
|
||||
gt_disp_vis = normalize_and_colormap(gt_disp)
|
||||
results.update({
|
||||
'gt_disp_vis': wandb.Image(
|
||||
gt_disp_vis,
|
||||
caption=f"GT Disparity \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}",
|
||||
)})
|
||||
wandb.log(results)
|
||||
|
||||
|
||||
def downsample(img, half_height_out=480):
|
||||
downsampled = cv2.pyrDown(img)
|
||||
diff = (downsampled.shape[0] - half_height_out) // 2
|
||||
return downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# model_path = "models/crestereo_eth3d.pth"
|
||||
model_path = "train_log/models/latest.pth"
|
||||
# model_path = "train_log/models/epoch-120.pth"
|
||||
# model_path = "train_log/models/epoch-250.pth"
|
||||
|
||||
print(model_path)
|
||||
# reference_pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png'
|
||||
reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
|
||||
# reference_pattern_path = '/home/nils/new_reference.png'
|
||||
# reference_pattern_path = '/home/nils/kinect_reference_high_res.png'
|
||||
reference_pattern_path = '/home/nils/miniprojekt/kinect_high_res_thresh_denoised.png'
|
||||
# reference_pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
|
||||
# reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
|
||||
# reference_pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png'
|
||||
|
||||
data_type = 'kinect'
|
||||
# data_type = 'kinect'
|
||||
data_type = 'blender'
|
||||
augment = False
|
||||
|
||||
args = parse_yaml("cfgs/train.yaml")
|
||||
|
||||
wandb.config.update({'model_path': model_path, 'reference_pattern': reference_pattern_path, 'augment': augment})
|
||||
wandb.config.update({'model_path': model_path, 'reference_pattern': reference_pattern_path, 'augment': augment, 'data_type': data_type, 'pattern_self_attention': args.pattern_attention})
|
||||
|
||||
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
|
||||
model = nn.DataParallel(model, device_ids=[device])
|
||||
@ -78,16 +101,32 @@ if __name__ == '__main__':
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type,
|
||||
pattern_path=reference_pattern_path, augment=augment)
|
||||
dataloader = DataLoader(dataset, args.batch_size, shuffle=True,
|
||||
num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True)
|
||||
for batch in dataloader:
|
||||
for left, right, disparity in zip(batch['left'], batch['right'], batch['disparity']):
|
||||
right = right.transpose(0, 2).transpose(0, 1)
|
||||
left_img = left
|
||||
imgL = left.cpu().detach().numpy()
|
||||
right_img = right
|
||||
imgR = right.cpu().detach().numpy()
|
||||
gt_disp = disparity
|
||||
do_infer(left_img, right_img, gt_disp, model)
|
||||
# dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type,
|
||||
# pattern_path=reference_pattern_path, augment=augment)
|
||||
# dataloader = DataLoader(dataset, args.batch_size, shuffle=True,
|
||||
# num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True)
|
||||
|
||||
# for batch in dataloader:
|
||||
gt_disp = None
|
||||
right = downsample(cv2.imread(reference_pattern_path))
|
||||
|
||||
if data_type == 'blender':
|
||||
img_path = '/media/Data1/connecting_the_dots_data/blender_renders/data/'
|
||||
elif data_type == 'kinect':
|
||||
img_path = '/home/nils/kinect_pngs/ir/'
|
||||
|
||||
for img in sorted(list(entry for entry in os.scandir(img_path) if 'depth' not in entry.name), key=lambda x:x.name)[:25]:
|
||||
print(img.path)
|
||||
if data_type == 'blender':
|
||||
baseline = 0.075 # meters
|
||||
fl = 560. # as per CTD
|
||||
|
||||
gt_path = img.path.rsplit('.')[0] + '_depth0001.png'
|
||||
gt_depth = downsample(cv2.imread(gt_path))
|
||||
|
||||
mystery_factor = 35 # we don't get reasonable disparities due to incorrect depth scaling (or something like that)
|
||||
gt_disp = (baseline * fl * mystery_factor) / gt_depth
|
||||
|
||||
left = downsample(cv2.imread(img.path))
|
||||
|
||||
do_infer(left, right, gt_disp, model, attend_pattern=args.pattern_attention)
|
||||
|
191
train.py
191
train.py
@ -9,7 +9,7 @@ import yaml
|
||||
|
||||
from nets import Model
|
||||
# from dataset import CREStereoDataset
|
||||
from dataset import CREStereoDataset, CTDDataset
|
||||
from dataset import BlenderDataset, CREStereoDataset, CTDDataset
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@ -32,14 +32,18 @@ def normalize_and_colormap(img):
|
||||
return ret
|
||||
|
||||
|
||||
def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=True):
|
||||
def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=True, last_img=None, test=False, attend_pattern=True):
|
||||
|
||||
print("Model Forwarding...")
|
||||
left = left.cpu().detach().numpy()
|
||||
if isinstance(left, torch.Tensor):
|
||||
left = left.cpu().detach().numpy()
|
||||
imgR = right.cpu().detach().numpy()
|
||||
imgL = left
|
||||
imgR = right.cpu().detach().numpy()
|
||||
imgR = right
|
||||
imgL = np.ascontiguousarray(imgL[None, :, :, :])
|
||||
imgR = np.ascontiguousarray(imgR[None, :, :, :])
|
||||
|
||||
flow_init = None
|
||||
|
||||
# chosen for convenience
|
||||
device = torch.device('cuda:0')
|
||||
@ -55,19 +59,54 @@ def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=Tru
|
||||
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
).clamp(min=0, max=255)
|
||||
imgR_dw2 = F.interpolate(
|
||||
imgR,
|
||||
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
).clamp(min=0, max=255)
|
||||
if last_img is not None:
|
||||
print('using flow_initialization')
|
||||
print(last_img.shape)
|
||||
# FIXME trying 8bit normalization because sometimes we get really high values that don't seem to help
|
||||
print(last_img.max(), last_img.min())
|
||||
if last_img.min() < 0:
|
||||
# print('Negative disparity detected. shifting...')
|
||||
last_img = last_img - last_img.min()
|
||||
if last_img.max() > 255:
|
||||
# print('Excessive disparity detected. scaling...')
|
||||
last_img = last_img / (last_img.max() / 255)
|
||||
|
||||
|
||||
last_img = np.dstack([last_img, last_img])
|
||||
# last_img = np.dstack([last_img, last_img, last_img])
|
||||
last_img = np.dstack([last_img])
|
||||
last_img = last_img.reshape((1, 2, 480, 640))
|
||||
# print(last_img.shape)
|
||||
# print(last_img.dtype)
|
||||
# print(last_img.max(), last_img.min())
|
||||
flow_init = torch.tensor(last_img.astype("float32")).to(device)
|
||||
# flow_init = F.interpolate(
|
||||
# last_img,
|
||||
# size=(last_img.shape[0] // 2, last_img.shape[1] // 2),
|
||||
# mode="bilinear",
|
||||
# align_corners=True,
|
||||
# )
|
||||
with torch.inference_mode():
|
||||
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
|
||||
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
|
||||
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=flow_init, self_attend_right=attend_pattern)
|
||||
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2, self_attend_right=attend_pattern)
|
||||
pf_base = pred_flow
|
||||
if isinstance(pf_base, list):
|
||||
pf_base = pred_flow[0]
|
||||
pf = torch.squeeze(pf_base[:, 0, :, :]).cpu().detach().numpy()
|
||||
print('pred_flow max min')
|
||||
print(pf.max(), pf.min())
|
||||
|
||||
|
||||
if not wandb_log:
|
||||
if test:
|
||||
return pred_flow
|
||||
return torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
|
||||
|
||||
log = {}
|
||||
@ -96,30 +135,36 @@ def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=Tru
|
||||
np.array([pred_disp.reshape(480, 640)]),
|
||||
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
|
||||
)
|
||||
log[f'pred_norm_{i}'] = wandb.Image(
|
||||
np.array([pred_disp_norm.reshape(480, 640)]),
|
||||
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
|
||||
)
|
||||
# log[f'pred_norm_{i}'] = wandb.Image(
|
||||
# np.array([pred_disp_norm.reshape(480, 640)]),
|
||||
# caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
|
||||
# )
|
||||
|
||||
log[f'pred_dw2_{i}'] = wandb.Image(
|
||||
np.array([pred_disp_dw2.reshape(240, 320)]),
|
||||
caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
|
||||
)
|
||||
log[f'pred_dw2_norm_{i}'] = wandb.Image(
|
||||
np.array([pred_disp_dw2_norm.reshape(240, 320)]),
|
||||
caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
|
||||
)
|
||||
# log[f'pred_dw2_{i}'] = wandb.Image(
|
||||
# np.array([pred_disp_dw2.reshape(240, 320)]),
|
||||
# caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
|
||||
# )
|
||||
# log[f'pred_dw2_norm_{i}'] = wandb.Image(
|
||||
# np.array([pred_disp_dw2_norm.reshape(240, 320)]),
|
||||
# caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
|
||||
# )
|
||||
|
||||
|
||||
log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left")
|
||||
log['input_right'] = wandb.Image(right.cpu().detach().numpy().transpose(1,2,0).astype('uint8'), caption="Input Right")
|
||||
input_right = right.cpu().detach().numpy() if isinstance(right, torch.Tensor) else right
|
||||
if input_right.shape != (480, 640, 3):
|
||||
input_right.transpose(1,2,0)
|
||||
log['input_right'] = wandb.Image(input_right.astype('uint8'), caption="Input Right")
|
||||
|
||||
log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}")
|
||||
|
||||
gt_disp = gt_disp.cpu().detach().numpy() if isinstance(gt_disp, torch.Tensor) else gt_disp
|
||||
disp = disp.cpu().detach().numpy() if isinstance(disp, torch.Tensor) else disp
|
||||
|
||||
disp_error = gt_disp - disp
|
||||
log['disp_error'] = wandb.Image(
|
||||
normalize_and_colormap(disp_error.abs()),
|
||||
caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{disp_error.abs().mean():.{2}f}",
|
||||
normalize_and_colormap(abs(disp_error)),
|
||||
caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{abs(disp_error).mean():.{2}f}",
|
||||
)
|
||||
|
||||
|
||||
@ -129,6 +174,7 @@ def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=Tru
|
||||
)
|
||||
|
||||
wandb.log(log)
|
||||
return pred_flow
|
||||
|
||||
|
||||
def parse_yaml(file_path: str) -> namedtuple:
|
||||
@ -172,12 +218,25 @@ def adjust_learning_rate(optimizer, epoch):
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8):
|
||||
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
|
||||
'''
|
||||
valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W)
|
||||
flow_preds[0]: (B, 2, H, W)
|
||||
flow_gt: (B, 2, H, W)
|
||||
'''
|
||||
if test:
|
||||
# print('sequence loss')
|
||||
if valid.shape != (2, 480, 640):
|
||||
valid = torch.stack([valid, valid])#.transpose(0,1)#.transpose(1,2)
|
||||
# print(valid.shape)
|
||||
#valid = torch.stack([valid, valid])
|
||||
# print(valid.shape)
|
||||
if valid.shape != (2, 480, 640):
|
||||
valid = valid.transpose(0,1)
|
||||
# print(valid.shape)
|
||||
# print(valid.shape)
|
||||
# print(flow_preds[0].shape)
|
||||
# print(flow_gt.shape)
|
||||
n_predictions = len(flow_preds)
|
||||
flow_loss = 0.0
|
||||
|
||||
@ -260,20 +319,41 @@ def main(args):
|
||||
start_iters = 0
|
||||
|
||||
# pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png'
|
||||
pattern_path = '/home/nils/kinect_reference_cropped.png'
|
||||
# pattern_path = '/home/nils/kinect_reference_cropped.png'
|
||||
# pattern_path = '/home/nils/kinect_reference_far.png'
|
||||
# pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
|
||||
pattern_path = '/home/nils/miniprojekt/kinect_high_res_thresh_denoised.png'
|
||||
# pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png'
|
||||
# datasets
|
||||
# dataset = CREStereoDataset(args.training_data_path)
|
||||
dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path)
|
||||
# dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path)
|
||||
# test_dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path, test_set=True)
|
||||
if args.dataset == 'blender':
|
||||
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
|
||||
dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path)
|
||||
test_dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path, test_set=True)
|
||||
elif args.dataset == 'ctd':
|
||||
dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path)
|
||||
test_dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path, test_set=True)
|
||||
else:
|
||||
print('unrecognized dataset')
|
||||
quit()
|
||||
|
||||
test_data_iter = iter(test_dataset)
|
||||
# if rank == 0:
|
||||
worklog.info(f"Dataset size: {len(dataset)}")
|
||||
print(args.batch_size)
|
||||
dataloader = DataLoader(dataset, args.batch_size, shuffle=True,
|
||||
num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
|
||||
num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True)
|
||||
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
|
||||
test_dataloader = DataLoader(test_dataset, args.batch_size, shuffle=False,
|
||||
num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True)
|
||||
|
||||
# counter
|
||||
cur_iters = start_iters
|
||||
total_iters = args.minibatch_per_epoch * args.n_total_epoch
|
||||
t0 = time.perf_counter()
|
||||
test_idx = 0
|
||||
for epoch_idx in range(start_epoch_idx, args.n_total_epoch + 1):
|
||||
|
||||
# adjust learning rate
|
||||
@ -310,24 +390,59 @@ def main(args):
|
||||
# forward
|
||||
# left = left.transpose(1, 2).transpose(2, 3)
|
||||
left = left.transpose(1, 3).transpose(2, 3)
|
||||
right = right.transpose(2, 3).transpose(2, 3)#.transpose(1, 2)
|
||||
flow_predictions = model(left.cuda(), right.cuda())
|
||||
# right = right.transpose(2, 3).transpose(2, 3)#.transpose(1, 2)
|
||||
flow_predictions = model(left.cuda(), right.cuda(), self_attend_right=args.pattern_attention)
|
||||
|
||||
# loss & backword
|
||||
loss = sequence_loss(
|
||||
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
||||
)
|
||||
|
||||
if batch_idx % 128 == 0:
|
||||
inference(
|
||||
mini_batch_data['left'][0],
|
||||
mini_batch_data['right'][0],
|
||||
mini_batch_data['disparity'][0],
|
||||
mini_batch_data['mask'][0],
|
||||
model,
|
||||
batch_idx,
|
||||
)
|
||||
if batch_idx % 512 == 0:
|
||||
test_idx = 0
|
||||
test_loss = 0
|
||||
for i, test_batch in enumerate(test_dataset):
|
||||
# test_batch = next(test_data_iter)
|
||||
if i >= 24:
|
||||
break
|
||||
|
||||
# TODO refactor, DRY
|
||||
left, right, gt_disp, valid_mask = (
|
||||
test_batch['left'],
|
||||
test_batch['right'],
|
||||
torch.Tensor(test_batch['disparity']).cuda(),
|
||||
torch.Tensor(test_batch['mask']).cuda(),
|
||||
)
|
||||
gt_disp = torch.dstack([gt_disp, gt_disp]).transpose(2,0).transpose(1,2)
|
||||
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
|
||||
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
|
||||
# print(f'left {left.shape}, right {right.shape}')
|
||||
# left = left.transpose([2, 0, 1])
|
||||
right = right.transpose([1, 2, 0])
|
||||
# right = right.transpose(2, 3).transpose(2, 3)#.transpose(1, 2)
|
||||
# print(f'left {left.shape}, right {right.shape}')
|
||||
|
||||
model.eval()
|
||||
flow_predictions = inference(
|
||||
left,
|
||||
right,
|
||||
# gt_disp,
|
||||
torch.Tensor(test_batch['disparity']).cuda(),
|
||||
valid_mask,
|
||||
model,
|
||||
test_idx,
|
||||
wandb_log=i % 4 == 0,
|
||||
test=True,
|
||||
)
|
||||
test_idx += 1
|
||||
test_loss += sequence_loss(
|
||||
flow_predictions, gt_flow, valid_mask, gamma=0.8, test=True
|
||||
).data.item()
|
||||
model.train()
|
||||
|
||||
avg_test_loss = test_loss / test_idx
|
||||
print(f'test_loss: {test_loss}\nlen test: {test_idx}\navg. loss: {avg_test_loss}')
|
||||
metrics['test/loss'] = avg_test_loss
|
||||
# loss stats
|
||||
loss_item = loss.data.item()
|
||||
epoch_total_train_loss += loss_item
|
||||
|
346
train_lightning.py
Normal file
346
train_lightning.py
Normal file
@ -0,0 +1,346 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
|
||||
import yaml
|
||||
# from tensorboardX import SummaryWriter
|
||||
|
||||
from nets import Model
|
||||
# from dataset import CREStereoDataset
|
||||
from dataset import BlenderDataset, CREStereoDataset, CTDDataset
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from pytorch_lightning.lite import LightningLite
|
||||
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
|
||||
from pytorch_lightning import Trainer, seed_everything
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
seed_everything(42, workers=True)
|
||||
|
||||
|
||||
import wandb
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
|
||||
def normalize_and_colormap(img):
|
||||
ret = (img - img.min()) / (img.max() - img.min()) * 255.0
|
||||
if isinstance(ret, torch.Tensor):
|
||||
ret = ret.cpu().detach().numpy()
|
||||
ret = ret.astype("uint8")
|
||||
ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO)
|
||||
return ret
|
||||
|
||||
|
||||
def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=True, last_img=None, test=False, attend_pattern=True):
|
||||
|
||||
print("Model Forwarding...")
|
||||
if isinstance(left, torch.Tensor):
|
||||
left = left# .cpu().detach().numpy()
|
||||
imgR = right# .cpu().detach().numpy()
|
||||
imgL = left
|
||||
imgR = right
|
||||
imgL = np.ascontiguousarray(imgL[None, :, :, :])
|
||||
imgR = np.ascontiguousarray(imgR[None, :, :, :])
|
||||
|
||||
flow_init = None
|
||||
|
||||
# chosen for convenience
|
||||
|
||||
imgL = torch.tensor(imgL.astype("float32"))
|
||||
imgR = torch.tensor(imgR.astype("float32"))
|
||||
imgL = imgL.transpose(2,3).transpose(1,2)
|
||||
if imgL.shape != imgR.shape:
|
||||
imgR = imgR.transpose(2,3).transpose(1,2)
|
||||
|
||||
imgL_dw2 = F.interpolate(
|
||||
imgL,
|
||||
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
).clamp(min=0, max=255)
|
||||
imgR_dw2 = F.interpolate(
|
||||
imgR,
|
||||
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
).clamp(min=0, max=255)
|
||||
if last_img is not None:
|
||||
print('using flow_initialization')
|
||||
print(last_img.shape)
|
||||
# FIXME trying 8bit normalization because sometimes we get really high values that don't seem to help
|
||||
print(last_img.max(), last_img.min())
|
||||
if last_img.min() < 0:
|
||||
# print('Negative disparity detected. shifting...')
|
||||
last_img = last_img - last_img.min()
|
||||
if last_img.max() > 255:
|
||||
# print('Excessive disparity detected. scaling...')
|
||||
last_img = last_img / (last_img.max() / 255)
|
||||
|
||||
|
||||
last_img = np.dstack([last_img, last_img])
|
||||
# last_img = np.dstack([last_img, last_img, last_img])
|
||||
last_img = np.dstack([last_img])
|
||||
last_img = last_img.reshape((1, 2, 480, 640))
|
||||
# print(last_img.shape)
|
||||
# print(last_img.dtype)
|
||||
# print(last_img.max(), last_img.min())
|
||||
flow_init = torch.tensor(last_img.astype("float32"))
|
||||
# flow_init = F.interpolate(
|
||||
# last_img,
|
||||
# size=(last_img.shape[0] // 2, last_img.shape[1] // 2),
|
||||
# mode="bilinear",
|
||||
# align_corners=True,
|
||||
# )
|
||||
with torch.inference_mode():
|
||||
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=flow_init, self_attend_right=attend_pattern)
|
||||
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2, self_attend_right=attend_pattern)
|
||||
pf_base = pred_flow
|
||||
if isinstance(pf_base, list):
|
||||
pf_base = pred_flow[0]
|
||||
pf = torch.squeeze(pf_base[:, 0, :, :])# .cpu().detach().numpy()
|
||||
print('pred_flow max min')
|
||||
print(pf.max(), pf.min())
|
||||
|
||||
|
||||
if not wandb_log:
|
||||
if test:
|
||||
return pred_flow
|
||||
return torch.squeeze(pred_flow[:, 0, :, :])# .cpu().detach().numpy()
|
||||
|
||||
log = {}
|
||||
in_h, in_w = left.shape[:2]
|
||||
|
||||
# Resize image in case the GPU memory overflows
|
||||
eval_h, eval_w = (in_h,in_w)
|
||||
|
||||
for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)):
|
||||
pred_disp = torch.squeeze(pf[:, 0, :, :])# .cpu().detach().numpy()
|
||||
pred_disp_dw2 = torch.squeeze(pf_dw2[:, 0, :, :])# .cpu().detach().numpy()
|
||||
|
||||
# pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
|
||||
# pred_disp_dw2_norm = cv2.normalize(pred_disp_dw2, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
|
||||
|
||||
if i == n_iter-1:
|
||||
t = float(in_w) / float(eval_w)
|
||||
disp = cv2.resize(pred_disp.cpu().detach().numpy(), (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t
|
||||
|
||||
log[f'disp_vis'] = wandb.Image(
|
||||
normalize_and_colormap(disp),
|
||||
caption=f"Disparity \n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
|
||||
)
|
||||
|
||||
log[f'pred_{i}'] = wandb.Image(
|
||||
np.array([pred_disp.cpu().detach().numpy().reshape(480, 640)]),
|
||||
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
|
||||
)
|
||||
# log[f'pred_norm_{i}'] = wandb.Image(
|
||||
# np.array([pred_disp_norm.reshape(480, 640)]),
|
||||
# caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
|
||||
# )
|
||||
|
||||
# log[f'pred_dw2_{i}'] = wandb.Image(
|
||||
# np.array([pred_disp_dw2.reshape(240, 320)]),
|
||||
# caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
|
||||
# )
|
||||
# log[f'pred_dw2_norm_{i}'] = wandb.Image(
|
||||
# np.array([pred_disp_dw2_norm.reshape(240, 320)]),
|
||||
# caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
|
||||
# )
|
||||
|
||||
|
||||
log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left")
|
||||
input_right = right# .cpu().detach().numpy() if isinstance(right, torch.Tensor) else right
|
||||
if input_right.shape != (480, 640, 3):
|
||||
input_right.transpose(1,2,0)
|
||||
log['input_right'] = wandb.Image(input_right.astype('uint8'), caption="Input Right")
|
||||
|
||||
log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}")
|
||||
|
||||
gt_disp = gt_disp# .cpu().detach().numpy() if isinstance(gt_disp, torch.Tensor) else gt_disp
|
||||
disp = disp# .cpu().detach().numpy() if isinstance(disp, torch.Tensor) else disp
|
||||
|
||||
disp_error = gt_disp - disp
|
||||
log['disp_error'] = wandb.Image(
|
||||
normalize_and_colormap(abs(disp_error)),
|
||||
caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{abs(disp_error).mean():.{2}f}",
|
||||
)
|
||||
|
||||
|
||||
log[f'gt_disp_vis'] = wandb.Image(
|
||||
normalize_and_colormap(gt_disp),
|
||||
caption=f"GT Disp Vis \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}",
|
||||
)
|
||||
|
||||
wandb.log(log)
|
||||
return pred_flow
|
||||
|
||||
|
||||
def parse_yaml(file_path: str) -> namedtuple:
|
||||
"""Parse yaml configuration file and return the object in `namedtuple`."""
|
||||
with open(file_path, "rb") as f:
|
||||
cfg: dict = yaml.safe_load(f)
|
||||
args = namedtuple("train_args", cfg.keys())(*cfg.values())
|
||||
return args
|
||||
|
||||
|
||||
def format_time(elapse):
|
||||
elapse = int(elapse)
|
||||
hour = elapse // 3600
|
||||
minute = elapse % 3600 // 60
|
||||
seconds = elapse % 60
|
||||
return "{:02d}:{:02d}:{:02d}".format(hour, minute, seconds)
|
||||
|
||||
|
||||
def ensure_dir(path):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
|
||||
def adjust_learning_rate(optimizer, epoch):
|
||||
|
||||
warm_up = 0.02
|
||||
const_range = 0.6
|
||||
min_lr_rate = 0.05
|
||||
|
||||
if epoch <= args.n_total_epoch * warm_up:
|
||||
lr = (1 - min_lr_rate) * args.base_lr / (
|
||||
args.n_total_epoch * warm_up
|
||||
) * epoch + min_lr_rate * args.base_lr
|
||||
elif args.n_total_epoch * warm_up < epoch <= args.n_total_epoch * const_range:
|
||||
lr = args.base_lr
|
||||
else:
|
||||
lr = (min_lr_rate - 1) * args.base_lr / (
|
||||
(1 - const_range) * args.n_total_epoch
|
||||
) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * args.base_lr
|
||||
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
|
||||
'''
|
||||
valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W)
|
||||
flow_preds[0]: (B, 2, H, W)
|
||||
flow_gt: (B, 2, H, W)
|
||||
'''
|
||||
if test:
|
||||
# print('sequence loss')
|
||||
if valid.shape != (2, 480, 640):
|
||||
valid = torch.stack([valid, valid])#.transpose(0,1)#.transpose(1,2)
|
||||
# print(valid.shape)
|
||||
#valid = torch.stack([valid, valid])
|
||||
# print(valid.shape)
|
||||
if valid.shape != (2, 480, 640):
|
||||
valid = valid.transpose(0,1)
|
||||
# print(valid.shape)
|
||||
# print(valid.shape)
|
||||
# print(flow_preds[0].shape)
|
||||
# print(flow_gt.shape)
|
||||
n_predictions = len(flow_preds)
|
||||
flow_loss = 0.0
|
||||
|
||||
# TEST
|
||||
flow_gt = torch.squeeze(flow_gt, dim=-1)
|
||||
|
||||
for i in range(n_predictions):
|
||||
i_weight = gamma ** (n_predictions - i - 1)
|
||||
i_loss = torch.abs(flow_preds[i] - flow_gt)
|
||||
flow_loss += i_weight * (valid.unsqueeze(1) * i_loss).mean()
|
||||
|
||||
return flow_loss
|
||||
|
||||
|
||||
|
||||
class CREStereoLightning(LightningModule):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.batch_size = args.batch_size
|
||||
self.model = Model(
|
||||
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False
|
||||
)
|
||||
|
||||
def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False, self_attend_right=True):
|
||||
return self.model(image1, image2, flow_init, iters, upsample, test_mode, self_attend_right)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# loss = self(batch)
|
||||
left, right, gt_disp, valid_mask = batch
|
||||
left = torch.Tensor(left).to(self.device)
|
||||
right = torch.Tensor(right).to(self.device)
|
||||
left = left
|
||||
right = right
|
||||
flow_predictions = self.forward(left, right)
|
||||
loss = sequence_loss(
|
||||
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
||||
)
|
||||
self.log("train_loss", loss)
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
left, right, gt_disp, valid_mask = batch
|
||||
left = torch.Tensor(left).to(self.device)
|
||||
right = torch.Tensor(right).to(self.device)
|
||||
print(left.shape)
|
||||
print(right.shape)
|
||||
flow_predictions = self.forward(left, right)
|
||||
val_loss = sequence_loss(
|
||||
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
||||
)
|
||||
self.log("val_loss", val_loss)
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
left, right, gt_disp, valid_mask = batch
|
||||
# left, right, gt_disp, valid_mask = (
|
||||
# batch["left"],
|
||||
# batch["right"],
|
||||
# batch["disparity"],
|
||||
# batch["mask"],
|
||||
# )
|
||||
left = torch.Tensor(left).to(self.device)
|
||||
right = torch.Tensor(right).to(self.device)
|
||||
flow_predictions = self.forward(left, right)
|
||||
test_loss = sequence_loss(
|
||||
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
||||
)
|
||||
self.log("test_loss", test_loss)
|
||||
|
||||
def configure_optimizers(self):
|
||||
return optim.Adam(self.model.parameters(), lr=0.01, betas=(0.9, 0.999))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# train configuration
|
||||
args = parse_yaml("cfgs/train.yaml")
|
||||
# wandb.init(project="crestereo-lightning", entity="cpt-captain")
|
||||
# Lite(strategy='dp', accelerator='gpu', devices=2).run(args)
|
||||
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
|
||||
model = CREStereoLightning(args)
|
||||
dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path, use_lightning=True)
|
||||
test_dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path, test_set=True, use_lightning=True)
|
||||
print(len(dataset))
|
||||
print(len(test_dataset))
|
||||
wandb_logger = WandbLogger(project="crestereo-lightning")
|
||||
wandb.config.update(args._asdict())
|
||||
|
||||
trainer = Trainer(
|
||||
max_epochs=args.n_total_epoch,
|
||||
accelerator='gpu',
|
||||
devices=2,
|
||||
# auto_scale_batch_size='binsearch',
|
||||
# strategy='ddp',
|
||||
deterministic=True,
|
||||
check_val_every_n_epoch=1,
|
||||
limit_val_batches=24,
|
||||
limit_test_batches=24,
|
||||
logger=wandb_logger,
|
||||
default_root_dir=args.log_dir_lightning,
|
||||
)
|
||||
# trainer.tune(model)
|
||||
trainer.fit(model, dataset, test_dataset)
|
Loading…
Reference in New Issue
Block a user