enable training
This commit is contained in:
parent
982b5027b9
commit
a35086dc02
2
.gitignore
vendored
2
.gitignore
vendored
@ -150,3 +150,5 @@ cython_debug/
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
vis_results/
|
@ -19,6 +19,7 @@
|
||||
|
||||
# References:
|
||||
- CREStereo: https://github.com/megvii-research/CREStereo
|
||||
- CREStereo-Pytorch: https://github.com/ibaiGorordo/CREStereo-Pytorch
|
||||
- RAFT: https://github.com/princeton-vl/RAFT
|
||||
- LoFTR: https://github.com/zju3dv/LoFTR
|
||||
- Grid sample replacement: https://zenn.dev/pinto0309/scraps/7d4032067d0160
|
||||
|
20
cfgs/train.yaml
Normal file
20
cfgs/train.yaml
Normal file
@ -0,0 +1,20 @@
|
||||
seed: 0
|
||||
mixed_precision: false
|
||||
base_lr: 4.0e-4
|
||||
|
||||
nr_gpus: 8
|
||||
batch_size: 4
|
||||
n_total_epoch: 600
|
||||
minibatch_per_epoch: 500
|
||||
|
||||
loadmodel: ~
|
||||
log_dir: "./train_log"
|
||||
model_save_freq_epoch: 1
|
||||
|
||||
max_disp: 256
|
||||
image_width: 512
|
||||
image_height: 384
|
||||
training_data_path: "./stereo_trainset/crestereo"
|
||||
|
||||
log_level: "logging.INFO"
|
||||
|
215
dataset.py
Normal file
215
dataset.py
Normal file
@ -0,0 +1,215 @@
|
||||
import os
|
||||
import cv2
|
||||
import glob
|
||||
import numpy as np
|
||||
from PIL import Image, ImageEnhance
|
||||
|
||||
from megengine.data.dataset import Dataset
|
||||
|
||||
|
||||
class Augmentor:
|
||||
def __init__(
|
||||
self,
|
||||
image_height=384,
|
||||
image_width=512,
|
||||
max_disp=256,
|
||||
scale_min=0.6,
|
||||
scale_max=1.0,
|
||||
seed=0,
|
||||
):
|
||||
super().__init__()
|
||||
self.image_height = image_height
|
||||
self.image_width = image_width
|
||||
self.max_disp = max_disp
|
||||
self.scale_min = scale_min
|
||||
self.scale_max = scale_max
|
||||
self.rng = np.random.RandomState(seed)
|
||||
|
||||
def chromatic_augmentation(self, img):
|
||||
random_brightness = np.random.uniform(0.8, 1.2)
|
||||
random_contrast = np.random.uniform(0.8, 1.2)
|
||||
random_gamma = np.random.uniform(0.8, 1.2)
|
||||
|
||||
img = Image.fromarray(img)
|
||||
|
||||
enhancer = ImageEnhance.Brightness(img)
|
||||
img = enhancer.enhance(random_brightness)
|
||||
enhancer = ImageEnhance.Contrast(img)
|
||||
img = enhancer.enhance(random_contrast)
|
||||
|
||||
gamma_map = [
|
||||
255 * 1.0 * pow(ele / 255.0, random_gamma) for ele in range(256)
|
||||
] * 3
|
||||
img = img.point(gamma_map) # use PIL's point-function to accelerate this part
|
||||
|
||||
img_ = np.array(img)
|
||||
|
||||
return img_
|
||||
|
||||
def __call__(self, left_img, right_img, left_disp):
|
||||
# 1. chromatic augmentation
|
||||
left_img = self.chromatic_augmentation(left_img)
|
||||
right_img = self.chromatic_augmentation(right_img)
|
||||
|
||||
# 2. spatial augmentation
|
||||
# 2.1) rotate & vertical shift for right image
|
||||
if self.rng.binomial(1, 0.5):
|
||||
angle, pixel = 0.1, 2
|
||||
px = self.rng.uniform(-pixel, pixel)
|
||||
ag = self.rng.uniform(-angle, angle)
|
||||
image_center = (
|
||||
self.rng.uniform(0, right_img.shape[0]),
|
||||
self.rng.uniform(0, right_img.shape[1]),
|
||||
)
|
||||
rot_mat = cv2.getRotationMatrix2D(image_center, ag, 1.0)
|
||||
right_img = cv2.warpAffine(
|
||||
right_img, rot_mat, right_img.shape[1::-1], flags=cv2.INTER_LINEAR
|
||||
)
|
||||
trans_mat = np.float32([[1, 0, 0], [0, 1, px]])
|
||||
right_img = cv2.warpAffine(
|
||||
right_img, trans_mat, right_img.shape[1::-1], flags=cv2.INTER_LINEAR
|
||||
)
|
||||
|
||||
# 2.2) random resize
|
||||
resize_scale = self.rng.uniform(self.scale_min, self.scale_max)
|
||||
|
||||
left_img = cv2.resize(
|
||||
left_img,
|
||||
None,
|
||||
fx=resize_scale,
|
||||
fy=resize_scale,
|
||||
interpolation=cv2.INTER_LINEAR,
|
||||
)
|
||||
right_img = cv2.resize(
|
||||
right_img,
|
||||
None,
|
||||
fx=resize_scale,
|
||||
fy=resize_scale,
|
||||
interpolation=cv2.INTER_LINEAR,
|
||||
)
|
||||
|
||||
disp_mask = (left_disp < float(self.max_disp / resize_scale)) & (left_disp > 0)
|
||||
disp_mask = disp_mask.astype("float32")
|
||||
disp_mask = cv2.resize(
|
||||
disp_mask,
|
||||
None,
|
||||
fx=resize_scale,
|
||||
fy=resize_scale,
|
||||
interpolation=cv2.INTER_LINEAR,
|
||||
)
|
||||
|
||||
left_disp = (
|
||||
cv2.resize(
|
||||
left_disp,
|
||||
None,
|
||||
fx=resize_scale,
|
||||
fy=resize_scale,
|
||||
interpolation=cv2.INTER_LINEAR,
|
||||
)
|
||||
* resize_scale
|
||||
)
|
||||
|
||||
# 2.3) random crop
|
||||
h, w, c = left_img.shape
|
||||
dx = w - self.image_width
|
||||
dy = h - self.image_height
|
||||
dy = self.rng.randint(min(0, dy), max(0, dy) + 1)
|
||||
dx = self.rng.randint(min(0, dx), max(0, dx) + 1)
|
||||
|
||||
M = np.float32([[1.0, 0.0, -dx], [0.0, 1.0, -dy]])
|
||||
left_img = cv2.warpAffine(
|
||||
left_img,
|
||||
M,
|
||||
(self.image_width, self.image_height),
|
||||
flags=cv2.INTER_LINEAR,
|
||||
borderValue=0,
|
||||
)
|
||||
right_img = cv2.warpAffine(
|
||||
right_img,
|
||||
M,
|
||||
(self.image_width, self.image_height),
|
||||
flags=cv2.INTER_LINEAR,
|
||||
borderValue=0,
|
||||
)
|
||||
left_disp = cv2.warpAffine(
|
||||
left_disp,
|
||||
M,
|
||||
(self.image_width, self.image_height),
|
||||
flags=cv2.INTER_LINEAR,
|
||||
borderValue=0,
|
||||
)
|
||||
disp_mask = cv2.warpAffine(
|
||||
disp_mask,
|
||||
M,
|
||||
(self.image_width, self.image_height),
|
||||
flags=cv2.INTER_LINEAR,
|
||||
borderValue=0,
|
||||
)
|
||||
|
||||
# 3. add random occlusion to right image
|
||||
if self.rng.binomial(1, 0.5):
|
||||
sx = int(self.rng.uniform(50, 100))
|
||||
sy = int(self.rng.uniform(50, 100))
|
||||
cx = int(self.rng.uniform(sx, right_img.shape[0] - sx))
|
||||
cy = int(self.rng.uniform(sy, right_img.shape[1] - sy))
|
||||
right_img[cx - sx : cx + sx, cy - sy : cy + sy] = np.mean(
|
||||
np.mean(right_img, 0), 0
|
||||
)[np.newaxis, np.newaxis]
|
||||
|
||||
return left_img, right_img, left_disp, disp_mask
|
||||
|
||||
|
||||
class CREStereoDataset(Dataset):
|
||||
def __init__(self, root):
|
||||
super().__init__()
|
||||
self.imgs = glob.glob(os.path.join(root, "**/*_left.jpg"), recursive=True)
|
||||
self.augmentor = Augmentor(
|
||||
image_height=384,
|
||||
image_width=512,
|
||||
max_disp=256,
|
||||
scale_min=0.6,
|
||||
scale_max=1.0,
|
||||
seed=0,
|
||||
)
|
||||
self.rng = np.random.RandomState(0)
|
||||
|
||||
def get_disp(self, path):
|
||||
disp = cv2.imread(path, cv2.IMREAD_UNCHANGED)
|
||||
return disp.astype(np.float32) / 32
|
||||
|
||||
def __getitem__(self, index):
|
||||
# find path
|
||||
left_path = self.imgs[index]
|
||||
prefix = left_path[: left_path.rfind("_")]
|
||||
right_path = prefix + "_right.jpg"
|
||||
left_disp_path = prefix + "_left.disp.png"
|
||||
right_disp_path = prefix + "_right.disp.png"
|
||||
|
||||
# read img, disp
|
||||
left_img = cv2.imread(left_path, cv2.IMREAD_COLOR)
|
||||
right_img = cv2.imread(right_path, cv2.IMREAD_COLOR)
|
||||
left_disp = self.get_disp(left_disp_path)
|
||||
right_disp = self.get_disp(right_disp_path)
|
||||
|
||||
if self.rng.binomial(1, 0.5):
|
||||
left_img, right_img = np.fliplr(right_img), np.fliplr(left_img)
|
||||
left_disp, right_disp = np.fliplr(right_disp), np.fliplr(left_disp)
|
||||
left_disp[left_disp == np.inf] = 0
|
||||
|
||||
# augmentaion
|
||||
left_img, right_img, left_disp, disp_mask = self.augmentor(
|
||||
left_img, right_img, left_disp
|
||||
)
|
||||
|
||||
left_img = left_img.transpose(2, 0, 1).astype("uint8")
|
||||
right_img = right_img.transpose(2, 0, 1).astype("uint8")
|
||||
|
||||
return {
|
||||
"left": left_img,
|
||||
"right": right_img,
|
||||
"disparity": left_disp,
|
||||
"mask": disp_mask,
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self.imgs)
|
BIN
models/crestereo_eth3d.mge
Normal file
BIN
models/crestereo_eth3d.mge
Normal file
Binary file not shown.
274
train.py
Normal file
274
train.py
Normal file
@ -0,0 +1,274 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
|
||||
import yaml
|
||||
from tensorboardX import SummaryWriter
|
||||
|
||||
from nets import Model
|
||||
from dataset import CREStereoDataset
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
def parse_yaml(file_path: str) -> namedtuple:
|
||||
"""Parse yaml configuration file and return the object in `namedtuple`."""
|
||||
with open(file_path, "rb") as f:
|
||||
cfg: dict = yaml.safe_load(f)
|
||||
args = namedtuple("train_args", cfg.keys())(*cfg.values())
|
||||
return args
|
||||
|
||||
|
||||
def format_time(elapse):
|
||||
elapse = int(elapse)
|
||||
hour = elapse // 3600
|
||||
minute = elapse % 3600 // 60
|
||||
seconds = elapse % 60
|
||||
return "{:02d}:{:02d}:{:02d}".format(hour, minute, seconds)
|
||||
|
||||
|
||||
def ensure_dir(path):
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path, exist_ok=True)
|
||||
|
||||
|
||||
def adjust_learning_rate(optimizer, epoch):
|
||||
|
||||
warm_up = 0.02
|
||||
const_range = 0.6
|
||||
min_lr_rate = 0.05
|
||||
|
||||
if epoch <= args.n_total_epoch * warm_up:
|
||||
lr = (1 - min_lr_rate) * args.base_lr / (
|
||||
args.n_total_epoch * warm_up
|
||||
) * epoch + min_lr_rate * args.base_lr
|
||||
elif args.n_total_epoch * warm_up < epoch <= args.n_total_epoch * const_range:
|
||||
lr = args.base_lr
|
||||
else:
|
||||
lr = (min_lr_rate - 1) * args.base_lr / (
|
||||
(1 - const_range) * args.n_total_epoch
|
||||
) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * args.base_lr
|
||||
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group['lr'] = lr
|
||||
|
||||
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8):
|
||||
'''
|
||||
valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W)
|
||||
flow_preds[0]: (B, 2, H, W)
|
||||
flow_gt: (B, 2, H, W)
|
||||
'''
|
||||
n_predictions = len(flow_preds)
|
||||
flow_loss = 0.0
|
||||
for i in range(n_predictions):
|
||||
i_weight = gamma ** (n_predictions - i - 1)
|
||||
i_loss = torch.abs(flow_preds[i] - flow_gt)
|
||||
flow_loss += i_weight * (valid.unsqueeze(1) * i_loss).mean()
|
||||
|
||||
return flow_loss
|
||||
|
||||
|
||||
def main(args):
|
||||
# initial info
|
||||
torch.manual_seed(args.seed)
|
||||
torch.cuda.manual_seed(args.seed)
|
||||
# rank, world_size = dist.get_rank(), dist.get_world_size()
|
||||
world_size = torch.cuda.device_count() # number of GPU(s)
|
||||
|
||||
# directory check
|
||||
log_model_dir = os.path.join(args.log_dir, "models")
|
||||
ensure_dir(log_model_dir)
|
||||
|
||||
# model / optimizer
|
||||
model = Model(
|
||||
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False
|
||||
)
|
||||
model = nn.DataParallel(model,device_ids=[i for i in range(world_size)])
|
||||
model.cuda()
|
||||
optimizer = optim.Adam(model.parameters(), lr=0.1, betas=(0.9, 0.999))
|
||||
# model = nn.DataParallel(model,device_ids=[0])
|
||||
|
||||
tb_log = SummaryWriter(os.path.join(args.log_dir, "train.events"))
|
||||
|
||||
# worklog
|
||||
logging.basicConfig(level=eval(args.log_level))
|
||||
worklog = logging.getLogger("train_logger")
|
||||
worklog.propagate = False
|
||||
fileHandler = logging.FileHandler(
|
||||
os.path.join(args.log_dir, "worklog.txt"), mode="a", encoding="utf8"
|
||||
)
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s %(message)s", datefmt="%Y/%m/%d %H:%M:%S"
|
||||
)
|
||||
fileHandler.setFormatter(formatter)
|
||||
consoleHandler = logging.StreamHandler(sys.stdout)
|
||||
formatter = logging.Formatter(
|
||||
fmt="\x1b[32m%(asctime)s\x1b[0m %(message)s", datefmt="%Y/%m/%d %H:%M:%S"
|
||||
)
|
||||
consoleHandler.setFormatter(formatter)
|
||||
worklog.handlers = [fileHandler, consoleHandler]
|
||||
|
||||
# params stat
|
||||
worklog.info(f"Use {world_size} GPU(s)")
|
||||
worklog.info("Params: %s" % sum([p.numel() for p in model.parameters()]))
|
||||
|
||||
# load pretrained model if exist
|
||||
chk_path = os.path.join(log_model_dir, "latest.pth")
|
||||
if args.loadmodel is not None:
|
||||
chk_path = args.loadmodel
|
||||
elif not os.path.exists(chk_path):
|
||||
chk_path = None
|
||||
|
||||
if chk_path is not None:
|
||||
# if rank == 0:
|
||||
worklog.info(f"loading model: {chk_path}")
|
||||
state_dict = torch.load(chk_path)
|
||||
model.load_state_dict(state_dict['state_dict'])
|
||||
optimizer.load_state_dict(state_dict['optim_state_dict'])
|
||||
resume_epoch_idx = state_dict["epoch"]
|
||||
resume_iters = state_dict["iters"]
|
||||
start_epoch_idx = resume_epoch_idx + 1
|
||||
start_iters = resume_iters
|
||||
else:
|
||||
start_epoch_idx = 1
|
||||
start_iters = 0
|
||||
|
||||
# datasets
|
||||
dataset = CREStereoDataset(args.training_data_path)
|
||||
# if rank == 0:
|
||||
worklog.info(f"Dataset size: {len(dataset)}")
|
||||
dataloader = DataLoader(dataset, args.batch_size, shuffle=True,
|
||||
num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
|
||||
|
||||
# counter
|
||||
cur_iters = start_iters
|
||||
total_iters = args.minibatch_per_epoch * args.n_total_epoch
|
||||
t0 = time.perf_counter()
|
||||
for epoch_idx in range(start_epoch_idx, args.n_total_epoch + 1):
|
||||
|
||||
# adjust learning rate
|
||||
epoch_total_train_loss = 0
|
||||
adjust_learning_rate(optimizer, epoch_idx)
|
||||
model.train()
|
||||
|
||||
t1 = time.perf_counter()
|
||||
# batch_idx = 0
|
||||
|
||||
# for mini_batch_data in dataloader:
|
||||
for batch_idx, mini_batch_data in enumerate(dataloader):
|
||||
|
||||
if batch_idx % args.minibatch_per_epoch == 0 and batch_idx != 0:
|
||||
break
|
||||
# batch_idx += 1
|
||||
cur_iters += 1
|
||||
|
||||
# parse data
|
||||
left, right, gt_disp, valid_mask = (
|
||||
mini_batch_data["left"],
|
||||
mini_batch_data["right"],
|
||||
mini_batch_data["disparity"].cuda(),
|
||||
mini_batch_data["mask"].cuda(),
|
||||
)
|
||||
|
||||
t2 = time.perf_counter()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# pre-process
|
||||
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
|
||||
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
|
||||
|
||||
# forward
|
||||
flow_predictions = model(left.cuda(), right.cuda())
|
||||
|
||||
# loss & backword
|
||||
loss = sequence_loss(
|
||||
flow_predictions, gt_flow, valid_mask, gamma=0.8
|
||||
)
|
||||
|
||||
# loss stats
|
||||
loss_item = loss.data.item()
|
||||
epoch_total_train_loss += loss_item
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
t3 = time.perf_counter()
|
||||
|
||||
if cur_iters % 10 == 0:
|
||||
tdata = t2 - t1
|
||||
time_train_passed = t3 - t0
|
||||
time_iter_passed = t3 - t1
|
||||
step_passed = cur_iters - start_iters
|
||||
eta = (
|
||||
(total_iters - cur_iters)
|
||||
/ max(step_passed, 1e-7)
|
||||
* time_train_passed
|
||||
)
|
||||
|
||||
meta_info = list()
|
||||
meta_info.append("{:.2g} b/s".format(1.0 / time_iter_passed))
|
||||
meta_info.append("passed:{}".format(format_time(time_train_passed)))
|
||||
meta_info.append("eta:{}".format(format_time(eta)))
|
||||
meta_info.append(
|
||||
"data_time:{:.2g}".format(tdata / time_iter_passed)
|
||||
)
|
||||
|
||||
meta_info.append(
|
||||
"lr:{:.5g}".format(optimizer.param_groups[0]["lr"])
|
||||
)
|
||||
meta_info.append(
|
||||
"[{}/{}:{}/{}]".format(
|
||||
epoch_idx,
|
||||
args.n_total_epoch,
|
||||
batch_idx,
|
||||
args.minibatch_per_epoch,
|
||||
)
|
||||
)
|
||||
loss_info = [" ==> {}:{:.4g}".format("loss", loss_item)]
|
||||
# exp_name = ['\n' + os.path.basename(os.getcwd())]
|
||||
|
||||
info = [",".join(meta_info)] + loss_info
|
||||
worklog.info("".join(info))
|
||||
|
||||
# minibatch loss
|
||||
tb_log.add_scalar("train/loss_batch", loss_item, cur_iters)
|
||||
tb_log.add_scalar(
|
||||
"train/lr", optimizer.param_groups[0]["lr"], cur_iters
|
||||
)
|
||||
tb_log.flush()
|
||||
|
||||
t1 = time.perf_counter()
|
||||
|
||||
tb_log.add_scalar(
|
||||
"train/loss",
|
||||
epoch_total_train_loss / args.minibatch_per_epoch,
|
||||
epoch_idx,
|
||||
)
|
||||
tb_log.flush()
|
||||
|
||||
# save model params
|
||||
ckp_data = {
|
||||
"epoch": epoch_idx,
|
||||
"iters": cur_iters,
|
||||
"batch_size": args.batch_size,
|
||||
"epoch_size": args.minibatch_per_epoch,
|
||||
"train_loss": epoch_total_train_loss / args.minibatch_per_epoch,
|
||||
"state_dict": model.state_dict(),
|
||||
"optim_state_dict": optimizer.state_dict(),
|
||||
}
|
||||
torch.save(ckp_data, os.path.join(log_model_dir, "latest.pth"))
|
||||
if epoch_idx % args.model_save_freq_epoch == 0:
|
||||
save_path = os.path.join(log_model_dir, "epoch-%d.pth" % epoch_idx)
|
||||
worklog.info(f"Model params saved: {save_path}")
|
||||
torch.save(ckp_data, save_path)
|
||||
|
||||
worklog.info("Training is done, exit.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# train configuration
|
||||
args = parse_yaml("cfgs/train.yaml")
|
||||
main(args)
|
Loading…
Reference in New Issue
Block a user