38 #ifndef PCL_ML_FERNS_FERN_TRAINER_HPP_ 39 #define PCL_ML_FERNS_FERN_TRAINER_HPP_ 42 template <
class FeatureType,
class DataSet,
class LabelType,
class ExampleIndex,
class NodeType>
45 , num_of_features_ (1000)
46 , num_of_thresholds_ (10)
47 , feature_handler_ (NULL)
48 , stats_estimator_ (NULL)
57 template <
class FeatureType,
class DataSet,
class LabelType,
class ExampleIndex,
class NodeType>
64 template <
class FeatureType,
class DataSet,
class LabelType,
class ExampleIndex,
class NodeType>
70 const size_t num_of_examples = examples_.size ();
73 std::vector<FeatureType> features;
74 feature_handler_->createRandomFeatures (num_of_features_, features);
80 std::vector<std::vector<float> > feature_results (num_of_features_);
81 std::vector<std::vector<unsigned char> > flags (num_of_features_);
83 for (
size_t feature_index = 0; feature_index < num_of_features_; ++feature_index)
85 feature_results[feature_index].reserve (num_of_examples);
86 flags[feature_index].reserve (num_of_examples);
88 feature_handler_->evaluateFeature (features[feature_index],
91 feature_results[feature_index],
92 flags[feature_index] );
96 std::vector<std::vector<std::vector<float> > > branch_feature_results (num_of_features_);
97 std::vector<std::vector<std::vector<unsigned char> > > branch_flags (num_of_features_);
98 std::vector<std::vector<std::vector<ExampleIndex> > > branch_examples (num_of_features_);
99 std::vector<std::vector<std::vector<LabelType> > > branch_label_data (num_of_features_);
102 for (
size_t feature_index = 0; feature_index < num_of_features_; ++feature_index)
104 branch_feature_results[feature_index].resize (1);
105 branch_flags[feature_index].resize (1);
106 branch_examples[feature_index].resize (1);
107 branch_label_data[feature_index].resize (1);
109 branch_feature_results[feature_index][0] = feature_results[feature_index];
110 branch_flags[feature_index][0] = flags[feature_index];
111 branch_examples[feature_index][0] = examples_;
112 branch_label_data[feature_index][0] = label_data_;
115 for (
size_t depth_index = 0; depth_index < fern_depth_; ++depth_index)
118 std::vector<std::vector<float> > thresholds (num_of_features_);
120 for (
size_t feature_index = 0; feature_index < num_of_features_; ++feature_index)
122 thresholds.reserve (num_of_thresholds_);
127 int best_feature_index = -1;
128 float best_feature_threshold = 0.0f;
129 float best_feature_information_gain = 0.0f;
131 for (
size_t feature_index = 0; feature_index < num_of_features_; ++feature_index)
133 for (
size_t threshold_index = 0; threshold_index < num_of_thresholds_; ++threshold_index)
135 float information_gain = 0.0f;
136 for (
size_t branch_index = 0; branch_index < branch_feature_results[feature_index].size (); ++branch_index)
139 branch_examples[feature_index][branch_index],
140 branch_label_data[feature_index][branch_index],
141 branch_feature_results[feature_index][branch_index],
142 branch_flags[feature_index][branch_index],
143 thresholds[feature_index][threshold_index]);
145 information_gain += branch_information_gain * branch_feature_results[feature_index][branch_index].size ();
148 if (information_gain > best_feature_information_gain)
150 best_feature_information_gain = information_gain;
151 best_feature_index =
static_cast<int> (feature_index);
152 best_feature_threshold = thresholds[feature_index][threshold_index];
158 fern.
accessFeature (depth_index) = features[best_feature_index];
162 for (
size_t feature_index = 0; feature_index < num_of_features_; ++feature_index)
164 std::vector<std::vector<float> > & cur_branch_feature_results = branch_feature_results[feature_index];
165 std::vector<std::vector<unsigned char> > & cur_branch_flags = branch_flags[feature_index];
166 std::vector<std::vector<ExampleIndex> > & cur_branch_examples = branch_examples[feature_index];
167 std::vector<std::vector<LabelType> > & cur_branch_label_data = branch_label_data[feature_index];
169 const size_t total_num_of_new_branches = num_of_branches * cur_branch_feature_results.size ();
171 std::vector<std::vector<float> > new_branch_feature_results (total_num_of_new_branches);
172 std::vector<std::vector<unsigned char> > new_branch_flags (total_num_of_new_branches);
173 std::vector<std::vector<ExampleIndex> > new_branch_examples (total_num_of_new_branches);
174 std::vector<std::vector<LabelType> > new_branch_label_data (total_num_of_new_branches);
176 for (
size_t branch_index = 0; branch_index < cur_branch_feature_results.size (); ++branch_index)
178 const size_t num_of_examples_in_this_branch = cur_branch_feature_results[branch_index].size ();
180 std::vector<unsigned char> branch_indices;
181 branch_indices.reserve (num_of_examples_in_this_branch);
184 cur_branch_flags[branch_index],
185 best_feature_threshold,
189 const size_t base_branch_index = branch_index * num_of_branches;
190 for (
size_t example_index = 0; example_index < num_of_examples_in_this_branch; ++example_index)
192 const size_t combined_branch_index = base_branch_index + branch_indices[example_index];
194 new_branch_feature_results[combined_branch_index].push_back (cur_branch_feature_results[branch_index][example_index]);
195 new_branch_flags[combined_branch_index].push_back (cur_branch_flags[branch_index][example_index]);
196 new_branch_examples[combined_branch_index].push_back (cur_branch_examples[branch_index][example_index]);
197 new_branch_label_data[combined_branch_index].push_back (cur_branch_label_data[branch_index][example_index]);
201 branch_feature_results[feature_index] = new_branch_feature_results;
202 branch_flags[feature_index] = new_branch_flags;
203 branch_examples[feature_index] = new_branch_examples;
204 branch_label_data[feature_index] = new_branch_label_data;
210 std::vector<std::vector<float> > final_feature_results (fern_depth_);
211 std::vector<std::vector<unsigned char> > final_flags (fern_depth_);
212 std::vector<std::vector<unsigned char> > final_branch_indices (fern_depth_);
213 for (
size_t depth_index = 0; depth_index < fern_depth_; ++depth_index)
215 final_feature_results[depth_index].reserve (num_of_examples);
216 final_flags[depth_index].reserve (num_of_examples);
217 final_branch_indices[depth_index].reserve (num_of_examples);
219 feature_handler_->evaluateFeature (fern.
accessFeature (depth_index),
222 final_feature_results[depth_index],
223 final_flags[depth_index] );
226 final_flags[depth_index],
228 final_branch_indices[depth_index]);
232 std::vector<std::vector<LabelType> > node_labels (0x1 << fern_depth_);
233 std::vector<std::vector<ExampleIndex> > node_examples (0x1 << fern_depth_);
235 for (
size_t example_index = 0; example_index < num_of_examples; ++example_index)
237 size_t node_index = 0;
238 for (
size_t depth_index = 0; depth_index < fern_depth_; ++depth_index)
240 node_index *= num_of_branches;
241 node_index += final_branch_indices[depth_index][example_index];
244 node_labels[node_index].push_back (label_data_[example_index]);
245 node_examples[node_index].push_back (examples_[example_index]);
249 const size_t num_of_nodes = 0x1 << fern_depth_;
250 for (
size_t node_index = 0; node_index < num_of_nodes; ++node_index)
252 stats_estimator_->
computeAndSetNodeStats (data_set_, node_examples[node_index], node_labels[node_index], fern[node_index]);
257 template <
class FeatureType,
class DataSet,
class LabelType,
class ExampleIndex,
class NodeType>
260 const size_t num_of_thresholds,
261 std::vector<float> & values,
262 std::vector<float> & thresholds)
265 float min_value = ::std::numeric_limits<float>::max();
266 float max_value = -::std::numeric_limits<float>::max();
268 const size_t num_of_values = values.size ();
269 for (
int value_index = 0; value_index < num_of_values; ++value_index)
271 const float value = values[value_index];
273 if (value < min_value) min_value = value;
274 if (value > max_value) max_value = value;
277 const float range = max_value - min_value;
278 const float step = range / (num_of_thresholds+2);
281 thresholds.resize (num_of_thresholds);
283 for (
int threshold_index = 0; threshold_index < num_of_thresholds; ++threshold_index)
285 thresholds[threshold_index] = min_value + step*(threshold_index+1);
static void createThresholdsUniform(const size_t num_of_thresholds, std::vector< float > &values, std::vector< float > &thresholds)
Creates uniformely distrebuted thresholds over the range of the supplied values.
float & accessThreshold(const size_t threshold_index)
Access operator for thresholds.
virtual void computeBranchIndices(std::vector< float > &results, std::vector< unsigned char > &flags, const float threshold, std::vector< unsigned char > &branch_indices) const =0
Computes the branch indices obtained by the specified threshold on the supplied feature evaluation re...
FeatureType & accessFeature(const size_t feature_index)
Access operator for features.
virtual float computeInformationGain(DataSet &data_set, std::vector< ExampleIndex > &examples, std::vector< LabelDataType > &label_data, std::vector< float > &results, std::vector< unsigned char > &flags, const float threshold) const =0
Computes the information gain obtained by the specified threshold on the supplied feature evaluation ...
virtual size_t getNumOfBranches() const =0
Returns the number of brances a node can have (e.g.
Class representing a Fern.
virtual ~FernTrainer()
Destructor.
void train(Fern< FeatureType, NodeType > &fern)
Trains a decision tree using the set training data and settings.
FernTrainer()
Constructor.
void initialize(const size_t num_of_decisions)
Initializes the fern.
virtual void computeAndSetNodeStats(DataSet &data_set, std::vector< ExampleIndex > &examples, std::vector< LabelDataType > &label_data, NodeType &node) const =0
Computes and sets the statistics for a node.