finish lightningification\n\nTraining still seems borked
This commit is contained in:
parent
0e2a4b2340
commit
d8169e01bc
@ -3,7 +3,7 @@ mixed_precision: false
|
||||
base_lr: 4.0e-4
|
||||
|
||||
nr_gpus: 3
|
||||
batch_size: 4
|
||||
batch_size: 2
|
||||
n_total_epoch: 300
|
||||
minibatch_per_epoch: 500
|
||||
|
||||
|
@ -22,9 +22,10 @@ except:
|
||||
|
||||
#Ref: https://github.com/princeton-vl/RAFT/blob/master/core/raft.py
|
||||
class CREStereo(nn.Module):
|
||||
def __init__(self, max_disp=192, mixed_precision=False, test_mode=False):
|
||||
def __init__(self, max_disp=192, mixed_precision=False, test_mode=False, batch_size=4):
|
||||
super(CREStereo, self).__init__()
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.max_flow = max_disp
|
||||
self.mixed_precision = mixed_precision
|
||||
self.test_mode = test_mode
|
||||
@ -37,10 +38,10 @@ class CREStereo(nn.Module):
|
||||
self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4)
|
||||
|
||||
# # NOTE Position_encoding as workaround for TensorRt
|
||||
# image1_shape = [1, 2, 480, 640]
|
||||
# self.pos_encoding_fn_small = PositionEncodingSine(
|
||||
# d_model=256, max_shape=(image1_shape[2] // 16, image1_shape[3] // 16)
|
||||
# )
|
||||
image1_shape = [1, 2, 480, 640]
|
||||
self.pos_encoding_fn_small = PositionEncodingSine(
|
||||
d_model=256, max_shape=(image1_shape[2] // 16, image1_shape[3] // 16)
|
||||
)
|
||||
|
||||
# loftr
|
||||
self.self_att_fn = LocalFeatureTransformer(
|
||||
@ -136,9 +137,9 @@ class CREStereo(nn.Module):
|
||||
inp_dw16 = F.avg_pool2d(inp, 4, stride=4)
|
||||
|
||||
# positional encoding and self-attention
|
||||
# pos_encoding_fn_small = PositionEncodingSine(
|
||||
# d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
|
||||
# )
|
||||
pos_encoding_fn_small = PositionEncodingSine(
|
||||
d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
|
||||
)
|
||||
# 'n c h w -> n (h w) c'
|
||||
x_tmp = self.pos_encoding_fn_small(fmap1_dw16)
|
||||
fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
|
||||
|
@ -111,12 +111,6 @@ class BasicEncoder(nn.Module):
|
||||
else:
|
||||
x_tensor = x
|
||||
|
||||
print()
|
||||
print()
|
||||
print(x_tensor.shape)
|
||||
print()
|
||||
print()
|
||||
|
||||
x_tensor = self.conv1(x_tensor)
|
||||
x_tensor = self.norm1(x_tensor)
|
||||
x_tensor = self.relu1(x_tensor)
|
||||
|
@ -16,10 +16,10 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
from pytorch_lightning.lite import LightningLite
|
||||
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
|
||||
from pytorch_lightning import Trainer, seed_everything
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
||||
|
||||
seed_everything(42, workers=True)
|
||||
|
||||
@ -39,148 +39,44 @@ def normalize_and_colormap(img):
|
||||
return ret
|
||||
|
||||
|
||||
def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=True, last_img=None, test=False, attend_pattern=True):
|
||||
|
||||
print("Model Forwarding...")
|
||||
if isinstance(left, torch.Tensor):
|
||||
left = left# .cpu().detach().numpy()
|
||||
imgR = right# .cpu().detach().numpy()
|
||||
imgL = left
|
||||
imgR = right
|
||||
imgL = np.ascontiguousarray(imgL[None, :, :, :])
|
||||
imgR = np.ascontiguousarray(imgR[None, :, :, :])
|
||||
|
||||
flow_init = None
|
||||
|
||||
# chosen for convenience
|
||||
|
||||
imgL = torch.tensor(imgL.astype("float32"))
|
||||
imgR = torch.tensor(imgR.astype("float32"))
|
||||
imgL = imgL.transpose(2,3).transpose(1,2)
|
||||
if imgL.shape != imgR.shape:
|
||||
imgR = imgR.transpose(2,3).transpose(1,2)
|
||||
|
||||
imgL_dw2 = F.interpolate(
|
||||
imgL,
|
||||
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
).clamp(min=0, max=255)
|
||||
imgR_dw2 = F.interpolate(
|
||||
imgR,
|
||||
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
|
||||
mode="bilinear",
|
||||
align_corners=True,
|
||||
).clamp(min=0, max=255)
|
||||
if last_img is not None:
|
||||
print('using flow_initialization')
|
||||
print(last_img.shape)
|
||||
# FIXME trying 8bit normalization because sometimes we get really high values that don't seem to help
|
||||
print(last_img.max(), last_img.min())
|
||||
if last_img.min() < 0:
|
||||
# print('Negative disparity detected. shifting...')
|
||||
last_img = last_img - last_img.min()
|
||||
if last_img.max() > 255:
|
||||
# print('Excessive disparity detected. scaling...')
|
||||
last_img = last_img / (last_img.max() / 255)
|
||||
|
||||
|
||||
last_img = np.dstack([last_img, last_img])
|
||||
# last_img = np.dstack([last_img, last_img, last_img])
|
||||
last_img = np.dstack([last_img])
|
||||
last_img = last_img.reshape((1, 2, 480, 640))
|
||||
# print(last_img.shape)
|
||||
# print(last_img.dtype)
|
||||
# print(last_img.max(), last_img.min())
|
||||
flow_init = torch.tensor(last_img.astype("float32"))
|
||||
# flow_init = F.interpolate(
|
||||
# last_img,
|
||||
# size=(last_img.shape[0] // 2, last_img.shape[1] // 2),
|
||||
# mode="bilinear",
|
||||
# align_corners=True,
|
||||
# )
|
||||
with torch.inference_mode():
|
||||
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=flow_init, self_attend_right=attend_pattern)
|
||||
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2, self_attend_right=attend_pattern)
|
||||
pf_base = pred_flow
|
||||
if isinstance(pf_base, list):
|
||||
pf_base = pred_flow[0]
|
||||
pf = torch.squeeze(pf_base[:, 0, :, :])# .cpu().detach().numpy()
|
||||
print('pred_flow max min')
|
||||
print(pf.max(), pf.min())
|
||||
|
||||
|
||||
if not wandb_log:
|
||||
if test:
|
||||
return pred_flow
|
||||
return torch.squeeze(pred_flow[:, 0, :, :])# .cpu().detach().numpy()
|
||||
|
||||
def log_images(left, right, pred_disp, gt_disp, wandb_logger=None):
|
||||
# wandb_logger.log_text('test')
|
||||
# return
|
||||
log = {}
|
||||
in_h, in_w = left.shape[:2]
|
||||
batch_idx = 1
|
||||
|
||||
# Resize image in case the GPU memory overflows
|
||||
eval_h, eval_w = (in_h,in_w)
|
||||
if isinstance(pred_disp, list):
|
||||
pred_disp = pred_disp[-1]
|
||||
|
||||
for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)):
|
||||
pred_disp = torch.squeeze(pf[:, 0, :, :])# .cpu().detach().numpy()
|
||||
pred_disp_dw2 = torch.squeeze(pf_dw2[:, 0, :, :])# .cpu().detach().numpy()
|
||||
|
||||
# pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
|
||||
# pred_disp_dw2_norm = cv2.normalize(pred_disp_dw2, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U)
|
||||
|
||||
if i == n_iter-1:
|
||||
t = float(in_w) / float(eval_w)
|
||||
disp = cv2.resize(pred_disp.cpu().detach().numpy(), (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t
|
||||
|
||||
log[f'disp_vis'] = wandb.Image(
|
||||
normalize_and_colormap(disp),
|
||||
caption=f"Disparity \n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
|
||||
)
|
||||
|
||||
log[f'pred_{i}'] = wandb.Image(
|
||||
np.array([pred_disp.cpu().detach().numpy().reshape(480, 640)]),
|
||||
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
|
||||
)
|
||||
# log[f'pred_norm_{i}'] = wandb.Image(
|
||||
# np.array([pred_disp_norm.reshape(480, 640)]),
|
||||
# caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
|
||||
# )
|
||||
|
||||
# log[f'pred_dw2_{i}'] = wandb.Image(
|
||||
# np.array([pred_disp_dw2.reshape(240, 320)]),
|
||||
# caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
|
||||
# )
|
||||
# log[f'pred_dw2_norm_{i}'] = wandb.Image(
|
||||
# np.array([pred_disp_dw2_norm.reshape(240, 320)]),
|
||||
# caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
|
||||
# )
|
||||
|
||||
|
||||
log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left")
|
||||
input_right = right# .cpu().detach().numpy() if isinstance(right, torch.Tensor) else right
|
||||
if input_right.shape != (480, 640, 3):
|
||||
input_right.transpose(1,2,0)
|
||||
log['input_right'] = wandb.Image(input_right.astype('uint8'), caption="Input Right")
|
||||
|
||||
log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}")
|
||||
|
||||
gt_disp = gt_disp# .cpu().detach().numpy() if isinstance(gt_disp, torch.Tensor) else gt_disp
|
||||
disp = disp# .cpu().detach().numpy() if isinstance(disp, torch.Tensor) else disp
|
||||
pred_disp = torch.squeeze(pred_disp[:, 0, :, :])
|
||||
gt_disp = torch.squeeze(gt_disp[:, 0, :, :])
|
||||
left = torch.squeeze(left[:, 0, :, :])
|
||||
right = torch.squeeze(right[:, 0, :, :])
|
||||
|
||||
disp = pred_disp
|
||||
disp_error = gt_disp - disp
|
||||
log['disp_error'] = wandb.Image(
|
||||
normalize_and_colormap(abs(disp_error)),
|
||||
caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{abs(disp_error).mean():.{2}f}",
|
||||
)
|
||||
|
||||
input_left = left[batch_idx].cpu().detach().numpy()# .transpose(1,2,0)
|
||||
input_right = right[batch_idx].cpu().detach().numpy()# .transpose(1,2,0)
|
||||
|
||||
log[f'gt_disp_vis'] = wandb.Image(
|
||||
normalize_and_colormap(gt_disp),
|
||||
caption=f"GT Disp Vis \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}",
|
||||
)
|
||||
|
||||
wandb.log(log)
|
||||
return pred_flow
|
||||
wandb_log = dict(
|
||||
key='samples',
|
||||
images=[
|
||||
normalize_and_colormap(pred_disp[batch_idx]),
|
||||
normalize_and_colormap(abs(disp_error[batch_idx])),
|
||||
normalize_and_colormap(gt_disp[batch_idx]),
|
||||
input_left,
|
||||
input_right,
|
||||
],
|
||||
caption=[
|
||||
f"Disparity \n{pred_disp[batch_idx].min():.{2}f}/{pred_disp[batch_idx].max():.{2}f}",
|
||||
f"Disp. Error\n{disp_error[batch_idx].min():.{2}f}/{disp_error[batch_idx].max():.{2}f}\n{abs(disp_error[batch_idx]).mean():.{2}f}",
|
||||
f"GT Disp Vis \n{gt_disp[batch_idx].min():.{2}f}/{gt_disp[batch_idx].max():.{2}f}",
|
||||
"Input Left",
|
||||
"Input Right"
|
||||
],
|
||||
)
|
||||
return wandb_log
|
||||
|
||||
|
||||
def parse_yaml(file_path: str) -> namedtuple:
|
||||
@ -259,9 +155,10 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
|
||||
|
||||
|
||||
class CREStereoLightning(LightningModule):
|
||||
def __init__(self, args):
|
||||
def __init__(self, args, logger):
|
||||
super().__init__()
|
||||
self.batch_size = args.batch_size
|
||||
self.wandb_logger = logger
|
||||
self.model = Model(
|
||||
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False
|
||||
)
|
||||
@ -270,13 +167,10 @@ class CREStereoLightning(LightningModule):
|
||||
return self.model(image1, image2, flow_init, iters, upsample, test_mode, self_attend_right)
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
# loss = self(batch)
|
||||
left, right, gt_disp, valid_mask = batch
|
||||
left = torch.Tensor(left).to(self.device)
|
||||
right = torch.Tensor(right).to(self.device)
|
||||
left = left
|
||||
right = right
|
||||
flow_predictions = self.forward(left, right)
|
||||
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
|
||||
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
|
||||
loss = sequence_loss(
|
||||
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
||||
)
|
||||
@ -285,62 +179,91 @@ class CREStereoLightning(LightningModule):
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
left, right, gt_disp, valid_mask = batch
|
||||
left = torch.Tensor(left).to(self.device)
|
||||
right = torch.Tensor(right).to(self.device)
|
||||
print(left.shape)
|
||||
print(right.shape)
|
||||
flow_predictions = self.forward(left, right)
|
||||
flow_predictions = self.forward(left, right, test_mode=True)
|
||||
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
|
||||
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
|
||||
val_loss = sequence_loss(
|
||||
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
||||
)
|
||||
self.log("val_loss", val_loss)
|
||||
if batch_idx % 4 == 0:
|
||||
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp))
|
||||
|
||||
def test_step(self, batch, batch_idx):
|
||||
left, right, gt_disp, valid_mask = batch
|
||||
# left, right, gt_disp, valid_mask = (
|
||||
# batch["left"],
|
||||
# batch["right"],
|
||||
# batch["disparity"],
|
||||
# batch["mask"],
|
||||
# )
|
||||
left = torch.Tensor(left).to(self.device)
|
||||
right = torch.Tensor(right).to(self.device)
|
||||
flow_predictions = self.forward(left, right)
|
||||
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
|
||||
flow_predictions = self.forward(left, right, test_mode=True)
|
||||
test_loss = sequence_loss(
|
||||
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
||||
)
|
||||
self.log("test_loss", test_loss)
|
||||
print('test_batch_idx:', batch_idx)
|
||||
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp))
|
||||
|
||||
def configure_optimizers(self):
|
||||
return optim.Adam(self.model.parameters(), lr=0.01, betas=(0.9, 0.999))
|
||||
return optim.Adam(self.model.parameters(), lr=0.1, betas=(0.9, 0.999))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# train configuration
|
||||
args = parse_yaml("cfgs/train.yaml")
|
||||
# wandb.init(project="crestereo-lightning", entity="cpt-captain")
|
||||
# Lite(strategy='dp', accelerator='gpu', devices=2).run(args)
|
||||
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
|
||||
model = CREStereoLightning(args)
|
||||
dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path, use_lightning=True)
|
||||
test_dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path, test_set=True, use_lightning=True)
|
||||
print(len(dataset))
|
||||
print(len(test_dataset))
|
||||
|
||||
wandb_logger = WandbLogger(project="crestereo-lightning")
|
||||
wandb.config.update(args._asdict())
|
||||
|
||||
model = CREStereoLightning(args, wandb_logger)
|
||||
|
||||
dataset = BlenderDataset(
|
||||
root=args.training_data_path,
|
||||
pattern_path=pattern_path,
|
||||
use_lightning=True,
|
||||
)
|
||||
test_dataset = BlenderDataset(
|
||||
root=args.training_data_path,
|
||||
pattern_path=pattern_path,
|
||||
test_set=True,
|
||||
use_lightning=True,
|
||||
)
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
args.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=16,
|
||||
drop_last=True,
|
||||
persistent_workers=True,
|
||||
pin_memory=True,
|
||||
)
|
||||
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
|
||||
test_dataloader = DataLoader(
|
||||
test_dataset,
|
||||
args.batch_size,
|
||||
shuffle=False,
|
||||
num_workers=16,
|
||||
drop_last=False,
|
||||
persistent_workers=True,
|
||||
pin_memory=True
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
max_epochs=args.n_total_epoch,
|
||||
accelerator='gpu',
|
||||
devices=2,
|
||||
# auto_scale_batch_size='binsearch',
|
||||
# strategy='ddp',
|
||||
max_epochs=args.n_total_epoch,
|
||||
callbacks=[
|
||||
EarlyStopping(
|
||||
monitor="val_loss",
|
||||
mode="min",
|
||||
patience=4,
|
||||
)
|
||||
],
|
||||
accumulate_grad_batches=8,
|
||||
deterministic=True,
|
||||
check_val_every_n_epoch=1,
|
||||
limit_val_batches=24,
|
||||
limit_test_batches=24,
|
||||
logger=wandb_logger,
|
||||
default_root_dir=args.log_dir_lightning,
|
||||
)
|
||||
# trainer.tune(model)
|
||||
trainer.fit(model, dataset, test_dataset)
|
||||
)
|
||||
|
||||
trainer.fit(model, dataloader, test_dataloader)
|
||||
|
Loading…
Reference in New Issue
Block a user