MLPACK  1.0.11
decision_stump.hpp
Go to the documentation of this file.
1 
22 #ifndef __MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_HPP
23 #define __MLPACK_METHODS_DECISION_STUMP_DECISION_STUMP_HPP
24 
25 #include <mlpack/core.hpp>
26 
27 namespace mlpack {
28 namespace decision_stump {
29 
43 template <typename MatType = arma::mat>
45 {
46  public:
56  DecisionStump(const MatType& data,
57  const arma::Row<size_t>& labels,
58  const size_t classes,
59  size_t inpBucketSize);
60 
69  void Classify(const MatType& test, arma::Row<size_t>& predictedLabels);
70 
83  DecisionStump(const DecisionStump<>& other,
84  const MatType& data,
85  const arma::rowvec& weights,
86  const arma::Row<size_t>& labels);
87 
89  int SplitAttribute() const { return splitAttribute; }
91  int& SplitAttribute() { return splitAttribute; }
92 
94  const arma::vec& Split() const { return split; }
96  arma::vec& Split() { return split; }
97 
99  const arma::Col<size_t> BinLabels() const { return binLabels; }
101  arma::Col<size_t>& BinLabels() { return binLabels; }
102 
103  private:
105  size_t numClass;
106 
109 
111  size_t bucketSize;
112 
114  arma::vec split;
115 
117  arma::Col<size_t> binLabels;
118 
127  template <bool isWeight>
128  double SetupSplitAttribute(const arma::rowvec& attribute,
129  const arma::Row<size_t>& labels,
130  const arma::rowvec& weightD);
131 
139  template <typename rType> void TrainOnAtt(const arma::rowvec& attribute,
140  const arma::Row<size_t>& labels);
141 
146  void MergeRanges();
147 
154  template <typename rType> rType CountMostFreq(const arma::Row<rType>&
155  subCols);
156 
162  template <typename rType> int IsDistinct(const arma::Row<rType>& featureRow);
163 
171  template <typename LabelType, bool isWeight>
172  double CalculateEntropy(arma::subview_row<LabelType> labels, int begin,
173  const arma::rowvec& tempD);
174 
182  template <bool isWeight>
183  void Train(const MatType& data, const arma::Row<size_t>& labels,
184  const arma::rowvec& weightD);
185 
186 };
187 
188 }; // namespace decision_stump
189 }; // namespace mlpack
190 
191 #include "decision_stump_impl.hpp"
192 
193 #endif