MLPACK  1.0.11
cosine_tree.hpp
Go to the documentation of this file.
1 
23 #ifndef __MLPACK_CORE_TREE_COSINE_TREE_COSINE_TREE_HPP
24 #define __MLPACK_CORE_TREE_COSINE_TREE_COSINE_TREE_HPP
25 
26 #include <mlpack/core.hpp>
27 #include <boost/heap/priority_queue.hpp>
28 
29 namespace mlpack {
30 namespace tree {
31 
32 // Predeclare classes for CosineNodeQueue typedef.
33 class CompareCosineNode;
34 class CosineTree;
35 
36 // CosineNodeQueue typedef.
37 typedef boost::heap::priority_queue<CosineTree*,
38  boost::heap::compare<CompareCosineNode> > CosineNodeQueue;
39 
41 {
42  public:
43 
52  CosineTree(const arma::mat& dataset);
53 
63  CosineTree(CosineTree& parentNode, const std::vector<size_t>& subIndices);
64 
79  CosineTree(const arma::mat& dataset,
80  const double epsilon,
81  const double delta);
82 
87  ~CosineTree();
88 
98  void ModifiedGramSchmidt(CosineNodeQueue& treeQueue,
99  arma::vec& centroid,
100  arma::vec& newBasisVector,
101  arma::vec* addBasisVector = NULL);
102 
115  double MonteCarloError(CosineTree* node,
116  CosineNodeQueue& treeQueue,
117  arma::vec* addBasisVector1 = NULL,
118  arma::vec* addBasisVector2 = NULL);
119 
125  void ConstructBasis(CosineNodeQueue& treeQueue);
126 
132  void CosineNodeSplit();
133 
140  void ColumnSamplesLS(std::vector<size_t>& sampledIndices,
141  arma::vec& probabilities, size_t numSamples);
142 
149  size_t ColumnSampleLS();
150 
163  size_t BinarySearch(arma::vec& cDistribution, double value, size_t start,
164  size_t end);
165 
173  void CalculateCosines(arma::vec& cosines);
174 
179  void CalculateCentroid();
180 
182  void GetFinalBasis(arma::mat& finalBasis) { finalBasis = basis; }
183 
185  const arma::mat& GetDataset() const { return dataset; }
186 
188  std::vector<size_t>& VectorIndices() { return indices; }
189 
191  void L2Error(const double error) { this->l2Error = error; }
192 
194  double L2Error() const { return l2Error; }
195 
197  arma::vec& Centroid() { return centroid; }
198 
200  void BasisVector(arma::vec& bVector) { this->basisVector = bVector; }
201 
203  arma::vec& BasisVector() { return basisVector; }
204 
206  CosineTree* Left() { return left; }
207 
209  CosineTree* Right() { return right; }
210 
212  size_t NumColumns() const { return numColumns; }
213 
215  double FrobNormSquared() const { return frobNormSquared; }
216 
218  size_t SplitPointIndex() const { return indices[splitPointIndex]; }
219 
220  private:
222  const arma::mat& dataset;
224  double epsilon;
226  double delta;
228  arma::mat basis;
236  std::vector<size_t> indices;
238  arma::vec l2NormsSquared;
240  arma::vec centroid;
242  arma::vec basisVector;
246  size_t numColumns;
248  double l2Error;
251 };
252 
254 {
255  public:
256 
257  // Comparison function for construction of priority queue.
258  bool operator() (const CosineTree* a, const CosineTree* b) const
259  {
260  return a->L2Error() < b->L2Error();
261  }
262 };
263 
264 }; // namespace tree
265 }; // namespace mlpack
266 
267 #endif