change a bunch of stuff, add wip lightning implementation

main
Cpt.Captain 2 years ago
parent 11959eef61
commit 63da24f429
  1. 181
      api_server.py
  2. 8
      cfgs/train.yaml
  3. 111
      dataset.py
  4. 13
      nets/attention/transformer.py
  5. 26
      nets/crestereo.py
  6. 41
      nets/extractor.py
  7. 2
      nets/update.py
  8. 111
      test_model.py
  9. 193
      train.py
  10. 346
      train_lightning.py

@ -1,3 +1,4 @@
import os
import json import json
from datetime import datetime from datetime import datetime
from typing import Union, Literal from typing import Union, Literal
@ -7,32 +8,149 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch_tensorrt
from cv2 import cv2 from cv2 import cv2
from fastapi import FastAPI, File, UploadFile from fastapi import FastAPI, File, UploadFile
from PIL import Image from PIL import Image
from nets import Model from nets import Model
from train import inference as ctd_inference
app = FastAPI() app = FastAPI()
reference_pattern_path = '/home/nils/kinect_reference_cropped.png' # reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
reference_pattern_path = '/home/nils/kinect_reference_far.png'
# reference_pattern_path = '/home/nils/kinect_diff_ref.png'
print(reference_pattern_path)
reference_pattern = cv2.imread(reference_pattern_path) reference_pattern = cv2.imread(reference_pattern_path)
# shift reference pattern a few pixels to the left to simulate further backdrop
trans_mat = np.float32([[1, 0, 0], [0, 1, 0]])
reference_pattern = cv2.warpAffine(
reference_pattern, trans_mat, reference_pattern.shape[1::-1], flags=cv2.INTER_LINEAR
)
iters = 20 iters = 20
minimal_data = False minimal_data = True
temporal_init = False
last_img = None
device = torch.device('cuda:0') device = torch.device('cuda:0')
def load_model(epoch): def downsize(img):
diff = (512 - 480) // 2
downsampled = cv2.pyrDown(img)
img = downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]]
return img
if 1024 in reference_pattern.shape:
reference_pattern = downsize(reference_pattern)
def ghetto_lcn(img):
# gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
gray = img
float_gray = gray.astype(np.float32) / 255.0
blur = cv2.GaussianBlur(float_gray, (0, 0), sigmaX=2, sigmaY=2)
num = float_gray - blur
blur = cv2.GaussianBlur(num * num, (0, 0), sigmaX=20, sigmaY=20)
den = cv2.pow(blur, 0.5)
gray = num / den
# cv2.normalize(gray, dst=gray, alpha=0.0, beta=1.0, norm_type=cv2.NORM_MINMAX)
cv2.normalize(gray, dst=gray, alpha=0.0, beta=255.0, norm_type=cv2.NORM_MINMAX)
return gray
# reference_pattern = ghetto_lcn(reference_pattern)
def load_model(epoch, use_tensorrt=False):
global model global model
epoch = f'epoch-{epoch}.pth' if epoch != 'latest' else 'latest.pth' epoch = f'epoch-{epoch}.pth' if epoch != 'latest' else 'latest.pth'
model_path = f"train_log/models/{epoch}" model_path = f"train_log/models/{epoch}"
model = Model(max_disp=256, mixed_precision=False, test_mode=True) model = Model(max_disp=256, mixed_precision=False, test_mode=True)
# FIXME WIP Workaround Dataparallel TensorRT incompatibility
model = nn.DataParallel(model, device_ids=[device]) model = nn.DataParallel(model, device_ids=[device])
# model.load_state_dict(torch.load(model_path), strict=False) # model.load_state_dict(torch.load(model_path), strict=False)
state_dict = torch.load(model_path)['state_dict'] state_dict = torch.load(model_path)['state_dict']
model.load_state_dict(state_dict, strict=True) model.load_state_dict(state_dict, strict=True)
model.to(device) model.to(device)
model.eval() model.eval()
if use_tensorrt:
np_model = Model(max_disp=256, mixed_precision=False, test_mode=True)
np_model.load_state_dict(model.module.state_dict(), strict=True)
np_model.to(device)
np_model.eval()
spec_dict = {
"inputs": [
torch_tensorrt.Input(
min_shape=[1, 2, 240, 320],
max_shape=[1, 2, 480, 640],
opt_shape=[1, 2, 480, 640],
dtype=torch.int32,
),
torch_tensorrt.Input(
min_shape=[1, 2, 240, 320],
max_shape=[1, 2, 480, 640],
opt_shape=[1, 2, 480, 640],
dtype=torch.int32,
),
],
"enabled_precisions": {torch.float, torch.half},
"refit": False,
"debug": False,
"device": {
"device_type": torch_tensorrt.DeviceType.GPU,
"gpu_id": 0,
"dla_core": 0,
"allow_gpu_fallback": True
},
"capability": torch_tensorrt.EngineCapability.default,
"num_min_timing_iters": 2,
"num_avg_timing_iters": 1,
}
spec = {
"forward":
torch_tensorrt.ts.TensorRTCompileSpec(**spec_dict)
}
# trt_model = torch_tensorrt.compile(np_model ,
# inputs=torch_tensorrt.Input(
# min_shape=[1, 2, 240, 320],
# max_shape=[1, 2, 480, 640],
# opt_shape=[1, 2, 480, 640],
# dtype=torch.int32,
# inputs = [torch_tensorrt.Input((1, 2, 480, 640)), torch_tensorrt.Input((1, 2, 480, 640))], # input shape
# enabled_precisions = {torch_tensorrt.dtype.half} # Run with FP16
# )
# trt_dw_model = torch_tensorrt.compile(np_model ,
# inputs = [torch_tensorrt.Input((1, 2, 240, 320)), torch_tensorrt.Input((1, 2, 240, 320))], # input shape
# enabled_precisions = {torch_tensorrt.dtype.half} # Run with FP16
# )
script_model = torch.jit.script(np_model.eval())
# script_dw_model = torch.jit.script(trt_dw_model.eval())
# save the TensorRT embedded Torchscript
# torch.jit.save(trt_model, 'trt_torchscript_module.ts')
# torch.jit.save(trt_dw_model, 'trt_torchscript_dw_module.ts')
print(script_model)
print(script_model.forward)
print(script_model.forward())
print(dir(script_model))
model = torch._C._jit_to_backend("tensorrt", script_model, spec)
print(f'loaded model {epoch}') print(f'loaded model {epoch}')
return model return model
@ -74,14 +192,37 @@ def inference(left, right, model, n_iter=20):
) )
with torch.inference_mode(): with torch.inference_mode():
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None) # pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2) # pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
pred_flow_dw2 = model.forward(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
pred_flow = model.forward(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy() pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
return pred_disp return pred_disp
def get_reference():
refs = [ref.path for ref in os.scandir('/home/nils/references/')]
for ref in refs:
reference = cv2.imread(ref)
yield reference
references = get_reference()
@app.post('/params/update_reference')
async def update_reference():
global references, reference_pattern
try:
reference_pattern = downsize(next(references))
print(reference_pattern.shape)
return {'status': 'success'}
except StopIteration:
references = get_reference()
return {'status': 'finished'}
@app.post("/model/update/{epoch}") @app.post("/model/update/{epoch}")
async def change_model(epoch: Union[int, Literal['latest']]): async def change_model(epoch: Union[int, Literal['latest']]):
global model global model
@ -103,8 +244,15 @@ async def set_minimal_data(enable: bool):
minimal_data = enable minimal_data = enable
@app.post("/params/temporal_init/{enable}")
async def set_temporal_init(enable: bool):
global temporal_init
temporal_init = enable
@app.put("/ir") @app.put("/ir")
async def read_ir_input(file: UploadFile = File(...)): async def read_ir_input(file: UploadFile = File(...)):
global last_img, minimal_data
try: try:
img = np.array(Image.open(BytesIO(await file.read()))) img = np.array(Image.open(BytesIO(await file.read())))
except Exception as e: except Exception as e:
@ -114,24 +262,35 @@ async def read_ir_input(file: UploadFile = File(...)):
if len(img.shape) == 2: if len(img.shape) == 2:
img = cv2.merge([img for _ in range(3)]) img = cv2.merge([img for _ in range(3)])
if img.shape == (1024, 1280, 3): if img.shape == (1024, 1280, 3):
diff = (512 - 480) // 2 img = downsize(img)
downsampled = cv2.pyrDown(img)
img = downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]]
img = img.transpose((1, 2, 0)) # img = img.transpose((1, 2, 0))
ref_pat = reference_pattern.transpose((1, 2, 0)) # ref_pat = reference_pattern.transpose((1, 2, 0))
ref_pat = reference_pattern
start = datetime.now() start = datetime.now()
pred_disp = inference(img, ref_pat, model, iters) if temporal_init:
pred_disp = ctd_inference(img, ref_pat, None, None, model, None, iters, False, last_img)
last_img = pred_disp
else:
pred_disp = ctd_inference(img, ref_pat, None, None, model, None, iters, False)
# pred_disp = inference(img, ref_pat, model, iters)
duration = (datetime.now() - start).total_seconds() duration = (datetime.now() - start).total_seconds()
if minimal_data: if minimal_data:
return json.dumps({'disp': pred_disp, 'duration': duration}, cls=NumpyEncoder) return json.dumps({'disp': pred_disp, 'duration': duration}, cls=NumpyEncoder)
else: else:
# return json.dumps({'disp': pred_disp, 'input': img, 'duration': duration},
return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img, 'duration': duration}, return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img, 'duration': duration},
cls=NumpyEncoder) cls=NumpyEncoder)
@app.get('/temporal_init')
def get_temporal_init():
return {'status': 'enabled' if temporal_init else 'disabled'}
@app.get('/') @app.get('/')
def main(): def main():
return {'test': 'abc'} return {'test': 'abc'}

@ -4,18 +4,22 @@ base_lr: 4.0e-4
nr_gpus: 3 nr_gpus: 3
batch_size: 4 batch_size: 4
n_total_epoch: 600 n_total_epoch: 300
minibatch_per_epoch: 500 minibatch_per_epoch: 500
loadmodel: ~ loadmodel: ~
log_dir: "./train_log" log_dir: "./train_log"
log_dir_lightning: "./train_log_lightning"
model_save_freq_epoch: 1 model_save_freq_epoch: 1
max_disp: 256 max_disp: 256
image_width: 640 image_width: 640
image_height: 480 image_height: 480
# training_data_path: "./stereo_trainset/crestereo" # training_data_path: "./stereo_trainset/crestereo"
training_data_path: "/media/Data1/connecting_the_dots_data/ctd_data/" pattern_attention: true
dataset: "blender"
# training_data_path: "/media/Data1/connecting_the_dots_data/ctd_data/"
training_data_path: "/media/Data1/connecting_the_dots_data/blender_renders/data"
log_level: "logging.INFO" log_level: "logging.INFO"

@ -17,7 +17,7 @@ class Augmentor:
scale_min=0.6, scale_min=0.6,
scale_max=1.0, scale_max=1.0,
seed=0, seed=0,
): ):
super().__init__() super().__init__()
self.image_height = image_height self.image_height = image_height
self.image_width = image_width self.image_width = image_width
@ -234,12 +234,16 @@ class CREStereoDataset(Dataset):
class CTDDataset(Dataset): class CTDDataset(Dataset):
def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False): def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False, split=0.9, test_set=False):
super().__init__() super().__init__()
self.rng = np.random.RandomState(0) self.rng = np.random.RandomState(0)
self.augment = augment self.augment = augment
self.blur = blur self.blur = blur
self.imgs = glob.glob(os.path.join(root, f"{data_type}/*/im0_*.npy"), recursive=True) imgs = glob.glob(os.path.join(root, f"{data_type}/*/im0_*.npy"), recursive=True)
if test_set:
self.imgs = imgs[:int(split * len(imgs))]
else:
self.imgs = imgs[int(split * len(imgs)):]
self.pattern = cv2.imread(pattern_path)#, cv2.IMREAD_GRAYSCALE) self.pattern = cv2.imread(pattern_path)#, cv2.IMREAD_GRAYSCALE)
if resize_pattern and self.pattern.shape != (480, 640, 3): if resize_pattern and self.pattern.shape != (480, 640, 3):
@ -271,6 +275,10 @@ class CTDDataset(Dataset):
# read img, disp # read img, disp
left_img = np.load(left_path) left_img = np.load(left_path)
if left_img.dtype == 'float32':
left_img = (left_img * 255).astype('uint8')
left_img = cv2.merge([left_img, left_img, left_img]).reshape((480, 640, 3)) left_img = cv2.merge([left_img, left_img, left_img]).reshape((480, 640, 3))
right_img = self.pattern right_img = self.pattern
@ -307,3 +315,100 @@ class CTDDataset(Dataset):
def __len__(self): def __len__(self):
return len(self.imgs) return len(self.imgs)
class BlenderDataset(CTDDataset):
def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False, split=0.9, test_set=False, use_lightning=False):
super().__init__(root, pattern_path)
self.use_lightning = use_lightning
imgs = [f for f in glob.glob(f"{root}/im_*.png", recursive=True) if not 'depth0001' in f]
if test_set:
self.imgs = imgs[:int(split * len(imgs))]
else:
self.imgs = imgs[int(split * len(imgs)):]
self.pattern = cv2.imread(pattern_path)#, cv2.IMREAD_GRAYSCALE)
if resize_pattern and self.pattern.shape != (480, 640, 3):
downsampled = cv2.pyrDown(self.pattern)
diff = (downsampled.shape[0] - 480) // 2
self.pattern = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
self.augmentor = Augmentor(
image_height=480,
image_width=640,
max_disp=256,
scale_min=0.6,
scale_max=1.0,
seed=0,
)
def __getitem__(self, index):
# find path
left_path = self.imgs[index]
left_disp_path = left_path.split('.')[0] + '_depth0001.png'
# read img, disp
left_img = cv2.imread(left_path)
if left_img.dtype == 'float32':
left_img = (left_img * 255).astype('uint8')
if left_img.shape != (480, 640, 3):
downsampled = cv2.pyrDown(left_img)
diff = (downsampled.shape[0] - 480) // 2
left_img = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
if left_img.shape[-1] != 3:
left_img = cv2.merge([left_img, left_img, left_img]).reshape((480, 640, 3))
right_img = self.pattern
left_disp = self.get_disp(left_disp_path)
if False: # self.rng.binomial(1, 0.5):
left_img, right_img = np.fliplr(right_img), np.fliplr(left_img)
left_disp, right_disp = np.fliplr(right_disp), np.fliplr(left_disp)
left_disp[left_disp == np.inf] = 0
if self.blur:
kernel_size = random.sample([1,3,5,7,9], 1)[0]
kernel = (kernel_size, kernel_size)
left_img = cv2.GaussianBlur(left_img, kernel, 0)
# augmentation
if not self.augment:
_left_img, _right_img, _left_disp, disp_mask = self.augmentor(
left_img, right_img, left_disp
)
else:
left_img, right_img, left_disp, disp_mask = self.augmentor(
left_img, right_img, left_disp
)
if not self.use_lightning:
right_img = right_img.transpose((2, 0, 1)).astype("uint8")
return {
"left": left_img,
"right": right_img,
"disparity": left_disp,
"mask": disp_mask,
}
right_img = right_img.transpose((2, 0, 1)).astype("uint8")
left_img = left_img.transpose((2, 0, 1)).astype("uint8")
return left_img, right_img, left_disp, disp_mask
def get_disp(self, path):
baseline = 0.075 # meters
fl = 560. # as per CTD
depth = cv2.imread(path, cv2.IMREAD_UNCHANGED)
downsampled = cv2.pyrDown(depth)
diff = (downsampled.shape[0] - 480) // 2
depth = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
# disp = np.load(path).transpose(1,2,0)
# disp = baseline * fl / depth
# return disp.astype(np.float32) / 32
# FIXME temporarily increase disparity until new data with better depth values is generated
# higher values seem to speedup convergence, but introduce much stronger artifacting
# mystery_factor = 150
mystery_factor = 1
disp = (baseline * fl * mystery_factor) / depth
return disp.astype(np.float32)

@ -86,8 +86,15 @@ class LocalFeatureTransformer(nn.Module):
""" """
assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal" assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
for layer, name in zip(self.layers, self.layer_names): # NOTE Workaround for non statically determinable zip
# for layer, name in zip(self.layers, self.layer_names):
# layer_zip = ((layer, self.layer_names[i]) for i, layer in enumerate(self.layers))
# layer_zip = []
# for i, layer in enumerate(self.layers):
# layer_zip.append((layer, self.layer_names[i]))
for i, layer in enumerate(self.layers):
name = self.layer_names[i]
if name == 'self': if name == 'self':
feat0 = layer(feat0, feat0, mask0, mask0) feat0 = layer(feat0, feat0, mask0, mask0)
feat1 = layer(feat1, feat1, mask1, mask1) feat1 = layer(feat1, feat1, mask1, mask1)
@ -97,4 +104,4 @@ class LocalFeatureTransformer(nn.Module):
else: else:
raise KeyError raise KeyError
return feat0, feat1 return feat0, feat1

@ -36,6 +36,12 @@ class CREStereo(nn.Module):
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=self.dropout) self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=self.dropout)
self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4) self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4)
# # NOTE Position_encoding as workaround for TensorRt
# image1_shape = [1, 2, 480, 640]
# self.pos_encoding_fn_small = PositionEncodingSine(
# d_model=256, max_shape=(image1_shape[2] // 16, image1_shape[3] // 16)
# )
# loftr # loftr
self.self_att_fn = LocalFeatureTransformer( self.self_att_fn = LocalFeatureTransformer(
d_model=256, nhead=8, layer_names=["self"] * 1, attention="linear" d_model=256, nhead=8, layer_names=["self"] * 1, attention="linear"
@ -81,7 +87,7 @@ class CREStereo(nn.Module):
zero_flow = torch.cat((_x, _y), dim=1).to(fmap.device) zero_flow = torch.cat((_x, _y), dim=1).to(fmap.device)
return zero_flow return zero_flow
def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False): def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False, self_attend_right=True):
""" Estimate optical flow between pair of frames """ """ Estimate optical flow between pair of frames """
image1 = 2 * (image1 / 255.0) - 1.0 image1 = 2 * (image1 / 255.0) - 1.0
@ -130,17 +136,22 @@ class CREStereo(nn.Module):
inp_dw16 = F.avg_pool2d(inp, 4, stride=4) inp_dw16 = F.avg_pool2d(inp, 4, stride=4)
# positional encoding and self-attention # positional encoding and self-attention
pos_encoding_fn_small = PositionEncodingSine( # pos_encoding_fn_small = PositionEncodingSine(
d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16) # d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
) # )
# 'n c h w -> n (h w) c' # 'n c h w -> n (h w) c'
x_tmp = pos_encoding_fn_small(fmap1_dw16) x_tmp = self.pos_encoding_fn_small(fmap1_dw16)
fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1]) fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
# 'n c h w -> n (h w) c' # 'n c h w -> n (h w) c'
x_tmp = pos_encoding_fn_small(fmap2_dw16) x_tmp = self.pos_encoding_fn_small(fmap2_dw16)
fmap2_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1]) fmap2_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
fmap1_dw16, fmap2_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16) # FIXME experimental ! no self-attention for pattern
if not self_attend_right:
fmap1_dw16, _ = self.self_att_fn(fmap1_dw16, fmap2_dw16)
else:
fmap1_dw16, fmap2_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16)
fmap1_dw16, fmap2_dw16 = [ fmap1_dw16, fmap2_dw16 = [
x.reshape(x.shape[0], image1.shape[2] // 16, -1, x.shape[2]).permute(0, 3, 1, 2) x.reshape(x.shape[0], image1.shape[2] // 16, -1, x.shape[2]).permute(0, 3, 1, 2)
for x in [fmap1_dw16, fmap2_dw16] for x in [fmap1_dw16, fmap2_dw16]
@ -258,3 +269,4 @@ class CREStereo(nn.Module):
return flow_up return flow_up
return predictions return predictions

@ -1,6 +1,8 @@
from typing import List
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor
# Ref: https://github.com/princeton-vl/RAFT/blob/master/core/extractor.py # Ref: https://github.com/princeton-vl/RAFT/blob/master/core/extractor.py
class ResidualBlock(nn.Module): class ResidualBlock(nn.Module):
@ -96,28 +98,43 @@ class BasicEncoder(nn.Module):
self.in_planes = dim self.in_planes = dim
return nn.Sequential(*layers) return nn.Sequential(*layers)
def forward(self, x): def forward(self, x: List[Tensor]):
# NOTE always assume list, otherwise TensorRT is sad
# batch_dim = x[0].shape[0]
# x_tensor = torch.cat(list(x), dim=0)
# if input is list, combine batch dimension # if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list) is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list: if is_list:
batch_dim = x[0].shape[0] batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0) x_tensor = torch.cat(x, dim=0)
else:
x_tensor = x
print()
print()
print(x_tensor.shape)
print()
print()
x = self.conv1(x) x_tensor = self.conv1(x_tensor)
x = self.norm1(x) x_tensor = self.norm1(x_tensor)
x = self.relu1(x) x_tensor = self.relu1(x_tensor)
x = self.layer1(x) x_tensor = self.layer1(x_tensor)
x = self.layer2(x) x_tensor = self.layer2(x_tensor)
x = self.layer3(x) x_tensor = self.layer3(x_tensor)
x = self.conv2(x) x_tensor = self.conv2(x_tensor)
if self.dropout is not None: if self.dropout is not None:
x = self.dropout(x) x_tensor = self.dropout(x_tensor)
if is_list: if is_list:
x = torch.split(x, x.shape[0]//2, dim=0) x_list = torch.split(x_tensor, x_tensor.shape[0]//2, dim=0)
return x_list
x_list = torch.split(x_tensor, x_tensor.shape[0]//2, dim=0)
return x_list
return x # return list(x)

@ -77,7 +77,7 @@ class BasicUpdateBlock(nn.Module):
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(256, mask_size**2 *9, 1, padding=0)) nn.Conv2d(256, mask_size**2 *9, 1, padding=0))
def forward(self, net, inp, corr, flow, upsample=True): def forward(self, net, inp, corr, flow, upsample: bool=True):
# print(inp.shape, corr.shape, flow.shape) # print(inp.shape, corr.shape, flow.shape)
motion_features = self.encoder(flow, corr) motion_features = self.encoder(flow, corr)
# print(motion_features.shape, inp.shape) # print(motion_features.shape, inp.shape)

@ -3,6 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
import cv2 import cv2
import os
from nets import Model from nets import Model
@ -16,17 +17,20 @@ device = 'cuda'
wandb.init(project="crestereo", entity="cpt-captain") wandb.init(project="crestereo", entity="cpt-captain")
def do_infer(left_img, right_img, gt_disp, model): def do_infer(left_img, right_img, gt_disp, model, attend_pattern=True):
disp = ctd_inference(left_img, right_img, gt_disp, None, model, None, n_iter=20, wandb_log=False) disp = ctd_inference(left_img, right_img, gt_disp, None, model, None, n_iter=20, wandb_log=False, attend_pattern=attend_pattern)
disp_vis = normalize_and_colormap(disp) disp_vis = normalize_and_colormap(disp)
gt_disp_vis = normalize_and_colormap(gt_disp) # gt_disp_vis = normalize_and_colormap(gt_disp)
if gt_disp.shape != disp.shape: # if gt_disp.shape != disp.shape:
gt_disp = gt_disp.reshape(disp.shape) # gt_disp = gt_disp.reshape(disp.shape)
disp_err = gt_disp - disp # disp_err = gt_disp - disp
disp_err = normalize_and_colormap(disp_err.abs()) # disp_err = normalize_and_colormap(disp_err.abs())
if isinstance(left_img, torch.Tensor):
wandb.log({ left_img = left_img.cpu().detach().numpy().astype('uint8')
right_img = right_img.cpu().detach().numpy().astype('uint8')
results = {
'disp': wandb.Image( 'disp': wandb.Image(
disp, disp,
caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}", caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}",
@ -35,41 +39,60 @@ def do_infer(left_img, right_img, gt_disp, model):
disp_vis, disp_vis,
caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}", caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}",
), ),
'gt_disp_vis': wandb.Image( # 'disp_err': wandb.Image(
gt_disp_vis, # disp_err,
caption=f"GT Disparity \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}", # caption=f"Disparity Error\n{disp_err.min():.{2}f}/{disp_err.max():.{2}f}",
), # ),
'disp_err': wandb.Image(
disp_err,
caption=f"Disparity Error\n{disp_err.min():.{2}f}/{disp_err.max():.{2}f}",
),
'input_left': wandb.Image( 'input_left': wandb.Image(
left_img.cpu().detach().numpy().astype('uint8'), left_img,
caption=f"Input left", caption=f"Input left",
), ),
'input_right': wandb.Image( 'input_right': wandb.Image(
right_img.cpu().detach().numpy().astype('uint8'), right_img,
caption=f"Input right", caption=f"Input right",
), ),
}) }
if gt_disp is not None:
print('logging gt')
print(f'gt: {gt_disp.max()}/{gt_disp.min()}/{gt_disp.mean()}')
gt_disp_vis = normalize_and_colormap(gt_disp)
results.update({
'gt_disp_vis': wandb.Image(
gt_disp_vis,
caption=f"GT Disparity \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}",
)})
wandb.log(results)
def downsample(img, half_height_out=480):
downsampled = cv2.pyrDown(img)
diff = (downsampled.shape[0] - half_height_out) // 2
return downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]]
if __name__ == '__main__': if __name__ == '__main__':
# model_path = "models/crestereo_eth3d.pth" # model_path = "models/crestereo_eth3d.pth"
model_path = "train_log/models/latest.pth" model_path = "train_log/models/latest.pth"
# model_path = "train_log/models/epoch-120.pth"
# model_path = "train_log/models/epoch-250.pth"
print(model_path)
# reference_pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png' # reference_pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png'
reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
# reference_pattern_path = '/home/nils/new_reference.png' # reference_pattern_path = '/home/nils/new_reference.png'
# reference_pattern_path = '/home/nils/kinect_reference_high_res.png' # reference_pattern_path = '/home/nils/kinect_reference_high_res.png'
reference_pattern_path = '/home/nils/miniprojekt/kinect_high_res_thresh_denoised.png'
# reference_pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
# reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
# reference_pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png' # reference_pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png'
data_type = 'kinect' # data_type = 'kinect'
data_type = 'blender'
augment = False augment = False
args = parse_yaml("cfgs/train.yaml") args = parse_yaml("cfgs/train.yaml")
wandb.config.update({'model_path': model_path, 'reference_pattern': reference_pattern_path, 'augment': augment}) wandb.config.update({'model_path': model_path, 'reference_pattern': reference_pattern_path, 'augment': augment, 'data_type': data_type, 'pattern_self_attention': args.pattern_attention})
model = Model(max_disp=256, mixed_precision=False, test_mode=True) model = Model(max_disp=256, mixed_precision=False, test_mode=True)
model = nn.DataParallel(model, device_ids=[device]) model = nn.DataParallel(model, device_ids=[device])
@ -78,16 +101,32 @@ if __name__ == '__main__':
model.to(device) model.to(device)
model.eval() model.eval()
dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type, # dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type,
pattern_path=reference_pattern_path, augment=augment) # pattern_path=reference_pattern_path, augment=augment)
dataloader = DataLoader(dataset, args.batch_size, shuffle=True, # dataloader = DataLoader(dataset, args.batch_size, shuffle=True,
num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True) # num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True)
for batch in dataloader:
for left, right, disparity in zip(batch['left'], batch['right'], batch['disparity']): # for batch in dataloader:
right = right.transpose(0, 2).transpose(0, 1) gt_disp = None
left_img = left right = downsample(cv2.imread(reference_pattern_path))
imgL = left.cpu().detach().numpy()
right_img = right if data_type == 'blender':
imgR = right.cpu().detach().numpy() img_path = '/media/Data1/connecting_the_dots_data/blender_renders/data/'
gt_disp = disparity elif data_type == 'kinect':
do_infer(left_img, right_img, gt_disp, model) img_path = '/home/nils/kinect_pngs/ir/'
for img in sorted(list(entry for entry in os.scandir(img_path) if 'depth' not in entry.name), key=lambda x:x.name)[:25]:
print(img.path)
if data_type == 'blender':
baseline = 0.075 # meters
fl = 560. # as per CTD
gt_path = img.path.rsplit('.')[0] + '_depth0001.png'
gt_depth = downsample(cv2.imread(gt_path))
mystery_factor = 35 # we don't get reasonable disparities due to incorrect depth scaling (or something like that)
gt_disp = (baseline * fl * mystery_factor) / gt_depth
left = downsample(cv2.imread(img.path))
do_infer(left, right, gt_disp, model, attend_pattern=args.pattern_attention)

@ -9,7 +9,7 @@ import yaml
from nets import Model from nets import Model
# from dataset import CREStereoDataset # from dataset import CREStereoDataset
from dataset import CREStereoDataset, CTDDataset from dataset import BlenderDataset, CREStereoDataset, CTDDataset
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -32,14 +32,18 @@ def normalize_and_colormap(img):
return ret return ret
def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=True): def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=True, last_img=None, test=False, attend_pattern=True):
print("Model Forwarding...") print("Model Forwarding...")
left = left.cpu().detach().numpy() if isinstance(left, torch.Tensor):
left = left.cpu().detach().numpy()
imgR = right.cpu().detach().numpy()
imgL = left imgL = left
imgR = right.cpu().detach().numpy() imgR = right
imgL = np.ascontiguousarray(imgL[None, :, :, :]) imgL = np.ascontiguousarray(imgL[None, :, :, :])
imgR = np.ascontiguousarray(imgR[None, :, :, :]) imgR = np.ascontiguousarray(imgR[None, :, :, :])
flow_init = None
# chosen for convenience # chosen for convenience
device = torch.device('cuda:0') device = torch.device('cuda:0')
@ -55,19 +59,54 @@ def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=Tru
size=(imgL.shape[2] // 2, imgL.shape[3] // 2), size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear", mode="bilinear",
align_corners=True, align_corners=True,
) ).clamp(min=0, max=255)
imgR_dw2 = F.interpolate( imgR_dw2 = F.interpolate(
imgR, imgR,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2), size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear", mode="bilinear",
align_corners=True, align_corners=True,
) ).clamp(min=0, max=255)
if last_img is not None:
print('using flow_initialization')
print(last_img.shape)
# FIXME trying 8bit normalization because sometimes we get really high values that don't seem to help
print(last_img.max(), last_img.min())
if last_img.min() < 0:
# print('Negative disparity detected. shifting...')
last_img = last_img - last_img.min()
if last_img.max() > 255:
# print('Excessive disparity detected. scaling...')
last_img = last_img / (last_img.max() / 255)
last_img = np.dstack([last_img, last_img])
# last_img = np.dstack([last_img, last_img, last_img])
last_img = np.dstack([last_img])
last_img = last_img.reshape((1, 2, 480, 640))
# print(last_img.shape)
# print(last_img.dtype)
# print(last_img.max(), last_img.min())
flow_init = torch.tensor(last_img.astype("float32")).to(device)
# flow_init = F.interpolate(
# last_img,
# size=(last_img.shape[0] // 2, last_img.shape[1] // 2),
# mode="bilinear",
# align_corners=True,
# )
with torch.inference_mode(): with torch.inference_mode():
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None) pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=flow_init, self_attend_right=attend_pattern)
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2) pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2, self_attend_right=attend_pattern)
pf_base = pred_flow
if isinstance(pf_base, list):
pf_base = pred_flow[0]
pf = torch.squeeze(pf_base[:, 0, :, :]).cpu().detach().numpy()
print('pred_flow max min')
print(pf.max(), pf.min())
if not wandb_log: if not wandb_log:
if test:
return pred_flow
return torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy() return torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
log = {} log = {}
@ -96,30 +135,36 @@ def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=Tru
np.array([pred_disp.reshape(480, 640)]), np.array([pred_disp.reshape(480, 640)]),
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
) )
log[f'pred_norm_{i}'] = wandb.Image( # log[f'pred_norm_{i}'] = wandb.Image(
np.array([pred_disp_norm.reshape(480, 640)]), # np.array([pred_disp_norm.reshape(480, 640)]),
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}", # caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
) # )
log[f'pred_dw2_{i}'] = wandb.Image( # log[f'pred_dw2_{i}'] = wandb.Image(
np.array([pred_disp_dw2.reshape(240, 320)]), # np.array([pred_disp_dw2.reshape(240, 320)]),
caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", # caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
) # )
log[f'pred_dw2_norm_{i}'] = wandb.Image( # log[f'pred_dw2_norm_{i}'] = wandb.Image(
np.array([pred_disp_dw2_norm.reshape(240, 320)]), # np.array([pred_disp_dw2_norm.reshape(240, 320)]),
caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}", # caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
) # )
log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left") log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left")
log['input_right'] = wandb.Image(right.cpu().detach().numpy().transpose(1,2,0).astype('uint8'), caption="Input Right") input_right = right.cpu().detach().numpy() if isinstance(right, torch.Tensor) else right
if input_right.shape != (480, 640, 3):
input_right.transpose(1,2,0)
log['input_right'] = wandb.Image(input_right.astype('uint8'), caption="Input Right")
log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}") log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}")
gt_disp = gt_disp.cpu().detach().numpy() if isinstance(gt_disp, torch.Tensor) else gt_disp
disp = disp.cpu().detach().numpy() if isinstance(disp, torch.Tensor) else disp
disp_error = gt_disp - disp disp_error = gt_disp - disp
log['disp_error'] = wandb.Image( log['disp_error'] = wandb.Image(
normalize_and_colormap(disp_error.abs()), normalize_and_colormap(abs(disp_error)),
caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{disp_error.abs().mean():.{2}f}", caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{abs(disp_error).mean():.{2}f}",
) )
@ -129,6 +174,7 @@ def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=Tru
) )
wandb.log(log) wandb.log(log)
return pred_flow
def parse_yaml(file_path: str) -> namedtuple: def parse_yaml(file_path: str) -> namedtuple:
@ -172,12 +218,25 @@ def adjust_learning_rate(optimizer, epoch):
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
param_group['lr'] = lr param_group['lr'] = lr
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8): def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
''' '''
valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W) valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W)
flow_preds[0]: (B, 2, H, W) flow_preds[0]: (B, 2, H, W)
flow_gt: (B, 2, H, W) flow_gt: (B, 2, H, W)
''' '''
if test:
# print('sequence loss')
if valid.shape != (2, 480, 640):
valid = torch.stack([valid, valid])#.transpose(0,1)#.transpose(1,2)
# print(valid.shape)
#valid = torch.stack([valid, valid])
# print(valid.shape)
if valid.shape != (2, 480, 640):
valid = valid.transpose(0,1)
# print(valid.shape)
# print(valid.shape)
# print(flow_preds[0].shape)
# print(flow_gt.shape)
n_predictions = len(flow_preds) n_predictions = len(flow_preds)
flow_loss = 0.0 flow_loss = 0.0
@ -260,20 +319,41 @@ def main(args):
start_iters = 0 start_iters = 0
# pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png' # pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png'
pattern_path = '/home/nils/kinect_reference_cropped.png' # pattern_path = '/home/nils/kinect_reference_cropped.png'
# pattern_path = '/home/nils/kinect_reference_far.png'
# pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
pattern_path = '/home/nils/miniprojekt/kinect_high_res_thresh_denoised.png'
# pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png' # pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png'
# datasets # datasets
# dataset = CREStereoDataset(args.training_data_path) # dataset = CREStereoDataset(args.training_data_path)
dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path) # dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path)
# test_dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path, test_set=True)
if args.dataset == 'blender':
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path)
test_dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path, test_set=True)
elif args.dataset == 'ctd':
dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path)
test_dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path, test_set=True)
else:
print('unrecognized dataset')
quit()
test_data_iter = iter(test_dataset)
# if rank == 0: # if rank == 0:
worklog.info(f"Dataset size: {len(dataset)}") worklog.info(f"Dataset size: {len(dataset)}")
print(args.batch_size)
dataloader = DataLoader(dataset, args.batch_size, shuffle=True, dataloader = DataLoader(dataset, args.batch_size, shuffle=True,
num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True) num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True)
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
test_dataloader = DataLoader(test_dataset, args.batch_size, shuffle=False,
num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True)
# counter # counter
cur_iters = start_iters cur_iters = start_iters
total_iters = args.minibatch_per_epoch * args.n_total_epoch total_iters = args.minibatch_per_epoch * args.n_total_epoch
t0 = time.perf_counter() t0 = time.perf_counter()
test_idx = 0
for epoch_idx in range(start_epoch_idx, args.n_total_epoch + 1): for epoch_idx in range(start_epoch_idx, args.n_total_epoch + 1):
# adjust learning rate # adjust learning rate
@ -310,24 +390,59 @@ def main(args):
# forward # forward
# left = left.transpose(1, 2).transpose(2, 3) # left = left.transpose(1, 2).transpose(2, 3)
left = left.transpose(1, 3).transpose(2, 3) left = left.transpose(1, 3).transpose(2, 3)
right = right.transpose(2, 3).transpose(2, 3)#.transpose(1, 2) # right = right.transpose(2, 3).transpose(2, 3)#.transpose(1, 2)
flow_predictions = model(left.cuda(), right.cuda()) flow_predictions = model(left.cuda(), right.cuda(), self_attend_right=args.pattern_attention)
# loss & backword # loss & backword
loss = sequence_loss( loss = sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8 flow_predictions, gt_flow, valid_mask, gamma=0.8
) )
if batch_idx % 128 == 0: if batch_idx % 512 == 0:
inference( test_idx = 0
mini_batch_data['left'][0], test_loss = 0
mini_batch_data['right'][0], for i, test_batch in enumerate(test_dataset):
mini_batch_data['disparity'][0], # test_batch = next(test_data_iter)
mini_batch_data['mask'][0], if i >= 24:
model, break
batch_idx,
) # TODO refactor, DRY
left, right, gt_disp, valid_mask = (
test_batch['left'],
test_batch['right'],
torch.Tensor(test_batch['disparity']).cuda(),
torch.Tensor(test_batch['mask']).cuda(),
)
gt_disp = torch.dstack([gt_disp, gt_disp]).transpose(2,0).transpose(1,2)
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
# print(f'left {left.shape}, right {right.shape}')
# left = left.transpose([2, 0, 1])
right = right.transpose([1, 2, 0])
# right = right.transpose(2, 3).transpose(2, 3)#.transpose(1, 2)
# print(f'left {left.shape}, right {right.shape}')
model.eval()
flow_predictions = inference(
left,
right,
# gt_disp,
torch.Tensor(test_batch['disparity']).cuda(),
valid_mask,
model,
test_idx,
wandb_log=i % 4 == 0,
test=True,
)
test_idx += 1
test_loss += sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8, test=True
).data.item()
model.train()
avg_test_loss = test_loss / test_idx
print(f'test_loss: {test_loss}\nlen test: {test_idx}\navg. loss: {avg_test_loss}')
metrics['test/loss'] = avg_test_loss
# loss stats # loss stats
loss_item = loss.data.item() loss_item = loss.data.item()
epoch_total_train_loss += loss_item epoch_total_train_loss += loss_item

@ -0,0 +1,346 @@
import os
import sys
import time
import logging
from collections import namedtuple
import yaml
# from tensorboardX import SummaryWriter
from nets import Model
# from dataset import CREStereoDataset
from dataset import BlenderDataset, CREStereoDataset, CTDDataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from pytorch_lightning.lite import LightningLite
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import WandbLogger
seed_everything(42, workers=True)
import wandb
import numpy as np
import cv2
def normalize_and_colormap(img):
ret = (img - img.min()) / (img.max() - img.min()) * 255.0
if isinstance(ret, torch.Tensor):
ret = ret.cpu().detach().numpy()
ret = ret.astype("uint8")
ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO)
return ret
def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=True, last_img=None, test=False, attend_pattern=True):
print("Model Forwarding...")
if isinstance(left, torch.Tensor):
left = left# .cpu().detach().numpy()
imgR = right# .cpu().detach().numpy()
imgL = left
imgR = right
imgL = np.ascontiguousarray(imgL[None, :, :, :])
imgR = np.ascontiguousarray(imgR[None, :, :, :])
flow_init = None
# chosen for convenience
imgL = torch.tensor(imgL.astype("float32"))
imgR = torch.tensor(imgR.astype("float32"))
imgL = imgL.transpose(2,3).transpose(1,2)
if imgL.shape != imgR.shape:
imgR = imgR.transpose(2,3).transpose(1,2)
imgL_dw2 = F.interpolate(
imgL,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
).clamp(min=0, max=255)
imgR_dw2 = F.interpolate(
imgR,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
).clamp(min=0, max=255)
if last_img is not None:
print('using flow_initialization')
print(last_img.shape)
# FIXME trying 8bit normalization because sometimes we get really high values that don't seem to help
print(last_img.max(), last_img.min())
if last_img.min() < 0:
# print('Negative disparity detected. shifting...')
last_img = last_img - last_img.min()
if last_img.max() > 255:
# print('Excessive disparity detected. scaling...')
last_img = last_img / (last_img.max() / 255)
last_img = np.dstack([last_img, last_img])
# last_img = np.dstack([last_img, last_img, last_img])
last_img = np.dstack([last_img])
last_img = last_img.reshape((1, 2, 480, 640))
# print(last_img.shape)
# print(last_img.dtype)
# print(last_img.max(), last_img.min())
flow_init = torch.tensor(last_img.astype("float32"))
# flow_init = F.interpolate(
# last_img,
# size=(last_img.shape[0] // 2, last_img.shape[1] // 2),
# mode="bilinear",
# align_corners=True,
# )
with torch.inference_mode():
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=flow_init, self_attend_right=attend_pattern)
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2, self_attend_right=attend_pattern)
pf_base = pred_flow
if isinstance(pf_base, list):
pf_base = pred_flow[0]
pf = torch.squeeze(pf_base[:, 0, :, :])# .cpu().detach().numpy()
print('pred_flow max min')
print(pf.max(), pf.min())
if not wandb_log:
if test:
return pred_flow
return torch.squeeze(pred_flow[:, 0, :, :])# .cpu().detach().numpy()
log = {}
in_h, in_w = left.shape[:2]
# Resize image in case the GPU memory overflows
eval_h, eval_w = (in_h,in_w)
for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)):
pred_disp = torch.squeeze(pf[:, 0, :, :])# .cpu().detach().numpy()
pred_disp_dw2 = torch.squeeze(pf_dw2[:, 0, :, :])# .cpu().detach().numpy()
# pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
# pred_disp_dw2_norm = cv2.normalize(pred_disp_dw2, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
if i == n_iter-1:
t = float(in_w) / float(eval_w)
disp = cv2.resize(pred_disp.cpu().detach().numpy(), (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t
log[f'disp_vis'] = wandb.Image(
normalize_and_colormap(disp),
caption=f"Disparity \n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
)
log[f'pred_{i}'] = wandb.Image(
np.array([pred_disp.cpu().detach().numpy().reshape(480, 640)]),
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
)
# log[f'pred_norm_{i}'] = wandb.Image(
# np.array([pred_disp_norm.reshape(480, 640)]),
# caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
# )
# log[f'pred_dw2_{i}'] = wandb.Image(
# np.array([pred_disp_dw2.reshape(240, 320)]),
# caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
# )
# log[f'pred_dw2_norm_{i}'] = wandb.Image(
# np.array([pred_disp_dw2_norm.reshape(240, 320)]),
# caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
# )
log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left")
input_right = right# .cpu().detach().numpy() if isinstance(right, torch.Tensor) else right
if input_right.shape != (480, 640, 3):
input_right.transpose(1,2,0)
log['input_right'] = wandb.Image(input_right.astype('uint8'), caption="Input Right")
log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}")
gt_disp = gt_disp# .cpu().detach().numpy() if isinstance(gt_disp, torch.Tensor) else gt_disp
disp = disp# .cpu().detach().numpy() if isinstance(disp, torch.Tensor) else disp
disp_error = gt_disp - disp
log['disp_error'] = wandb.Image(
normalize_and_colormap(abs(disp_error)),
caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{abs(disp_error).mean():.{2}f}",
)
log[f'gt_disp_vis'] = wandb.Image(
normalize_and_colormap(gt_disp),
caption=f"GT Disp Vis \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}",
)
wandb.log(log)
return pred_flow
def parse_yaml(file_path: str) -> namedtuple:
"""Parse yaml configuration file and return the object in `namedtuple`."""
with open(file_path, "rb") as f:
cfg: dict = yaml.safe_load(f)
args = namedtuple("train_args", cfg.keys())(*cfg.values())
return args
def format_time(elapse):
elapse = int(elapse)
hour = elapse // 3600
minute = elapse % 3600 // 60
seconds = elapse % 60
return "{:02d}:{:02d}:{:02d}".format(hour, minute, seconds)
def ensure_dir(path):
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
def adjust_learning_rate(optimizer, epoch):
warm_up = 0.02
const_range = 0.6
min_lr_rate = 0.05
if epoch <= args.n_total_epoch * warm_up:
lr = (1 - min_lr_rate) * args.base_lr / (
args.n_total_epoch * warm_up
) * epoch + min_lr_rate * args.base_lr
elif args.n_total_epoch * warm_up < epoch <= args.n_total_epoch * const_range:
lr = args.base_lr
else:
lr = (min_lr_rate - 1) * args.base_lr / (
(1 - const_range) * args.n_total_epoch
) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * args.base_lr
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
'''
valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W)
flow_preds[0]: (B, 2, H, W)
flow_gt: (B, 2, H, W)
'''
if test:
# print('sequence loss')
if valid.shape != (2, 480, 640):
valid = torch.stack([valid, valid])#.transpose(0,1)#.transpose(1,2)
# print(valid.shape)
#valid = torch.stack([valid, valid])
# print(valid.shape)
if valid.shape != (2, 480, 640):
valid = valid.transpose(0,1)
# print(valid.shape)
# print(valid.shape)
# print(flow_preds[0].shape)
# print(flow_gt.shape)
n_predictions = len(flow_preds)
flow_loss = 0.0
# TEST
flow_gt = torch.squeeze(flow_gt, dim=-1)
for i in range(n_predictions):
i_weight = gamma ** (n_predictions - i - 1)
i_loss = torch.abs(flow_preds[i] - flow_gt)
flow_loss += i_weight * (valid.unsqueeze(1) * i_loss).mean()
return flow_loss
class CREStereoLightning(LightningModule):
def __init__(self, args):
super().__init__()
self.batch_size = args.batch_size
self.model = Model(
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False
)
def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False, self_attend_right=True):
return self.model(image1, image2, flow_init, iters, upsample, test_mode, self_attend_right)
def training_step(self, batch, batch_idx):
# loss = self(batch)
left, right, gt_disp, valid_mask = batch
left = torch.Tensor(left).to(self.device)
right = torch.Tensor(right).to(self.device)
left = left
right = right
flow_predictions = self.forward(left, right)
loss = sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8
)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
left, right, gt_disp, valid_mask = batch
left = torch.Tensor(left).to(self.device)
right = torch.Tensor(right).to(self.device)
print(left.shape)
print(right.shape)
flow_predictions = self.forward(left, right)
val_loss = sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8
)
self.log("val_loss", val_loss)
def test_step(self, batch, batch_idx):
left, right, gt_disp, valid_mask = batch
# left, right, gt_disp, valid_mask = (
# batch["left"],
# batch["right"],
# batch["disparity"],
# batch["mask"],
# )
left = torch.Tensor(left).to(self.device)
right = torch.Tensor(right).to(self.device)
flow_predictions = self.forward(left, right)
test_loss = sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8
)
self.log("test_loss", test_loss)
def configure_optimizers(self):
return optim.Adam(self.model.parameters(), lr=0.01, betas=(0.9, 0.999))
if __name__ == "__main__":
# train configuration
args = parse_yaml("cfgs/train.yaml")
# wandb.init(project="crestereo-lightning", entity="cpt-captain")
# Lite(strategy='dp', accelerator='gpu', devices=2).run(args)
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
model = CREStereoLightning(args)
dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path, use_lightning=True)
test_dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path, test_set=True, use_lightning=True)
print(len(dataset))
print(len(test_dataset))
wandb_logger = WandbLogger(project="crestereo-lightning")
wandb.config.update(args._asdict())
trainer = Trainer(
max_epochs=args.n_total_epoch,
accelerator='gpu',
devices=2,
# auto_scale_batch_size='binsearch',
# strategy='ddp',
deterministic=True,
check_val_every_n_epoch=1,
limit_val_batches=24,
limit_test_batches=24,
logger=wandb_logger,
default_root_dir=args.log_dir_lightning,
)
# trainer.tune(model)
trainer.fit(model, dataset, test_dataset)
Loading…
Cancel
Save