changed some stuff, re-added default CREStereo scheduler
This commit is contained in:
parent
37c537ca31
commit
2731ef1ada
@ -1,12 +1,12 @@
|
||||
seed: 0
|
||||
mixed_precision: false
|
||||
# base_lr: 4.0e-4
|
||||
base_lr: 0.001
|
||||
t_max: 161
|
||||
mixed_precision: true
|
||||
base_lr: 4.0e-4
|
||||
# base_lr: 0.00001
|
||||
t_max: 16100
|
||||
|
||||
nr_gpus: 3
|
||||
batch_size: 2
|
||||
n_total_epoch: 300
|
||||
batch_size: 3
|
||||
n_total_epoch: 100
|
||||
minibatch_per_epoch: 500
|
||||
|
||||
loadmodel: ~
|
||||
@ -17,11 +17,18 @@ model_save_freq_epoch: 1
|
||||
max_disp: 256
|
||||
image_width: 640
|
||||
image_height: 480
|
||||
# dataset: "blender"
|
||||
# training_data_path: "./stereo_trainset/crestereo"
|
||||
pattern_attention: false
|
||||
dataset: "blender"
|
||||
# training_data_path: "/media/Data1/connecting_the_dots_data/ctd_data/"
|
||||
training_data_path: "/media/Data1/connecting_the_dots_data/blender_renders/data"
|
||||
# training_data_path: "/media/Data1/connecting_the_dots_data/blender_renders/data"
|
||||
training_data_path: "/media/Data1/connecting_the_dots_data/blender_renders_ctd_randomize_light/data"
|
||||
|
||||
|
||||
# FIXME any of this??
|
||||
pattern_attention: false
|
||||
scene_attention: true
|
||||
ignore_pattern_completely: false
|
||||
|
||||
|
||||
log_level: "logging.INFO"
|
||||
|
||||
|
71
dataset.py
71
dataset.py
@ -7,6 +7,11 @@ from PIL import Image, ImageEnhance
|
||||
|
||||
from megengine.data.dataset import Dataset
|
||||
|
||||
def downsample(img):
|
||||
downsampled = cv2.pyrDown(img)
|
||||
diff = (downsampled.shape[0] - 480) // 2
|
||||
return downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
|
||||
|
||||
|
||||
class Augmentor:
|
||||
def __init__(
|
||||
@ -234,11 +239,12 @@ class CREStereoDataset(Dataset):
|
||||
|
||||
|
||||
class CTDDataset(Dataset):
|
||||
def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False, split=0.9, test_set=False):
|
||||
def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False, split=0.9, test_set=False, use_lightning=True):
|
||||
super().__init__()
|
||||
self.rng = np.random.RandomState(0)
|
||||
self.augment = augment
|
||||
self.blur = blur
|
||||
self.use_lightning = use_lightning
|
||||
imgs = glob.glob(os.path.join(root, f"{data_type}/*/im0_*.npy"), recursive=True)
|
||||
if test_set:
|
||||
self.imgs = imgs[:int(split * len(imgs))]
|
||||
@ -248,10 +254,7 @@ class CTDDataset(Dataset):
|
||||
|
||||
if resize_pattern and self.pattern.shape != (480, 640, 3):
|
||||
# self.pattern = cv2.resize(self.pattern, (640, 480))
|
||||
print(self.pattern.shape)
|
||||
downsampled = cv2.pyrDown(self.pattern)
|
||||
diff = (downsampled.shape[0] - 480) // 2
|
||||
self.pattern = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
|
||||
self.pattern = downsample(self.pattern)
|
||||
|
||||
self.augmentor = Augmentor(
|
||||
image_height=480,
|
||||
@ -304,14 +307,18 @@ class CTDDataset(Dataset):
|
||||
left_img, right_img, left_disp
|
||||
)
|
||||
|
||||
right_img = right_img.transpose((2, 0, 1)).astype("uint8")
|
||||
|
||||
return {
|
||||
"left": left_img,
|
||||
"right": right_img,
|
||||
"disparity": left_disp,
|
||||
"mask": disp_mask,
|
||||
}
|
||||
if not self.use_lightning:
|
||||
right_img = right_img.transpose((2, 0, 1)).astype("uint8")
|
||||
return {
|
||||
"left": left_img,
|
||||
"right": right_img,
|
||||
"disparity": left_disp,
|
||||
"mask": disp_mask,
|
||||
}
|
||||
right_img = right_img.transpose((2, 0, 1)).astype("uint8")
|
||||
left_img = left_img.transpose((2, 0, 1)).astype("uint8")
|
||||
return left_img, right_img, left_disp, disp_mask
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imgs)
|
||||
@ -321,17 +328,30 @@ class BlenderDataset(CTDDataset):
|
||||
def __init__(self, root, pattern_path: str, data_type: str = 'syn', augment=True, resize_pattern=True, blur=False, split=0.9, test_set=False, use_lightning=False):
|
||||
super().__init__(root, pattern_path)
|
||||
self.use_lightning = use_lightning
|
||||
imgs = [f for f in glob.glob(f"{root}/im_*.png", recursive=True) if not 'depth0001' in f]
|
||||
if test_set:
|
||||
additional_img_types = {
|
||||
'depth',
|
||||
'disp',
|
||||
'grad',
|
||||
}
|
||||
|
||||
pngs = glob.glob(f"{root}/im_*.png", recursive=True)
|
||||
imgs = [
|
||||
img for img in pngs
|
||||
if all(
|
||||
map(
|
||||
lambda x: x not in img, additional_img_types
|
||||
)
|
||||
)
|
||||
]
|
||||
if not test_set:
|
||||
self.imgs = imgs[:int(split * len(imgs))]
|
||||
else:
|
||||
self.imgs = imgs[int(split * len(imgs)):]
|
||||
|
||||
self.pattern = cv2.imread(pattern_path)#, cv2.IMREAD_GRAYSCALE)
|
||||
|
||||
if resize_pattern and self.pattern.shape != (480, 640, 3):
|
||||
downsampled = cv2.pyrDown(self.pattern)
|
||||
diff = (downsampled.shape[0] - 480) // 2
|
||||
self.pattern = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
|
||||
self.pattern = downsample(self.pattern)
|
||||
|
||||
self.augmentor = Augmentor(
|
||||
image_height=480,
|
||||
@ -345,7 +365,7 @@ class BlenderDataset(CTDDataset):
|
||||
def __getitem__(self, index):
|
||||
# find path
|
||||
left_path = self.imgs[index]
|
||||
left_disp_path = left_path.split('.')[0] + '_depth0001.png'
|
||||
left_disp_path = left_path.split('.')[0] + '_disp0001.png'
|
||||
|
||||
# read img, disp
|
||||
left_img = cv2.imread(left_path)
|
||||
@ -354,14 +374,14 @@ class BlenderDataset(CTDDataset):
|
||||
left_img = (left_img * 255).astype('uint8')
|
||||
|
||||
if left_img.shape != (480, 640, 3):
|
||||
downsampled = cv2.pyrDown(left_img)
|
||||
diff = (downsampled.shape[0] - 480) // 2
|
||||
left_img = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
|
||||
left_img = downsample(left_img)
|
||||
if left_img.shape[-1] != 3:
|
||||
left_img = cv2.merge([left_img, left_img, left_img]).reshape((480, 640, 3))
|
||||
|
||||
right_img = self.pattern
|
||||
left_disp = self.get_disp(left_disp_path)
|
||||
# left_disp = self.get_disp(left_disp_path)
|
||||
disp = cv2.imread(left_disp_path, cv2.IMREAD_UNCHANGED)
|
||||
left_disp = downsample(disp)
|
||||
|
||||
if False: # self.rng.binomial(1, 0.5):
|
||||
left_img, right_img = np.fliplr(right_img), np.fliplr(left_img)
|
||||
@ -384,7 +404,6 @@ class BlenderDataset(CTDDataset):
|
||||
)
|
||||
|
||||
if not self.use_lightning:
|
||||
# right_img = right_img.transpose((2, 0, 1)).astype("uint8")
|
||||
return {
|
||||
"left": left_img,
|
||||
"right": right_img,
|
||||
@ -400,15 +419,13 @@ class BlenderDataset(CTDDataset):
|
||||
baseline = 0.075 # meters
|
||||
fl = 560. # as per CTD
|
||||
depth = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
downsampled = cv2.pyrDown(depth)
|
||||
diff = (downsampled.shape[0] - 480) // 2
|
||||
depth = downsampled[diff:downsampled.shape[0]-diff, 0:downsampled.shape[1]]
|
||||
depth = downsample(depth)
|
||||
# disp = np.load(path).transpose(1,2,0)
|
||||
# disp = baseline * fl / depth
|
||||
# return disp.astype(np.float32) / 32
|
||||
# FIXME temporarily increase disparity until new data with better depth values is generated
|
||||
# higher values seem to speedup convergence, but introduce much stronger artifacting
|
||||
mystery_factor = 150
|
||||
mystery_factor = 35
|
||||
# mystery_factor = 1
|
||||
disp = (baseline * fl * mystery_factor) / depth
|
||||
return disp.astype(np.float32)
|
||||
|
@ -94,6 +94,32 @@ def format_time(elapse):
|
||||
return "{:02d}:{:02d}:{:02d}".format(hour, minute, seconds)
|
||||
|
||||
|
||||
|
||||
def outlier_fraction(estimate, target, mask=None, threshold=0):
|
||||
def _process_inputs(estimate, target, mask):
|
||||
if estimate.shape != target.shape:
|
||||
raise Exception(f'estimate and target have to be same shape (expected {estimate.shape} == {target.shape})')
|
||||
if mask is None:
|
||||
mask = np.ones(estimate.shape, dtype=np.bool)
|
||||
else:
|
||||
mask = mask != 0
|
||||
if estimate.shape != mask.shape:
|
||||
raise Exception(f'estimate and mask have to be same shape (expected {estimate.shape} == {mask.shape})')
|
||||
return estimate, target, mask
|
||||
estimate = torch.squeeze(estimate[:, 0, :, :])
|
||||
target = torch.squeeze(target[:, 0, :, :])
|
||||
|
||||
estimate, target, mask = _process_inputs(estimate, target, mask)
|
||||
|
||||
mask = mask.cpu().detach().numpy()
|
||||
estimate = estimate.cpu().detach().numpy()
|
||||
target = target.cpu().detach().numpy()
|
||||
|
||||
diff = np.abs(estimate[mask] - target[mask])
|
||||
m = (diff > threshold).sum() / mask.sum()
|
||||
return m
|
||||
|
||||
|
||||
def ensure_dir(path):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path, exist_ok=True)
|
||||
@ -134,12 +160,12 @@ def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8, test=False):
|
||||
return flow_loss
|
||||
|
||||
|
||||
|
||||
class CREStereoLightning(LightningModule):
|
||||
def __init__(self, args, logger, pattern_path, data_path):
|
||||
def __init__(self, args, logger=None, pattern_path='', data_path=''):
|
||||
super().__init__()
|
||||
self.batch_size = args.batch_size
|
||||
self.wandb_logger = logger
|
||||
self.data_type = 'blender' if 'blender' in data_path else 'ctd'
|
||||
self.lr = args.base_lr
|
||||
print(f'lr = {self.lr}')
|
||||
self.T_max = args.t_max if args.t_max else None
|
||||
@ -149,13 +175,25 @@ class CREStereoLightning(LightningModule):
|
||||
self.model = Model(
|
||||
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False
|
||||
)
|
||||
# so I can access it in adjust learn rate more easily
|
||||
self.n_total_epoch = args.n_total_epoch
|
||||
self.base_lr = args.base_lr
|
||||
|
||||
self.automatic_optimization = False
|
||||
|
||||
def train_dataloader(self):
|
||||
dataset = BlenderDataset(
|
||||
root=self.data_path,
|
||||
pattern_path=self.pattern_path,
|
||||
use_lightning=True,
|
||||
)
|
||||
if self.data_type == 'blender':
|
||||
dataset = BlenderDataset(
|
||||
root=self.data_path,
|
||||
pattern_path=self.pattern_path,
|
||||
use_lightning=True,
|
||||
)
|
||||
elif self.data_type == 'ctd':
|
||||
dataset = CTDDataset(
|
||||
root=self.data_path,
|
||||
pattern_path=self.pattern_path,
|
||||
use_lightning=True,
|
||||
)
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
self.batch_size,
|
||||
@ -169,12 +207,20 @@ class CREStereoLightning(LightningModule):
|
||||
return dataloader
|
||||
|
||||
def val_dataloader(self):
|
||||
test_dataset = BlenderDataset(
|
||||
root=self.data_path,
|
||||
pattern_path=self.pattern_path,
|
||||
test_set=True,
|
||||
use_lightning=True,
|
||||
)
|
||||
if self.data_type == 'blender':
|
||||
test_dataset = BlenderDataset(
|
||||
root=self.data_path,
|
||||
pattern_path=self.pattern_path,
|
||||
test_set=True,
|
||||
use_lightning=True,
|
||||
)
|
||||
elif self.data_type == 'ctd':
|
||||
test_dataset = CTDDataset(
|
||||
root=self.data_path,
|
||||
pattern_path=self.pattern_path,
|
||||
test_set=True,
|
||||
use_lightning=True,
|
||||
)
|
||||
|
||||
test_dataloader = DataLoader(
|
||||
test_dataset,
|
||||
@ -190,12 +236,20 @@ class CREStereoLightning(LightningModule):
|
||||
|
||||
def test_dataloader(self):
|
||||
# TODO change this to use IRL data?
|
||||
test_dataset = BlenderDataset(
|
||||
root=self.data_path,
|
||||
pattern_path=self.pattern_path,
|
||||
test_set=True,
|
||||
use_lightning=True,
|
||||
)
|
||||
if self.data_type == 'blender':
|
||||
test_dataset = CTDDataset(
|
||||
root=self.data_path,
|
||||
pattern_path=self.pattern_path,
|
||||
test_set=True,
|
||||
use_lightning=True,
|
||||
)
|
||||
elif self.data_type == 'ctd':
|
||||
test_dataset = BlenderDataset(
|
||||
root=self.data_path,
|
||||
pattern_path=self.pattern_path,
|
||||
test_set=True,
|
||||
use_lightning=True,
|
||||
)
|
||||
|
||||
test_dataloader = DataLoader(
|
||||
test_dataset,
|
||||
@ -232,6 +286,10 @@ class CREStereoLightning(LightningModule):
|
||||
image_log['key'] = 'debug_train'
|
||||
self.wandb_logger.log_image(**image_log)
|
||||
self.log("train_loss", loss)
|
||||
|
||||
# update learn rate every N epochs
|
||||
if self.trainer.is_last_batch:
|
||||
self.adjust_learning_rate()
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx):
|
||||
@ -243,6 +301,11 @@ class CREStereoLightning(LightningModule):
|
||||
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
||||
)
|
||||
self.log("val_loss", val_loss)
|
||||
of = {}
|
||||
for threshold in [0.1, 0.5, 1, 2, 5]:
|
||||
of[threshold] = outlier_fraction(flow_predictions[0], gt_flow, valid_mask, threshold)
|
||||
self.log("outlier_fraction", of)
|
||||
# print(', '.join(f'of{thr}={val}' for thr, val in of.items()))
|
||||
if batch_idx % 8 == 0:
|
||||
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp))
|
||||
|
||||
@ -256,6 +319,9 @@ class CREStereoLightning(LightningModule):
|
||||
)
|
||||
self.log("test_loss", test_loss)
|
||||
self.wandb_logger.log_image(**log_images(left, right, flow_predictions, gt_disp))
|
||||
|
||||
def predict_step(self, batch, batch_idx, dataloader_idx=0):
|
||||
return self(batch)
|
||||
|
||||
def configure_optimizers(self):
|
||||
optimizer = optim.Adam(self.model.parameters(), lr=self.lr, betas=(0.9, 0.999))
|
||||
@ -268,35 +334,70 @@ class CREStereoLightning(LightningModule):
|
||||
'name': 'CosineAnnealingLRScheduler',
|
||||
}
|
||||
return [optimizer], [lr_scheduler]
|
||||
|
||||
def adjust_learning_rate(self):
|
||||
optimizer = self.optimizers().optimizer
|
||||
epoch = self.trainer.current_epoch+1
|
||||
|
||||
warm_up = 0.02
|
||||
const_range = 0.6
|
||||
min_lr_rate = 0.05
|
||||
|
||||
if epoch <= self.n_total_epoch * warm_up:
|
||||
lr = (1 - min_lr_rate) * self.base_lr / (
|
||||
self.n_total_epoch * warm_up
|
||||
) * epoch + min_lr_rate * self.base_lr
|
||||
elif self.n_total_epoch * warm_up < epoch <= self.n_total_epoch * const_range:
|
||||
lr = self.base_lr
|
||||
else:
|
||||
lr = (min_lr_rate - 1) * self.base_lr / (
|
||||
(1 - const_range) * self.n_total_epoch
|
||||
) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * self.base_lr
|
||||
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
wandb_logger = WandbLogger(project="crestereo-lightning", log_model=True)
|
||||
# train configuration
|
||||
args = parse_yaml("cfgs/train.yaml")
|
||||
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
|
||||
|
||||
run = wandb.init(project="crestereo-lightning", config=args._asdict(), tags=['new_scheduler', 'default_lr', f'{"" if args.pattern_attention else "no-"}pattern-attention'], notes='')
|
||||
run.config.update(args._asdict())
|
||||
wandb_logger.experiment.config.update(args._asdict())
|
||||
config = wandb.config
|
||||
wandb_logger = WandbLogger(project="crestereo-lightning", id=run.id, log_model=True)
|
||||
# wandb_logger = WandbLogger(project="crestereo-lightning", log_model='all')
|
||||
# wandb_logger.experiment.config.update(args._asdict())
|
||||
if 'blender' in config.training_data_path:
|
||||
# this was used for our blender renders
|
||||
pattern_path = '/home/nils/miniprojekt/kinect_syn_ref.png'
|
||||
if 'ctd' in config.training_data_path:
|
||||
# this one is used (i hope) for ctd
|
||||
pattern_path = '/home/nils/kinect_from_settings.png'
|
||||
|
||||
|
||||
devices = min(config.nr_gpus, torch.cuda.device_count())
|
||||
if devices != config.nr_gpus:
|
||||
print(f'Using less devices than expected! ({devices} / {config.nr_gpus})')
|
||||
|
||||
model = CREStereoLightning(
|
||||
# args,
|
||||
config,
|
||||
wandb_logger,
|
||||
pattern_path,
|
||||
args.training_data_path,
|
||||
config.training_data_path,
|
||||
# lr=0.00017378008287493763, # found with auto_lr_find=True
|
||||
)
|
||||
# NOTE turn this down once it's working, this might use too much space
|
||||
# wandb_logger.watch(model, log_graph=False) #, log='all')
|
||||
|
||||
model_checkpoint = ModelCheckpoint(
|
||||
monitor="val_loss",
|
||||
mode="min",
|
||||
save_top_k=2,
|
||||
save_last=True,
|
||||
)
|
||||
|
||||
trainer = Trainer(
|
||||
accelerator='gpu',
|
||||
devices=args.nr_gpus,
|
||||
max_epochs=args.n_total_epoch,
|
||||
devices=devices,
|
||||
max_epochs=config.n_total_epoch,
|
||||
callbacks=[
|
||||
EarlyStopping(
|
||||
monitor="val_loss",
|
||||
@ -304,25 +405,20 @@ if __name__ == "__main__":
|
||||
patience=16,
|
||||
),
|
||||
LearningRateMonitor(),
|
||||
ModelCheckpoint(
|
||||
monitor="val_loss",
|
||||
mode="min",
|
||||
save_top_k=2,
|
||||
save_last=True,
|
||||
)
|
||||
model_checkpoint,
|
||||
],
|
||||
strategy=DDPSpawnStrategy(find_unused_parameters=False),
|
||||
# auto_scale_batch_size='binsearch',
|
||||
# auto_lr_find=True,
|
||||
accumulate_grad_batches=4,
|
||||
# accumulate_grad_batches=4, # needed to disable for manual optimization
|
||||
deterministic=True,
|
||||
check_val_every_n_epoch=1,
|
||||
limit_val_batches=64,
|
||||
limit_test_batches=256,
|
||||
logger=wandb_logger,
|
||||
default_root_dir=args.log_dir_lightning,
|
||||
default_root_dir=config.log_dir_lightning,
|
||||
)
|
||||
|
||||
# trainer.tune(model)
|
||||
trainer.fit(model)
|
||||
trainer.validate()
|
||||
# trainer.validate(chkpt_path=model_checkpoint.best_model_path)
|
||||
|
Loading…
Reference in New Issue
Block a user