You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
56 lines
1.3 KiB
56 lines
1.3 KiB
6 years ago
|
#pragma once
|
||
|
|
||
|
#include "node.h"
|
||
|
|
||
|
template <typename SplitFunctionT, typename LeafFunctionT>
|
||
|
class Tree {
|
||
|
public:
|
||
|
Tree() : root_(nullptr) {}
|
||
|
Tree(NodePtr root) : root_(root) {}
|
||
|
|
||
|
virtual ~Tree() {}
|
||
|
|
||
|
std::shared_ptr<LeafFunctionT> inference(const SamplePtr sample) const {
|
||
|
if(root_ == nullptr) {
|
||
|
std::cout << "[ERROR] tree inference root node is NULL";
|
||
|
exit(-1);
|
||
|
}
|
||
|
|
||
|
NodePtr node = root_;
|
||
|
while(node->type() == SplitNode<SplitFunctionT, LeafFunctionT>::TYPE) {
|
||
|
auto splitNode = std::static_pointer_cast<SplitNode<SplitFunctionT, LeafFunctionT>>(node);
|
||
|
bool left = splitNode->Split(sample);
|
||
|
if(left) {
|
||
|
node = splitNode->left();
|
||
|
}
|
||
|
else {
|
||
|
node = splitNode->right();
|
||
|
}
|
||
|
}
|
||
|
|
||
|
auto leaf_node = std::static_pointer_cast<LeafNode<LeafFunctionT>>(node);
|
||
|
return leaf_node->leaf_node_fcn();
|
||
|
}
|
||
|
|
||
|
NodePtr root() const { return root_; }
|
||
|
void set_root(NodePtr root) { root_ = root; }
|
||
|
|
||
|
virtual void Save(SerializationOut& ar) const {
|
||
|
int type = root_->type();
|
||
|
ar << type;
|
||
|
root_->Save(ar);
|
||
|
}
|
||
|
|
||
|
virtual void Load(SerializationIn& ar) {
|
||
|
int type;
|
||
|
ar >> type;
|
||
|
root_ = MakeNode<SplitFunctionT, LeafFunctionT>(type);
|
||
|
root_->Load(ar);
|
||
|
}
|
||
|
|
||
|
|
||
|
public:
|
||
|
NodePtr root_;
|
||
|
};
|
||
|
|