You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
46 lines
1.3 KiB
46 lines
1.3 KiB
#include "hyperdepth.h"
|
|
|
|
|
|
int main() {
|
|
cv::Mat_<uint8_t> im = read_im(0);
|
|
cv::Mat_<uint16_t> disp = read_disp(0);
|
|
int im_rows = im.rows;
|
|
int im_cols = im.cols;
|
|
std::cout << im.rows << "/" << im.cols << std::endl;
|
|
std::cout << disp.rows << "/" << disp.cols << std::endl;
|
|
|
|
cv::Mat_<uint16_t> ta_disp(im_rows, im_cols);
|
|
cv::Mat_<uint16_t> es_disp(im_rows, im_cols);
|
|
|
|
int n_disp_bins = 16;
|
|
|
|
for(int row = 0; row < im_rows; ++row) {
|
|
std::vector<TrainDatum> data;
|
|
extract_row_samples(im, disp, row, data, false, n_disp_bins);
|
|
|
|
std::ostringstream forest_path;
|
|
forest_path << "cforest_" << row << ".bin";
|
|
BinarySerializationIn fin(forest_path.str());
|
|
HDForest forest;
|
|
forest.Load(fin);
|
|
|
|
auto res = forest.inferencemt(data, 18);
|
|
for(int col = 0; col < im_cols; ++col) {
|
|
auto fcn = res[col];
|
|
auto target = std::static_pointer_cast<ClassificationTarget>(data[col].target);
|
|
|
|
float ta = col - float(target->cl()) / n_disp_bins;
|
|
float es = col - float(fcn->argmax()) / n_disp_bins;
|
|
es = std::max(0.f, es);
|
|
|
|
ta_disp(row, col) = int(ta * 16);
|
|
es_disp(row, col) = int(es * 16);
|
|
}
|
|
}
|
|
|
|
cv::imwrite("disp_orig.png", disp);
|
|
cv::imwrite("disp_ta.png", ta_disp);
|
|
cv::imwrite("disp_es.png", es_disp);
|
|
}
|
|
|
|
|
|
|