Loading...
Searching...
No Matches
mtbGPR.C
1/*---------------------------------------------------------------------------*\
2 ██╗████████╗██╗ ██╗ █████╗ ██████╗ █████╗ ███████╗██╗ ██╗
3 ██║╚══██╔══╝██║ ██║██╔══██╗██╔════╝██╔══██╗ ██╔════╝██║ ██║
4 ██║ ██║ ███████║███████║██║ ███████║█████╗█████╗ ██║ ██║
5 ██║ ██║ ██╔══██║██╔══██║██║ ██╔══██║╚════╝██╔══╝ ╚██╗ ██╔╝
6 ██║ ██║ ██║ ██║██║ ██║╚██████╗██║ ██║ ██║ ╚████╔╝
7 ╚═╝ ╚═╝ ╚═╝ ╚═╝╚═╝ ╚═╝ ╚═════╝╚═╝ ╚═╝ ╚═╝ ╚═══╝
8
9 * In real Time Highly Advanced Computational Applications for Finite Volumes
10 * Copyright (C) 2017 by the ITHACA-FV authors
11-------------------------------------------------------------------------------
12
13License
14 This file is part of ITHACA-FV
15
16 ITHACA-FV is free software: you can redistribute it and/or modify
17 it under the terms of the GNU Lesser General Public License as published by
18 the Free Software Foundation, either version 3 of the License, or
19 (at your option) any later version.
20
21 ITHACA-FV is distributed in the hope that it will be useful,
22 but WITHOUT ANY WARRANTY; without even the implied warranty of
23 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24 GNU Lesser General Public License for more details.
25
26 You should have received a copy of the GNU Lesser General Public License
27 along with ITHACA-FV. If not, see <http://www.gnu.org/licenses/>.
28
29\*---------------------------------------------------------------------------*/
30
31#include "mtbGPR.H"
32#include "error.H"
33
34mtbGPR::mtbGPR(const Foam::dictionary& dict)
35{
36 kernelTypeWord_ = dict.lookupOrDefault<Foam::word>("kernel", "matern");
37 useDataNormalization_ = dict.lookupOrDefault<bool>("normalize", true);
38 optimizeHyperparams_ = dict.lookupOrDefault<bool>("optimizeHyperparams", true);
39 kernelScale_ = dict.lookupOrDefault<Foam::scalar>("kernelScale", 0.10);
40 lengthScale_ = dict.lookupOrDefault<Foam::scalar>("lengthScale", 0.10);
41 noise_ = dict.lookupOrDefault<Foam::scalar>("noise", 1e-5);
42}
43
44mtbGPR::~mtbGPR() = default;
45
46mathtoolbox::GaussianProcessRegressor::KernelType mtbGPR::parseKernelType(const Foam::word& kernelWord) const
47{
48 const Foam::word lower = kernelWord;
49 if (lower == "matern")
50 {
51 return mathtoolbox::GaussianProcessRegressor::KernelType::ArdMatern52;
52 }
53 else if (lower == "squared_exp")
54 {
55 return mathtoolbox::GaussianProcessRegressor::KernelType::ArdSquaredExp;
56 }
57
58 FatalErrorInFunction
59 << "Unknown GPR kernel: " << kernelWord
60 << ". Valid options are: matern, squared_exp"
61 << Foam::exit(Foam::FatalError);
62 return mathtoolbox::GaussianProcessRegressor::KernelType::ArdMatern52; // unreachable
63}
64
65void mtbGPR::fit(const Eigen::MatrixXd& X, const Eigen::VectorXd& y)
66{
67 if (X.rows() == 0)
68 {
69 FatalErrorInFunction << "Input matrix has zero rows" << Foam::exit(Foam::FatalError);
70 }
71 if (X.cols() != y.size())
72 {
73 FatalErrorInFunction
74 << "Input size mismatch: cols(X) = " << X.cols() << ", size(y) = " << y.size()
75 << Foam::exit(Foam::FatalError);
76 }
77
78 const auto kernelType = parseKernelType(kernelTypeWord_);
79
80 Eigen::VectorXd kernelHyperparams = Eigen::VectorXd::Constant(X.rows() + 1, lengthScale_);
81 kernelHyperparams[0] = kernelScale_;
82
83 impl_ = std::make_unique<mathtoolbox::GaussianProcessRegressor>(X, y, kernelType, useDataNormalization_);
84
85 if (optimizeHyperparams_)
86 {
87 impl_->PerformMaximumLikelihood(kernelHyperparams, noise_);
88 }
89 else
90 {
91 impl_->SetHyperparams(kernelHyperparams, noise_);
92 }
93}
94
95Foam::scalar mtbGPR::predict(const Eigen::VectorXd& x)
96{
97 if (!impl_)
98 {
99 FatalErrorInFunction << "mtbGPR used before calling fit()" << Foam::exit(Foam::FatalError);
100 }
101 return impl_->PredictMean(x);
102}
103
104Eigen::VectorXd mtbGPR::predict(const Eigen::MatrixXd& X)
105{
106 if (!impl_)
107 {
108 FatalErrorInFunction << "mtbGPR used before calling fit()" << Foam::exit(Foam::FatalError);
109 }
110
111 Eigen::VectorXd result(X.cols());
112 for (int i = 0; i < X.cols(); ++i)
113 {
114 result(i) = impl_->PredictMean(X.col(i));
115 }
116 return result;
117}
118
119void mtbGPR::printInfo() const
120{
121 Foam::Info << "mtbGPR Model Info:" << Foam::endl;
122 Foam::Info << "\t kernel: " << kernelTypeWord_ << Foam::endl;
123 Foam::Info << "\t normalize: " << (useDataNormalization_ ? "true" : "false") << Foam::endl;
124 Foam::Info << "\t optimizeHyperparams: " << (optimizeHyperparams_ ? "true" : "false") << Foam::endl;
125 Foam::Info << "\t kernelScale: " << kernelScale_ << Foam::endl;
126 Foam::Info << "\t lengthScale: " << lengthScale_ << Foam::endl;
127 Foam::Info << "\t noise: " << noise_ << Foam::endl;
128}