|
|
|
import numpy as np
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
import cv2
|
|
|
|
from pathlib import Path
|
|
|
|
import sys
|
|
|
|
|
|
|
|
import hyperdepth as hd
|
|
|
|
|
|
|
|
sys.path.append('../')
|
|
|
|
import dataset
|
|
|
|
|
|
|
|
|
|
|
|
def get_data(n, row_from, row_to, train):
|
|
|
|
imsizes = [(256, 384)]
|
|
|
|
focal_lengths = [160]
|
|
|
|
dset = dataset.SynDataset(n, imsizes=imsizes, focal_lengths=focal_lengths, train=train)
|
|
|
|
ims = np.empty((n, row_to - row_from, imsizes[0][1]), dtype=np.uint8)
|
|
|
|
disps = np.empty((n, row_to - row_from, imsizes[0][1]), dtype=np.float32)
|
|
|
|
for idx in range(n):
|
|
|
|
print(f'load sample {idx} train={train}')
|
|
|
|
sample = dset[idx]
|
|
|
|
ims[idx] = (sample['im0'][0, row_from:row_to] * 255).astype(np.uint8)
|
|
|
|
disps[idx] = sample['disp0'][0, row_from:row_to]
|
|
|
|
return ims, disps
|
|
|
|
|
|
|
|
|
|
|
|
params = hd.TrainParams(
|
|
|
|
n_trees=4,
|
|
|
|
max_tree_depth=,
|
|
|
|
n_test_split_functions=50,
|
|
|
|
n_test_thresholds=10,
|
|
|
|
n_test_samples=4096,
|
|
|
|
min_samples_to_split=16,
|
|
|
|
min_samples_for_leaf=8)
|
|
|
|
|
|
|
|
n_disp_bins = 20
|
|
|
|
depth_switch = 0
|
|
|
|
|
|
|
|
row_from = 100
|
|
|
|
row_to = 108
|
|
|
|
n_train_samples = 1024
|
|
|
|
n_test_samples = 32
|
|
|
|
|
|
|
|
train_ims, train_disps = get_data(n_train_samples, row_from, row_to, True)
|
|
|
|
test_ims, test_disps = get_data(n_test_samples, row_from, row_to, False)
|
|
|
|
|
|
|
|
for tree_depth in [8, 10, 12, 14, 16]:
|
|
|
|
depth_switch = tree_depth - 4
|
|
|
|
|
|
|
|
prefix = f'td{tree_depth}_ds{depth_switch}'
|
|
|
|
prefix = Path(f'./forests/{prefix}/')
|
|
|
|
prefix.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
|
|
|
hd.train_forest(params, train_ims, train_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch,
|
|
|
|
forest_prefix=str(prefix / 'fr'))
|
|
|
|
|
|
|
|
es = hd.eval_forest(test_ims, test_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch,
|
|
|
|
forest_prefix=str(prefix / 'fr'))
|
|
|
|
|
|
|
|
np.save(str(prefix / 'ta.npy'), test_disps)
|
|
|
|
np.save(str(prefix / 'es.npy'), es)
|
|
|
|
|
|
|
|
# plt.figure();
|
|
|
|
# plt.subplot(2,1,1); plt.imshow(test_disps[0], vmin=0, vmax=4);
|
|
|
|
# plt.subplot(2,1,2); plt.imshow(es[0], vmin=0, vmax=4);
|
|
|
|
# plt.show()
|