You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

287 lines
8.5 KiB

#include <sstream>
#include <iomanip>
#include "rf/forest.h"
#include "rf/spliteval.h"
class HyperdepthSplitEvaluator : public SplitEvaluator {
public:
HyperdepthSplitEvaluator(bool normalize, int n_classes, int n_disp_bins, int depth_switch)
: SplitEvaluator(normalize), n_classes_(n_classes), n_disp_bins_(n_disp_bins), depth_switch_(depth_switch) {}
virtual ~HyperdepthSplitEvaluator() {}
protected:
virtual float Purity(const std::vector<TrainDatum>& targets, int depth) const {
if(targets.size() == 0) return 0;
int n_classes = n_classes_;
if(depth >= depth_switch_) {
n_classes *= n_disp_bins_;
}
std::vector<int> ps;
ps.resize(n_classes, 0);
for(auto target : targets) {
auto ctarget = std::static_pointer_cast<ClassificationTarget>(target.optimize_target);
int cl = ctarget->cl();
if(depth < depth_switch_) {
cl /= n_disp_bins_;
}
ps[cl] += 1;
}
float h = 0;
for(int cl = 0; cl < n_classes; ++cl) {
float fi = float(ps[cl]) / float(targets.size());
if(fi > 0) {
h = h - fi * std::log(fi);
}
}
return h;
}
private:
int n_classes_;
int n_disp_bins_;
int depth_switch_;
};
class HyperdepthLeafFunction {
public:
HyperdepthLeafFunction() : n_classes_(-1) {}
HyperdepthLeafFunction(int n_classes) : n_classes_(n_classes) {}
virtual ~HyperdepthLeafFunction() {}
virtual std::shared_ptr<HyperdepthLeafFunction> Copy() const {
auto fcn = std::make_shared<HyperdepthLeafFunction>();
fcn->n_classes_ = n_classes_;
fcn->counts_.resize(counts_.size());
for(size_t idx = 0; idx < counts_.size(); ++idx) {
fcn->counts_[idx] = counts_[idx];
}
fcn->sum_counts_ = sum_counts_;
return fcn;
}
virtual std::shared_ptr<HyperdepthLeafFunction> Create(const std::vector<TrainDatum>& samples) {
auto stat = std::make_shared<HyperdepthLeafFunction>();
stat->counts_.resize(n_classes_, 0);
for(auto sample : samples) {
auto ctarget = std::static_pointer_cast<ClassificationTarget>(sample.target);
stat->counts_[ctarget->cl()] += 1;
}
stat->sum_counts_ = samples.size();
return stat;
}
virtual std::shared_ptr<HyperdepthLeafFunction> Reduce(const std::vector<std::shared_ptr<HyperdepthLeafFunction>>& fcns) const {
auto stat = std::make_shared<HyperdepthLeafFunction>();
auto cfcn0 = std::static_pointer_cast<HyperdepthLeafFunction>(fcns[0]);
stat->counts_.resize(cfcn0->counts_.size(), 0);
stat->sum_counts_ = 0;
for(auto fcn : fcns) {
auto cfcn = std::static_pointer_cast<HyperdepthLeafFunction>(fcn);
for(size_t cl = 0; cl < stat->counts_.size(); ++cl) {
stat->counts_[cl] += cfcn->counts_[cl];
}
stat->sum_counts_ += cfcn->sum_counts_;
}
return stat;
}
virtual std::tuple<int,int> argmax() const {
int max_idx = 0;
int max_count = counts_[0];
int max2_idx = -1;
int max2_count = -1;
for(size_t idx = 1; idx < counts_.size(); ++idx) {
if(counts_[idx] > max_count) {
max2_count = max_count;
max2_idx = max_idx;
max_count = counts_[idx];
max_idx = idx;
}
else if(counts_[idx] > max2_count) {
max2_count = counts_[idx];
max2_idx = idx;
}
}
return std::make_tuple(max_idx, max2_idx);
}
virtual std::vector<float> prob_vec() const {
std::vector<float> probs(counts_.size(), 0.f);
int sum = 0;
for(int cnt : counts_) {
sum += cnt;
}
for(size_t idx = 0; idx < counts_.size(); ++idx) {
probs[idx] = float(counts_[idx]) / sum;
}
return probs;
}
virtual void Save(SerializationOut& ar) const {
ar << n_classes_;
int n_counts = counts_.size();
ar << n_counts;
for(int idx = 0; idx < n_counts; ++idx) {
ar << counts_[idx];
}
ar << sum_counts_;
}
virtual void Load(SerializationIn& ar) {
ar >> n_classes_;
int n_counts;
ar >> n_counts;
counts_.resize(n_counts);
for(int idx = 0; idx < n_counts; ++idx) {
ar >> counts_[idx];
}
ar >> sum_counts_;
}
public:
int n_classes_;
std::vector<int> counts_;
int sum_counts_;
DISABLE_COPY_AND_ASSIGN(HyperdepthLeafFunction);
};
typedef SplitFunctionPixelDifference HDSplitFunctionT;
typedef HyperdepthLeafFunction HDLeafFunctionT;
typedef HyperdepthSplitEvaluator HDSplitEvaluatorT;
typedef Forest<HDSplitFunctionT, HDLeafFunctionT> HDForest;
template <typename T>
class Raw {
public:
const T* raw;
const int nsamples;
const int rows;
const int cols;
Raw(const T* raw, int nsamples, int rows, int cols)
: raw(raw), nsamples(nsamples), rows(rows), cols(cols) {}
T operator()(int n, int r, int c) const {
return raw[(n * rows + r) * cols + c];
}
};
class RawSample : public Sample {
public:
RawSample(const Raw<uint8_t>& raw, int n, int rc, int cc, int patch_height, int patch_width)
: Sample(1, patch_height, patch_width), raw(raw), n(n), rc(rc), cc(cc) {}
virtual float at(int ch, int r, int c) const {
r += rc - height_ / 2;
c += cc - width_ / 2;
r = std::max(0, std::min(raw.rows-1, r));
c = std::max(0, std::min(raw.cols-1, c));
return raw(n, r, c);
}
protected:
const Raw<uint8_t>& raw;
int n;
int rc;
int cc;
};
void extract_row_samples(const Raw<uint8_t>& im, const Raw<float>& disp, int row, int n_disp_bins, bool only_valid, std::vector<TrainDatum>& data) {
for(int n = 0; n < im.nsamples; ++n) {
for(int col = 0; col < im.cols; ++col) {
float d = disp(n, row, col);
float pos = col - d;
int cl = pos * n_disp_bins;
if((d < 0 || cl < 0) && only_valid) continue;
auto sample = std::make_shared<RawSample>(im, n, row, col, 32, 32);
auto target = std::make_shared<ClassificationTarget>(cl);
auto datum = TrainDatum(sample, target);
data.push_back(datum);
}
}
std::cout << "extracted " << data.size() << " train samples" << std::endl;
std::cout << "n_classes (" << im.cols << ") * n_disp_bins (" << n_disp_bins << ") = " << (im.cols * n_disp_bins) << std::endl;
}
void train(int row_from, int row_to, TrainParameters params, const uint8_t* ims, const float* disps, int n, int h, int w, int n_disp_bins, int depth_switch, int n_threads, std::string forest_prefix) {
Raw<uint8_t> raw_ims(ims, n, h, w);
Raw<float> raw_disps(disps, n, h, w);
int n_classes = w;
auto gen_split_fcn = std::make_shared<HDSplitFunctionT>();
auto gen_leaf_fcn = std::make_shared<HDLeafFunctionT>(n_classes * n_disp_bins);
auto split_eval = std::make_shared<HDSplitEvaluatorT>(true, n_classes, n_disp_bins, depth_switch);
for(int row = row_from; row < row_to; ++row) {
std::cout << "train row " << row << std::endl;
std::vector<TrainDatum> data;
extract_row_samples(raw_ims, raw_disps, row, n_disp_bins, true, data);
TrainForestQueued<HDSplitFunctionT, HDLeafFunctionT, HDSplitEvaluatorT> train(params, gen_split_fcn, gen_leaf_fcn, split_eval, n_threads, true);
auto forest = train.Train(data, TrainType::TRAIN, nullptr);
std::ostringstream forest_path;
forest_path << forest_prefix << row << ".bin";
std::cout << "save forest of row " << row << " to " << forest_path.str() << std::endl;
BinarySerializationOut fout(forest_path.str());
forest->Save(fout);
}
}
void eval(int row_from, int row_to, const uint8_t* ims, const float* disps, int n, int h, int w, int n_disp_bins, int depth_switch, int n_threads, std::string forest_prefix, float* out) {
Raw<uint8_t> raw_ims(ims, n, h, w);
Raw<float> raw_disps(disps, n, h, w);
for(int row = row_from; row < row_to; ++row) {
std::vector<TrainDatum> data;
extract_row_samples(raw_ims, raw_disps, row, n_disp_bins, false, data);
std::ostringstream forest_path;
forest_path << forest_prefix << row << ".bin";
std::cout << "eval row " << row << " - " << forest_path.str() << std::endl;
BinarySerializationIn fin(forest_path.str());
HDForest forest;
forest.Load(fin);
auto res = forest.inferencemt(data, n_threads);
for(int nidx = 0; nidx < n; ++nidx) {
for(int col = 0; col < w; ++col) {
auto fcn = res[nidx * w + col];
int pos, pos2;
std::tie(pos, pos2) = fcn->argmax();
float disp = col - float(pos) / n_disp_bins;
float disp2 = col - float(pos2) / n_disp_bins;
float prob = fcn->prob_vec()[pos];
out[((nidx * h + row) * w + col) * 3 + 0] = disp;
out[((nidx * h + row) * w + col) * 3 + 1] = prob;
out[((nidx * h + row) * w + col) * 3 + 2] = std::abs(disp - disp2);
}
}
}
}