You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
connecting_the_dots/hyperdepth/hyperparam_search.py

66 lines
1.8 KiB

5 years ago
import numpy as np
import matplotlib.pyplot as plt
import cv2
from pathlib import Path
import sys
import hyperdepth as hd
sys.path.append('../')
import dataset
def get_data(n, row_from, row_to, train):
imsizes = [(256,384)]
focal_lengths = [160]
dset = dataset.SynDataset(n, imsizes=imsizes, focal_lengths=focal_lengths, train=train)
ims = np.empty((n, row_to-row_from, imsizes[0][1]), dtype=np.uint8)
disps = np.empty((n, row_to-row_from, imsizes[0][1]), dtype=np.float32)
for idx in range(n):
print(f'load sample {idx} train={train}')
sample = dset[idx]
ims[idx] = (sample['im0'][0,row_from:row_to] * 255).astype(np.uint8)
disps[idx] = sample['disp0'][0,row_from:row_to]
return ims, disps
params = hd.TrainParams(
n_trees=4,
max_tree_depth=,
n_test_split_functions=50,
n_test_thresholds=10,
n_test_samples=4096,
min_samples_to_split=16,
min_samples_for_leaf=8)
n_disp_bins = 20
depth_switch = 0
row_from = 100
row_to = 108
n_train_samples = 1024
n_test_samples = 32
train_ims, train_disps = get_data(n_train_samples, row_from, row_to, True)
test_ims, test_disps = get_data(n_test_samples, row_from, row_to, False)
for tree_depth in [8,10,12,14,16]:
depth_switch = tree_depth - 4
prefix = f'td{tree_depth}_ds{depth_switch}'
prefix = Path(f'./forests/{prefix}/')
prefix.mkdir(parents=True, exist_ok=True)
hd.train_forest(params, train_ims, train_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch, forest_prefix=str(prefix / 'fr'))
es = hd.eval_forest(test_ims, test_disps, n_disp_bins=n_disp_bins, depth_switch=depth_switch, forest_prefix=str(prefix / 'fr'))
np.save(str(prefix / 'ta.npy'), test_disps)
np.save(str(prefix / 'es.npy'), es)
# plt.figure();
# plt.subplot(2,1,1); plt.imshow(test_disps[0], vmin=0, vmax=4);
# plt.subplot(2,1,2); plt.imshow(es[0], vmin=0, vmax=4);
# plt.show()