You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 
connecting_the_dots/hyperdepth/hyperdepth.pyx

86 lines
3.4 KiB

cimport cython
import numpy as np
cimport numpy as np
from libc.stdlib cimport free, malloc
from libcpp cimport bool
from libcpp.string cimport string
from cpython cimport PyObject, Py_INCREF
CREATE_INIT = True # workaround, so cython builds a init function
np.import_array()
ctypedef unsigned char uint8_t
cdef extern from "rf/train.h":
cdef cppclass TrainParameters:
int n_trees;
int max_tree_depth;
int n_test_split_functions;
int n_test_thresholds;
int n_test_samples;
int min_samples_to_split;
int min_samples_for_leaf;
int print_node_info;
TrainParameters();
cdef extern from "hyperdepth.h":
void train(int row_from, int row_to, TrainParameters params, const uint8_t* ims, const float* disps, int n, int h, int w, int n_disp_bins, int depth_switch, int n_threads, string forest_prefix);
void eval(int row_from, int row_to, const uint8_t* ims, const float* disps, int n, int h, int w, int n_disp_bins, int depth_switch, int n_threads, string forest_prefix, float* out);
cdef class TrainParams:
cdef TrainParameters params;
def __cinit__(self, int n_trees=6, int max_tree_depth=8, int n_test_split_functions=50, int n_test_thresholds=10, int n_test_samples=4096, int min_samples_to_split=16, int min_samples_for_leaf=8, int print_node_info=100):
self.params.n_trees = n_trees
self.params.max_tree_depth = max_tree_depth
self.params.n_test_split_functions = n_test_split_functions
self.params.n_test_thresholds = n_test_thresholds
self.params.n_test_samples = n_test_samples
self.params.min_samples_to_split = min_samples_to_split
self.params.min_samples_for_leaf = min_samples_for_leaf
self.params.print_node_info = print_node_info
def __str__(self):
return f'n_trees={self.params.n_trees}, max_tree_depth={self.params.max_tree_depth}, n_test_split_functions={self.params.n_test_split_functions}, n_test_thresholds={self.params.n_test_thresholds}, n_test_samples={self.params.n_test_samples}, min_samples_to_split={self.params.min_samples_to_split}, min_samples_for_leaf={self.params.min_samples_for_leaf}'
def train_forest(TrainParams params, uint8_t[:,:,::1] ims, float[:,:,::1] disps, int n_disp_bins=10, int depth_switch=0, int n_threads=18, str forest_prefix='forest', int row_from=-1, int row_to=-1):
cdef int n = ims.shape[0]
cdef int h = ims.shape[1]
cdef int w = ims.shape[2]
if row_from < 0:
row_from = 0
if row_to > h or row_to < 0:
row_to = h
if n != disps.shape[0] or h != disps.shape[1] or w != disps.shape[2]:
raise Exception('ims.shape != disps.shape')
train(row_from, row_to, params.params, &ims[0,0,0], &disps[0,0,0], n, h, w, n_disp_bins, depth_switch, n_threads, forest_prefix.encode())
def eval_forest(uint8_t[:,:,::1] ims, float[:,:,::1] disps, int n_disp_bins=10, int depth_switch=0, int n_threads=18, str forest_prefix='forest', int row_from=-1, int row_to=-1):
cdef int n = ims.shape[0]
cdef int h = ims.shape[1]
cdef int w = ims.shape[2]
if n != disps.shape[0] or h != disps.shape[1] or w != disps.shape[2]:
raise Exception('ims.shape != disps.shape')
if row_from < 0:
row_from = 0
if row_to > h or row_to < 0:
row_to = h
out = np.empty((n, h, w, 3), dtype=np.float32)
cdef float[:,:,:,::1] out_view = out
eval(row_from, row_to, &ims[0,0,0], &disps[0,0,0], n, h, w, n_disp_bins, depth_switch, n_threads, forest_prefix.encode(), &out_view[0,0,0,0])
return out