12#include "linearsolvers.h"
20RBFSpline::RBFSpline(
const DataTable& samples, RadialBasisFunctionType type,
22 : RBFSpline(samples, type, false)
26RBFSpline::RBFSpline(
const DataTable& samples, RadialBasisFunctionType type,
27 DenseMatrix w,
double e)
31 dim(samples.getNumVariables()),
32 numSamples(samples.getNumSamples())
34 if (type == RadialBasisFunctionType::THIN_PLATE_SPLINE)
36 fn = std::shared_ptr<RadialBasisFunction>(
new ThinPlateSpline());
39 else if (type == RadialBasisFunctionType::MULTIQUADRIC)
41 fn = std::shared_ptr<RadialBasisFunction>(
new Multiquadric());
43 else if (type == RadialBasisFunctionType::INVERSE_QUADRIC)
45 fn = std::shared_ptr<RadialBasisFunction>(
new InverseQuadric());
47 else if (type == RadialBasisFunctionType::INVERSE_MULTIQUADRIC)
49 fn = std::shared_ptr<RadialBasisFunction>(
new InverseMultiquadric());
51 else if (type == RadialBasisFunctionType::GAUSSIAN)
53 fn = std::shared_ptr<RadialBasisFunction>(
new Gaussian());
56 else if (type == RadialBasisFunctionType::LINEAR)
58 fn = std::shared_ptr<RadialBasisFunction>(
new Linear());
60 else if (type == RadialBasisFunctionType::CUBIC)
62 fn = std::shared_ptr<RadialBasisFunction>(
new Cubic());
64 else if (type == RadialBasisFunctionType::QUINTIC)
66 fn = std::shared_ptr<RadialBasisFunction>(
new Quintic());
70 fn = std::shared_ptr<RadialBasisFunction>(
new ThinPlateSpline());
76RBFSpline::RBFSpline(
const DataTable& samples, RadialBasisFunctionType type,
77 bool normalized,
double e)
79 normalized(normalized),
81 dim(samples.getNumVariables()),
82 numSamples(samples.getNumSamples())
84 if (type == RadialBasisFunctionType::THIN_PLATE_SPLINE)
86 fn = std::shared_ptr<RadialBasisFunction>(
new ThinPlateSpline());
88 else if (type == RadialBasisFunctionType::MULTIQUADRIC)
90 fn = std::shared_ptr<RadialBasisFunction>(
new Multiquadric());
92 else if (type == RadialBasisFunctionType::INVERSE_QUADRIC)
94 fn = std::shared_ptr<RadialBasisFunction>(
new InverseQuadric());
96 else if (type == RadialBasisFunctionType::INVERSE_MULTIQUADRIC)
98 fn = std::shared_ptr<RadialBasisFunction>(
new InverseMultiquadric());
100 else if (type == RadialBasisFunctionType::GAUSSIAN)
102 fn = std::shared_ptr<RadialBasisFunction>(
new Gaussian());
105 else if (type == RadialBasisFunctionType::LINEAR)
107 fn = std::shared_ptr<RadialBasisFunction>(
new Linear());
109 else if (type == RadialBasisFunctionType::CUBIC)
111 fn = std::shared_ptr<RadialBasisFunction>(
new Cubic());
113 else if (type == RadialBasisFunctionType::QUINTIC)
115 fn = std::shared_ptr<RadialBasisFunction>(
new Quintic());
119 fn = std::shared_ptr<RadialBasisFunction>(
new ThinPlateSpline());
132 A.setZero(numSamples, numSamples);
134 b.setZero(numSamples, 1);
137 for (
auto it1 = samples.cbegin(); it1 != samples.cend(); ++it1, ++i)
142 for (auto it2 = samples.cbegin(); it2 != samples.cend(); ++it2, ++j)
144 double val = fn->eval(dist(* it1, * it2));
154 double y = it1->getY();
171 DenseMatrix P = computePreconditionMatrix();
173 DenseMatrix Ap = P * A;
174 DenseMatrix bp = P * b;
180 Foam::Info <<
"Computing RBF weights using dense solver." << Foam::endl;
181 Foam::Info <<
"The radius of the RBF is equal to " << e << Foam::endl;
197 weights = A.colPivHouseholderQr().solve(b);
200 double err = (A * weights - b).norm() / b.norm();
201 Foam::Info <<
"Error: " << Foam::setprecision(10) << err << Foam::endl;
210double RBFSpline::eval(DenseVector x)
const
212 std::vector<double> y;
214 for (
int i = 0; i < x.rows(); i++)
222double RBFSpline::eval(std::vector<double> x)
const
224 assert(x.size() == dim);
225 double fval, sum = 0, sumw = 0;
228 for (
auto it = samples.cbegin(); it != samples.cend(); ++it, ++i)
230 fval = fn->eval(dist(x, it->getX()));
231 sumw += weights(i) * fval;
235 return normalized ? sumw / sum : sumw;
293DenseMatrix RBFSpline::computePreconditionMatrix()
const
296 P.setZero(numSamples, numSamples);
299 int sigma = std::max(1.0,
300 std::floor(0.1 * numSamples));
303 for (
auto it1 = samples.cbegin(); it1 != samples.cend(); ++it1, ++i)
305 Point p1(it1->getX());
307 std::vector<Point> shifted_points;
310 for (
auto it2 = samples.cbegin(); it2 != samples.cend(); ++it2, ++j)
312 Point p2(it2->getX());
315 shifted_points.push_back(p3);
318 std::sort(shifted_points.begin(), shifted_points.end());
320 std::vector<Point> points;
321 std::vector<int> indices;
323 for (
int j = 0; j < sigma; j++)
325 Point p(shifted_points.at(j));
326 indices.push_back(p.getIndex());
328 p2.setIndex(p.getIndex());
336 for (
int k = 0; k < 1; k++)
338 Point p(shifted_points.at(shifted_points.size() - 1 - k));
339 indices.push_back(p.getIndex());
341 p2.setIndex(p.getIndex());
342 points.push_back(p2);
346 int m = points.size();
354 assert(points.front().getIndex() == i);
356 for (
int k1 = 0; k1 < m; k1++)
358 for (
int k2 = 0; k2 < m; k2++)
360 Point p = points.at(k1) - points.at(k2);
361 B(k1, k2) = fn->eval(p.dist());
368 Eigen::JacobiSVD<DenseMatrix> svd(B, Eigen::ComputeThinU | Eigen::ComputeThinV);
372 for (
unsigned int j = 0; j < numSamples; j++)
374 auto it = find(indices.begin(), indices.end(), j);
376 if (it != indices.end())
378 int k = it - indices.begin();
381 assert(points.at(k).getIndex() == j);
392double RBFSpline::dist(std::vector<double> x, std::vector<double> y)
const
394 assert(x.size() == y.size());
397 for (
unsigned int i = 0; i < x.size(); i++)
399 sum += (x.at(i) - y.at(i)) * (x.at(i) - y.at(i));
402 return std::sqrt(sum);
408double RBFSpline::dist(DataPoint x, DataPoint y)
const
410 return dist(x.getX(), y.getX());
413bool RBFSpline::dist_sort(DataPoint x, DataPoint y)
const
415 std::vector<double> zeros(x.getDimX(), 0);
416 DataPoint origin(zeros, 0.0);
417 double x_dist = dist(x, origin);
418 double y_dist = dist(y, origin);
419 return (x_dist < y_dist);