CREStereo Repository for the 'Towards accurate and robust depth estimation' project
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
CREStereo-pytorch-nxt/train_lightning.py

424 lines
15 KiB

import os
import sys
import time
import logging
from collections import namedtuple
import yaml
from nets import Model
from dataset import BlenderDataset, CREStereoDataset, CTDDataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.strategies import DDPSpawnStrategy
seed_everything(42, workers=True)
import wandb
import numpy as np
import cv2
def normalize_and_colormap(img):
ret = (img - img.min()) / (img.max() - img.min()) * 255.0
if isinstance(ret, torch.Tensor):
ret = ret.cpu().detach().numpy()
ret = ret.astype("uint8")
ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO)
return ret
def log_images(left, right, pred_disp, gt_disp):
log = {}
batch_idx = 0
if isinstance(pred_disp, list):
pred_disp = pred_disp[-1]
pred_disp = torch.squeeze(pred_disp[:, 0, :, :])
gt_disp = torch.squeeze(gt_disp[:, 0, :, :])
left = torch.squeeze(left[:, 0, :, :])
right = torch.squeeze(right[:, 0, :, :])
disp = pred_disp
disp_error = gt_disp - disp
input_left = left[batch_idx].cpu().detach().numpy()# .transpose(1,2,0)
input_right = right[batch_idx].cpu().detach().numpy()# .transpose(1,2,0)
wandb_log = dict(
key='samples',
images=[
normalize_and_colormap(pred_disp[batch_idx]),
normalize_and_colormap(abs(disp_error[batch_idx])),
normalize_and_colormap(gt_disp[batch_idx]),
input_left,
input_right,
],
caption=[
f"Disparity \n{pred_disp[batch_idx].min():.{2}f}/{pred_disp[batch_idx].max():.{2}f}",
f"Disp. Error\n{disp_error[batch_idx].min():.{2}f}/{disp_error[batch_idx].max():.{2}f}\n{abs(disp_error[batch_idx]).mean():.{2}f}",
f"GT Disp Vis \n{gt_disp[batch_idx].min():.{2}f}/{gt_disp[batch_idx].max():.{2}f}",
"Input Left",
"Input Right"
],
)
return wandb_log
def parse_yaml(file_path: str) -> namedtuple:
"""Parse yaml configuration file and return the object in `namedtuple`."""
with open(file_path, "rb") as f:
cfg: dict = yaml.safe_load(f)
args = namedtuple("train_args", cfg.keys())(*cfg.values())
return args
def format_time(elapse):
elapse = int(elapse)
hour = elapse // 3600
minute = elapse % 3600 // 60
seconds = elapse % 60
return "{:02d}:{:02d}:{:02d}".format(hour, minute, seconds)
def outlier_fraction(estimate, target, mask=None, threshold=0):
def _process_inputs(estimate, target, mask):
if estimate.shape != target.shape:
raise Exception(f'estimate and target have to be same shape (expected {estimate.shape} == {target.shape})')
if mask is None:
mask = np.ones(estimate.shape, dtype=np.bool)
else:
mask = mask != 0
if estimate.shape != mask.shape:
raise Exception(f'estimate and mask have to be same shape (expected {estimate.shape} == {mask.shape})')
return estimate, target, mask
estimate = torch.squeeze(estimate[:, 0, :, :])
target = torch.squeeze(target[:, 0, :, :])
estimate, target, mask = _process_inputs(estimate, target, mask)
mask = mask.cpu().detach().numpy()
estimate = estimate.cpu().detach().numpy()
target = target.cpu().detach().numpy()
diff = np.abs(estimate[mask] - target[mask])
m = (diff > threshold).sum() / mask.sum()
return m
def ensure_dir(path):
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
'''
valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W)
flow_preds[0]: (B, 2, H, W)
flow_gt: (B, 2, H, W)
'''
"""
if test:
# print('sequence loss')
if valid.shape != (2, 480, 640):
valid = torch.stack([valid, valid])#.transpose(0,1)#.transpose(1,2)
# print(valid.shape)
#valid = torch.stack([valid, valid])
# print(valid.shape)
if valid.shape != (2, 480, 640):
valid = valid.transpose(0,1)
# print(valid.shape)
"""
# print(valid.shape)
# print(flow_preds[0].shape)
# print(flow_gt.shape)
n_predictions = len(flow_preds)
flow_loss = 0.0
# TEST
# flow_gt = torch.squeeze(flow_gt, dim=-1)
for i in range(n_predictions):
i_weight = gamma ** (n_predictions - i - 1)
i_loss = torch.abs(flow_preds[i] - flow_gt)
flow_loss += i_weight * (valid.unsqueeze(1) * i_loss).mean()
return flow_loss
class CREStereoLightning(LightningModule):
def __init__(self, args, logger=None, pattern_path='', data_path=''):
super().__init__()
self.batch_size = args.batch_size
self.wandb_logger = logger
self.data_type = 'blender' if 'blender' in data_path else 'ctd'
self.lr = args.base_lr
print(f'lr = {self.lr}')
self.T_max = args.t_max if args.t_max else None
self.pattern_attention = args.pattern_attention
self.pattern_path = pattern_path
self.data_path = data_path
self.model = Model(
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False
)
# so I can access it in adjust learn rate more easily
self.n_total_epoch = args.n_total_epoch
self.base_lr = args.base_lr
self.automatic_optimization = False
def train_dataloader(self):
if self.data_type == 'blender':
dataset = BlenderDataset(
root=self.data_path,
pattern_path=self.pattern_path,
use_lightning=True,
)
elif self.data_type == 'ctd':
dataset = CTDDataset(
root=self.data_path,
pattern_path=self.pattern_path,
use_lightning=True,
)
dataloader = DataLoader(
dataset,
self.batch_size,
shuffle=True,
num_workers=4,
drop_last=True,
persistent_workers=True,
pin_memory=True,
)
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
return dataloader
def val_dataloader(self):
if self.data_type == 'blender':
test_dataset = BlenderDataset(
root=self.data_path,
pattern_path=self.pattern_path,
test_set=True,
use_lightning=True,
)
elif self.data_type == 'ctd':
test_dataset = CTDDataset(
root=self.data_path,
pattern_path=self.pattern_path,
test_set=True,
use_lightning=True,
)
test_dataloader = DataLoader(
test_dataset,
self.batch_size,
shuffle=False,
num_workers=4,
drop_last=False,
persistent_workers=True,
pin_memory=True
)
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
return test_dataloader
def test_dataloader(self):
# TODO change this to use IRL data?
if self.data_type == 'blender':
test_dataset = CTDDataset(
root=self.data_path,
pattern_path=self.pattern_path,
test_set=True,
use_lightning=True,
)
elif self.data_type == 'ctd':
test_dataset = BlenderDataset(
root=self.data_path,
pattern_path=self.pattern_path,
test_set=True,
use_lightning=True,
)
test_dataloader = DataLoader(
test_dataset,
self.batch_size,
shuffle=False,
num_workers=4,
drop_last=False,
persistent_workers=True,
pin_memory=True
)
return test_dataloader
def forward(
self,
image1,
image2,
flow_init=None,
iters=10,
upsample=True,
test_mode=False,
):
return self.model(image1, image2, flow_init, iters, upsample, test_mode, self.pattern_attention)
def training_step(self, batch, batch_idx):
left, right, gt_disp, valid_mask = batch
flow_predictions = self.forward(left, right)
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
loss = sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8
)
if batch_idx % 128 == 0:
image_log = log_images(left, right, flow_predictions, gt_disp)
image_log['key'] = 'debug_train'
self.wandb_logger.log_image(**image_log)
self.log("train_loss", loss)
# update learn rate every N epochs
if self.trainer.is_last_batch:
self.adjust_learning_rate()
return loss
def validation_step(self, batch, batch_idx):
left, right, gt_disp, valid_mask = batch
flow_predictions = self.forward(left, right, test_mode=True)
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
val_loss = sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8
)
self.log("val_loss", val_loss)
of = {}
for threshold in [0.1, 0.5, 1, 2, 5]:
of[threshold] = outlier_fraction(flow_predictions[0], gt_flow, valid_mask, threshold)
self.log("outlier_fraction", of)
# print(', '.join(f'of{thr}={val}' for thr, val in of.items()))
if batch_idx % 8 == 0:
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp))
def test_step(self, batch, batch_idx):
left, right, gt_disp, valid_mask = batch
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
flow_predictions = self.forward(left, right, test_mode=True)
test_loss = sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8
)
self.log("test_loss", test_loss)
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp))
def predict_step(self, batch, batch_idx, dataloader_idx=0):
return self(batch)
def configure_optimizers(self):
optimizer = optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.999))
print('len(self.train_dataloader)', len(self.train_dataloader()))
lr_scheduler = {
'scheduler': torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=self.T_max if self.T_max else len(self.train_dataloader())/self.batch_size,
),
'name': 'CosineAnnealingLRScheduler',
}
return [optimizer], [lr_scheduler]
def adjust_learning_rate(self):
optimizer = self.optimizers().optimizer
epoch = self.trainer.current_epoch+1
warm_up = 0.02
const_range = 0.6
min_lr_rate = 0.05
if epoch <= self.n_total_epoch * warm_up:
lr = (1 - min_lr_rate) * self.base_lr / (
self.n_total_epoch * warm_up
) * epoch + min_lr_rate * self.base_lr
elif self.n_total_epoch * warm_up < epoch <= self.n_total_epoch * const_range:
lr = self.base_lr
else:
lr = (min_lr_rate - 1) * self.base_lr / (
(1 - const_range) * self.n_total_epoch
) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * self.base_lr
for param_group in optimizer.param_groups:
param_group['lr'] = lr
if __name__ == "__main__":
wandb_logger = WandbLogger(project="crestereo-lightning", log_model=True)
# train configuration
args = parse_yaml("cfgs/train.yaml")
wandb_logger.experiment.config.update(args._asdict())
config = wandb.config
if 'blender' in config.training_data_path:
# this was used for our blender renders
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
if 'ctd' in config.training_data_path:
# this one is used (i hope) for ctd
pattern_path = '/home/nils/kinect_from_settings.png'
devices = min(config.nr_gpus, torch.cuda.device_count())
if devices != config.nr_gpus:
print(f'Using less devices than expected! ({devices} / {config.nr_gpus})')
model = CREStereoLightning(
# args,
config,
wandb_logger,
pattern_path,
config.training_data_path,
# lr=0.00017378008287493763, # found with auto_lr_find=True
)
# NOTE turn this down once it's working, this might use too much space
# wandb_logger.watch(model, log_graph=False) #, log='all')
model_checkpoint = ModelCheckpoint(
monitor="val_loss",
mode="min",
save_top_k=2,
save_last=True,
)
trainer = Trainer(
accelerator='gpu',
devices=devices,
max_epochs=config.n_total_epoch,
callbacks=[
EarlyStopping(
monitor="val_loss",
mode="min",
patience=16,
),
LearningRateMonitor(),
model_checkpoint,
],
strategy=DDPSpawnStrategy(find_unused_parameters=False),
# auto_scale_batch_size='binsearch',
# auto_lr_find=True,
# accumulate_grad_batches=4, # needed to disable for manual optimization
deterministic=True,
check_val_every_n_epoch=1,
limit_val_batches=64,
limit_test_batches=256,
logger=wandb_logger,
default_root_dir=config.log_dir_lightning,
)
# trainer.tune(model)
trainer.fit(model)
# trainer.validate(chkpt_path=model_checkpoint.best_model_path)