You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
87 lines
3.4 KiB
87 lines
3.4 KiB
6 years ago
|
cimport cython
|
||
|
import numpy as np
|
||
|
cimport numpy as np
|
||
|
|
||
|
from libc.stdlib cimport free, malloc
|
||
|
from libcpp cimport bool
|
||
|
from libcpp.string cimport string
|
||
|
from cpython cimport PyObject, Py_INCREF
|
||
|
|
||
|
CREATE_INIT = True # workaround, so cython builds a init function
|
||
|
|
||
|
np.import_array()
|
||
|
|
||
|
|
||
|
ctypedef unsigned char uint8_t
|
||
|
|
||
|
cdef extern from "rf/train.h":
|
||
|
cdef cppclass TrainParameters:
|
||
|
int n_trees;
|
||
|
int max_tree_depth;
|
||
|
int n_test_split_functions;
|
||
|
int n_test_thresholds;
|
||
|
int n_test_samples;
|
||
|
int min_samples_to_split;
|
||
|
int min_samples_for_leaf;
|
||
|
int print_node_info;
|
||
|
TrainParameters();
|
||
|
|
||
|
|
||
|
cdef extern from "hyperdepth.h":
|
||
|
void train(int row_from, int row_to, TrainParameters params, const uint8_t* ims, const float* disps, int n, int h, int w, int n_disp_bins, int depth_switch, int n_threads, string forest_prefix);
|
||
|
void eval(int row_from, int row_to, const uint8_t* ims, const float* disps, int n, int h, int w, int n_disp_bins, int depth_switch, int n_threads, string forest_prefix, float* out);
|
||
|
|
||
|
|
||
|
|
||
|
|
||
|
cdef class TrainParams:
|
||
|
cdef TrainParameters params;
|
||
|
|
||
|
def __cinit__(self, int n_trees=6, int max_tree_depth=8, int n_test_split_functions=50, int n_test_thresholds=10, int n_test_samples=4096, int min_samples_to_split=16, int min_samples_for_leaf=8, int print_node_info=100):
|
||
|
self.params.n_trees = n_trees
|
||
|
self.params.max_tree_depth = max_tree_depth
|
||
|
self.params.n_test_split_functions = n_test_split_functions
|
||
|
self.params.n_test_thresholds = n_test_thresholds
|
||
|
self.params.n_test_samples = n_test_samples
|
||
|
self.params.min_samples_to_split = min_samples_to_split
|
||
|
self.params.min_samples_for_leaf = min_samples_for_leaf
|
||
|
self.params.print_node_info = print_node_info
|
||
|
|
||
|
def __str__(self):
|
||
|
return f'n_trees={self.params.n_trees}, max_tree_depth={self.params.max_tree_depth}, n_test_split_functions={self.params.n_test_split_functions}, n_test_thresholds={self.params.n_test_thresholds}, n_test_samples={self.params.n_test_samples}, min_samples_to_split={self.params.min_samples_to_split}, min_samples_for_leaf={self.params.min_samples_for_leaf}'
|
||
|
|
||
|
|
||
|
def train_forest(TrainParams params, uint8_t[:,:,::1] ims, float[:,:,::1] disps, int n_disp_bins=10, int depth_switch=0, int n_threads=18, str forest_prefix='forest', int row_from=-1, int row_to=-1):
|
||
|
cdef int n = ims.shape[0]
|
||
|
cdef int h = ims.shape[1]
|
||
|
cdef int w = ims.shape[2]
|
||
|
|
||
|
if row_from < 0:
|
||
|
row_from = 0
|
||
|
if row_to > h or row_to < 0:
|
||
|
row_to = h
|
||
|
|
||
|
if n != disps.shape[0] or h != disps.shape[1] or w != disps.shape[2]:
|
||
|
raise Exception('ims.shape != disps.shape')
|
||
|
|
||
|
train(row_from, row_to, params.params, &ims[0,0,0], &disps[0,0,0], n, h, w, n_disp_bins, depth_switch, n_threads, forest_prefix.encode())
|
||
|
|
||
|
|
||
|
def eval_forest(uint8_t[:,:,::1] ims, float[:,:,::1] disps, int n_disp_bins=10, int depth_switch=0, int n_threads=18, str forest_prefix='forest', int row_from=-1, int row_to=-1):
|
||
|
cdef int n = ims.shape[0]
|
||
|
cdef int h = ims.shape[1]
|
||
|
cdef int w = ims.shape[2]
|
||
|
|
||
|
if n != disps.shape[0] or h != disps.shape[1] or w != disps.shape[2]:
|
||
|
raise Exception('ims.shape != disps.shape')
|
||
|
|
||
|
if row_from < 0:
|
||
|
row_from = 0
|
||
|
if row_to > h or row_to < 0:
|
||
|
row_to = h
|
||
|
|
||
|
out = np.empty((n, h, w, 3), dtype=np.float32)
|
||
|
cdef float[:,:,:,::1] out_view = out
|
||
|
eval(row_from, row_to, &ims[0,0,0], &disps[0,0,0], n, h, w, n_disp_bins, depth_switch, n_threads, forest_prefix.encode(), &out_view[0,0,0,0])
|
||
|
return out
|