add wandb, make compatible with ctd data
This commit is contained in:
parent
b333d5515f
commit
70e4bf6fe1
121
dataset.py
121
dataset.py
@ -1,4 +1,5 @@
|
||||
import os
|
||||
import random
|
||||
import cv2
|
||||
import glob
|
||||
import numpy as np
|
||||
@ -48,8 +49,8 @@ class Augmentor:
|
||||
|
||||
def __call__(self, left_img, right_img, left_disp):
|
||||
# 1. chromatic augmentation
|
||||
left_img = self.chromatic_augmentation(left_img)
|
||||
right_img = self.chromatic_augmentation(right_img)
|
||||
# left_img = self.chromatic_augmentation(left_img)
|
||||
# right_img = self.chromatic_augmentation(right_img)
|
||||
|
||||
# 2. spatial augmentation
|
||||
# 2.1) rotate & vertical shift for right image
|
||||
@ -62,6 +63,7 @@ class Augmentor:
|
||||
self.rng.uniform(0, right_img.shape[1]),
|
||||
)
|
||||
rot_mat = cv2.getRotationMatrix2D(image_center, ag, 1.0)
|
||||
# right_img = right_img.transpose(2, 1, 0)
|
||||
right_img = cv2.warpAffine(
|
||||
right_img, rot_mat, right_img.shape[1::-1], flags=cv2.INTER_LINEAR
|
||||
)
|
||||
@ -69,6 +71,7 @@ class Augmentor:
|
||||
right_img = cv2.warpAffine(
|
||||
right_img, trans_mat, right_img.shape[1::-1], flags=cv2.INTER_LINEAR
|
||||
)
|
||||
# right_img = right_img.transpose(1, 2, 0)
|
||||
|
||||
# 2.2) random resize
|
||||
resize_scale = self.rng.uniform(self.scale_min, self.scale_max)
|
||||
@ -80,6 +83,10 @@ class Augmentor:
|
||||
fy=resize_scale,
|
||||
interpolation=cv2.INTER_LINEAR,
|
||||
)
|
||||
if len(left_img.shape) == 2:
|
||||
left_img.shape += 1,
|
||||
# left_img = cv2.merge([left_img, left_img, left_img])
|
||||
|
||||
right_img = cv2.resize(
|
||||
right_img,
|
||||
None,
|
||||
@ -87,6 +94,9 @@ class Augmentor:
|
||||
fy=resize_scale,
|
||||
interpolation=cv2.INTER_LINEAR,
|
||||
)
|
||||
if len(left_img.shape) == 2:
|
||||
left_img.shape += 1,
|
||||
# left_img = cv2.merge([left_img, left_img, left_img])
|
||||
|
||||
disp_mask = (left_disp < float(self.max_disp / resize_scale)) & (left_disp > 0)
|
||||
disp_mask = disp_mask.astype("float32")
|
||||
@ -110,7 +120,11 @@ class Augmentor:
|
||||
)
|
||||
|
||||
# 2.3) random crop
|
||||
h, w, c = left_img.shape
|
||||
if len(left_img.shape) == 3:
|
||||
h, w, c = left_img.shape
|
||||
else:
|
||||
h, w = left_img.shape
|
||||
c = 1
|
||||
dx = w - self.image_width
|
||||
dy = h - self.image_height
|
||||
dy = self.rng.randint(min(0, dy), max(0, dy) + 1)
|
||||
@ -144,7 +158,7 @@ class Augmentor:
|
||||
(self.image_width, self.image_height),
|
||||
flags=cv2.INTER_LINEAR,
|
||||
borderValue=0,
|
||||
)
|
||||
)
|
||||
|
||||
# 3. add random occlusion to right image
|
||||
if self.rng.binomial(1, 0.5):
|
||||
@ -156,6 +170,9 @@ class Augmentor:
|
||||
np.mean(right_img, 0), 0
|
||||
)[np.newaxis, np.newaxis]
|
||||
|
||||
if len(left_img.shape) == 2:
|
||||
left_img = cv2.merge([left_img, left_img, left_img])
|
||||
|
||||
return left_img, right_img, left_disp, disp_mask
|
||||
|
||||
|
||||
@ -197,7 +214,8 @@ class CREStereoDataset(Dataset):
|
||||
left_disp[left_disp == np.inf] = 0
|
||||
|
||||
# augmentaion
|
||||
left_img, right_img, left_disp, disp_mask = self.augmentor(
|
||||
# left_img, right_img, left_disp, disp_mask = self.augmentor(
|
||||
_, _, left_disp, disp_mask = self.augmentor(
|
||||
left_img, right_img, left_disp
|
||||
)
|
||||
|
||||
@ -213,3 +231,96 @@ class CREStereoDataset(Dataset):
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imgs)
|
||||
|
||||
|
||||
class CTDDataset(Dataset):
|
||||
def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False):
|
||||
super().__init__()
|
||||
self.rng = np.random.RandomState(0)
|
||||
self.augment = augment
|
||||
self.blur = blur
|
||||
self.imgs = glob.glob(os.path.join(root, f"{data_type}/*/im0_*.npy"), recursive=True)
|
||||
self.pattern = cv2.imread(pattern_path)#, cv2.IMREAD_GRAYSCALE)
|
||||
|
||||
if resize_pattern and self.pattern.shape != (480, 640, 3):
|
||||
# self.pattern = cv2.resize(self.pattern, (640, 480))
|
||||
print(self.pattern.shape)
|
||||
downsampled = cv2.pyrDown(self.pattern)
|
||||
diff = (downsampled.shape[0] - 480) // 2
|
||||
self.pattern = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
|
||||
|
||||
self.augmentor = Augmentor(
|
||||
image_height=480,
|
||||
image_width=640,
|
||||
max_disp=256,
|
||||
scale_min=0.6,
|
||||
scale_max=1.0,
|
||||
seed=0,
|
||||
)
|
||||
|
||||
def get_disp(self, path):
|
||||
# disp = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
disp = np.load(path).transpose(1,2,0)
|
||||
# return disp.astype(np.float32) / 32
|
||||
return disp
|
||||
|
||||
def __getitem__(self, index):
|
||||
# find path
|
||||
left_path = self.imgs[index]
|
||||
prefix = left_path[: left_path.rfind("_")]
|
||||
# right_path = prefix + "_right.jpg"
|
||||
left_disp_path = left_path.replace('im', 'disp')
|
||||
# right_disp_path = prefix + "_right.disp.png"
|
||||
|
||||
# read img, disp
|
||||
# left_img = cv2.imread(left_path, cv2.IMREAD_COLOR)
|
||||
# right_img = cv2.imread(right_path, cv2.IMREAD_COLOR)
|
||||
# left_img = np.load(left_path).transpose(1,2,0)
|
||||
left_img = np.load(left_path)
|
||||
# left_img = cv2.cvtColor(left_img, cv2.COLOR_GRAY2RGB)
|
||||
# FIXME DO WE NEED THIS? OTHERWISE IT's PRETTY DARK
|
||||
left_img = cv2.normalize(left_img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
|
||||
# left_img = cv2.merge([left_img, left_img, left_img]).reshape((3, 480, 640))
|
||||
left_img = cv2.merge([left_img, left_img, left_img]).reshape((480, 640, 3))
|
||||
|
||||
right_img = self.pattern
|
||||
# right_img = cv2.merge([right_img, right_img, right_img]).reshape((3, 480, 640))
|
||||
left_disp = self.get_disp(left_disp_path)
|
||||
# right_disp = self.get_disp(right_disp_path)
|
||||
|
||||
if False: # self.rng.binomial(1, 0.5):
|
||||
left_img, right_img = np.fliplr(right_img), np.fliplr(left_img)
|
||||
left_disp, right_disp = np.fliplr(right_disp), np.fliplr(left_disp)
|
||||
left_disp[left_disp == np.inf] = 0
|
||||
|
||||
if self.blur:
|
||||
kernel_size = random.sample([1,3,5,7,9], 1)[0]
|
||||
kernel = (kernel_size, kernel_size)
|
||||
left_img = cv2.GaussianBlur(left_img, kernel, 0)
|
||||
|
||||
# augmentation
|
||||
if not self.augment:
|
||||
_left_img, _right_img, _left_disp, disp_mask = self.augmentor(
|
||||
left_img, right_img, left_disp
|
||||
)
|
||||
else:
|
||||
left_img, right_img, left_disp, disp_mask = self.augmentor(
|
||||
left_img, right_img, left_disp
|
||||
)
|
||||
|
||||
# left_img = left_img.transpose(2, 0, 1).astype("uint8")
|
||||
# left_img = left_img.transpose(2, 0, 1).astype("float32")
|
||||
# left_img = left_img.astype("float32")
|
||||
right_img = right_img.transpose(2, 0, 1).astype("uint8")
|
||||
# right_img = right_img.astype("float32")
|
||||
# print('post_augment', left_img)
|
||||
|
||||
return {
|
||||
"left": left_img,
|
||||
"right": right_img,
|
||||
"disparity": left_disp,
|
||||
"mask": disp_mask,
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imgs)
|
||||
|
@ -155,6 +155,8 @@ class CREStereo(nn.Module):
|
||||
flow = None
|
||||
flow_up = None
|
||||
if flow_init is not None:
|
||||
if isinstance(flow_init, list):
|
||||
flow_init = flow_init[0]
|
||||
scale = fmap1.shape[2] / flow_init.shape[2]
|
||||
flow = -scale * F.interpolate(
|
||||
flow_init,
|
||||
|
217
test_model.py
217
test_model.py
@ -1,12 +1,22 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
import cv2
|
||||
from imread_from_url import imread_from_url
|
||||
|
||||
from nets import Model
|
||||
|
||||
import wandb
|
||||
|
||||
import random
|
||||
from torch.utils.data import DataLoader
|
||||
from dataset import CTDDataset
|
||||
from train import normalize_and_colormap, parse_yaml, inference as ctd_inference
|
||||
|
||||
device = 'cuda'
|
||||
wandb.init(project="crestereo", entity="cpt-captain")
|
||||
|
||||
|
||||
|
||||
#Ref: https://github.com/megvii-research/CREStereo/blob/master/test.py
|
||||
def inference(left, right, model, n_iter=20):
|
||||
@ -41,39 +51,198 @@ def inference(left, right, model, n_iter=20):
|
||||
|
||||
return pred_disp
|
||||
|
||||
|
||||
def inference_ctd(left, right, gt_disp, mask, model, epoch, n_iter=20):
|
||||
|
||||
print("Model Forwarding...")
|
||||
# print(left.shape)
|
||||
left = left.cpu().detach().numpy()
|
||||
imgL = left
|
||||
imgR = right.cpu().detach().numpy()
|
||||
imgL = np.ascontiguousarray(imgL[None, :, :, :])
|
||||
imgR = np.ascontiguousarray(imgR[None, :, :, :])
|
||||
|
||||
# chosen for convenience
|
||||
device = torch.device('cuda:0')
|
||||
|
||||
imgL = torch.tensor(imgL.astype("float32")).to(device)
|
||||
imgR = torch.tensor(imgR.astype("float32")).to(device)
|
||||
imgL = imgL.transpose(2,3).transpose(1,2)
|
||||
|
||||
imgL_dw2 = F.interpolate(
|
||||
imgL,
|
||||
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
imgR_dw2 = F.interpolate(
|
||||
imgR,
|
||||
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
with torch.inference_mode():
|
||||
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
|
||||
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
|
||||
|
||||
log = {}
|
||||
for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)):
|
||||
pred_disp = torch.squeeze(pf[:, 0, :, :]).cpu().detach().numpy()
|
||||
pred_disp_dw2 = torch.squeeze(pf_dw2[:, 0, :, :]).cpu().detach().numpy()
|
||||
|
||||
pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
|
||||
pred_disp_dw2_norm = cv2.normalize(pred_disp_dw2, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
|
||||
|
||||
log[f'pred_{i}'] = wandb.Image(
|
||||
np.array([pred_disp.reshape(480, 640)]),
|
||||
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
|
||||
)
|
||||
log[f'pred_norm_{i}'] = wandb.Image(
|
||||
np.array([pred_disp_norm.reshape(480, 640)]),
|
||||
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
|
||||
)
|
||||
|
||||
log[f'pred_dw2_{i}'] = wandb.Image(
|
||||
np.array([pred_disp_dw2.reshape(240, 320)]),
|
||||
caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
|
||||
)
|
||||
log[f'pred_dw2_norm_{i}'] = wandb.Image(
|
||||
np.array([pred_disp_dw2_norm.reshape(240, 320)]),
|
||||
caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
|
||||
)
|
||||
|
||||
|
||||
log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left")
|
||||
log['input_right'] = wandb.Image(right.cpu().detach().numpy().transpose(1,2,0).astype('uint8'), caption="Input Right")
|
||||
|
||||
log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}")
|
||||
|
||||
|
||||
disp_error = gt_disp - disp
|
||||
log['disp_error'] = wandb.Image(
|
||||
normalize_and_colormap(disp_error),
|
||||
caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{disp_error.mean():.{2}f}",
|
||||
)
|
||||
|
||||
wandb.log(log)
|
||||
|
||||
|
||||
def do_infer(left_img, right_img, gt_disp, model):
|
||||
in_h, in_w = left_img.shape[:2]
|
||||
|
||||
# Resize image in case the GPU memory overflows
|
||||
eval_h, eval_w = (in_h,in_w)
|
||||
|
||||
# FIXME borked for some reason, hopefully not very important
|
||||
|
||||
imgL = left_img.cpu().detach().numpy() if isinstance(left_img, torch.Tensor) else left_img
|
||||
imgR = right_img.cpu().detach().numpy() if isinstance(right_img, torch.Tensor) else right_img
|
||||
|
||||
imgL = cv2.resize(imgL, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
|
||||
imgR = cv2.resize(imgR, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
# pred = ctd_inference(imgL, imgR, gt_disp, None, model, None, n_iter=20)
|
||||
pred = ctd_inference(left_img, right_img, gt_disp, None, model, None, n_iter=20, wandb_log=False)
|
||||
|
||||
t = float(in_w) / float(eval_w)
|
||||
disp = cv2.resize(pred, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t
|
||||
|
||||
disp_vis = normalize_and_colormap(disp)
|
||||
gt_disp_vis = normalize_and_colormap(gt_disp)
|
||||
if gt_disp.shape != disp.shape:
|
||||
gt_disp = gt_disp.reshape(disp.shape)
|
||||
disp_err = gt_disp - disp
|
||||
disp_err = normalize_and_colormap(disp_err.abs())
|
||||
|
||||
wandb.log({
|
||||
'disp_vis': wandb.Image(
|
||||
disp_vis,
|
||||
caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}",
|
||||
),
|
||||
'gt_disp_vis': wandb.Image(
|
||||
gt_disp_vis,
|
||||
caption=f"GT Disparity \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}",
|
||||
),
|
||||
'disp_err': wandb.Image(
|
||||
disp_err,
|
||||
caption=f"Disparity Error\n{disp_err.min():.{2}f}/{disp_err.max():.{2}f}",
|
||||
),
|
||||
'input_left': wandb.Image(
|
||||
left_img.cpu().detach().numpy().astype('uint8'),
|
||||
caption=f"Input left",
|
||||
),
|
||||
'input_right': wandb.Image(
|
||||
right_img.cpu().detach().numpy().astype('uint8'),
|
||||
caption=f"Input right",
|
||||
),
|
||||
})
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# model_path = "models/crestereo_eth3d.pth"
|
||||
model_path = "train_log/models/latest.pth"
|
||||
|
||||
left_img = imread_from_url("https://raw.githubusercontent.com/megvii-research/CREStereo/master/img/test/left.png")
|
||||
right_img = imread_from_url("https://raw.githubusercontent.com/megvii-research/CREStereo/master/img/test/right.png")
|
||||
# reference_pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png'
|
||||
reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
|
||||
# reference_pattern_path = '/home/nils/new_reference.png'
|
||||
# reference_pattern_path = '/home/nils/kinect_reference_high_res.png'
|
||||
# reference_pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png'
|
||||
|
||||
in_h, in_w = left_img.shape[:2]
|
||||
data_type = 'kinect'
|
||||
augment = False
|
||||
|
||||
# Resize image in case the GPU memory overflows
|
||||
eval_h, eval_w = (in_h,in_w)
|
||||
imgL = cv2.resize(left_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
|
||||
imgR = cv2.resize(right_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
|
||||
args = parse_yaml("cfgs/train.yaml")
|
||||
|
||||
model_path = "models/crestereo_eth3d.pth"
|
||||
wandb.config.update({'model_path': model_path, 'reference_pattern': reference_pattern_path, 'augment': augment})
|
||||
|
||||
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
|
||||
model.load_state_dict(torch.load(model_path), strict=True)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
|
||||
model = nn.DataParallel(model,device_ids=[device])
|
||||
# model.load_state_dict(torch.load(model_path), strict=False)
|
||||
state_dict = torch.load(model_path)['state_dict']
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
|
||||
pred = inference(imgL, imgR, model, n_iter=20)
|
||||
CTD = True
|
||||
if not CTD:
|
||||
left_img = cv2.imread("../test_imgs/left.png")
|
||||
right_img = cv2.imread("../test_imgs/right.png")
|
||||
in_h, in_w = left_img.shape[:2]
|
||||
|
||||
t = float(in_w) / float(eval_w)
|
||||
disp = cv2.resize(pred, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t
|
||||
# Resize image in case the GPU memory overflows
|
||||
eval_h, eval_w = (in_h,in_w)
|
||||
|
||||
disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0
|
||||
disp_vis = disp_vis.astype("uint8")
|
||||
disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO)
|
||||
# FIXME borked for some reason, hopefully not very important
|
||||
imgL = cv2.resize(left_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
|
||||
imgR = cv2.resize(right_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
|
||||
|
||||
combined_img = np.hstack((left_img, disp_vis))
|
||||
cv2.namedWindow("output", cv2.WINDOW_NORMAL)
|
||||
cv2.imshow("output", combined_img)
|
||||
cv2.imwrite("output.jpg", disp_vis)
|
||||
cv2.waitKey(0)
|
||||
pred = inference(imgL, imgR, model, n_iter=20)
|
||||
|
||||
t = float(in_w) / float(eval_w)
|
||||
disp = cv2.resize(pred, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t
|
||||
|
||||
disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0
|
||||
disp_vis = disp_vis.astype("uint8")
|
||||
disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO)
|
||||
|
||||
combined_img = np.hstack((left_img, disp_vis))
|
||||
# cv2.namedWindow("output", cv2.WINDOW_NORMAL)
|
||||
# cv2.imshow("output", combined_img)
|
||||
cv2.imwrite("output.jpg", disp_vis)
|
||||
# cv2.waitKey(0)
|
||||
|
||||
else:
|
||||
dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type, pattern_path=reference_pattern_path, augment=augment)
|
||||
dataloader = DataLoader(dataset, args.batch_size, shuffle=True,
|
||||
num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True)
|
||||
for batch in dataloader:
|
||||
for left, right, disparity in zip(batch['left'], batch['right'], batch['disparity']):
|
||||
right = right.transpose(0, 2).transpose(0, 1)
|
||||
left_img = left
|
||||
imgL = left.cpu().detach().numpy()
|
||||
right_img = right
|
||||
imgR = right.cpu().detach().numpy()
|
||||
gt_disp = disparity
|
||||
do_infer(left_img, right_img, gt_disp, model)
|
||||
|
||||
|
175
train.py
175
train.py
@ -5,16 +5,131 @@ import logging
|
||||
from collections import namedtuple
|
||||
|
||||
import yaml
|
||||
from tensorboardX import SummaryWriter
|
||||
# from tensorboardX import SummaryWriter
|
||||
|
||||
from nets import Model
|
||||
from dataset import CREStereoDataset
|
||||
# from dataset import CREStereoDataset
|
||||
from dataset import CREStereoDataset, CTDDataset
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
import wandb
|
||||
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
|
||||
def normalize_and_colormap(img):
|
||||
ret = (img - img.min()) / (img.max() - img.min()) * 255.0
|
||||
if isinstance(ret, torch.Tensor):
|
||||
ret = ret.cpu().detach().numpy()
|
||||
ret = ret.astype("uint8")
|
||||
ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO)
|
||||
return ret
|
||||
|
||||
|
||||
def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=True):
|
||||
|
||||
print("Model Forwarding...")
|
||||
left = left.cpu().detach().numpy()
|
||||
imgL = left
|
||||
imgR = right.cpu().detach().numpy()
|
||||
imgL = np.ascontiguousarray(imgL[None, :, :, :])
|
||||
imgR = np.ascontiguousarray(imgR[None, :, :, :])
|
||||
|
||||
# chosen for convenience
|
||||
device = torch.device('cuda:0')
|
||||
|
||||
imgL = torch.tensor(imgL.astype("float32")).to(device)
|
||||
imgR = torch.tensor(imgR.astype("float32")).to(device)
|
||||
imgL = imgL.transpose(2,3).transpose(1,2)
|
||||
if imgL.shape != imgR.shape:
|
||||
imgR = imgR.transpose(2,3).transpose(1,2)
|
||||
|
||||
imgL_dw2 = F.interpolate(
|
||||
imgL,
|
||||
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
imgR_dw2 = F.interpolate(
|
||||
imgR,
|
||||
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
)
|
||||
with torch.inference_mode():
|
||||
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
|
||||
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
|
||||
|
||||
|
||||
if not wandb_log:
|
||||
return torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
|
||||
|
||||
log = {}
|
||||
in_h, in_w = left.shape[:2]
|
||||
|
||||
# Resize image in case the GPU memory overflows
|
||||
eval_h, eval_w = (in_h,in_w)
|
||||
|
||||
for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)):
|
||||
pred_disp = torch.squeeze(pf[:, 0, :, :]).cpu().detach().numpy()
|
||||
pred_disp_dw2 = torch.squeeze(pf_dw2[:, 0, :, :]).cpu().detach().numpy()
|
||||
|
||||
pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
|
||||
pred_disp_dw2_norm = cv2.normalize(pred_disp_dw2, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
|
||||
|
||||
if i == n_iter-1:
|
||||
t = float(in_w) / float(eval_w)
|
||||
disp = cv2.resize(pred_disp, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t
|
||||
|
||||
log[f'disp_vis'] = wandb.Image(
|
||||
normalize_and_colormap(disp),
|
||||
caption=f"Disparity \n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
|
||||
)
|
||||
|
||||
log[f'pred_{i}'] = wandb.Image(
|
||||
np.array([pred_disp.reshape(480, 640)]),
|
||||
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
|
||||
)
|
||||
log[f'pred_norm_{i}'] = wandb.Image(
|
||||
np.array([pred_disp_norm.reshape(480, 640)]),
|
||||
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
|
||||
)
|
||||
|
||||
log[f'pred_dw2_{i}'] = wandb.Image(
|
||||
np.array([pred_disp_dw2.reshape(240, 320)]),
|
||||
caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
|
||||
)
|
||||
log[f'pred_dw2_norm_{i}'] = wandb.Image(
|
||||
np.array([pred_disp_dw2_norm.reshape(240, 320)]),
|
||||
caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
|
||||
)
|
||||
|
||||
|
||||
log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left")
|
||||
log['input_right'] = wandb.Image(right.cpu().detach().numpy().transpose(1,2,0).astype('uint8'), caption="Input Right")
|
||||
|
||||
log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}")
|
||||
|
||||
disp_error = gt_disp - disp
|
||||
log['disp_error'] = wandb.Image(
|
||||
normalize_and_colormap(disp_error.abs()),
|
||||
caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{disp_error.abs().mean():.{2}f}",
|
||||
)
|
||||
|
||||
|
||||
log[f'gt_disp_vis'] = wandb.Image(
|
||||
normalize_and_colormap(gt_disp),
|
||||
caption=f"GT Disp Vis \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}",
|
||||
)
|
||||
|
||||
wandb.log(log)
|
||||
|
||||
|
||||
def parse_yaml(file_path: str) -> namedtuple:
|
||||
"""Parse yaml configuration file and return the object in `namedtuple`."""
|
||||
@ -65,6 +180,10 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8):
|
||||
'''
|
||||
n_predictions = len(flow_preds)
|
||||
flow_loss = 0.0
|
||||
|
||||
# TEST
|
||||
flow_gt = torch.squeeze(flow_gt, dim=-1)
|
||||
|
||||
for i in range(n_predictions):
|
||||
i_weight = gamma ** (n_predictions - i - 1)
|
||||
i_loss = torch.abs(flow_preds[i] - flow_gt)
|
||||
@ -93,7 +212,9 @@ def main(args):
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.1, betas=(0.9, 0.999))
|
||||
# model = nn.DataParallel(model,device_ids=[0])
|
||||
|
||||
tb_log = SummaryWriter(os.path.join(args.log_dir, "train.events"))
|
||||
# tb_log = SummaryWriter(os.path.join(args.log_dir, "train.events"))
|
||||
wandb.watch(model)
|
||||
metrics = {}
|
||||
|
||||
# worklog
|
||||
logging.basicConfig(level=eval(args.log_level))
|
||||
@ -138,8 +259,12 @@ def main(args):
|
||||
start_epoch_idx = 1
|
||||
start_iters = 0
|
||||
|
||||
# pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png'
|
||||
pattern_path = '/home/nils/kinect_reference_cropped.png'
|
||||
# pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png'
|
||||
# datasets
|
||||
dataset = CREStereoDataset(args.training_data_path)
|
||||
# dataset = CREStereoDataset(args.training_data_path)
|
||||
dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path)
|
||||
# if rank == 0:
|
||||
worklog.info(f"Dataset size: {len(dataset)}")
|
||||
dataloader = DataLoader(dataset, args.batch_size, shuffle=True,
|
||||
@ -183,6 +308,9 @@ def main(args):
|
||||
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
|
||||
|
||||
# forward
|
||||
# left = left.transpose(1, 2).transpose(2, 3)
|
||||
left = left.transpose(1, 3).transpose(2, 3)
|
||||
right = right.transpose(2, 3).transpose(2, 3)#.transpose(1, 2)
|
||||
flow_predictions = model(left.cuda(), right.cuda())
|
||||
|
||||
# loss & backword
|
||||
@ -190,6 +318,16 @@ def main(args):
|
||||
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
||||
)
|
||||
|
||||
if batch_idx % 128 == 0:
|
||||
inference(
|
||||
mini_batch_data['left'][0],
|
||||
mini_batch_data['right'][0],
|
||||
mini_batch_data['disparity'][0],
|
||||
mini_batch_data['mask'][0],
|
||||
model,
|
||||
batch_idx,
|
||||
)
|
||||
|
||||
# loss stats
|
||||
loss_item = loss.data.item()
|
||||
epoch_total_train_loss += loss_item
|
||||
@ -234,20 +372,25 @@ def main(args):
|
||||
worklog.info("".join(info))
|
||||
|
||||
# minibatch loss
|
||||
tb_log.add_scalar("train/loss_batch", loss_item, cur_iters)
|
||||
tb_log.add_scalar(
|
||||
"train/lr", optimizer.param_groups[0]["lr"], cur_iters
|
||||
)
|
||||
tb_log.flush()
|
||||
# tb_log.add_scalar("train/loss_batch", loss_item, cur_iters)
|
||||
metrics['train/loss_batch'] = loss_item
|
||||
# tb_log.add_scalar(
|
||||
# "train/lr", optimizer.param_groups[0]["lr"], cur_iters
|
||||
# )
|
||||
metrics['train/lr'] = optimizer.param_groups[0]["lr"]
|
||||
# tb_log.flush()
|
||||
wandb.log(metrics)
|
||||
|
||||
t1 = time.perf_counter()
|
||||
|
||||
tb_log.add_scalar(
|
||||
"train/loss",
|
||||
epoch_total_train_loss / args.minibatch_per_epoch,
|
||||
epoch_idx,
|
||||
)
|
||||
tb_log.flush()
|
||||
# tb_log.add_scalar(
|
||||
# "train/loss",
|
||||
# epoch_total_train_loss / args.minibatch_per_epoch,
|
||||
# epoch_idx,
|
||||
# )
|
||||
metrics['train/loss'] = epoch_total_train_loss / args.minibatch_per_epoch
|
||||
# tb_log.flush()
|
||||
wandb.log(metrics)
|
||||
|
||||
# save model params
|
||||
ckp_data = {
|
||||
@ -271,4 +414,6 @@ def main(args):
|
||||
if __name__ == "__main__":
|
||||
# train configuration
|
||||
args = parse_yaml("cfgs/train.yaml")
|
||||
wandb.init(project="crestereo", entity="cpt-captain")
|
||||
wandb.config.update(args._asdict())
|
||||
main(args)
|
||||
|
Loading…
Reference in New Issue
Block a user