You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
269 lines
8.7 KiB
269 lines
8.7 KiB
import os
|
|
import sys
|
|
import time
|
|
import logging
|
|
from collections import namedtuple
|
|
|
|
import yaml
|
|
# from tensorboardX import SummaryWriter
|
|
|
|
from nets import Model
|
|
# from dataset import CREStereoDataset
|
|
from dataset import BlenderDataset, CREStereoDataset, CTDDataset
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.optim as optim
|
|
from torch.utils.data import DataLoader
|
|
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
|
|
from pytorch_lightning import Trainer, seed_everything
|
|
from pytorch_lightning.loggers import WandbLogger
|
|
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
|
|
|
seed_everything(42, workers=True)
|
|
|
|
|
|
import wandb
|
|
|
|
import numpy as np
|
|
import cv2
|
|
|
|
|
|
def normalize_and_colormap(img):
|
|
ret = (img - img.min()) / (img.max() - img.min()) * 255.0
|
|
if isinstance(ret, torch.Tensor):
|
|
ret = ret.cpu().detach().numpy()
|
|
ret = ret.astype("uint8")
|
|
ret = cv2.applyColorMap(ret, cv2.COLORMAP_INFERNO)
|
|
return ret
|
|
|
|
|
|
def log_images(left, right, pred_disp, gt_disp, wandb_logger=None):
|
|
# wandb_logger.log_text('test')
|
|
# return
|
|
log = {}
|
|
batch_idx = 1
|
|
|
|
if isinstance(pred_disp, list):
|
|
pred_disp = pred_disp[-1]
|
|
|
|
pred_disp = torch.squeeze(pred_disp[:, 0, :, :])
|
|
gt_disp = torch.squeeze(gt_disp[:, 0, :, :])
|
|
left = torch.squeeze(left[:, 0, :, :])
|
|
right = torch.squeeze(right[:, 0, :, :])
|
|
|
|
disp = pred_disp
|
|
disp_error = gt_disp - disp
|
|
|
|
input_left = left[batch_idx].cpu().detach().numpy()# .transpose(1,2,0)
|
|
input_right = right[batch_idx].cpu().detach().numpy()# .transpose(1,2,0)
|
|
|
|
wandb_log = dict(
|
|
key='samples',
|
|
images=[
|
|
normalize_and_colormap(pred_disp[batch_idx]),
|
|
normalize_and_colormap(abs(disp_error[batch_idx])),
|
|
normalize_and_colormap(gt_disp[batch_idx]),
|
|
input_left,
|
|
input_right,
|
|
],
|
|
caption=[
|
|
f"Disparity \n{pred_disp[batch_idx].min():.{2}f}/{pred_disp[batch_idx].max():.{2}f}",
|
|
f"Disp. Error\n{disp_error[batch_idx].min():.{2}f}/{disp_error[batch_idx].max():.{2}f}\n{abs(disp_error[batch_idx]).mean():.{2}f}",
|
|
f"GT Disp Vis \n{gt_disp[batch_idx].min():.{2}f}/{gt_disp[batch_idx].max():.{2}f}",
|
|
"Input Left",
|
|
"Input Right"
|
|
],
|
|
)
|
|
return wandb_log
|
|
|
|
|
|
def parse_yaml(file_path: str) -> namedtuple:
|
|
"""Parse yaml configuration file and return the object in `namedtuple`."""
|
|
with open(file_path, "rb") as f:
|
|
cfg: dict = yaml.safe_load(f)
|
|
args = namedtuple("train_args", cfg.keys())(*cfg.values())
|
|
return args
|
|
|
|
|
|
def format_time(elapse):
|
|
elapse = int(elapse)
|
|
hour = elapse // 3600
|
|
minute = elapse % 3600 // 60
|
|
seconds = elapse % 60
|
|
return "{:02d}:{:02d}:{:02d}".format(hour, minute, seconds)
|
|
|
|
|
|
def ensure_dir(path):
|
|
if not os.path.exists(path):
|
|
os.makedirs(path, exist_ok=True)
|
|
|
|
|
|
def adjust_learning_rate(optimizer, epoch):
|
|
|
|
warm_up = 0.02
|
|
const_range = 0.6
|
|
min_lr_rate = 0.05
|
|
|
|
if epoch <= args.n_total_epoch * warm_up:
|
|
lr = (1 - min_lr_rate) * args.base_lr / (
|
|
args.n_total_epoch * warm_up
|
|
) * epoch + min_lr_rate * args.base_lr
|
|
elif args.n_total_epoch * warm_up < epoch <= args.n_total_epoch * const_range:
|
|
lr = args.base_lr
|
|
else:
|
|
lr = (min_lr_rate - 1) * args.base_lr / (
|
|
(1 - const_range) * args.n_total_epoch
|
|
) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * args.base_lr
|
|
|
|
for param_group in optimizer.param_groups:
|
|
param_group['lr'] = lr
|
|
|
|
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
|
|
'''
|
|
valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W)
|
|
flow_preds[0]: (B, 2, H, W)
|
|
flow_gt: (B, 2, H, W)
|
|
'''
|
|
if test:
|
|
# print('sequence loss')
|
|
if valid.shape != (2, 480, 640):
|
|
valid = torch.stack([valid, valid])#.transpose(0,1)#.transpose(1,2)
|
|
# print(valid.shape)
|
|
#valid = torch.stack([valid, valid])
|
|
# print(valid.shape)
|
|
if valid.shape != (2, 480, 640):
|
|
valid = valid.transpose(0,1)
|
|
# print(valid.shape)
|
|
# print(valid.shape)
|
|
# print(flow_preds[0].shape)
|
|
# print(flow_gt.shape)
|
|
n_predictions = len(flow_preds)
|
|
flow_loss = 0.0
|
|
|
|
# TEST
|
|
flow_gt = torch.squeeze(flow_gt, dim=-1)
|
|
|
|
for i in range(n_predictions):
|
|
i_weight = gamma ** (n_predictions - i - 1)
|
|
i_loss = torch.abs(flow_preds[i] - flow_gt)
|
|
flow_loss += i_weight * (valid.unsqueeze(1) * i_loss).mean()
|
|
|
|
return flow_loss
|
|
|
|
|
|
|
|
class CREStereoLightning(LightningModule):
|
|
def __init__(self, args, logger):
|
|
super().__init__()
|
|
self.batch_size = args.batch_size
|
|
self.wandb_logger = logger
|
|
self.model = Model(
|
|
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False
|
|
)
|
|
|
|
def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False, self_attend_right=True):
|
|
return self.model(image1, image2, flow_init, iters, upsample, test_mode, self_attend_right)
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
left, right, gt_disp, valid_mask = batch
|
|
flow_predictions = self.forward(left, right)
|
|
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
|
|
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
|
|
loss = sequence_loss(
|
|
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
|
)
|
|
self.log("train_loss", loss)
|
|
return loss
|
|
|
|
def validation_step(self, batch, batch_idx):
|
|
left, right, gt_disp, valid_mask = batch
|
|
flow_predictions = self.forward(left, right, test_mode=True)
|
|
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
|
|
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
|
|
val_loss = sequence_loss(
|
|
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
|
)
|
|
self.log("val_loss", val_loss)
|
|
if batch_idx % 4 == 0:
|
|
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp))
|
|
|
|
def test_step(self, batch, batch_idx):
|
|
left, right, gt_disp, valid_mask = batch
|
|
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
|
|
flow_predictions = self.forward(left, right, test_mode=True)
|
|
test_loss = sequence_loss(
|
|
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
|
)
|
|
self.log("test_loss", test_loss)
|
|
print('test_batch_idx:', batch_idx)
|
|
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp))
|
|
|
|
def configure_optimizers(self):
|
|
return optim.Adam(self.model.parameters(), lr=0.1, betas=(0.9, 0.999))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
# train configuration
|
|
args = parse_yaml("cfgs/train.yaml")
|
|
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
|
|
|
|
wandb_logger = WandbLogger(project="crestereo-lightning")
|
|
wandb.config.update(args._asdict())
|
|
|
|
model = CREStereoLightning(args, wandb_logger)
|
|
|
|
dataset = BlenderDataset(
|
|
root=args.training_data_path,
|
|
pattern_path=pattern_path,
|
|
use_lightning=True,
|
|
)
|
|
test_dataset = BlenderDataset(
|
|
root=args.training_data_path,
|
|
pattern_path=pattern_path,
|
|
test_set=True,
|
|
use_lightning=True,
|
|
)
|
|
|
|
dataloader = DataLoader(
|
|
dataset,
|
|
args.batch_size,
|
|
shuffle=True,
|
|
num_workers=16,
|
|
drop_last=True,
|
|
persistent_workers=True,
|
|
pin_memory=True,
|
|
)
|
|
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
|
|
test_dataloader = DataLoader(
|
|
test_dataset,
|
|
args.batch_size,
|
|
shuffle=False,
|
|
num_workers=16,
|
|
drop_last=False,
|
|
persistent_workers=True,
|
|
pin_memory=True
|
|
)
|
|
|
|
trainer = Trainer(
|
|
accelerator='gpu',
|
|
devices=2,
|
|
max_epochs=args.n_total_epoch,
|
|
callbacks=[
|
|
EarlyStopping(
|
|
monitor="val_loss",
|
|
mode="min",
|
|
patience=4,
|
|
)
|
|
],
|
|
accumulate_grad_batches=8,
|
|
deterministic=True,
|
|
check_val_every_n_epoch=1,
|
|
limit_val_batches=24,
|
|
limit_test_batches=24,
|
|
logger=wandb_logger,
|
|
default_root_dir=args.log_dir_lightning,
|
|
)
|
|
|
|
trainer.fit(model, dataloader, test_dataloader)
|
|
|