Alexandria  2.27.0
SDC-CH common library for the Euclid project
SOMTrainer.h
Go to the documentation of this file.
1 /*
2  * Copyright (C) 2012-2022 Euclid Science Ground Segment
3  *
4  * This library is free software; you can redistribute it and/or modify it under
5  * the terms of the GNU Lesser General Public License as published by the Free
6  * Software Foundation; either version 3.0 of the License, or (at your option)
7  * any later version.
8  *
9  * This library is distributed in the hope that it will be useful, but WITHOUT
10  * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
11  * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
12  * details.
13  *
14  * You should have received a copy of the GNU Lesser General Public License
15  * along with this library; if not, write to the Free Software Foundation, Inc.,
16  * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17  */
18 
19 /*
20  * @file SOMTrainer.h
21  * @author nikoapos
22  */
23 
24 #ifndef SOM_SOMTRAINER_H
25 #define SOM_SOMTRAINER_H
26 
28 #include "SOM/NeighborhoodFunc.h"
29 #include "SOM/SOM.h"
30 #include "SOM/SamplingPolicy.h"
31 
32 namespace Euclid {
33 namespace SOM {
34 
35 template <typename NeighborhoodFunction>
36 class SOMTrainer {
37 
38 public:
39  SOMTrainer(NeighborhoodFunction neighborhood_func, LearningRestraintFunc::Signature learning_restraint_func)
40  : m_neighborhood_func(std::move(neighborhood_func))
41  , m_learning_restraint_func(std::move(learning_restraint_func)) {}
42 
43  template <typename DistFunc, typename InputIter, typename InputToWeightFunc,
44  template <class> class Sampler = SamplingPolicy::FullSet>
45  void train(SOM<DistFunc>& som, std::size_t iter_no, InputIter begin, InputIter end, InputToWeightFunc weight_func,
46  const Sampler<InputIter>& sampling_policy = Sampler<InputIter>{}) {
47 
48  // We repeat the training for iter_no iterations
49  for (std::size_t i = 0; i < iter_no; ++i) {
50 
51  // Compute the factor of the current iteration
52  auto learn_factor = m_learning_restraint_func(i, iter_no);
53  if (learn_factor == 0) {
54  continue;
55  }
56 
57  // Go through the training sample of the iteration
58  for (auto it = sampling_policy.start(begin, end); it != end; it = sampling_policy.next(it)) {
59 
60  // Get the weights of the input object
61  auto input_weights = weight_func(*it);
62 
63  // Find the coordinates of the BMU for the input
64  std::size_t bmu_x, bmu_y;
65  double nd_distance;
66  std::tie(bmu_x, bmu_y, nd_distance) = som.findBMU(*it, weight_func);
67 
68  // Now go through all the cells and update their values according their coordinates
69  std::size_t size_x, size_y;
70  std::tie(size_x, size_y) = som.getSize();
71 
72  for (std::size_t cell_y = 0; cell_y < size_y; ++cell_y) {
73  for (std::size_t cell_x = 0; cell_x < size_x; ++cell_x) {
74  auto cell = som(cell_x, cell_y);
75 
76  // Compute the factor based on the distance of the BMU and the cell
77  auto neighborhood_factor = m_neighborhood_func({bmu_x, bmu_y}, {cell_x, cell_y}, i, iter_no);
78 
79  // Get the weights of the cell and update them
80  if (neighborhood_factor != 0) {
81  for (std::size_t wi = 0; wi < som.getDimensions(); ++wi) {
82  cell[wi] = cell[wi] + neighborhood_factor * learn_factor * (input_weights[wi] - cell[wi]);
83  }
84  }
85  }
86  }
87  }
88  }
89  }
90 
91 private:
92  NeighborhoodFunction m_neighborhood_func;
94 };
95 
96 } // namespace SOM
97 } // namespace Euclid
98 
99 #endif /* SOM_SOMTRAINER_H */
void train(SOM< DistFunc > &som, std::size_t iter_no, InputIter begin, InputIter end, InputToWeightFunc weight_func, const Sampler< InputIter > &sampling_policy=Sampler< InputIter >{})
Definition: SOMTrainer.h:45
NeighborhoodFunction m_neighborhood_func
Definition: SOMTrainer.h:92
LearningRestraintFunc::Signature m_learning_restraint_func
Definition: SOMTrainer.h:93
SOMTrainer(NeighborhoodFunction neighborhood_func, LearningRestraintFunc::Signature learning_restraint_func)
Definition: SOMTrainer.h:39
STL namespace.
T tie(T... args)