finish lightningification\n\nTraining still seems borked

main
Cpt.Captain 2 years ago
parent 0e2a4b2340
commit d8169e01bc
  1. 2
      cfgs/train.yaml
  2. 17
      nets/crestereo.py
  3. 6
      nets/extractor.py
  4. 261
      train_lightning.py

@ -3,7 +3,7 @@ mixed_precision: false
base_lr: 4.0e-4 base_lr: 4.0e-4
nr_gpus: 3 nr_gpus: 3
batch_size: 4 batch_size: 2
n_total_epoch: 300 n_total_epoch: 300
minibatch_per_epoch: 500 minibatch_per_epoch: 500

@ -22,9 +22,10 @@ except:
#Ref: https://github.com/princeton-vl/RAFT/blob/master/core/raft.py #Ref: https://github.com/princeton-vl/RAFT/blob/master/core/raft.py
class CREStereo(nn.Module): class CREStereo(nn.Module):
def __init__(self, max_disp=192, mixed_precision=False, test_mode=False): def __init__(self, max_disp=192, mixed_precision=False, test_mode=False, batch_size=4):
super(CREStereo, self).__init__() super(CREStereo, self).__init__()
self.batch_size = batch_size
self.max_flow = max_disp self.max_flow = max_disp
self.mixed_precision = mixed_precision self.mixed_precision = mixed_precision
self.test_mode = test_mode self.test_mode = test_mode
@ -37,10 +38,10 @@ class CREStereo(nn.Module):
self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4) self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4)
# # NOTE Position_encoding as workaround for TensorRt # # NOTE Position_encoding as workaround for TensorRt
# image1_shape = [1, 2, 480, 640] image1_shape = [1, 2, 480, 640]
# self.pos_encoding_fn_small = PositionEncodingSine( self.pos_encoding_fn_small = PositionEncodingSine(
# d_model=256, max_shape=(image1_shape[2] // 16, image1_shape[3] // 16) d_model=256, max_shape=(image1_shape[2] // 16, image1_shape[3] // 16)
# ) )
# loftr # loftr
self.self_att_fn = LocalFeatureTransformer( self.self_att_fn = LocalFeatureTransformer(
@ -136,9 +137,9 @@ class CREStereo(nn.Module):
inp_dw16 = F.avg_pool2d(inp, 4, stride=4) inp_dw16 = F.avg_pool2d(inp, 4, stride=4)
# positional encoding and self-attention # positional encoding and self-attention
# pos_encoding_fn_small = PositionEncodingSine( pos_encoding_fn_small = PositionEncodingSine(
# d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16) d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
# ) )
# 'n c h w -> n (h w) c' # 'n c h w -> n (h w) c'
x_tmp = self.pos_encoding_fn_small(fmap1_dw16) x_tmp = self.pos_encoding_fn_small(fmap1_dw16)
fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1]) fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])

@ -111,12 +111,6 @@ class BasicEncoder(nn.Module):
else: else:
x_tensor = x x_tensor = x
print()
print()
print(x_tensor.shape)
print()
print()
x_tensor = self.conv1(x_tensor) x_tensor = self.conv1(x_tensor)
x_tensor = self.norm1(x_tensor) x_tensor = self.norm1(x_tensor)
x_tensor = self.relu1(x_tensor) x_tensor = self.relu1(x_tensor)

@ -16,10 +16,10 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from pytorch_lightning.lite import LightningLite
from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning import Trainer, seed_everything from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
seed_everything(42, workers=True) seed_everything(42, workers=True)
@ -39,148 +39,44 @@ def normalize_and_colormap(img):
return ret return ret
def inference(left, right, gt_disp, mask, model, epoch, n_iter=20, wandb_log=True, last_img=None, test=False, attend_pattern=True): def log_images(left, right, pred_disp, gt_disp, wandb_logger=None):
# wandb_logger.log_text('test')
print("Model Forwarding...") # return
if isinstance(left, torch.Tensor):
left = left# .cpu().detach().numpy()
imgR = right# .cpu().detach().numpy()
imgL = left
imgR = right
imgL = np.ascontiguousarray(imgL[None, :, :, :])
imgR = np.ascontiguousarray(imgR[None, :, :, :])
flow_init = None
# chosen for convenience
imgL = torch.tensor(imgL.astype("float32"))
imgR = torch.tensor(imgR.astype("float32"))
imgL = imgL.transpose(2,3).transpose(1,2)
if imgL.shape != imgR.shape:
imgR = imgR.transpose(2,3).transpose(1,2)
imgL_dw2 = F.interpolate(
imgL,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
).clamp(min=0, max=255)
imgR_dw2 = F.interpolate(
imgR,
size=(imgL.shape[2] // 2, imgL.shape[3] // 2),
mode="bilinear",
align_corners=True,
).clamp(min=0, max=255)
if last_img is not None:
print('using flow_initialization')
print(last_img.shape)
# FIXME trying 8bit normalization because sometimes we get really high values that don't seem to help
print(last_img.max(), last_img.min())
if last_img.min() < 0:
# print('Negative disparity detected. shifting...')
last_img = last_img - last_img.min()
if last_img.max() > 255:
# print('Excessive disparity detected. scaling...')
last_img = last_img / (last_img.max() / 255)
last_img = np.dstack([last_img, last_img])
# last_img = np.dstack([last_img, last_img, last_img])
last_img = np.dstack([last_img])
last_img = last_img.reshape((1, 2, 480, 640))
# print(last_img.shape)
# print(last_img.dtype)
# print(last_img.max(), last_img.min())
flow_init = torch.tensor(last_img.astype("float32"))
# flow_init = F.interpolate(
# last_img,
# size=(last_img.shape[0] // 2, last_img.shape[1] // 2),
# mode="bilinear",
# align_corners=True,
# )
with torch.inference_mode():
pred_flow_dw2 = model(image1=imgL_dw2, image2=imgR_dw2, iters=n_iter, flow_init=flow_init, self_attend_right=attend_pattern)
pred_flow = model(imgL, imgR, iters=n_iter, flow_init=pred_flow_dw2, self_attend_right=attend_pattern)
pf_base = pred_flow
if isinstance(pf_base, list):
pf_base = pred_flow[0]
pf = torch.squeeze(pf_base[:, 0, :, :])# .cpu().detach().numpy()
print('pred_flow max min')
print(pf.max(), pf.min())
if not wandb_log:
if test:
return pred_flow
return torch.squeeze(pred_flow[:, 0, :, :])# .cpu().detach().numpy()
log = {} log = {}
in_h, in_w = left.shape[:2] batch_idx = 1
# Resize image in case the GPU memory overflows
eval_h, eval_w = (in_h,in_w)
for i, (pf, pf_dw2) in enumerate(zip(pred_flow, pred_flow_dw2)):
pred_disp = torch.squeeze(pf[:, 0, :, :])# .cpu().detach().numpy()
pred_disp_dw2 = torch.squeeze(pf_dw2[:, 0, :, :])# .cpu().detach().numpy()
# pred_disp_norm = cv2.normalize(pred_disp, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) if isinstance(pred_disp, list):
# pred_disp_dw2_norm = cv2.normalize(pred_disp_dw2, None, 0, 255, cv2.NORM_MINMAX, cv2.CV_8U) pred_disp = pred_disp[-1]
if i == n_iter-1: pred_disp = torch.squeeze(pred_disp[:, 0, :, :])
t = float(in_w) / float(eval_w) gt_disp = torch.squeeze(gt_disp[:, 0, :, :])
disp = cv2.resize(pred_disp.cpu().detach().numpy(), (in_w, in_h), interpolation=cv2.INTER_LINEAR) * t left = torch.squeeze(left[:, 0, :, :])
right = torch.squeeze(right[:, 0, :, :])
log[f'disp_vis'] = wandb.Image(
normalize_and_colormap(disp),
caption=f"Disparity \n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
)
log[f'pred_{i}'] = wandb.Image(
np.array([pred_disp.cpu().detach().numpy().reshape(480, 640)]),
caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
)
# log[f'pred_norm_{i}'] = wandb.Image(
# np.array([pred_disp_norm.reshape(480, 640)]),
# caption=f"Pred. Disp. It {i}\n{pred_disp.min():.{2}f}/{pred_disp.max():.{2}f}",
# )
# log[f'pred_dw2_{i}'] = wandb.Image(
# np.array([pred_disp_dw2.reshape(240, 320)]),
# caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
# )
# log[f'pred_dw2_norm_{i}'] = wandb.Image(
# np.array([pred_disp_dw2_norm.reshape(240, 320)]),
# caption=f"Pred. Disp. Dw2 It {i}\n{pred_disp_dw2.min():.{2}f}/{pred_disp_dw2.max():.{2}f}",
# )
log['input_left'] = wandb.Image(left.astype('uint8'), caption="Input Left")
input_right = right# .cpu().detach().numpy() if isinstance(right, torch.Tensor) else right
if input_right.shape != (480, 640, 3):
input_right.transpose(1,2,0)
log['input_right'] = wandb.Image(input_right.astype('uint8'), caption="Input Right")
log['gt_disp'] = wandb.Image(gt_disp, caption=f"GT Disparity\n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}")
gt_disp = gt_disp# .cpu().detach().numpy() if isinstance(gt_disp, torch.Tensor) else gt_disp
disp = disp# .cpu().detach().numpy() if isinstance(disp, torch.Tensor) else disp
disp = pred_disp
disp_error = gt_disp - disp disp_error = gt_disp - disp
log['disp_error'] = wandb.Image(
normalize_and_colormap(abs(disp_error)),
caption=f"Disp. Error\n{disp_error.min():.{2}f}/{disp_error.max():.{2}f}\n{abs(disp_error).mean():.{2}f}",
)
log[f'gt_disp_vis'] = wandb.Image( input_left = left[batch_idx].cpu().detach().numpy()# .transpose(1,2,0)
normalize_and_colormap(gt_disp), input_right = right[batch_idx].cpu().detach().numpy()# .transpose(1,2,0)
caption=f"GT Disp Vis \n{gt_disp.min():.{2}f}/{gt_disp.max():.{2}f}",
wandb_log = dict(
key='samples',
images=[
normalize_and_colormap(pred_disp[batch_idx]),
normalize_and_colormap(abs(disp_error[batch_idx])),
normalize_and_colormap(gt_disp[batch_idx]),
input_left,
input_right,
],
caption=[
f"Disparity \n{pred_disp[batch_idx].min():.{2}f}/{pred_disp[batch_idx].max():.{2}f}",
f"Disp. Error\n{disp_error[batch_idx].min():.{2}f}/{disp_error[batch_idx].max():.{2}f}\n{abs(disp_error[batch_idx]).mean():.{2}f}",
f"GT Disp Vis \n{gt_disp[batch_idx].min():.{2}f}/{gt_disp[batch_idx].max():.{2}f}",
"Input Left",
"Input Right"
],
) )
return wandb_log
wandb.log(log)
return pred_flow
def parse_yaml(file_path: str) -> namedtuple: def parse_yaml(file_path: str) -> namedtuple:
@ -259,9 +155,10 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
class CREStereoLightning(LightningModule): class CREStereoLightning(LightningModule):
def __init__(self, args): def __init__(self, args, logger):
super().__init__() super().__init__()
self.batch_size = args.batch_size self.batch_size = args.batch_size
self.wandb_logger = logger
self.model = Model( self.model = Model(
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False
) )
@ -270,13 +167,10 @@ class CREStereoLightning(LightningModule):
return self.model(image1, image2, flow_init, iters, upsample, test_mode, self_attend_right) return self.model(image1, image2, flow_init, iters, upsample, test_mode, self_attend_right)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
# loss = self(batch)
left, right, gt_disp, valid_mask = batch left, right, gt_disp, valid_mask = batch
left = torch.Tensor(left).to(self.device)
right = torch.Tensor(right).to(self.device)
left = left
right = right
flow_predictions = self.forward(left, right) flow_predictions = self.forward(left, right)
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
loss = sequence_loss( loss = sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8 flow_predictions, gt_flow, valid_mask, gamma=0.8
) )
@ -285,56 +179,85 @@ class CREStereoLightning(LightningModule):
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
left, right, gt_disp, valid_mask = batch left, right, gt_disp, valid_mask = batch
left = torch.Tensor(left).to(self.device) flow_predictions = self.forward(left, right, test_mode=True)
right = torch.Tensor(right).to(self.device) gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
print(left.shape) gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
print(right.shape)
flow_predictions = self.forward(left, right)
val_loss = sequence_loss( val_loss = sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8 flow_predictions, gt_flow, valid_mask, gamma=0.8
) )
self.log("val_loss", val_loss) self.log("val_loss", val_loss)
if batch_idx % 4 == 0:
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp))
def test_step(self, batch, batch_idx): def test_step(self, batch, batch_idx):
left, right, gt_disp, valid_mask = batch left, right, gt_disp, valid_mask = batch
# left, right, gt_disp, valid_mask = ( gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
# batch["left"], flow_predictions = self.forward(left, right, test_mode=True)
# batch["right"],
# batch["disparity"],
# batch["mask"],
# )
left = torch.Tensor(left).to(self.device)
right = torch.Tensor(right).to(self.device)
flow_predictions = self.forward(left, right)
test_loss = sequence_loss( test_loss = sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8 flow_predictions, gt_flow, valid_mask, gamma=0.8
) )
self.log("test_loss", test_loss) self.log("test_loss", test_loss)
print('test_batch_idx:', batch_idx)
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp))
def configure_optimizers(self): def configure_optimizers(self):
return optim.Adam(self.model.parameters(), lr=0.01, betas=(0.9, 0.999)) return optim.Adam(self.model.parameters(), lr=0.1, betas=(0.9, 0.999))
if __name__ == "__main__": if __name__ == "__main__":
# train configuration # train configuration
args = parse_yaml("cfgs/train.yaml") args = parse_yaml("cfgs/train.yaml")
# wandb.init(project="crestereo-lightning", entity="cpt-captain")
# Lite(strategy='dp', accelerator='gpu', devices=2).run(args)
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png' pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
model = CREStereoLightning(args)
dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path, use_lightning=True)
test_dataset = BlenderDataset(root=args.training_data_path, pattern_path=pattern_path, test_set=True, use_lightning=True)
print(len(dataset))
print(len(test_dataset))
wandb_logger = WandbLogger(project="crestereo-lightning") wandb_logger = WandbLogger(project="crestereo-lightning")
wandb.config.update(args._asdict()) wandb.config.update(args._asdict())
model = CREStereoLightning(args, wandb_logger)
dataset = BlenderDataset(
root=args.training_data_path,
pattern_path=pattern_path,
use_lightning=True,
)
test_dataset = BlenderDataset(
root=args.training_data_path,
pattern_path=pattern_path,
test_set=True,
use_lightning=True,
)
dataloader = DataLoader(
dataset,
args.batch_size,
shuffle=True,
num_workers=16,
drop_last=True,
persistent_workers=True,
pin_memory=True,
)
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
test_dataloader = DataLoader(
test_dataset,
args.batch_size,
shuffle=False,
num_workers=16,
drop_last=False,
persistent_workers=True,
pin_memory=True
)
trainer = Trainer( trainer = Trainer(
max_epochs=args.n_total_epoch,
accelerator='gpu', accelerator='gpu',
devices=2, devices=2,
# auto_scale_batch_size='binsearch', max_epochs=args.n_total_epoch,
# strategy='ddp', callbacks=[
EarlyStopping(
monitor="val_loss",
mode="min",
patience=4,
)
],
accumulate_grad_batches=8,
deterministic=True, deterministic=True,
check_val_every_n_epoch=1, check_val_every_n_epoch=1,
limit_val_batches=24, limit_val_batches=24,
@ -342,5 +265,5 @@ if __name__ == "__main__":
logger=wandb_logger, logger=wandb_logger,
default_root_dir=args.log_dir_lightning, default_root_dir=args.log_dir_lightning,
) )
# trainer.tune(model)
trainer.fit(model, dataset, test_dataset) trainer.fit(model, dataloader, test_dataloader)

Loading…
Cancel
Save