change a bunch of stuff, add wip lightning implementation

This commit is contained in:
Cpt.Captain 2022-08-24 16:25:12 +02:00
parent 11959eef61
commit 63da24f429
10 changed files with 916 additions and 112 deletions

View File

@ -1,3 +1,4 @@
import os
import json
from datetime import datetime
from typing import Union, Literal
@ -7,32 +8,149 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_tensorrt
from cv2 import cv2
from fastapi import FastAPI, File, UploadFile
from PIL import Image
from nets import Model
from train import inference as ctd_inference
app = FastAPI()
reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
# reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
reference_pattern_path = '/home/nils/kinect_reference_far.png'
# reference_pattern_path = '/home/nils/kinect_diff_ref.png'
print(reference_pattern_path)
reference_pattern = cv2.imread(reference_pattern_path)
# shift reference pattern a few pixels to the left to simulate further backdrop
trans_mat = np.float32([[1, 0, 0], [0, 1, 0]])
reference_pattern = cv2.warpAffine(
reference_pattern, trans_mat, reference_pattern.shape[1::-1], flags=cv2.INTER_LINEAR
)
iters = 20
minimal_data = False
minimal_data = True
temporal_init = False
last_img = None
device = torch.device('cuda:0')
def load_model(epoch):
def downsize(img):
diff = (512 - 480) // 2
downsampled = cv2.pyrDown(img)
img = downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]]
return img
if 1024 in reference_pattern.shape:
reference_pattern = downsize(reference_pattern)
def ghetto_lcn(img):
# gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
gray = img
float_gray = gray.astype(np.float32) / 255.0
blur = cv2.GaussianBlur(float_gray, (0, 0), sigmaX=2, sigmaY=2)
num = float_gray - blur
blur = cv2.GaussianBlur(num * num, (0, 0), sigmaX=20, sigmaY=20)
den = cv2.pow(blur, 0.5)
gray = num / den
# cv2.normalize(gray, dst=gray, alpha=0.0, beta=1.0, norm_type=cv2.NORM_MINMAX)
cv2.normalize(gray, dst=gray, alpha=0.0, beta=255.0, norm_type=cv2.NORM_MINMAX)
return gray
# reference_pattern = ghetto_lcn(reference_pattern)
def load_model(epoch, use_tensorrt=False):
global model
epoch = f'epoch-{epoch}.pth' if epoch != 'latest' else 'latest.pth'
model_path = f"train_log/models/{epoch}"
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
# FIXME WIP Workaround Dataparallel TensorRT incompatibility
model = nn.DataParallel(model, device_ids=[device])
# model.load_state_dict(torch.load(model_path), strict=False)
state_dict = torch.load(model_path)['state_dict']
model.load_state_dict(state_dict, strict=True)
model.to(device)
model.eval()
if use_tensorrt:
np_model = Model(max_disp=256, mixed_precision=False, test_mode=True)
np_model.load_state_dict(model.module.state_dict(), strict=True)
np_model.to(device)
np_model.eval()
spec_dict = {
"inputs": [
torch_tensorrt.Input(
min_shape=[1, 2, 240, 320],
max_shape=[1, 2, 480, 640],
opt_shape=[1, 2, 480, 640],
dtype=torch.int32,
),
torch_tensorrt.Input(
min_shape=[1, 2, 240, 320],
max_shape=[1, 2, 480, 640],
opt_shape=[1, 2, 480, 640],
dtype=torch.int32,
),
],
"enabled_precisions": {torch.float, torch.half},
"refit": False,
"debug": False,
"device": {
"device_type": torch_tensorrt.DeviceType.GPU,
"gpu_id": 0,
"dla_core": 0,
"allow_gpu_fallback": True
},
"capability": torch_tensorrt.EngineCapability.default,
"num_min_timing_iters": 2,
"num_avg_timing_iters": 1,
}
spec = {
"forward":
torch_tensorrt.ts.TensorRTCompileSpec(**spec_dict)
}
# trt_model = torch_tensorrt.compile(np_model ,
# inputs=torch_tensorrt.Input(
# min_shape=[1, 2, 240, 320],
# max_shape=[1, 2, 480, 640],
# opt_shape=[1, 2, 480, 640],
# dtype=torch.int32,
# inputs = [torch_tensorrt.Input((1, 2, 480, 640)), torch_tensorrt.Input((1, 2, 480, 640))], # input shape
# enabled_precisions = {torch_tensorrt.dtype.half} # Run with FP16
# )
# trt_dw_model = torch_tensorrt.compile(np_model ,
# inputs = [torch_tensorrt.Input((1, 2, 240, 320)), torch_tensorrt.Input((1, 2, 240, 320))], # input shape
# enabled_precisions = {torch_tensorrt.dtype.half} # Run with FP16
# )
script_model = torch.jit.script(np_model.eval())
# script_dw_model = torch.jit.script(trt_dw_model.eval())
# save the TensorRT embedded Torchscript
# torch.jit.save(trt_model, 'trt_torchscript_module.ts')
# torch.jit.save(trt_dw_model, 'trt_torchscript_dw_module.ts')
print(script_model)
print(script_model.forward)
print(script_model.forward())
print(dir(script_model))
model = torch._C._jit_to_backend("tensorrt", script_model, spec)
print(f'loaded model {epoch}')
return model
@ -74,14 +192,37 @@ def inference(left, right, model, n_iter=20):
)
with torch.inference_mode():
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
# pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
# pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
pred_flow_dw2 = model.forward(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
pred_flow = model.forward(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
pred_disp = torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
return pred_disp
def get_reference():
refs = [ref.path for ref in os.scandir('/home/nils/references/')]
for ref in refs:
reference = cv2.imread(ref)
yield reference
references = get_reference()
@app.post('/params/update_reference')
async def update_reference():
global references, reference_pattern
try:
reference_pattern = downsize(next(references))
print(reference_pattern.shape)
return {'status': 'success'}
except StopIteration:
references = get_reference()
return {'status': 'finished'}
@app.post("/model/update/{epoch}")
async def change_model(epoch: Union[int, Literal['latest']]):
global model
@ -103,8 +244,15 @@ async def set_minimal_data(enable: bool):
minimal_data = enable
@app.post("/params/temporal_init/{enable}")
async def set_temporal_init(enable: bool):
global temporal_init
temporal_init = enable
@app.put("/ir")
async def read_ir_input(file: UploadFile = File(...)):
global last_img, minimal_data
try:
img = np.array(Image.open(BytesIO(await file.read())))
except Exception as e:
@ -114,24 +262,35 @@ async def read_ir_input(file: UploadFile = File(...)):
if len(img.shape) == 2:
img = cv2.merge([img for _ in range(3)])
if img.shape == (1024, 1280, 3):
diff = (512 - 480) // 2
downsampled = cv2.pyrDown(img)
img = downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]]
img = downsize(img)
img = img.transpose((1, 2, 0))
ref_pat = reference_pattern.transpose((1, 2, 0))
# img = img.transpose((1, 2, 0))
# ref_pat = reference_pattern.transpose((1, 2, 0))
ref_pat = reference_pattern
start = datetime.now()
pred_disp = inference(img, ref_pat, model, iters)
if temporal_init:
pred_disp = ctd_inference(img, ref_pat, None, None, model, None, iters, False, last_img)
last_img = pred_disp
else:
pred_disp = ctd_inference(img, ref_pat, None, None, model, None, iters, False)
# pred_disp = inference(img, ref_pat, model, iters)
duration = (datetime.now() - start).total_seconds()
if minimal_data:
return json.dumps({'disp': pred_disp, 'duration': duration}, cls=NumpyEncoder)
else:
# return json.dumps({'disp': pred_disp, 'input': img, 'duration': duration},
return json.dumps({'disp': pred_disp, 'reference': ref_pat, 'input': img, 'duration': duration},
cls=NumpyEncoder)
@app.get('/temporal_init')
def get_temporal_init():
return {'status': 'enabled' if temporal_init else 'disabled'}
@app.get('/')
def main():
return {'test': 'abc'}

View File

@ -4,18 +4,22 @@ base_lr: 4.0e-4
nr_gpus: 3
batch_size: 4
n_total_epoch: 600
n_total_epoch: 300
minibatch_per_epoch: 500
loadmodel: ~
log_dir: "./train_log"
log_dir_lightning: "./train_log_lightning"
model_save_freq_epoch: 1
max_disp: 256
image_width: 640
image_height: 480
# training_data_path: "./stereo_trainset/crestereo"
training_data_path: "/media/Data1/connecting_the_dots_data/ctd_data/"
pattern_attention: true
dataset: "blender"
# training_data_path: "/media/Data1/connecting_the_dots_data/ctd_data/"
training_data_path: "/media/Data1/connecting_the_dots_data/blender_renders/data"
log_level: "logging.INFO"

View File

@ -17,7 +17,7 @@ class Augmentor:
scale_min=0.6,
scale_max=1.0,
seed=0,
):
):
super().__init__()
self.image_height = image_height
self.image_width = image_width
@ -234,12 +234,16 @@ class CREStereoDataset(Dataset):
class CTDDataset(Dataset):
def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False):
def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False, split=0.9, test_set=False):
super().__init__()
self.rng = np.random.RandomState(0)
self.augment = augment
self.blur = blur
self.imgs = glob.glob(os.path.join(root, f"{data_type}/*/im0_*.npy"), recursive=True)
imgs = glob.glob(os.path.join(root, f"{data_type}/*/im0_*.npy"), recursive=True)
if test_set:
self.imgs = imgs[:int(split * len(imgs))]
else:
self.imgs = imgs[int(split * len(imgs)):]
self.pattern = cv2.imread(pattern_path)#, cv2.IMREAD_GRAYSCALE)
if resize_pattern and self.pattern.shape != (480, 640, 3):
@ -271,6 +275,10 @@ class CTDDataset(Dataset):
# read img, disp
left_img = np.load(left_path)
if left_img.dtype == 'float32':
left_img = (left_img * 255).astype('uint8')
left_img = cv2.merge([left_img, left_img, left_img]).reshape((480, 640, 3))
right_img = self.pattern
@ -307,3 +315,100 @@ class CTDDataset(Dataset):
def __len__(self):
return len(self.imgs)
class BlenderDataset(CTDDataset):
def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False, split=0.9, test_set=False, use_lightning=False):
super().__init__(root, pattern_path)
self.use_lightning = use_lightning
imgs = [f for f in glob.glob(f"{root}/im_*.png", recursive=True) if not 'depth0001' in f]
if test_set:
self.imgs = imgs[:int(split * len(imgs))]
else:
self.imgs = imgs[int(split * len(imgs)):]
self.pattern = cv2.imread(pattern_path)#, cv2.IMREAD_GRAYSCALE)
if resize_pattern and self.pattern.shape != (480, 640, 3):
downsampled = cv2.pyrDown(self.pattern)
diff = (downsampled.shape[0] - 480) // 2
self.pattern = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
self.augmentor = Augmentor(
image_height=480,
image_width=640,
max_disp=256,
scale_min=0.6,
scale_max=1.0,
seed=0,
)
def __getitem__(self, index):
# find path
left_path = self.imgs[index]
left_disp_path = left_path.split('.')[0] + '_depth0001.png'
# read img, disp
left_img = cv2.imread(left_path)
if left_img.dtype == 'float32':
left_img = (left_img * 255).astype('uint8')
if left_img.shape != (480, 640, 3):
downsampled = cv2.pyrDown(left_img)
diff = (downsampled.shape[0] - 480) // 2
left_img = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
if left_img.shape[-1] != 3:
left_img = cv2.merge([left_img, left_img, left_img]).reshape((480, 640, 3))
right_img = self.pattern
left_disp = self.get_disp(left_disp_path)
if False: # self.rng.binomial(1, 0.5):
left_img, right_img = np.fliplr(right_img), np.fliplr(left_img)
left_disp, right_disp = np.fliplr(right_disp), np.fliplr(left_disp)
left_disp[left_disp == np.inf] = 0
if self.blur:
kernel_size = random.sample([1,3,5,7,9], 1)[0]
kernel = (kernel_size, kernel_size)
left_img = cv2.GaussianBlur(left_img, kernel, 0)
# augmentation
if not self.augment:
_left_img, _right_img, _left_disp, disp_mask = self.augmentor(
left_img, right_img, left_disp
)
else:
left_img, right_img, left_disp, disp_mask = self.augmentor(
left_img, right_img, left_disp
)
if not self.use_lightning:
right_img = right_img.transpose((2, 0, 1)).astype("uint8")
return {
"left": left_img,
"right": right_img,
"disparity": left_disp,
"mask": disp_mask,
}
right_img = right_img.transpose((2, 0, 1)).astype("uint8")
left_img = left_img.transpose((2, 0, 1)).astype("uint8")
return left_img, right_img, left_disp, disp_mask
def get_disp(self, path):
baseline = 0.075 # meters
fl = 560. # as per CTD
depth = cv2.imread(path, cv2.IMREAD_UNCHANGED)
downsampled = cv2.pyrDown(depth)
diff = (downsampled.shape[0] - 480) // 2
depth = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
# disp = np.load(path).transpose(1,2,0)
# disp = baseline * fl / depth
# return disp.astype(np.float32) / 32
# FIXME temporarily increase disparity until new data with better depth values is generated
# higher values seem to speedup convergence, but introduce much stronger artifacting
# mystery_factor = 150
mystery_factor = 1
disp = (baseline * fl * mystery_factor) / depth
return disp.astype(np.float32)

View File

@ -86,8 +86,15 @@ class LocalFeatureTransformer(nn.Module):
"""
assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal"
for layer, name in zip(self.layers, self.layer_names):
# NOTE Workaround for non statically determinable zip
# for layer, name in zip(self.layers, self.layer_names):
# layer_zip = ((layer, self.layer_names[i]) for i, layer in enumerate(self.layers))
# layer_zip = []
# for i, layer in enumerate(self.layers):
# layer_zip.append((layer, self.layer_names[i]))
for i, layer in enumerate(self.layers):
name = self.layer_names[i]
if name == 'self':
feat0 = layer(feat0, feat0, mask0, mask0)
feat1 = layer(feat1, feat1, mask1, mask1)
@ -97,4 +104,4 @@ class LocalFeatureTransformer(nn.Module):
else:
raise KeyError
return feat0, feat1
return feat0, feat1

View File

@ -36,6 +36,12 @@ class CREStereo(nn.Module):
self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=self.dropout)
self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4)
# # NOTE Position_encoding as workaround for TensorRt
# image1_shape = [1, 2, 480, 640]
# self.pos_encoding_fn_small = PositionEncodingSine(
# d_model=256, max_shape=(image1_shape[2] // 16, image1_shape[3] // 16)
# )
# loftr
self.self_att_fn = LocalFeatureTransformer(
d_model=256, nhead=8, layer_names=["self"] * 1, attention="linear"
@ -81,7 +87,7 @@ class CREStereo(nn.Module):
zero_flow = torch.cat((_x, _y), dim=1).to(fmap.device)
return zero_flow
def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False):
def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False, self_attend_right=True):
""" Estimate optical flow between pair of frames """
image1 = 2 * (image1 / 255.0) - 1.0
@ -130,17 +136,22 @@ class CREStereo(nn.Module):
inp_dw16 = F.avg_pool2d(inp, 4, stride=4)
# positional encoding and self-attention
pos_encoding_fn_small = PositionEncodingSine(
d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
)
# pos_encoding_fn_small = PositionEncodingSine(
# d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
# )
# 'n c h w -> n (h w) c'
x_tmp = pos_encoding_fn_small(fmap1_dw16)
x_tmp = self.pos_encoding_fn_small(fmap1_dw16)
fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
# 'n c h w -> n (h w) c'
x_tmp = pos_encoding_fn_small(fmap2_dw16)
x_tmp = self.pos_encoding_fn_small(fmap2_dw16)
fmap2_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
fmap1_dw16, fmap2_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16)
# FIXME experimental ! no self-attention for pattern
if not self_attend_right:
fmap1_dw16, _ = self.self_att_fn(fmap1_dw16, fmap2_dw16)
else:
fmap1_dw16, fmap2_dw16 = self.self_att_fn(fmap1_dw16, fmap2_dw16)
fmap1_dw16, fmap2_dw16 = [
x.reshape(x.shape[0], image1.shape[2] // 16, -1, x.shape[2]).permute(0, 3, 1, 2)
for x in [fmap1_dw16, fmap2_dw16]
@ -258,3 +269,4 @@ class CREStereo(nn.Module):
return flow_up
return predictions

View File

@ -1,6 +1,8 @@
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
# Ref: https://github.com/princeton-vl/RAFT/blob/master/core/extractor.py
class ResidualBlock(nn.Module):
@ -96,28 +98,43 @@ class BasicEncoder(nn.Module):
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
def forward(self, x: List[Tensor]):
# NOTE always assume list, otherwise TensorRT is sad
# batch_dim = x[0].shape[0]
# x_tensor = torch.cat(list(x), dim=0)
# if input is list, combine batch dimension
is_list = isinstance(x, tuple) or isinstance(x, list)
if is_list:
batch_dim = x[0].shape[0]
x = torch.cat(x, dim=0)
x_tensor = torch.cat(x, dim=0)
else:
x_tensor = x
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
print()
print()
print(x_tensor.shape)
print()
print()
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x_tensor = self.conv1(x_tensor)
x_tensor = self.norm1(x_tensor)
x_tensor = self.relu1(x_tensor)
x = self.conv2(x)
x_tensor = self.layer1(x_tensor)
x_tensor = self.layer2(x_tensor)
x_tensor = self.layer3(x_tensor)
x_tensor = self.conv2(x_tensor)
if self.dropout is not None:
x = self.dropout(x)
x_tensor = self.dropout(x_tensor)
if is_list:
x = torch.split(x, x.shape[0]//2, dim=0)
x_list = torch.split(x_tensor, x_tensor.shape[0]//2, dim=0)
return x_list
return x
x_list = torch.split(x_tensor, x_tensor.shape[0]//2, dim=0)
return x_list
# return list(x)

View File

@ -77,7 +77,7 @@ class BasicUpdateBlock(nn.Module):
nn.ReLU(inplace=True),
nn.Conv2d(256, mask_size**2 *9, 1, padding=0))
def forward(self, net, inp, corr, flow, upsample=True):
def forward(self, net, inp, corr, flow, upsample: bool=True):
# print(inp.shape, corr.shape, flow.shape)
motion_features = self.encoder(flow, corr)
# print(motion_features.shape, inp.shape)

View File

@ -3,6 +3,7 @@ import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
import os
from nets import Model
@ -16,17 +17,20 @@ device = 'cuda'
wandb.init(project="crestereo", entity="cpt-captain")
def do_infer(left_img, right_img, gt_disp, model):
disp = ctd_inference(left_img, right_img, gt_disp, None, model, None, n_iter=20, wandb_log=False)
def do_infer(left_img, right_img, gt_disp, model, attend_pattern=True):
disp = ctd_inference(left_img, right_img, gt_disp, None, model, None, n_iter=20, wandb_log=False, attend_pattern=attend_pattern)
disp_vis = normalize_and_colormap(disp)
gt_disp_vis = normalize_and_colormap(gt_disp)
if gt_disp.shape != disp.shape:
gt_disp = gt_disp.reshape(disp.shape)
disp_err = gt_disp - disp
disp_err = normalize_and_colormap(disp_err.abs())
# gt_disp_vis = normalize_and_colormap(gt_disp)
# if gt_disp.shape != disp.shape:
# gt_disp = gt_disp.reshape(disp.shape)
# disp_err = gt_disp - disp
# disp_err = normalize_and_colormap(disp_err.abs())
if isinstance(left_img, torch.Tensor):
left_img = left_img.cpu().detach().numpy().astype('uint8')
right_img = right_img.cpu().detach().numpy().astype('uint8')
wandb.log({
results = {
'disp': wandb.Image(
disp,
caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}",
@ -35,41 +39,60 @@ def do_infer(left_img, right_img, gt_disp, model):
disp_vis,
caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}",
),
'gt_disp_vis': wandb.Image(
gt_disp_vis,
caption=f"GT Disparity \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}",
),
'disp_err': wandb.Image(
disp_err,
caption=f"Disparity Error\n{disp_err.min():.{2}f}/{disp_err.max():.{2}f}",
),
# 'disp_err': wandb.Image(
# disp_err,
# caption=f"Disparity Error\n{disp_err.min():.{2}f}/{disp_err.max():.{2}f}",
# ),
'input_left': wandb.Image(
left_img.cpu().detach().numpy().astype('uint8'),
left_img,
caption=f"Input left",
),
'input_right': wandb.Image(
right_img.cpu().detach().numpy().astype('uint8'),
right_img,
caption=f"Input right",
),
})
}
if gt_disp is not None:
print('logging gt')
print(f'gt: {gt_disp.max()}/{gt_disp.min()}/{gt_disp.mean()}')
gt_disp_vis = normalize_and_colormap(gt_disp)
results.update({
'gt_disp_vis': wandb.Image(
gt_disp_vis,
caption=f"GT Disparity \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}",
)})
wandb.log(results)
def downsample(img, half_height_out=480):
downsampled = cv2.pyrDown(img)
diff = (downsampled.shape[0] - half_height_out) // 2
return downsampled[diff:downsampled.shape[0] - diff, 0:downsampled.shape[1]]
if __name__ == '__main__':
# model_path = "models/crestereo_eth3d.pth"
model_path = "train_log/models/latest.pth"
# model_path = "train_log/models/epoch-120.pth"
# model_path = "train_log/models/epoch-250.pth"
print(model_path)
# reference_pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png'
reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
# reference_pattern_path = '/home/nils/new_reference.png'
# reference_pattern_path = '/home/nils/kinect_reference_high_res.png'
reference_pattern_path = '/home/nils/miniprojekt/kinect_high_res_thresh_denoised.png'
# reference_pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
# reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
# reference_pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png'
data_type = 'kinect'
# data_type = 'kinect'
data_type = 'blender'
augment = False
args = parse_yaml("cfgs/train.yaml")
wandb.config.update({'model_path': model_path, 'reference_pattern': reference_pattern_path, 'augment': augment})
wandb.config.update({'model_path': model_path, 'reference_pattern': reference_pattern_path, 'augment': augment, 'data_type': data_type, 'pattern_self_attention': args.pattern_attention})
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
model = nn.DataParallel(model, device_ids=[device])
@ -78,16 +101,32 @@ if __name__ == '__main__':
model.to(device)
model.eval()
dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type,
pattern_path=reference_pattern_path, augment=augment)
dataloader = DataLoader(dataset, args.batch_size, shuffle=True,
num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True)
for batch in dataloader:
for left, right, disparity in zip(batch['left'], batch['right'], batch['disparity']):
right = right.transpose(0, 2).transpose(0, 1)
left_img = left
imgL = left.cpu().detach().numpy()
right_img = right
imgR = right.cpu().detach().numpy()
gt_disp = disparity
do_infer(left_img, right_img, gt_disp, model)
# dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type,
# pattern_path=reference_pattern_path, augment=augment)
# dataloader = DataLoader(dataset, args.batch_size, shuffle=True,
# num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True)
# for batch in dataloader:
gt_disp = None
right = downsample(cv2.imread(reference_pattern_path))
if data_type == 'blender':
img_path = '/media/Data1/connecting_the_dots_data/blender_renders/data/'
elif data_type == 'kinect':
img_path = '/home/nils/kinect_pngs/ir/'
for img in sorted(list(entry for entry in os.scandir(img_path) if 'depth' not in entry.name), key=lambda x:x.name)[:25]:
print(img.path)
if data_type == 'blender':
baseline = 0.075 # meters
fl = 560. # as per CTD
gt_path = img.path.rsplit('.')[0] + '_depth0001.png'
gt_depth = downsample(cv2.imread(gt_path))
mystery_factor = 35 # we don't get reasonable disparities due to incorrect depth scaling (or something like that)
gt_disp = (baseline * fl * mystery_factor) / gt_depth
left = downsample(cv2.imread(img.path))
do_infer(left, right, gt_disp, model, attend_pattern=args.pattern_attention)

191
train.py
View File

@ -9,7 +9,7 @@ import yaml
from nets import Model
# from dataset import CREStereoDataset
from dataset import CREStereoDataset, CTDDataset
from dataset import BlenderDataset, CREStereoDataset, CTDDataset
import torch
import torch.nn as nn
@ -32,14 +32,18 @@ def normalize_and_colormap(img):
return ret
def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=True):
def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=True, last_img=None, test=False, attend_pattern=True):
print("Model Forwarding...")
left = left.cpu().detach().numpy()
if isinstance(left, torch.Tensor):
left = left.cpu().detach().numpy()
imgR = right.cpu().detach().numpy()
imgL = left
imgR = right.cpu().detach().numpy()
imgR = right
imgL = np.ascontiguousarray(imgL[None, :, :, :])
imgR = np.ascontiguousarray(imgR[None, :, :, :])
flow_init = None
# chosen for convenience
device = torch.device('cuda:0')
@ -55,19 +59,54 @@ def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=Tru
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
)
).clamp(min=0, max=255)
imgR_dw2 = F.interpolate(
imgR,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
)
).clamp(min=0, max=255)
if last_img is not None:
print('using flow_initialization')
print(last_img.shape)
# FIXME trying 8bit normalization because sometimes we get really high values that don't seem to help
print(last_img.max(), last_img.min())
if last_img.min() < 0:
# print('Negative disparity detected. shifting...')
last_img = last_img - last_img.min()
if last_img.max() > 255:
# print('Excessive disparity detected. scaling...')
last_img = last_img / (last_img.max() / 255)
last_img = np.dstack([last_img, last_img])
# last_img = np.dstack([last_img, last_img, last_img])
last_img = np.dstack([last_img])
last_img = last_img.reshape((1, 2, 480, 640))
# print(last_img.shape)
# print(last_img.dtype)
# print(last_img.max(), last_img.min())
flow_init = torch.tensor(last_img.astype("float32")).to(device)
# flow_init = F.interpolate(
# last_img,
# size=(last_img.shape[0] // 2, last_img.shape[1] // 2),
# mode="bilinear",
# align_corners=True,
# )
with torch.inference_mode():
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=flow_init, self_attend_right=attend_pattern)
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2, self_attend_right=attend_pattern)
pf_base = pred_flow
if isinstance(pf_base, list):
pf_base = pred_flow[0]
pf = torch.squeeze(pf_base[:, 0, :, :]).cpu().detach().numpy()
print('pred_flow max min')
print(pf.max(), pf.min())
if not wandb_log:
if test:
return pred_flow
return torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
log = {}
@ -96,30 +135,36 @@ def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=Tru
np.array([pred_disp.reshape(480, 640)]),
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
)
log[f'pred_norm_{i}'] = wandb.Image(
np.array([pred_disp_norm.reshape(480, 640)]),
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
)
# log[f'pred_norm_{i}'] = wandb.Image(
# np.array([pred_disp_norm.reshape(480, 640)]),
# caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
# )
log[f'pred_dw2_{i}'] = wandb.Image(
np.array([pred_disp_dw2.reshape(240, 320)]),
caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
)
log[f'pred_dw2_norm_{i}'] = wandb.Image(
np.array([pred_disp_dw2_norm.reshape(240, 320)]),
caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
)
# log[f'pred_dw2_{i}'] = wandb.Image(
# np.array([pred_disp_dw2.reshape(240, 320)]),
# caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
# )
# log[f'pred_dw2_norm_{i}'] = wandb.Image(
# np.array([pred_disp_dw2_norm.reshape(240, 320)]),
# caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
# )
log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left")
log['input_right'] = wandb.Image(right.cpu().detach().numpy().transpose(1,2,0).astype('uint8'), caption="Input Right")
input_right = right.cpu().detach().numpy() if isinstance(right, torch.Tensor) else right
if input_right.shape != (480, 640, 3):
input_right.transpose(1,2,0)
log['input_right'] = wandb.Image(input_right.astype('uint8'), caption="Input Right")
log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}")
gt_disp = gt_disp.cpu().detach().numpy() if isinstance(gt_disp, torch.Tensor) else gt_disp
disp = disp.cpu().detach().numpy() if isinstance(disp, torch.Tensor) else disp
disp_error = gt_disp - disp
log['disp_error'] = wandb.Image(
normalize_and_colormap(disp_error.abs()),
caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{disp_error.abs().mean():.{2}f}",
normalize_and_colormap(abs(disp_error)),
caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{abs(disp_error).mean():.{2}f}",
)
@ -129,6 +174,7 @@ def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=Tru
)
wandb.log(log)
return pred_flow
def parse_yaml(file_path: str) -> namedtuple:
@ -172,12 +218,25 @@ def adjust_learning_rate(optimizer, epoch):
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8):
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
'''
valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W)
flow_preds[0]: (B, 2, H, W)
flow_gt: (B, 2, H, W)
'''
if test:
# print('sequence loss')
if valid.shape != (2, 480, 640):
valid = torch.stack([valid, valid])#.transpose(0,1)#.transpose(1,2)
# print(valid.shape)
#valid = torch.stack([valid, valid])
# print(valid.shape)
if valid.shape != (2, 480, 640):
valid = valid.transpose(0,1)
# print(valid.shape)
# print(valid.shape)
# print(flow_preds[0].shape)
# print(flow_gt.shape)
n_predictions = len(flow_preds)
flow_loss = 0.0
@ -260,20 +319,41 @@ def main(args):
start_iters = 0
# pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png'
pattern_path = '/home/nils/kinect_reference_cropped.png'
# pattern_path = '/home/nils/kinect_reference_cropped.png'
# pattern_path = '/home/nils/kinect_reference_far.png'
# pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
pattern_path = '/home/nils/miniprojekt/kinect_high_res_thresh_denoised.png'
# pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png'
# datasets
# dataset = CREStereoDataset(args.training_data_path)
dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path)
# dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path)
# test_dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path, test_set=True)
if args.dataset == 'blender':
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path)
test_dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path, test_set=True)
elif args.dataset == 'ctd':
dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path)
test_dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path, test_set=True)
else:
print('unrecognized dataset')
quit()
test_data_iter = iter(test_dataset)
# if rank == 0:
worklog.info(f"Dataset size: {len(dataset)}")
print(args.batch_size)
dataloader = DataLoader(dataset, args.batch_size, shuffle=True,
num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True)
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
test_dataloader = DataLoader(test_dataset, args.batch_size, shuffle=False,
num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True)
# counter
cur_iters = start_iters
total_iters = args.minibatch_per_epoch * args.n_total_epoch
t0 = time.perf_counter()
test_idx = 0
for epoch_idx in range(start_epoch_idx, args.n_total_epoch + 1):
# adjust learning rate
@ -310,24 +390,59 @@ def main(args):
# forward
# left = left.transpose(1, 2).transpose(2, 3)
left = left.transpose(1, 3).transpose(2, 3)
right = right.transpose(2, 3).transpose(2, 3)#.transpose(1, 2)
flow_predictions = model(left.cuda(), right.cuda())
# right = right.transpose(2, 3).transpose(2, 3)#.transpose(1, 2)
flow_predictions = model(left.cuda(), right.cuda(), self_attend_right=args.pattern_attention)
# loss & backword
loss = sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8
)
if batch_idx % 128 == 0:
inference(
mini_batch_data['left'][0],
mini_batch_data['right'][0],
mini_batch_data['disparity'][0],
mini_batch_data['mask'][0],
model,
batch_idx,
)
if batch_idx % 512 == 0:
test_idx = 0
test_loss = 0
for i, test_batch in enumerate(test_dataset):
# test_batch = next(test_data_iter)
if i >= 24:
break
# TODO refactor, DRY
left, right, gt_disp, valid_mask = (
test_batch['left'],
test_batch['right'],
torch.Tensor(test_batch['disparity']).cuda(),
torch.Tensor(test_batch['mask']).cuda(),
)
gt_disp = torch.dstack([gt_disp, gt_disp]).transpose(2,0).transpose(1,2)
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
# print(f'left {left.shape}, right {right.shape}')
# left = left.transpose([2, 0, 1])
right = right.transpose([1, 2, 0])
# right = right.transpose(2, 3).transpose(2, 3)#.transpose(1, 2)
# print(f'left {left.shape}, right {right.shape}')
model.eval()
flow_predictions = inference(
left,
right,
# gt_disp,
torch.Tensor(test_batch['disparity']).cuda(),
valid_mask,
model,
test_idx,
wandb_log=i % 4 == 0,
test=True,
)
test_idx += 1
test_loss += sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8, test=True
).data.item()
model.train()
avg_test_loss = test_loss / test_idx
print(f'test_loss: {test_loss}\nlen test: {test_idx}\navg. loss: {avg_test_loss}')
metrics['test/loss'] = avg_test_loss
# loss stats
loss_item = loss.data.item()
epoch_total_train_loss += loss_item

346
train_lightning.py Normal file
View File

@ -0,0 +1,346 @@
import os
import sys
import time
import logging
from collections import namedtuple
import yaml
# from tensorboardX import SummaryWriter
from nets import Model
# from dataset import CREStereoDataset
from dataset import BlenderDataset, CREStereoDataset, CTDDataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from pytorch_lightning.lite import LightningLite
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import WandbLogger
seed_everything(42, workers=True)
import wandb
import numpy as np
import cv2
def normalize_and_colormap(img):
ret = (img - img.min()) / (img.max() - img.min()) * 255.0
if isinstance(ret, torch.Tensor):
ret = ret.cpu().detach().numpy()
ret = ret.astype("uint8")
ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO)
return ret
def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=True, last_img=None, test=False, attend_pattern=True):
print("Model Forwarding...")
if isinstance(left, torch.Tensor):
left = left# .cpu().detach().numpy()
imgR = right# .cpu().detach().numpy()
imgL = left
imgR = right
imgL = np.ascontiguousarray(imgL[None, :, :, :])
imgR = np.ascontiguousarray(imgR[None, :, :, :])
flow_init = None
# chosen for convenience
imgL = torch.tensor(imgL.astype("float32"))
imgR = torch.tensor(imgR.astype("float32"))
imgL = imgL.transpose(2,3).transpose(1,2)
if imgL.shape != imgR.shape:
imgR = imgR.transpose(2,3).transpose(1,2)
imgL_dw2 = F.interpolate(
imgL,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
).clamp(min=0, max=255)
imgR_dw2 = F.interpolate(
imgR,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
).clamp(min=0, max=255)
if last_img is not None:
print('using flow_initialization')
print(last_img.shape)
# FIXME trying 8bit normalization because sometimes we get really high values that don't seem to help
print(last_img.max(), last_img.min())
if last_img.min() < 0:
# print('Negative disparity detected. shifting...')
last_img = last_img - last_img.min()
if last_img.max() > 255:
# print('Excessive disparity detected. scaling...')
last_img = last_img / (last_img.max() / 255)
last_img = np.dstack([last_img, last_img])
# last_img = np.dstack([last_img, last_img, last_img])
last_img = np.dstack([last_img])
last_img = last_img.reshape((1, 2, 480, 640))
# print(last_img.shape)
# print(last_img.dtype)
# print(last_img.max(), last_img.min())
flow_init = torch.tensor(last_img.astype("float32"))
# flow_init = F.interpolate(
# last_img,
# size=(last_img.shape[0] // 2, last_img.shape[1] // 2),
# mode="bilinear",
# align_corners=True,
# )
with torch.inference_mode():
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=flow_init, self_attend_right=attend_pattern)
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2, self_attend_right=attend_pattern)
pf_base = pred_flow
if isinstance(pf_base, list):
pf_base = pred_flow[0]
pf = torch.squeeze(pf_base[:, 0, :, :])# .cpu().detach().numpy()
print('pred_flow max min')
print(pf.max(), pf.min())
if not wandb_log:
if test:
return pred_flow
return torch.squeeze(pred_flow[:, 0, :, :])# .cpu().detach().numpy()
log = {}
in_h, in_w = left.shape[:2]
# Resize image in case the GPU memory overflows
eval_h, eval_w = (in_h,in_w)
for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)):
pred_disp = torch.squeeze(pf[:, 0, :, :])# .cpu().detach().numpy()
pred_disp_dw2 = torch.squeeze(pf_dw2[:, 0, :, :])# .cpu().detach().numpy()
# pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
# pred_disp_dw2_norm = cv2.normalize(pred_disp_dw2, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
if i == n_iter-1:
t = float(in_w) / float(eval_w)
disp = cv2.resize(pred_disp.cpu().detach().numpy(), (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t
log[f'disp_vis'] = wandb.Image(
normalize_and_colormap(disp),
caption=f"Disparity \n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
)
log[f'pred_{i}'] = wandb.Image(
np.array([pred_disp.cpu().detach().numpy().reshape(480, 640)]),
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
)
# log[f'pred_norm_{i}'] = wandb.Image(
# np.array([pred_disp_norm.reshape(480, 640)]),
# caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
# )
# log[f'pred_dw2_{i}'] = wandb.Image(
# np.array([pred_disp_dw2.reshape(240, 320)]),
# caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
# )
# log[f'pred_dw2_norm_{i}'] = wandb.Image(
# np.array([pred_disp_dw2_norm.reshape(240, 320)]),
# caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
# )
log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left")
input_right = right# .cpu().detach().numpy() if isinstance(right, torch.Tensor) else right
if input_right.shape != (480, 640, 3):
input_right.transpose(1,2,0)
log['input_right'] = wandb.Image(input_right.astype('uint8'), caption="Input Right")
log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}")
gt_disp = gt_disp# .cpu().detach().numpy() if isinstance(gt_disp, torch.Tensor) else gt_disp
disp = disp# .cpu().detach().numpy() if isinstance(disp, torch.Tensor) else disp
disp_error = gt_disp - disp
log['disp_error'] = wandb.Image(
normalize_and_colormap(abs(disp_error)),
caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{abs(disp_error).mean():.{2}f}",
)
log[f'gt_disp_vis'] = wandb.Image(
normalize_and_colormap(gt_disp),
caption=f"GT Disp Vis \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}",
)
wandb.log(log)
return pred_flow
def parse_yaml(file_path: str) -> namedtuple:
"""Parse yaml configuration file and return the object in `namedtuple`."""
with open(file_path, "rb") as f:
cfg: dict = yaml.safe_load(f)
args = namedtuple("train_args", cfg.keys())(*cfg.values())
return args
def format_time(elapse):
elapse = int(elapse)
hour = elapse // 3600
minute = elapse % 3600 // 60
seconds = elapse % 60
return "{:02d}:{:02d}:{:02d}".format(hour, minute, seconds)
def ensure_dir(path):
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
def adjust_learning_rate(optimizer, epoch):
warm_up = 0.02
const_range = 0.6
min_lr_rate = 0.05
if epoch <= args.n_total_epoch * warm_up:
lr = (1 - min_lr_rate) * args.base_lr / (
args.n_total_epoch * warm_up
) * epoch + min_lr_rate * args.base_lr
elif args.n_total_epoch * warm_up < epoch <= args.n_total_epoch * const_range:
lr = args.base_lr
else:
lr = (min_lr_rate - 1) * args.base_lr / (
(1 - const_range) * args.n_total_epoch
) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * args.base_lr
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
'''
valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W)
flow_preds[0]: (B, 2, H, W)
flow_gt: (B, 2, H, W)
'''
if test:
# print('sequence loss')
if valid.shape != (2, 480, 640):
valid = torch.stack([valid, valid])#.transpose(0,1)#.transpose(1,2)
# print(valid.shape)
#valid = torch.stack([valid, valid])
# print(valid.shape)
if valid.shape != (2, 480, 640):
valid = valid.transpose(0,1)
# print(valid.shape)
# print(valid.shape)
# print(flow_preds[0].shape)
# print(flow_gt.shape)
n_predictions = len(flow_preds)
flow_loss = 0.0
# TEST
flow_gt = torch.squeeze(flow_gt, dim=-1)
for i in range(n_predictions):
i_weight = gamma ** (n_predictions - i - 1)
i_loss = torch.abs(flow_preds[i] - flow_gt)
flow_loss += i_weight * (valid.unsqueeze(1) * i_loss).mean()
return flow_loss
class CREStereoLightning(LightningModule):
def __init__(self, args):
super().__init__()
self.batch_size = args.batch_size
self.model = Model(
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False
)
def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False, self_attend_right=True):
return self.model(image1, image2, flow_init, iters, upsample, test_mode, self_attend_right)
def training_step(self, batch, batch_idx):
# loss = self(batch)
left, right, gt_disp, valid_mask = batch
left = torch.Tensor(left).to(self.device)
right = torch.Tensor(right).to(self.device)
left = left
right = right
flow_predictions = self.forward(left, right)
loss = sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8
)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
left, right, gt_disp, valid_mask = batch
left = torch.Tensor(left).to(self.device)
right = torch.Tensor(right).to(self.device)
print(left.shape)
print(right.shape)
flow_predictions = self.forward(left, right)
val_loss = sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8
)
self.log("val_loss", val_loss)
def test_step(self, batch, batch_idx):
left, right, gt_disp, valid_mask = batch
# left, right, gt_disp, valid_mask = (
# batch["left"],
# batch["right"],
# batch["disparity"],
# batch["mask"],
# )
left = torch.Tensor(left).to(self.device)
right = torch.Tensor(right).to(self.device)
flow_predictions = self.forward(left, right)
test_loss = sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8
)
self.log("test_loss", test_loss)
def configure_optimizers(self):
return optim.Adam(self.model.parameters(), lr=0.01, betas=(0.9, 0.999))
if __name__ == "__main__":
# train configuration
args = parse_yaml("cfgs/train.yaml")
# wandb.init(project="crestereo-lightning", entity="cpt-captain")
# Lite(strategy='dp', accelerator='gpu', devices=2).run(args)
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
model = CREStereoLightning(args)
dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path, use_lightning=True)
test_dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path, test_set=True, use_lightning=True)
print(len(dataset))
print(len(test_dataset))
wandb_logger = WandbLogger(project="crestereo-lightning")
wandb.config.update(args._asdict())
trainer = Trainer(
max_epochs=args.n_total_epoch,
accelerator='gpu',
devices=2,
# auto_scale_batch_size='binsearch',
# strategy='ddp',
deterministic=True,
check_val_every_n_epoch=1,
limit_val_batches=24,
limit_test_batches=24,
logger=wandb_logger,
default_root_dir=args.log_dir_lightning,
)
# trainer.tune(model)
trainer.fit(model, dataset, test_dataset)