Merge pull request #2 from Qjizhi/main

enable training
This commit is contained in:
Ibai Gorordo 2022-04-12 10:12:23 +09:00 committed by GitHub
commit d91a867d5d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 512 additions and 0 deletions

2
.gitignore vendored
View File

@ -150,3 +150,5 @@ cython_debug/
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
vis_results/

View File

@ -19,6 +19,7 @@
# References:
- CREStereo: https://github.com/megvii-research/CREStereo
- CREStereo-Pytorch: https://github.com/ibaiGorordo/CREStereo-Pytorch
- RAFT: https://github.com/princeton-vl/RAFT
- LoFTR: https://github.com/zju3dv/LoFTR
- Grid sample replacement: https://zenn.dev/pinto0309/scraps/7d4032067d0160

20
cfgs/train.yaml Normal file
View File

@ -0,0 +1,20 @@
seed: 0
mixed_precision: false
base_lr: 4.0e-4
nr_gpus: 8
batch_size: 4
n_total_epoch: 600
minibatch_per_epoch: 500
loadmodel: ~
log_dir: "./train_log"
model_save_freq_epoch: 1
max_disp: 256
image_width: 512
image_height: 384
training_data_path: "./stereo_trainset/crestereo"
log_level: "logging.INFO"

215
dataset.py Normal file
View File

@ -0,0 +1,215 @@
import os
import cv2
import glob
import numpy as np
from PIL import Image, ImageEnhance
from megengine.data.dataset import Dataset
class Augmentor:
def __init__(
self,
image_height=384,
image_width=512,
max_disp=256,
scale_min=0.6,
scale_max=1.0,
seed=0,
):
super().__init__()
self.image_height = image_height
self.image_width = image_width
self.max_disp = max_disp
self.scale_min = scale_min
self.scale_max = scale_max
self.rng = np.random.RandomState(seed)
def chromatic_augmentation(self, img):
random_brightness = np.random.uniform(0.8, 1.2)
random_contrast = np.random.uniform(0.8, 1.2)
random_gamma = np.random.uniform(0.8, 1.2)
img = Image.fromarray(img)
enhancer = ImageEnhance.Brightness(img)
img = enhancer.enhance(random_brightness)
enhancer = ImageEnhance.Contrast(img)
img = enhancer.enhance(random_contrast)
gamma_map = [
255 * 1.0 * pow(ele / 255.0, random_gamma) for ele in range(256)
] * 3
img = img.point(gamma_map) # use PIL's point-function to accelerate this part
img_ = np.array(img)
return img_
def __call__(self, left_img, right_img, left_disp):
# 1. chromatic augmentation
left_img = self.chromatic_augmentation(left_img)
right_img = self.chromatic_augmentation(right_img)
# 2. spatial augmentation
# 2.1) rotate & vertical shift for right image
if self.rng.binomial(1, 0.5):
angle, pixel = 0.1, 2
px = self.rng.uniform(-pixel, pixel)
ag = self.rng.uniform(-angle, angle)
image_center = (
self.rng.uniform(0, right_img.shape[0]),
self.rng.uniform(0, right_img.shape[1]),
)
rot_mat = cv2.getRotationMatrix2D(image_center, ag, 1.0)
right_img = cv2.warpAffine(
right_img, rot_mat, right_img.shape[1::-1], flags=cv2.INTER_LINEAR
)
trans_mat = np.float32([[1, 0, 0], [0, 1, px]])
right_img = cv2.warpAffine(
right_img, trans_mat, right_img.shape[1::-1], flags=cv2.INTER_LINEAR
)
# 2.2) random resize
resize_scale = self.rng.uniform(self.scale_min, self.scale_max)
left_img = cv2.resize(
left_img,
None,
fx=resize_scale,
fy=resize_scale,
interpolation=cv2.INTER_LINEAR,
)
right_img = cv2.resize(
right_img,
None,
fx=resize_scale,
fy=resize_scale,
interpolation=cv2.INTER_LINEAR,
)
disp_mask = (left_disp < float(self.max_disp / resize_scale)) & (left_disp > 0)
disp_mask = disp_mask.astype("float32")
disp_mask = cv2.resize(
disp_mask,
None,
fx=resize_scale,
fy=resize_scale,
interpolation=cv2.INTER_LINEAR,
)
left_disp = (
cv2.resize(
left_disp,
None,
fx=resize_scale,
fy=resize_scale,
interpolation=cv2.INTER_LINEAR,
)
* resize_scale
)
# 2.3) random crop
h, w, c = left_img.shape
dx = w - self.image_width
dy = h - self.image_height
dy = self.rng.randint(min(0, dy), max(0, dy) + 1)
dx = self.rng.randint(min(0, dx), max(0, dx) + 1)
M = np.float32([[1.0, 0.0, -dx], [0.0, 1.0, -dy]])
left_img = cv2.warpAffine(
left_img,
M,
(self.image_width, self.image_height),
flags=cv2.INTER_LINEAR,
borderValue=0,
)
right_img = cv2.warpAffine(
right_img,
M,
(self.image_width, self.image_height),
flags=cv2.INTER_LINEAR,
borderValue=0,
)
left_disp = cv2.warpAffine(
left_disp,
M,
(self.image_width, self.image_height),
flags=cv2.INTER_LINEAR,
borderValue=0,
)
disp_mask = cv2.warpAffine(
disp_mask,
M,
(self.image_width, self.image_height),
flags=cv2.INTER_LINEAR,
borderValue=0,
)
# 3. add random occlusion to right image
if self.rng.binomial(1, 0.5):
sx = int(self.rng.uniform(50, 100))
sy = int(self.rng.uniform(50, 100))
cx = int(self.rng.uniform(sx, right_img.shape[0] - sx))
cy = int(self.rng.uniform(sy, right_img.shape[1] - sy))
right_img[cx - sx : cx + sx, cy - sy : cy + sy] = np.mean(
np.mean(right_img, 0), 0
)[np.newaxis, np.newaxis]
return left_img, right_img, left_disp, disp_mask
class CREStereoDataset(Dataset):
def __init__(self, root):
super().__init__()
self.imgs = glob.glob(os.path.join(root, "**/*_left.jpg"), recursive=True)
self.augmentor = Augmentor(
image_height=384,
image_width=512,
max_disp=256,
scale_min=0.6,
scale_max=1.0,
seed=0,
)
self.rng = np.random.RandomState(0)
def get_disp(self, path):
disp = cv2.imread(path, cv2.IMREAD_UNCHANGED)
return disp.astype(np.float32) / 32
def __getitem__(self, index):
# find path
left_path = self.imgs[index]
prefix = left_path[: left_path.rfind("_")]
right_path = prefix + "_right.jpg"
left_disp_path = prefix + "_left.disp.png"
right_disp_path = prefix + "_right.disp.png"
# read img, disp
left_img = cv2.imread(left_path, cv2.IMREAD_COLOR)
right_img = cv2.imread(right_path, cv2.IMREAD_COLOR)
left_disp = self.get_disp(left_disp_path)
right_disp = self.get_disp(right_disp_path)
if self.rng.binomial(1, 0.5):
left_img, right_img = np.fliplr(right_img), np.fliplr(left_img)
left_disp, right_disp = np.fliplr(right_disp), np.fliplr(left_disp)
left_disp[left_disp == np.inf] = 0
# augmentaion
left_img, right_img, left_disp, disp_mask = self.augmentor(
left_img, right_img, left_disp
)
left_img = left_img.transpose(2, 0, 1).astype("uint8")
right_img = right_img.transpose(2, 0, 1).astype("uint8")
return {
"left": left_img,
"right": right_img,
"disparity": left_disp,
"mask": disp_mask,
}
def __len__(self):
return len(self.imgs)

BIN
models/crestereo_eth3d.mge Normal file

Binary file not shown.

274
train.py Normal file
View File

@ -0,0 +1,274 @@
import os
import sys
import time
import logging
from collections import namedtuple
import yaml
from tensorboardX import SummaryWriter
from nets import Model
from dataset import CREStereoDataset
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
def parse_yaml(file_path: str) -> namedtuple:
"""Parse yaml configuration file and return the object in `namedtuple`."""
with open(file_path, "rb") as f:
cfg: dict = yaml.safe_load(f)
args = namedtuple("train_args", cfg.keys())(*cfg.values())
return args
def format_time(elapse):
elapse = int(elapse)
hour = elapse // 3600
minute = elapse % 3600 // 60
seconds = elapse % 60
return "{:02d}:{:02d}:{:02d}".format(hour, minute, seconds)
def ensure_dir(path):
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
def adjust_learning_rate(optimizer, epoch):
warm_up = 0.02
const_range = 0.6
min_lr_rate = 0.05
if epoch <= args.n_total_epoch * warm_up:
lr = (1 - min_lr_rate) * args.base_lr / (
args.n_total_epoch * warm_up
) * epoch + min_lr_rate * args.base_lr
elif args.n_total_epoch * warm_up < epoch <= args.n_total_epoch * const_range:
lr = args.base_lr
else:
lr = (min_lr_rate - 1) * args.base_lr / (
(1 - const_range) * args.n_total_epoch
) * epoch + (1 - min_lr_rate * const_range) / (1 - const_range) * args.base_lr
for param_group in optimizer.param_groups:
param_group['lr'] = lr
def sequence_loss(flow_preds, flow_gt, valid, gamma=0.8):
'''
valid: (2, 384, 512) (B, H, W) -> (B, 1, H, W)
flow_preds[0]: (B, 2, H, W)
flow_gt: (B, 2, H, W)
'''
n_predictions = len(flow_preds)
flow_loss = 0.0
for i in range(n_predictions):
i_weight = gamma ** (n_predictions - i - 1)
i_loss = torch.abs(flow_preds[i] - flow_gt)
flow_loss += i_weight * (valid.unsqueeze(1) * i_loss).mean()
return flow_loss
def main(args):
# initial info
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
# rank, world_size = dist.get_rank(), dist.get_world_size()
world_size = torch.cuda.device_count() # number of GPU(s)
# directory check
log_model_dir = os.path.join(args.log_dir, "models")
ensure_dir(log_model_dir)
# model / optimizer
model = Model(
max_disp=args.max_disp, mixed_precision=args.mixed_precision, test_mode=False
)
model = nn.DataParallel(model,device_ids=[i for i in range(world_size)])
model.cuda()
optimizer = optim.Adam(model.parameters(), lr=0.1, betas=(0.9, 0.999))
# model = nn.DataParallel(model,device_ids=[0])
tb_log = SummaryWriter(os.path.join(args.log_dir, "train.events"))
# worklog
logging.basicConfig(level=eval(args.log_level))
worklog = logging.getLogger("train_logger")
worklog.propagate = False
fileHandler = logging.FileHandler(
os.path.join(args.log_dir, "worklog.txt"), mode="a", encoding="utf8"
)
formatter = logging.Formatter(
fmt="%(asctime)s %(message)s", datefmt="%Y/%m/%d %H:%M:%S"
)
fileHandler.setFormatter(formatter)
consoleHandler = logging.StreamHandler(sys.stdout)
formatter = logging.Formatter(
fmt="\x1b[32m%(asctime)s\x1b[0m %(message)s", datefmt="%Y/%m/%d %H:%M:%S"
)
consoleHandler.setFormatter(formatter)
worklog.handlers = [fileHandler, consoleHandler]
# params stat
worklog.info(f"Use {world_size} GPU(s)")
worklog.info("Params: %s" % sum([p.numel() for p in model.parameters()]))
# load pretrained model if exist
chk_path = os.path.join(log_model_dir, "latest.pth")
if args.loadmodel is not None:
chk_path = args.loadmodel
elif not os.path.exists(chk_path):
chk_path = None
if chk_path is not None:
# if rank == 0:
worklog.info(f"loading model: {chk_path}")
state_dict = torch.load(chk_path)
model.load_state_dict(state_dict['state_dict'])
optimizer.load_state_dict(state_dict['optim_state_dict'])
resume_epoch_idx = state_dict["epoch"]
resume_iters = state_dict["iters"]
start_epoch_idx = resume_epoch_idx + 1
start_iters = resume_iters
else:
start_epoch_idx = 1
start_iters = 0
# datasets
dataset = CREStereoDataset(args.training_data_path)
# if rank == 0:
worklog.info(f"Dataset size: {len(dataset)}")
dataloader = DataLoader(dataset, args.batch_size, shuffle=True,
num_workers=0, drop_last=True, persistent_workers=False, pin_memory=True)
# counter
cur_iters = start_iters
total_iters = args.minibatch_per_epoch * args.n_total_epoch
t0 = time.perf_counter()
for epoch_idx in range(start_epoch_idx, args.n_total_epoch + 1):
# adjust learning rate
epoch_total_train_loss = 0
adjust_learning_rate(optimizer, epoch_idx)
model.train()
t1 = time.perf_counter()
# batch_idx = 0
# for mini_batch_data in dataloader:
for batch_idx, mini_batch_data in enumerate(dataloader):
if batch_idx % args.minibatch_per_epoch == 0 and batch_idx != 0:
break
# batch_idx += 1
cur_iters += 1
# parse data
left, right, gt_disp, valid_mask = (
mini_batch_data["left"],
mini_batch_data["right"],
mini_batch_data["disparity"].cuda(),
mini_batch_data["mask"].cuda(),
)
t2 = time.perf_counter()
optimizer.zero_grad()
# pre-process
gt_disp = torch.unsqueeze(gt_disp, dim=1) # [2, 384, 512] -> [2, 1, 384, 512]
gt_flow = torch.cat([gt_disp, gt_disp * 0], dim=1) # [2, 2, 384, 512]
# forward
flow_predictions = model(left.cuda(), right.cuda())
# loss & backword
loss = sequence_loss(
flow_predictions, gt_flow, valid_mask, gamma=0.8
)
# loss stats
loss_item = loss.data.item()
epoch_total_train_loss += loss_item
loss.backward()
optimizer.step()
t3 = time.perf_counter()
if cur_iters % 10 == 0:
tdata = t2 - t1
time_train_passed = t3 - t0
time_iter_passed = t3 - t1
step_passed = cur_iters - start_iters
eta = (
(total_iters - cur_iters)
/ max(step_passed, 1e-7)
* time_train_passed
)
meta_info = list()
meta_info.append("{:.2g} b/s".format(1.0 / time_iter_passed))
meta_info.append("passed:{}".format(format_time(time_train_passed)))
meta_info.append("eta:{}".format(format_time(eta)))
meta_info.append(
"data_time:{:.2g}".format(tdata / time_iter_passed)
)
meta_info.append(
"lr:{:.5g}".format(optimizer.param_groups[0]["lr"])
)
meta_info.append(
"[{}/{}:{}/{}]".format(
epoch_idx,
args.n_total_epoch,
batch_idx,
args.minibatch_per_epoch,
)
)
loss_info = [" ==> {}:{:.4g}".format("loss", loss_item)]
# exp_name = ['\n' + os.path.basename(os.getcwd())]
info = [",".join(meta_info)] + loss_info
worklog.info("".join(info))
# minibatch loss
tb_log.add_scalar("train/loss_batch", loss_item, cur_iters)
tb_log.add_scalar(
"train/lr", optimizer.param_groups[0]["lr"], cur_iters
)
tb_log.flush()
t1 = time.perf_counter()
tb_log.add_scalar(
"train/loss",
epoch_total_train_loss / args.minibatch_per_epoch,
epoch_idx,
)
tb_log.flush()
# save model params
ckp_data = {
"epoch": epoch_idx,
"iters": cur_iters,
"batch_size": args.batch_size,
"epoch_size": args.minibatch_per_epoch,
"train_loss": epoch_total_train_loss / args.minibatch_per_epoch,
"state_dict": model.state_dict(),
"optim_state_dict": optimizer.state_dict(),
}
torch.save(ckp_data, os.path.join(log_model_dir, "latest.pth"))
if epoch_idx % args.model_save_freq_epoch == 0:
save_path = os.path.join(log_model_dir, "epoch-%d.pth" % epoch_idx)
worklog.info(f"Model params saved: {save_path}")
torch.save(ckp_data, save_path)
worklog.info("Training is done, exit.")
if __name__ == "__main__":
# train configuration
args = parse_yaml("cfgs/train.yaml")
main(args)