Alexandria  2.27.0
SDC-CH common library for the Euclid project
GridInterpolation.icpp
Go to the documentation of this file.
1 #ifndef GRIDINTERPOLATION_IMPL
2 #error Please, include "MathUtils/interpolation/GridInterpolation.h"
3 #endif
4 
5 #include "AlexandriaKernel/Tuples.h"
6 #include "MathUtils/interpolation/interpolation.h"
7 
8 namespace Euclid {
9 namespace MathUtils {
10 
11 template <typename T, typename Enable = void>
12 struct InterpolationImpl;
13 
14 /**
15  * Trait for continuous types
16  */
17 template <typename T>
18 struct InterpolationImpl<T, typename std::enable_if<std::is_floating_point<T>::value>::type> {
19  static double interpolate(const T x, const std::vector<T>& knots, const std::vector<double>& values,
20  bool extrapolate) {
21  return simple_interpolation(x, knots, values, extrapolate);
22  }
23 
24  template <typename... Rest>
25  static double interpolate(const T x, const std::vector<T>& knots,
26  const std::vector<std::unique_ptr<InterpN<Rest...>>>& interpolators, bool extrapolate,
27  const Rest... rest) {
28  // If no extrapolation, and the value if out-of-bounds, just clip at 0
29  if ((x < knots.front() || x > knots.back()) && !extrapolate) {
30  return 0.;
31  }
32 
33  if (knots.size() == 1) {
34  return (*interpolators[0])(rest...);
35  }
36 
37  std::size_t x2i = std::lower_bound(knots.begin(), knots.end(), x) - knots.begin();
38  if (x2i == 0) {
39  ++x2i;
40  } else if (x2i == knots.size()) {
41  --x2i;
42  }
43  std::size_t x1i = x2i - 1;
44 
45  double y1 = (*interpolators[x1i])(rest...);
46  double y2 = (*interpolators[x2i])(rest...);
47 
48  return simple_interpolation(x, knots[x1i], knots[x2i], y1, y2, extrapolate);
49  }
50 
51  static void checkOrder(const std::vector<T>& knots) {
52  if (!std::is_sorted(knots.begin(), knots.end())) {
53  throw InterpolationException("coordinates must be sorted");
54  }
55  }
56 };
57 
58 template <typename T>
59 struct InterpolationImpl<T, typename std::enable_if<std::is_integral<T>::value>::type> {
60  static double interpolate(const T x, const std::vector<T>& knots, const std::vector<double>& values,
61  bool /*extrapolate*/) {
62  if (x < knots.front() || x > knots.back())
63  return 0.;
64  return values[x];
65  }
66 
67  template <typename... Rest>
68  static double interpolate(const T x, const std::vector<T>& knots,
69  const std::vector<std::unique_ptr<InterpN<Rest...>>>& interpolators, bool,
70  const Rest... rest) {
71  if (x < knots.front() || x > knots.back())
72  return 0.;
73  return (*interpolators[x])(rest...);
74  }
75 
76  static void checkOrder(const std::vector<T>& knots) {
77  if (knots.front() != 0) {
78  throw InterpolationException("int axis must start at 0");
79  }
80  for (auto b = knots.begin() + 1; b != knots.end(); ++b) {
81  if (*b - *(b - 1) != 1) {
82  throw InterpolationException("int values must be contiguous");
83  }
84  }
85  }
86 };
87 
88 /**
89  * Trait for discrete types
90  */
91 template <typename T>
92 struct InterpolationImpl<T, typename std::enable_if<!std::is_arithmetic<T>::value>::type> {
93  static double interpolate(const T x, const std::vector<T>& knots, const std::vector<double>& values,
94  bool /*extrapolate*/) {
95  std::size_t i = std::find(knots.begin(), knots.end(), x) - knots.begin();
96  if (i >= knots.size() || knots[i] != x)
97  return 0.;
98  return values[i];
99  }
100 
101  template <typename... Rest>
102  static double interpolate(const T x, const std::vector<T>& knots,
103  const std::vector<std::unique_ptr<InterpN<Rest...>>>& interpolators, bool,
104  const Rest... rest) {
105  std::size_t i = std::find(knots.begin(), knots.end(), x) - knots.begin();
106  if (i >= knots.size() || knots[i] != x)
107  return 0.;
108  return (*interpolators[i])(rest...);
109  }
110 
111  static void checkOrder(const std::vector<T>&) {
112  // Discrete axes do not need to be in order
113  }
114 };
115 
116 /**
117  * Specialization (and end of the recursion) for a 1-dimensional interpolation.
118  */
119 template <typename T>
120 class InterpN<T> {
121 public:
122  /**
123  * Constructor
124  * @param grid
125  * A 1-dimensional grid
126  * @param values
127  * @param type
128  * @param extrapolate
129  */
130  InterpN(const std::tuple<std::vector<T>>& grid, const NdArray::NdArray<double>& values, bool extrapolate)
131  : m_knots(std::get<0>(grid)), m_values(values.begin(), values.end()), m_extrapolate(extrapolate) {
132  if (values.shape().size() != 1) {
133  throw InterpolationException() << "values and coordinates dimensionalities must match: " << values.shape().size()
134  << " != 1";
135  }
136  if (m_knots.size() != values.size()) {
137  throw InterpolationException() << "The size of the grid and the size of the values do not match: "
138  << m_knots.size() << " != " << m_values.size();
139  }
140  }
141 
142  /**
143  * Call as a function
144  * @param x
145  * Coordinate value
146  * @return
147  * Interpolated value
148  */
149  double operator()(const T x) const {
150  return InterpolationImpl<T>::interpolate(x, m_knots, m_values, m_extrapolate);
151  }
152 
153  /// Copy constructor
154  InterpN(const InterpN&) = default;
155 
156  /// Move constructor
157  InterpN(InterpN&&) = default;
158 
159 private:
160  std::vector<T> m_knots;
161  std::vector<double> m_values;
162  bool m_extrapolate;
163 };
164 
165 /**
166  * Recursive specialization of an N-Dimensional interpolator
167  * @tparam N Dimensionality (N > 1)
168  * @tparam F The first element of the index sequence
169  * @tparam Rest The rest of the elements from the index sequence
170  */
171 template <typename T, typename... Rest>
172 class InterpN<T, Rest...> {
173 public:
174  /**
175  * Constructor
176  * @param grid
177  * @param values
178  * @param type
179  * @param extrapolate
180  */
181  InterpN(const std::tuple<std::vector<T>, std::vector<Rest>...>& grid, const NdArray::NdArray<double>& values,
182  bool extrapolate)
183  : m_extrapolate(extrapolate) {
184  constexpr std::size_t N = sizeof...(Rest) + 1;
185 
186  if (values.shape().size() != N) {
187  throw InterpolationException() << "values and coordinates dimensionality must match: " << values.shape().size()
188  << " != " << N;
189  }
190  m_knots = std::get<0>(grid);
191  InterpolationImpl<T>::checkOrder(m_knots);
192  if (m_knots.size() != values.shape().back()) {
193  throw InterpolationException("coordinates and value sizes must match");
194  }
195  // Build nested interpolators
196  auto subgrid = Tuple::Tail(std::move(grid));
197  m_interpolators.resize(m_knots.size());
198  for (size_t i = 0; i < m_knots.size(); ++i) {
199  auto subvalues = values.rslice(i);
200  m_interpolators[i].reset(new InterpN<Rest...>(subgrid, subvalues, extrapolate));
201  }
202  }
203 
204  /**
205  * Call as a function
206  * @param x Value for the axis for the first dimension
207  * @param rest Values for the next set of axes
208  * @return The interpolated value
209  * @details
210  * Doubles<Rest>... is used to expand into (N-1) doubles
211  * x is used to find the interpolators for x1 and x2 s.t. x1 <= x <=x2
212  * Those two interpolators are used to compute y1 for x1, and y2 for x2 (based on the rest of the parameters)
213  * A final linear interpolator is used to get the value of y at the position x
214  */
215  double operator()(T x, Rest... rest) const {
216  return InterpolationImpl<T>::interpolate(x, m_knots, m_interpolators, m_extrapolate, rest...);
217  }
218 
219  /// Copy constructor
220  InterpN(const InterpN& other) : m_knots(other.m_knots), m_extrapolate(other.m_extrapolate) {
221  m_interpolators.resize(m_knots.size());
222  for (size_t i = 0; i < m_interpolators.size(); ++i) {
223  m_interpolators[i].reset(new InterpN<Rest...>(*other.m_interpolators[i]));
224  }
225  }
226 
227 private:
228  std::vector<T> m_knots;
229  std::vector<std::unique_ptr<InterpN<Rest...>>> m_interpolators;
230  bool m_extrapolate;
231 };
232 
233 } // namespace MathUtils
234 } // namespace Euclid