add wandb, make compatible with ctd data

This commit is contained in:
Nils Koch 2022-05-30 16:13:06 +02:00
parent b333d5515f
commit 70e4bf6fe1
4 changed files with 471 additions and 44 deletions

View File

@ -1,4 +1,5 @@
import os
import random
import cv2
import glob
import numpy as np
@ -48,8 +49,8 @@ class Augmentor:
def __call__(self, left_img, right_img, left_disp):
# 1. chromatic augmentation
left_img = self.chromatic_augmentation(left_img)
right_img = self.chromatic_augmentation(right_img)
# left_img = self.chromatic_augmentation(left_img)
# right_img = self.chromatic_augmentation(right_img)
# 2. spatial augmentation
# 2.1) rotate & vertical shift for right image
@ -62,6 +63,7 @@ class Augmentor:
self.rng.uniform(0, right_img.shape[1]),
)
rot_mat = cv2.getRotationMatrix2D(image_center, ag, 1.0)
# right_img = right_img.transpose(2, 1, 0)
right_img = cv2.warpAffine(
right_img, rot_mat, right_img.shape[1::-1], flags=cv2.INTER_LINEAR
)
@ -69,6 +71,7 @@ class Augmentor:
right_img = cv2.warpAffine(
right_img, trans_mat, right_img.shape[1::-1], flags=cv2.INTER_LINEAR
)
# right_img = right_img.transpose(1, 2, 0)
# 2.2) random resize
resize_scale = self.rng.uniform(self.scale_min, self.scale_max)
@ -80,6 +83,10 @@ class Augmentor:
fy=resize_scale,
interpolation=cv2.INTER_LINEAR,
)
if len(left_img.shape) == 2:
left_img.shape += 1,
# left_img = cv2.merge([left_img, left_img, left_img])
right_img = cv2.resize(
right_img,
None,
@ -87,6 +94,9 @@ class Augmentor:
fy=resize_scale,
interpolation=cv2.INTER_LINEAR,
)
if len(left_img.shape) == 2:
left_img.shape += 1,
# left_img = cv2.merge([left_img, left_img, left_img])
disp_mask = (left_disp < float(self.max_disp / resize_scale)) & (left_disp > 0)
disp_mask = disp_mask.astype("float32")
@ -110,7 +120,11 @@ class Augmentor:
)
# 2.3) random crop
h, w, c = left_img.shape
if len(left_img.shape) == 3:
h, w, c = left_img.shape
else:
h, w = left_img.shape
c = 1
dx = w - self.image_width
dy = h - self.image_height
dy = self.rng.randint(min(0, dy), max(0, dy) + 1)
@ -144,7 +158,7 @@ class Augmentor:
(self.image_width, self.image_height),
flags=cv2.INTER_LINEAR,
borderValue=0,
)
)
# 3. add random occlusion to right image
if self.rng.binomial(1, 0.5):
@ -156,6 +170,9 @@ class Augmentor:
np.mean(right_img, 0), 0
)[np.newaxis, np.newaxis]
if len(left_img.shape) == 2:
left_img = cv2.merge([left_img, left_img, left_img])
return left_img, right_img, left_disp, disp_mask
@ -197,7 +214,8 @@ class CREStereoDataset(Dataset):
left_disp[left_disp == np.inf] = 0
# augmentaion
left_img, right_img, left_disp, disp_mask = self.augmentor(
# left_img, right_img, left_disp, disp_mask = self.augmentor(
_, _, left_disp, disp_mask = self.augmentor(
left_img, right_img, left_disp
)
@ -213,3 +231,96 @@ class CREStereoDataset(Dataset):
def __len__(self):
return len(self.imgs)
class CTDDataset(Dataset):
def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False):
super().__init__()
self.rng = np.random.RandomState(0)
self.augment = augment
self.blur = blur
self.imgs = glob.glob(os.path.join(root, f"{data_type}/*/im0_*.npy"), recursive=True)
self.pattern = cv2.imread(pattern_path)#, cv2.IMREAD_GRAYSCALE)
if resize_pattern and self.pattern.shape != (480, 640, 3):
# self.pattern = cv2.resize(self.pattern, (640, 480))
print(self.pattern.shape)
downsampled = cv2.pyrDown(self.pattern)
diff = (downsampled.shape[0] - 480) // 2
self.pattern = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
self.augmentor = Augmentor(
image_height=480,
image_width=640,
max_disp=256,
scale_min=0.6,
scale_max=1.0,
seed=0,
)
def get_disp(self, path):
# disp = cv2.imread(path, cv2.IMREAD_UNCHANGED)
disp = np.load(path).transpose(1,2,0)
# return disp.astype(np.float32) / 32
return disp
def __getitem__(self, index):
# find path
left_path = self.imgs[index]
prefix = left_path[: left_path.rfind("_")]
# right_path = prefix + "_right.jpg"
left_disp_path = left_path.replace('im', 'disp')
# right_disp_path = prefix + "_right.disp.png"
# read img, disp
# left_img = cv2.imread(left_path, cv2.IMREAD_COLOR)
# right_img = cv2.imread(right_path, cv2.IMREAD_COLOR)
# left_img = np.load(left_path).transpose(1,2,0)
left_img = np.load(left_path)
# left_img = cv2.cvtColor(left_img, cv2.COLOR_GRAY2RGB)
# FIXME DO WE NEED THIS? OTHERWISE IT's PRETTY DARK
left_img = cv2.normalize(left_img, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
# left_img = cv2.merge([left_img, left_img, left_img]).reshape((3, 480, 640))
left_img = cv2.merge([left_img, left_img, left_img]).reshape((480, 640, 3))
right_img = self.pattern
# right_img = cv2.merge([right_img, right_img, right_img]).reshape((3, 480, 640))
left_disp = self.get_disp(left_disp_path)
# right_disp = self.get_disp(right_disp_path)
if False: # self.rng.binomial(1, 0.5):
left_img, right_img = np.fliplr(right_img), np.fliplr(left_img)
left_disp, right_disp = np.fliplr(right_disp), np.fliplr(left_disp)
left_disp[left_disp == np.inf] = 0
if self.blur:
kernel_size = random.sample([1,3,5,7,9], 1)[0]
kernel = (kernel_size, kernel_size)
left_img = cv2.GaussianBlur(left_img, kernel, 0)
# augmentation
if not self.augment:
_left_img, _right_img, _left_disp, disp_mask = self.augmentor(
left_img, right_img, left_disp
)
else:
left_img, right_img, left_disp, disp_mask = self.augmentor(
left_img, right_img, left_disp
)
# left_img = left_img.transpose(2, 0, 1).astype("uint8")
# left_img = left_img.transpose(2, 0, 1).astype("float32")
# left_img = left_img.astype("float32")
right_img = right_img.transpose(2, 0, 1).astype("uint8")
# right_img = right_img.astype("float32")
# print('post_augment', left_img)
return {
"left": left_img,
"right": right_img,
"disparity": left_disp,
"mask": disp_mask,
}
def __len__(self):
return len(self.imgs)

View File

@ -155,6 +155,8 @@ class CREStereo(nn.Module):
flow = None
flow_up = None
if flow_init is not None:
if isinstance(flow_init, list):
flow_init = flow_init[0]
scale = fmap1.shape[2] / flow_init.shape[2]
flow = -scale * F.interpolate(
flow_init,

View File

@ -1,12 +1,22 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
from imread_from_url import imread_from_url
from nets import Model
import wandb
import random
from torch.utils.data import DataLoader
from dataset import CTDDataset
from train import normalize_and_colormap, parse_yaml, inference as ctd_inference
device = 'cuda'
wandb.init(project="crestereo", entity="cpt-captain")
#Ref: https://github.com/megvii-research/CREStereo/blob/master/test.py
def inference(left, right, model, n_iter=20):
@ -41,39 +51,198 @@ def inference(left, right, model, n_iter=20):
return pred_disp
def inference_ctd(left, right, gt_disp, mask, model, epoch, n_iter=20):
print("Model Forwarding...")
# print(left.shape)
left = left.cpu().detach().numpy()
imgL = left
imgR = right.cpu().detach().numpy()
imgL = np.ascontiguousarray(imgL[None, :, :, :])
imgR = np.ascontiguousarray(imgR[None, :, :, :])
# chosen for convenience
device = torch.device('cuda:0')
imgL = torch.tensor(imgL.astype("float32")).to(device)
imgR = torch.tensor(imgR.astype("float32")).to(device)
imgL = imgL.transpose(2,3).transpose(1,2)
imgL_dw2 = F.interpolate(
imgL,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
)
imgR_dw2 = F.interpolate(
imgR,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
)
with torch.inference_mode():
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
log = {}
for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)):
pred_disp = torch.squeeze(pf[:, 0, :, :]).cpu().detach().numpy()
pred_disp_dw2 = torch.squeeze(pf_dw2[:, 0, :, :]).cpu().detach().numpy()
pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
pred_disp_dw2_norm = cv2.normalize(pred_disp_dw2, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
log[f'pred_{i}'] = wandb.Image(
np.array([pred_disp.reshape(480, 640)]),
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
)
log[f'pred_norm_{i}'] = wandb.Image(
np.array([pred_disp_norm.reshape(480, 640)]),
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
)
log[f'pred_dw2_{i}'] = wandb.Image(
np.array([pred_disp_dw2.reshape(240, 320)]),
caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
)
log[f'pred_dw2_norm_{i}'] = wandb.Image(
np.array([pred_disp_dw2_norm.reshape(240, 320)]),
caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
)
log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left")
log['input_right'] = wandb.Image(right.cpu().detach().numpy().transpose(1,2,0).astype('uint8'), caption="Input Right")
log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}")
disp_error = gt_disp - disp
log['disp_error'] = wandb.Image(
normalize_and_colormap(disp_error),
caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{disp_error.mean():.{2}f}",
)
wandb.log(log)
def do_infer(left_img, right_img, gt_disp, model):
in_h, in_w = left_img.shape[:2]
# Resize image in case the GPU memory overflows
eval_h, eval_w = (in_h,in_w)
# FIXME borked for some reason, hopefully not very important
imgL = left_img.cpu().detach().numpy() if isinstance(left_img, torch.Tensor) else left_img
imgR = right_img.cpu().detach().numpy() if isinstance(right_img, torch.Tensor) else right_img
imgL = cv2.resize(imgL, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
imgR = cv2.resize(imgR, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
# pred = ctd_inference(imgL, imgR, gt_disp, None, model, None, n_iter=20)
pred = ctd_inference(left_img, right_img, gt_disp, None, model, None, n_iter=20, wandb_log=False)
t = float(in_w) / float(eval_w)
disp = cv2.resize(pred, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t
disp_vis = normalize_and_colormap(disp)
gt_disp_vis = normalize_and_colormap(gt_disp)
if gt_disp.shape != disp.shape:
gt_disp = gt_disp.reshape(disp.shape)
disp_err = gt_disp - disp
disp_err = normalize_and_colormap(disp_err.abs())
wandb.log({
'disp_vis': wandb.Image(
disp_vis,
caption=f"Pred. Disparity \n{disp.min():.{2}f}/{disp.max():.{2}f}",
),
'gt_disp_vis': wandb.Image(
gt_disp_vis,
caption=f"GT Disparity \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}",
),
'disp_err': wandb.Image(
disp_err,
caption=f"Disparity Error\n{disp_err.min():.{2}f}/{disp_err.max():.{2}f}",
),
'input_left': wandb.Image(
left_img.cpu().detach().numpy().astype('uint8'),
caption=f"Input left",
),
'input_right': wandb.Image(
right_img.cpu().detach().numpy().astype('uint8'),
caption=f"Input right",
),
})
if __name__ == '__main__':
# model_path = "models/crestereo_eth3d.pth"
model_path = "train_log/models/latest.pth"
left_img = imread_from_url("https://raw.githubusercontent.com/megvii-research/CREStereo/master/img/test/left.png")
right_img = imread_from_url("https://raw.githubusercontent.com/megvii-research/CREStereo/master/img/test/right.png")
# reference_pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png'
reference_pattern_path = '/home/nils/kinect_reference_cropped.png'
# reference_pattern_path = '/home/nils/new_reference.png'
# reference_pattern_path = '/home/nils/kinect_reference_high_res.png'
# reference_pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png'
in_h, in_w = left_img.shape[:2]
data_type = 'kinect'
augment = False
# Resize image in case the GPU memory overflows
eval_h, eval_w = (in_h,in_w)
imgL = cv2.resize(left_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
imgR = cv2.resize(right_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
args = parse_yaml("cfgs/train.yaml")
model_path = "models/crestereo_eth3d.pth"
wandb.config.update({'model_path': model_path, 'reference_pattern': reference_pattern_path, 'augment': augment})
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
model.load_state_dict(torch.load(model_path), strict=True)
model.to(device)
model.eval()
model = Model(max_disp=256, mixed_precision=False, test_mode=True)
model = nn.DataParallel(model,device_ids=[device])
# model.load_state_dict(torch.load(model_path), strict=False)
state_dict = torch.load(model_path)['state_dict']
model.load_state_dict(state_dict, strict=True)
model.to(device)
model.eval()
pred = inference(imgL, imgR, model, n_iter=20)
CTD = True
if not CTD:
left_img = cv2.imread("../test_imgs/left.png")
right_img = cv2.imread("../test_imgs/right.png")
in_h, in_w = left_img.shape[:2]
t = float(in_w) / float(eval_w)
disp = cv2.resize(pred, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t
# Resize image in case the GPU memory overflows
eval_h, eval_w = (in_h,in_w)
disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0
disp_vis = disp_vis.astype("uint8")
disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO)
# FIXME borked for some reason, hopefully not very important
imgL = cv2.resize(left_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
imgR = cv2.resize(right_img, (eval_w, eval_h), interpolation=cv2.INTER_LINEAR)
combined_img = np.hstack((left_img, disp_vis))
cv2.namedWindow("output", cv2.WINDOW_NORMAL)
cv2.imshow("output", combined_img)
cv2.imwrite("output.jpg", disp_vis)
cv2.waitKey(0)
pred = inference(imgL, imgR, model, n_iter=20)
t = float(in_w) / float(eval_w)
disp = cv2.resize(pred, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t
disp_vis = (disp - disp.min()) / (disp.max() - disp.min()) * 255.0
disp_vis = disp_vis.astype("uint8")
disp_vis = cv2.applyColorMap(disp_vis, cv2.COLORMAP_INFERNO)
combined_img = np.hstack((left_img, disp_vis))
# cv2.namedWindow("output", cv2.WINDOW_NORMAL)
# cv2.imshow("output", combined_img)
cv2.imwrite("output.jpg", disp_vis)
# cv2.waitKey(0)
else:
dataset = CTDDataset('/media/Data1/connecting_the_dots_data/ctd_data/', data_type=data_type, pattern_path=reference_pattern_path, augment=augment)
dataloader = DataLoader(dataset, args.batch_size, shuffle=True,
num_workers=0, drop_last=False, persistent_workers=False, pin_memory=True)
for batch in dataloader:
for left, right, disparity in zip(batch['left'], batch['right'], batch['disparity']):
right = right.transpose(0, 2).transpose(0, 1)
left_img = left
imgL = left.cpu().detach().numpy()
right_img = right
imgR = right.cpu().detach().numpy()
gt_disp = disparity
do_infer(left_img, right_img, gt_disp, model)

175
train.py
View File

@ -5,16 +5,131 @@ import logging
from collections import namedtuple
import yaml
from tensorboardX import SummaryWriter
# from tensorboardX import SummaryWriter
from nets import Model
from dataset import CREStereoDataset
# from dataset import CREStereoDataset
from dataset import CREStereoDataset, CTDDataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
import wandb
import numpy as np
import cv2
def normalize_and_colormap(img):
ret = (img - img.min()) / (img.max() - img.min()) * 255.0
if isinstance(ret, torch.Tensor):
ret = ret.cpu().detach().numpy()
ret = ret.astype("uint8")
ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO)
return ret
def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=True):
print("Model Forwarding...")
left = left.cpu().detach().numpy()
imgL = left
imgR = right.cpu().detach().numpy()
imgL = np.ascontiguousarray(imgL[None, :, :, :])
imgR = np.ascontiguousarray(imgR[None, :, :, :])
# chosen for convenience
device = torch.device('cuda:0')
imgL = torch.tensor(imgL.astype("float32")).to(device)
imgR = torch.tensor(imgR.astype("float32")).to(device)
imgL = imgL.transpose(2,3).transpose(1,2)
if imgL.shape != imgR.shape:
imgR = imgR.transpose(2,3).transpose(1,2)
imgL_dw2 = F.interpolate(
imgL,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
)
imgR_dw2 = F.interpolate(
imgR,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
)
with torch.inference_mode():
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=None)
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2)
if not wandb_log:
return torch.squeeze(pred_flow[:, 0, :, :]).cpu().detach().numpy()
log = {}
in_h, in_w = left.shape[:2]
# Resize image in case the GPU memory overflows
eval_h, eval_w = (in_h,in_w)
for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)):
pred_disp = torch.squeeze(pf[:, 0, :, :]).cpu().detach().numpy()
pred_disp_dw2 = torch.squeeze(pf_dw2[:, 0, :, :]).cpu().detach().numpy()
pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
pred_disp_dw2_norm = cv2.normalize(pred_disp_dw2, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
if i == n_iter-1:
t = float(in_w) / float(eval_w)
disp = cv2.resize(pred_disp, (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t
log[f'disp_vis'] = wandb.Image(
normalize_and_colormap(disp),
caption=f"Disparity \n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
)
log[f'pred_{i}'] = wandb.Image(
np.array([pred_disp.reshape(480, 640)]),
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
)
log[f'pred_norm_{i}'] = wandb.Image(
np.array([pred_disp_norm.reshape(480, 640)]),
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
)
log[f'pred_dw2_{i}'] = wandb.Image(
np.array([pred_disp_dw2.reshape(240, 320)]),
caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
)
log[f'pred_dw2_norm_{i}'] = wandb.Image(
np.array([pred_disp_dw2_norm.reshape(240, 320)]),
caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
)
log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left")
log['input_right'] = wandb.Image(right.cpu().detach().numpy().transpose(1,2,0).astype('uint8'), caption="Input Right")
log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}")
disp_error = gt_disp - disp
log['disp_error'] = wandb.Image(
normalize_and_colormap(disp_error.abs()),
caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{disp_error.abs().mean():.{2}f}",
)
log[f'gt_disp_vis'] = wandb.Image(
normalize_and_colormap(gt_disp),
caption=f"GT Disp Vis \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}",
)
wandb.log(log)
def parse_yaml(file_path: str) -> namedtuple:
"""Parse yaml configuration file and return the object in `namedtuple`."""
@ -65,6 +180,10 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8):
'''
n_predictions = len(flow_preds)
flow_loss = 0.0
# TEST
flow_gt = torch.squeeze(flow_gt, dim=-1)
for i in range(n_predictions):
i_weight = gamma ** (n_predictions - i - 1)
i_loss = torch.abs(flow_preds[i] - flow_gt)
@ -93,7 +212,9 @@ def main(args):
optimizer = optim.Adam(model.parameters(), lr=0.1, betas=(0.9, 0.999))
# model = nn.DataParallel(model,device_ids=[0])
tb_log = SummaryWriter(os.path.join(args.log_dir, "train.events"))
# tb_log = SummaryWriter(os.path.join(args.log_dir, "train.events"))
wandb.watch(model)
metrics = {}
# worklog
logging.basicConfig(level=eval(args.log_level))
@ -138,8 +259,12 @@ def main(args):
start_epoch_idx = 1
start_iters = 0
# pattern_path = '/home/nils/kinect_reference_high_res_scaled_down.png'
pattern_path = '/home/nils/kinect_reference_cropped.png'
# pattern_path = '/home/nils/orig_ctd/connecting_the_dots/data/kinect_pattern.png'
# datasets
dataset = CREStereoDataset(args.training_data_path)
# dataset = CREStereoDataset(args.training_data_path)
dataset = CTDDataset(args.training_data_path, pattern_path=pattern_path)
# if rank == 0:
worklog.info(f"Dataset size: {len(dataset)}")
dataloader = DataLoader(dataset, args.batch_size, shuffle=True,
@ -183,6 +308,9 @@ def main(args):
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
# forward
# left = left.transpose(1, 2).transpose(2, 3)
left = left.transpose(1, 3).transpose(2, 3)
right = right.transpose(2, 3).transpose(2, 3)#.transpose(1, 2)
flow_predictions = model(left.cuda(), right.cuda())
# loss & backword
@ -190,6 +318,16 @@ def main(args):
flow_predictions, gt_flow, valid_mask, gamma=0.8
)
if batch_idx % 128 == 0:
inference(
mini_batch_data['left'][0],
mini_batch_data['right'][0],
mini_batch_data['disparity'][0],
mini_batch_data['mask'][0],
model,
batch_idx,
)
# loss stats
loss_item = loss.data.item()
epoch_total_train_loss += loss_item
@ -234,20 +372,25 @@ def main(args):
worklog.info("".join(info))
# minibatch loss
tb_log.add_scalar("train/loss_batch", loss_item, cur_iters)
tb_log.add_scalar(
"train/lr", optimizer.param_groups[0]["lr"], cur_iters
)
tb_log.flush()
# tb_log.add_scalar("train/loss_batch", loss_item, cur_iters)
metrics['train/loss_batch'] = loss_item
# tb_log.add_scalar(
# "train/lr", optimizer.param_groups[0]["lr"], cur_iters
# )
metrics['train/lr'] = optimizer.param_groups[0]["lr"]
# tb_log.flush()
wandb.log(metrics)
t1 = time.perf_counter()
tb_log.add_scalar(
"train/loss",
epoch_total_train_loss / args.minibatch_per_epoch,
epoch_idx,
)
tb_log.flush()
# tb_log.add_scalar(
# "train/loss",
# epoch_total_train_loss / args.minibatch_per_epoch,
# epoch_idx,
# )
metrics['train/loss'] = epoch_total_train_loss / args.minibatch_per_epoch
# tb_log.flush()
wandb.log(metrics)
# save model params
ckp_data = {
@ -271,4 +414,6 @@ def main(args):
if __name__ == "__main__":
# train configuration
args = parse_yaml("cfgs/train.yaml")
wandb.init(project="crestereo", entity="cpt-captain")
wandb.config.update(args._asdict())
main(args)