fix lightning, prepare sweeps

main
Cpt.Captain 2 years ago
parent d8169e01bc
commit 37c537ca31
  1. 6
      cfgs/train.yaml
  2. 6
      dataset.py
  3. 14
      nets/crestereo.py
  4. 1
      train.py
  5. 211
      train_lightning.py

@ -1,6 +1,8 @@
seed: 0 seed: 0
mixed_precision: false mixed_precision: false
base_lr: 4.0e-4 # base_lr: 4.0e-4
base_lr: 0.001
t_max: 161
nr_gpus: 3 nr_gpus: 3
batch_size: 2 batch_size: 2
@ -16,7 +18,7 @@ max_disp: 256
image_width: 640 image_width: 640
image_height: 480 image_height: 480
# training_data_path: "./stereo_trainset/crestereo" # training_data_path: "./stereo_trainset/crestereo"
pattern_attention: true pattern_attention: false
dataset: "blender" dataset: "blender"
# training_data_path: "/media/Data1/connecting_the_dots_data/ctd_data/" # training_data_path: "/media/Data1/connecting_the_dots_data/ctd_data/"
training_data_path: "/media/Data1/connecting_the_dots_data/blender_renders/data" training_data_path: "/media/Data1/connecting_the_dots_data/blender_renders/data"

@ -384,7 +384,7 @@ class BlenderDataset(CTDDataset):
) )
if not self.use_lightning: if not self.use_lightning:
right_img = right_img.transpose((2, 0, 1)).astype("uint8") # right_img = right_img.transpose((2, 0, 1)).astype("uint8")
return { return {
"left": left_img, "left": left_img,
"right": right_img, "right": right_img,
@ -408,7 +408,7 @@ class BlenderDataset(CTDDataset):
# return disp.astype(np.float32) / 32 # return disp.astype(np.float32) / 32
# FIXME temporarily increase disparity until new data with better depth values is generated # FIXME temporarily increase disparity until new data with better depth values is generated
# higher values seem to speedup convergence, but introduce much stronger artifacting # higher values seem to speedup convergence, but introduce much stronger artifacting
# mystery_factor = 150 mystery_factor = 150
mystery_factor = 1 # mystery_factor = 1
disp = (baseline * fl * mystery_factor) / depth disp = (baseline * fl * mystery_factor) / depth
return disp.astype(np.float32) return disp.astype(np.float32)

@ -38,10 +38,10 @@ class CREStereo(nn.Module):
self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4) self.update_block = BasicUpdateBlock(hidden_dim=self.hidden_dim, cor_planes=4 * 9, mask_size=4)
# # NOTE Position_encoding as workaround for TensorRt # # NOTE Position_encoding as workaround for TensorRt
image1_shape = [1, 2, 480, 640] # image1_shape = [1, 2, 480, 640]
self.pos_encoding_fn_small = PositionEncodingSine( # self.pos_encoding_fn_small = PositionEncodingSine(
d_model=256, max_shape=(image1_shape[2] // 16, image1_shape[3] // 16) # d_model=256, max_shape=(image1_shape[2] // 16, image1_shape[3] // 16)
) # )
# loftr # loftr
self.self_att_fn = LocalFeatureTransformer( self.self_att_fn = LocalFeatureTransformer(
@ -141,10 +141,12 @@ class CREStereo(nn.Module):
d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16) d_model=256, max_shape=(image1.shape[2] // 16, image1.shape[3] // 16)
) )
# 'n c h w -> n (h w) c' # 'n c h w -> n (h w) c'
x_tmp = self.pos_encoding_fn_small(fmap1_dw16) # x_tmp = self.pos_encoding_fn_small(fmap1_dw16)
x_tmp = pos_encoding_fn_small(fmap1_dw16)
fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1]) fmap1_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
# 'n c h w -> n (h w) c' # 'n c h w -> n (h w) c'
x_tmp = self.pos_encoding_fn_small(fmap2_dw16) # x_tmp = self.pos_encoding_fn_small(fmap2_dw16)
x_tmp = pos_encoding_fn_small(fmap2_dw16)
fmap2_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1]) fmap2_dw16 = x_tmp.permute(0, 2, 3, 1).reshape(x_tmp.shape[0], x_tmp.shape[2] * x_tmp.shape[3], x_tmp.shape[1])
# FIXME experimental ! no self-attention for pattern # FIXME experimental ! no self-attention for pattern

@ -419,6 +419,7 @@ def main(args):
# print(f'left {left.shape}, right {right.shape}') # print(f'left {left.shape}, right {right.shape}')
# left = left.transpose([2, 0, 1]) # left = left.transpose([2, 0, 1])
right = right.transpose([1, 2, 0]) right = right.transpose([1, 2, 0])
# right = right.transpose(2, 3).transpose(2, 3)#.transpose(1, 2) # right = right.transpose(2, 3).transpose(2, 3)#.transpose(1, 2)
# print(f'left {left.shape}, right {right.shape}') # print(f'left {left.shape}, right {right.shape}')

@ -5,10 +5,8 @@ import logging
from collections import namedtuple from collections import namedtuple
import yaml import yaml
# from tensorboardX import SummaryWriter
from nets import Model from nets import Model
# from dataset import CREStereoDataset
from dataset import BlenderDataset, CREStereoDataset, CTDDataset from dataset import BlenderDataset, CREStereoDataset, CTDDataset
import torch import torch
@ -18,8 +16,11 @@ import torch.optim as optim
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from pytorch_lightning import LightningDataModule, LightningModule, Trainer from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning import Trainer, seed_everything from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.strategies import DDPSpawnStrategy
seed_everything(42, workers=True) seed_everything(42, workers=True)
@ -39,11 +40,9 @@ def normalize_and_colormap(img):
return ret return ret
def log_images(left, right, pred_disp, gt_disp, wandb_logger=None): def log_images(left, right, pred_disp, gt_disp):
# wandb_logger.log_text('test')
# return
log = {} log = {}
batch_idx = 1 batch_idx = 0
if isinstance(pred_disp, list): if isinstance(pred_disp, list):
pred_disp = pred_disp[-1] pred_disp = pred_disp[-1]
@ -100,32 +99,13 @@ def ensure_dir(path):
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
def adjust_learning_rate(optimizer, epoch):
warm_up = 0.02
const_range = 0.6
min_lr_rate = 0.05
if epoch <= args.n_total_epoch * warm_up:
lr = (1 - min_lr_rate) * args.base_lr / (
args.n_total_epoch * warm_up
) * epoch + min_lr_rate * args.base_lr
elif args.n_total_epoch * warm_up < epoch <= args.n_total_epoch * const_range:
lr = args.base_lr
else:
lr = (min_lr_rate - 1) * args.base_lr / (
(1 - const_range) * args.n_total_epoch
) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * args.base_lr
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False): def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
''' '''
valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W) valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W)
flow_preds[0]: (B, 2, H, W) flow_preds[0]: (B, 2, H, W)
flow_gt: (B, 2, H, W) flow_gt: (B, 2, H, W)
''' '''
"""
if test: if test:
# print('sequence loss') # print('sequence loss')
if valid.shape != (2, 480, 640): if valid.shape != (2, 480, 640):
@ -136,6 +116,7 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
if valid.shape != (2, 480, 640): if valid.shape != (2, 480, 640):
valid = valid.transpose(0,1) valid = valid.transpose(0,1)
# print(valid.shape) # print(valid.shape)
"""
# print(valid.shape) # print(valid.shape)
# print(flow_preds[0].shape) # print(flow_preds[0].shape)
# print(flow_gt.shape) # print(flow_gt.shape)
@ -143,7 +124,7 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
flow_loss = 0.0 flow_loss = 0.0
# TEST # TEST
flow_gt = torch.squeeze(flow_gt, dim=-1) # flow_gt = torch.squeeze(flow_gt, dim=-1)
for i in range(n_predictions): for i in range(n_predictions):
i_weight = gamma ** (n_predictions - i - 1) i_weight = gamma ** (n_predictions - i - 1)
@ -155,16 +136,88 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
class CREStereoLightning(LightningModule): class CREStereoLightning(LightningModule):
def __init__(self, args, logger): def __init__(self, args, logger, pattern_path, data_path):
super().__init__() super().__init__()
self.batch_size = args.batch_size self.batch_size = args.batch_size
self.wandb_logger = logger self.wandb_logger = logger
self.lr = args.base_lr
print(f'lr = {self.lr}')
self.T_max = args.t_max if args.t_max else None
self.pattern_attention = args.pattern_attention
self.pattern_path = pattern_path
self.data_path = data_path
self.model = Model( self.model = Model(
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False
) )
def forward(self, image1, image2, flow_init=None, iters=10, upsample=True, test_mode=False, self_attend_right=True): def train_dataloader(self):
return self.model(image1, image2, flow_init, iters, upsample, test_mode, self_attend_right) dataset = BlenderDataset(
root=self.data_path,
pattern_path=self.pattern_path,
use_lightning=True,
)
dataloader = DataLoader(
dataset,
self.batch_size,
shuffle=True,
num_workers=4,
drop_last=True,
persistent_workers=True,
pin_memory=True,
)
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
return dataloader
def val_dataloader(self):
test_dataset = BlenderDataset(
root=self.data_path,
pattern_path=self.pattern_path,
test_set=True,
use_lightning=True,
)
test_dataloader = DataLoader(
test_dataset,
self.batch_size,
shuffle=False,
num_workers=4,
drop_last=False,
persistent_workers=True,
pin_memory=True
)
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
return test_dataloader
def test_dataloader(self):
# TODO change this to use IRL data?
test_dataset = BlenderDataset(
root=self.data_path,
pattern_path=self.pattern_path,
test_set=True,
use_lightning=True,
)
test_dataloader = DataLoader(
test_dataset,
self.batch_size,
shuffle=False,
num_workers=4,
drop_last=False,
persistent_workers=True,
pin_memory=True
)
return test_dataloader
def forward(
self,
image1,
image2,
flow_init=None,
iters=10,
upsample=True,
test_mode=False,
):
return self.model(image1, image2, flow_init, iters, upsample, test_mode, self.pattern_attention)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
left, right, gt_disp, valid_mask = batch left, right, gt_disp, valid_mask = batch
@ -174,6 +227,10 @@ class CREStereoLightning(LightningModule):
loss = sequence_loss( loss = sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8 flow_predictions, gt_flow, valid_mask, gamma=0.8
) )
if batch_idx % 128 == 0:
image_log = log_images(left, right, flow_predictions, gt_disp)
image_log['key'] = 'debug_train'
self.wandb_logger.log_image(**image_log)
self.log("train_loss", loss) self.log("train_loss", loss)
return loss return loss
@ -186,22 +243,31 @@ class CREStereoLightning(LightningModule):
flow_predictions, gt_flow, valid_mask, gamma=0.8 flow_predictions, gt_flow, valid_mask, gamma=0.8
) )
self.log("val_loss", val_loss) self.log("val_loss", val_loss)
if batch_idx % 4 == 0: if batch_idx % 8 == 0:
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp)) self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp))
def test_step(self, batch, batch_idx): def test_step(self, batch, batch_idx):
left, right, gt_disp, valid_mask = batch left, right, gt_disp, valid_mask = batch
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512] gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512] gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
flow_predictions = self.forward(left, right, test_mode=True) flow_predictions = self.forward(left, right, test_mode=True)
test_loss = sequence_loss( test_loss = sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8 flow_predictions, gt_flow, valid_mask, gamma=0.8
) )
self.log("test_loss", test_loss) self.log("test_loss", test_loss)
print('test_batch_idx:', batch_idx)
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp)) self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp))
def configure_optimizers(self): def configure_optimizers(self):
return optim.Adam(self.model.parameters(), lr=0.1, betas=(0.9, 0.999)) optimizer = optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.999))
print('len(self.train_dataloader)', len(self.train_dataloader()))
lr_scheduler = {
'scheduler': torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=self.T_max if self.T_max else len(self.train_dataloader())/self.batch_size,
),
'name': 'CosineAnnealingLRScheduler',
}
return [optimizer], [lr_scheduler]
if __name__ == "__main__": if __name__ == "__main__":
@ -209,61 +275,54 @@ if __name__ == "__main__":
args = parse_yaml("cfgs/train.yaml") args = parse_yaml("cfgs/train.yaml")
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png' pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
wandb_logger = WandbLogger(project="crestereo-lightning") run = wandb.init(project="crestereo-lightning", config=args._asdict(), tags=['new_scheduler', 'default_lr', f'{"" if args.pattern_attention else "no-"}pattern-attention'], notes='')
wandb.config.update(args._asdict()) run.config.update(args._asdict())
config = wandb.config
model = CREStereoLightning(args, wandb_logger) wandb_logger = WandbLogger(project="crestereo-lightning", id=run.id, log_model=True)
# wandb_logger = WandbLogger(project="crestereo-lightning", log_model='all')
dataset = BlenderDataset( # wandb_logger.experiment.config.update(args._asdict())
root=args.training_data_path,
pattern_path=pattern_path, model = CREStereoLightning(
use_lightning=True, # args,
) config,
test_dataset = BlenderDataset( wandb_logger,
root=args.training_data_path, pattern_path,
pattern_path=pattern_path, args.training_data_path,
test_set=True, # lr=0.00017378008287493763, # found with auto_lr_find=True
use_lightning=True,
)
dataloader = DataLoader(
dataset,
args.batch_size,
shuffle=True,
num_workers=16,
drop_last=True,
persistent_workers=True,
pin_memory=True,
)
# num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
test_dataloader = DataLoader(
test_dataset,
args.batch_size,
shuffle=False,
num_workers=16,
drop_last=False,
persistent_workers=True,
pin_memory=True
) )
# NOTE turn this down once it's working, this might use too much space
# wandb_logger.watch(model, log_graph=False) #, log='all')
trainer = Trainer( trainer = Trainer(
accelerator='gpu', accelerator='gpu',
devices=2, devices=args.nr_gpus,
max_epochs=args.n_total_epoch, max_epochs=args.n_total_epoch,
callbacks=[ callbacks=[
EarlyStopping( EarlyStopping(
monitor="val_loss", monitor="val_loss",
mode="min", mode="min",
patience=4, patience=16,
),
LearningRateMonitor(),
ModelCheckpoint(
monitor="val_loss",
mode="min",
save_top_k=2,
save_last=True,
) )
], ],
accumulate_grad_batches=8, strategy=DDPSpawnStrategy(find_unused_parameters=False),
# auto_scale_batch_size='binsearch',
# auto_lr_find=True,
accumulate_grad_batches=4,
deterministic=True, deterministic=True,
check_val_every_n_epoch=1, check_val_every_n_epoch=1,
limit_val_batches=24, limit_val_batches=64,
limit_test_batches=24, limit_test_batches=256,
logger=wandb_logger, logger=wandb_logger,
default_root_dir=args.log_dir_lightning, default_root_dir=args.log_dir_lightning,
) )
trainer.fit(model, dataloader, test_dataloader) # trainer.tune(model)
trainer.fit(model)
trainer.validate()

Loading…
Cancel
Save