CREStereo Repository for the 'Towards accurate and robust depth estimation' project
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
CREStereo-pytorch-nxt/train_lightning.py

270 lines
8.7 KiB

import os
import sys
import time
import logging
from collections import namedtuple
import yaml
# from tensorboardX import SummaryWriter
from nets import Model
# from dataset import CREStereoDataset
from dataset import BlenderDataset, CREStereoDataset, CTDDataset
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
seed_everything(42, workers=True)
import wandb
import numpy as np
import cv2
def normalize_and_colormap(img):
ret = (img - img.min()) / (img.max() - img.min()) * 255.0
if isinstance(ret, torch.Tensor):
ret = ret.cpu().detach().numpy()
ret = ret.astype("uint8")
ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO)
return ret
def log_images(left, right, pred_disp, gt_disp, wandb_logger=None):
# wandb_logger.log_text('test')
# return
log = {}
batch_idx = 1
if isinstance(pred_disp, list):
pred_disp = pred_disp[-1]
pred_disp = torch.squeeze(pred_disp[:, 0, :, :])
gt_disp = torch.squeeze(gt_disp[:, 0, :, :])
left = torch.squeeze(left[:, 0, :, :])
right = torch.squeeze(right[:, 0, :, :])
disp = pred_disp
disp_error = gt_disp - disp
input_left = left[batch_idx].cpu().detach().numpy()# .transpose(1,2,0)
input_right = right[batch_idx].cpu().detach().numpy()# .transpose(1,2,0)
wandb_log = dict(
key='samples',
images=[
normalize_and_colormap(pred_disp[batch_idx]),
normalize_and_colormap(abs(disp_error[batch_idx])),
normalize_and_colormap(gt_disp[batch_idx]),
input_left,
input_right,
],
caption=[
f"Disparity \n{pred_disp[batch_idx].min():.{2}f}/{pred_disp[batch_idx].max():.{2}f}",
f"Disp. Error\n{disp_error[batch_idx].min():.{2}f}/{disp_error[batch_idx].max():.{2}f}\n{abs(disp_error[batch_idx]).mean():.{2}f}",
f"GT Disp Vis \n{gt_disp[batch_idx].min():.{2}f}/{gt_disp[batch_idx].max():.{2}f}",
"Input Left",
"Input Right"
],
)
return wandb_log
def parse_yaml(file_path: str) -> namedtuple:
"""Parse yaml configuration file and return the object in `namedtuple`."""
with open(file_path, "rb") as f:
cfg: dict = yaml.safe_load(f)
args = namedtuple("train_args", cfg.keys())(*cfg.values())
return args
def format_time(elapse):
elapse = int(elapse)
hour = elapse // 3600
minute = elapse % 3600 // 60
seconds = elapse % 60
return "{:02d}:{:02d}:{:02d}".format(hour, minute, seconds)
def ensure_dir(path):
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
def adjust_learning_rate(optimizer, epoch):
warm_up = 0.02
const_range = 0.6
min_lr_rate = 0.05
if epoch <= args.n_total_epoch * warm_up:
lr = (1 - min_lr_rate) * args.base_lr / (
args.n_total_epoch * warm_up
) * epoch + min_lr_rate * args.base_lr
elif args.n_total_epoch * warm_up < epoch <= args.n_total_epoch * const_range:
lr = args.base_lr
else:
lr = (min_lr_rate - 1) * args.base_lr / (
(1 - const_range) * args.n_total_epoch
) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * args.base_lr
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
'''
valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W)
flow_preds[0]: (B, 2, H, W)
flow_gt: (B, 2, H, W)
'''
if test:
# print('sequence loss')
if valid.shape != (2, 480, 640):
valid = torch.stack([valid, valid])#.transpose(0,1)#.transpose(1,2)
# print(valid.shape)
#valid = torch.stack([valid, valid])
# print(valid.shape)
if valid.shape != (2, 480, 640):
valid = valid.transpose(0,1)
# print(valid.shape)
# print(valid.shape)
# print(flow_preds[0].shape)
# print(flow_gt.shape)
n_predictions = len(flow_preds)
flow_loss = 0.0
# TEST
flow_gt = torch.squeeze(flow_gt, dim=-1)
for i in range(n_predictions):
i_weight = gamma ** (n_predictions - i - 1)
i_loss = torch.abs(flow_preds[i] - flow_gt)
flow_loss += i_weight * (valid.unsqueeze(1) * i_loss).mean()
return flow_loss
class CREStereoLightning(LightningModule):
def __init__(self, args, logger):
super().__init__()
self.batch_size = args.batch_size
self.wandb_logger = logger
self.model = Model(
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False
)
def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False, self_attend_right=True):
return self.model(image1, image2, flow_init, iters, upsample, test_mode, self_attend_right)
def training_step(self, batch, batch_idx):
left, right, gt_disp, valid_mask = batch
flow_predictions = self.forward(left, right)
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
loss = sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8
)
self.log("train_loss", loss)
return loss
def validation_step(self, batch, batch_idx):
left, right, gt_disp, valid_mask = batch
flow_predictions = self.forward(left, right, test_mode=True)
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
val_loss = sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8
)
self.log("val_loss", val_loss)
if batch_idx % 4 == 0:
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp))
def test_step(self, batch, batch_idx):
left, right, gt_disp, valid_mask = batch
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
flow_predictions = self.forward(left, right, test_mode=True)
test_loss = sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8
)
self.log("test_loss", test_loss)
print('test_batch_idx:', batch_idx)
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp))
def configure_optimizers(self):
return optim.Adam(self.model.parameters(), lr=0.1, betas=(0.9, 0.999))
if __name__ == "__main__":
# train configuration
args = parse_yaml("cfgs/train.yaml")
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
wandb_logger = WandbLogger(project="crestereo-lightning")
wandb.config.update(args._asdict())
model = CREStereoLightning(args, wandb_logger)
dataset = BlenderDataset(
root=args.training_data_path,
pattern_path=pattern_path,
use_lightning=True,
)
test_dataset = BlenderDataset(
root=args.training_data_path,
pattern_path=pattern_path,
test_set=True,
use_lightning=True,
)
dataloader = DataLoader(
dataset,
args.batch_size,
shuffle=True,
num_workers=16,
drop_last=True,
persistent_workers=True,
pin_memory=True,
)
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
test_dataloader = DataLoader(
test_dataset,
args.batch_size,
shuffle=False,
num_workers=16,
drop_last=False,
persistent_workers=True,
pin_memory=True
)
trainer = Trainer(
accelerator='gpu',
devices=2,
max_epochs=args.n_total_epoch,
callbacks=[
EarlyStopping(
monitor="val_loss",
mode="min",
patience=4,
)
],
accumulate_grad_batches=8,
deterministic=True,
check_val_every_n_epoch=1,
limit_val_batches=24,
limit_test_batches=24,
logger=wandb_logger,
default_root_dir=args.log_dir_lightning,
)
trainer.fit(model, dataloader, test_dataloader)