You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

106 lines
2.4 KiB

#pragma once
#include <random>
class SplitFunction {
public:
SplitFunction() {}
virtual ~SplitFunction() {}
virtual float Compute(SamplePtr sample) const = 0;
virtual bool Split(SamplePtr sample) const {
return Compute(sample) < threshold_;
}
virtual void Save(SerializationOut& ar) const {
ar << threshold_;
}
virtual void Load(SerializationIn& ar) {
ar >> threshold_;
}
virtual float threshold() const { return threshold_; }
virtual void set_threshold(float threshold) { threshold_ = threshold; }
protected:
float threshold_;
};
class SplitFunctionPixelDifference : public SplitFunction {
public:
SplitFunctionPixelDifference() {}
virtual ~SplitFunctionPixelDifference() {}
virtual std::shared_ptr<SplitFunctionPixelDifference> Copy() const {
std::shared_ptr<SplitFunctionPixelDifference> split_fcn = std::make_shared<SplitFunctionPixelDifference>();
split_fcn->threshold_ = threshold_;
split_fcn->c0_ = c0_;
split_fcn->c1_ = c1_;
split_fcn->h0_ = h0_;
split_fcn->h1_ = h1_;
split_fcn->w0_ = w0_;
split_fcn->w1_ = w1_;
return split_fcn;
}
virtual std::shared_ptr<SplitFunctionPixelDifference> Generate(std::mt19937& rng, const SamplePtr sample) const {
std::shared_ptr<SplitFunctionPixelDifference> split_fcn = std::make_shared<SplitFunctionPixelDifference>();
std::uniform_int_distribution<int> cdist(0, sample->channels()-1);
split_fcn->c0_ = cdist(rng);
split_fcn->c1_ = cdist(rng);
std::uniform_int_distribution<int> hdist(0, sample->height()-1);
split_fcn->h0_ = hdist(rng);
split_fcn->h1_ = hdist(rng);
std::uniform_int_distribution<int> wdist(0, sample->width()-1);
split_fcn->w0_ = wdist(rng);
split_fcn->w1_ = wdist(rng);
return split_fcn;
}
virtual float Compute(SamplePtr sample) const {
return (*sample)(c0_, h0_, w0_) - (*sample)(c1_, h1_, w1_);
}
virtual void Save(SerializationOut& ar) const {
SplitFunction::Save(ar);
ar << c0_;
ar << c1_;
ar << h0_;
ar << h1_;
ar << w0_;
ar << w1_;
}
virtual void Load(SerializationIn& ar) {
SplitFunction::Load(ar);
ar >> c0_;
ar >> c1_;
ar >> h0_;
ar >> h1_;
ar >> w0_;
ar >> w1_;
}
private:
int c0_;
int c1_;
int h0_;
int h1_;
int w0_;
int w1_;
DISABLE_COPY_AND_ASSIGN(SplitFunctionPixelDifference);
};