Loading...
Searching...
No Matches
torch2Foam.C
Go to the documentation of this file.
1/*---------------------------------------------------------------------------*\
2 ██╗████████╗██╗ ██╗ █████╗ ██████╗ █████╗ ███████╗██╗ ██╗
3 ██║╚══██╔══╝██║ ██║██╔══██╗██╔════╝██╔══██╗ ██╔════╝██║ ██║
4 ██║ ██║ ███████║███████║██║ ███████║█████╗█████╗ ██║ ██║
5 ██║ ██║ ██╔══██║██╔══██║██║ ██╔══██║╚════╝██╔══╝ ╚██╗ ██╔╝
6 ██║ ██║ ██║ ██║██║ ██║╚██████╗██║ ██║ ██║ ╚████╔╝
7 ╚═╝ ╚═╝ ╚═╝ ╚═╝╚═╝ ╚═╝ ╚═════╝╚═╝ ╚═╝ ╚═╝ ╚═══╝
8
9 * In real Time Highly Advanced Computational Applications for Finite Volumes
10 * Copyright (C) 2017 by the ITHACA-FV authors
11-------------------------------------------------------------------------------
12
13 License
14 This file is part of ITHACA-FV
15
16 ITHACA-FV is free software: you can redistribute it and/or modify
17 it under the terms of the GNU Lesser General Public License as published by
18 the Free Software Foundation, either version 3 of the License, or
19 (at your option) any later version.
20
21 ITHACA-FV is distributed in the hope that it will be useful,
22 but WITHOUT ANY WARRANTY; without even the implied warranty of
23 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24 GNU Lesser General Public License for more details.
25
26 You should have received a copy of the GNU Lesser General Public License
27 along with ITHACA-FV. If not, see <http://www.gnu.org/licenses/>.
28
29\*---------------------------------------------------------------------------*/
30#include "torch2Foam.H"
31using namespace ITHACAtorch::torch2Foam;
32
33namespace ITHACAtorch
34{
35namespace torch2Foam
36{
37
38template<>
39torch::Tensor field2Torch(Field<vector>& field)
40{
41 int rows = 1;
42 int cols = field.size() * 3;
43 double* dataPtr = &field[0][0];
44 return torch::from_blob(dataPtr, {rows, cols}, {torch::kFloat64}).clone().to(
45 torch::kFloat32);
46}
47
48template<>
49torch::Tensor field2Torch(Field<scalar>& field)
50{
51 int rows = 1;
52 int cols = field.size();
53 double* dataPtr = &field[0];
54 return torch::from_blob(dataPtr, {rows, cols}, {torch::kFloat64}).clone().to(
55 torch::kFloat32);
56}
57
58template<>
59Field<vector> torch2Field(torch::Tensor& torchTensor)
60{
61 std::string error_message("The provided tensor has " + std::to_string(
62 torchTensor.dim()) +
63 " dimensions and with the current implementation only 1-D tensor can be casted in an OpenFOAM field.");
64 M_Assert(torchTensor.dim() <= 2, error_message.c_str());
65 M_Assert(torchTensor.dim() != 0, "The provided tensor has 0 dimension");
66 Field<vector> a(torchTensor.numel() / 3);
67 std::memcpy(&a[0][0], torchTensor.to(torch::kFloat64).data_ptr(),
68 sizeof (double)*torchTensor.numel());
69 return a;
70}
71
72template<>
73Field<scalar> torch2Field(torch::Tensor& torchTensor)
74{
75 std::string error_message("The provided tensor has " + std::to_string(
76 torchTensor.dim()) +
77 " dimensions and with the current implementation only 1-D tensor can be casted in an OpenFOAM field.");
78 M_Assert(torchTensor.dim() <= 2, error_message.c_str());
79 M_Assert(torchTensor.dim() != 0, "The provided tensor has 0 dimension");
80 Field<scalar> a(torchTensor.numel());
81 std::memcpy(&a[0], torchTensor.to(torch::kFloat64).data_ptr(),
82 sizeof (double)*torchTensor.numel());
83 return a;
84}
85
86template<>
87torch::Tensor ptrList2Torch(PtrList<Field<vector>>& ptrList)
88{
89 int Nrows = ptrList.size();
90 int Ncols = ptrList[0].size() * 3;
91 torch::Tensor out = torch::randn({Nrows, Ncols});
92
93 for (auto i = 0; i < ptrList.size(); i++)
94 {
95 out.slice(0, i, i + 1) = field2Torch(ptrList[i]);
96 }
97
98 return out;
99}
100
101template<>
102torch::Tensor ptrList2Torch(PtrList<Field<scalar>>& ptrList)
103{
104 int Nrows = ptrList.size();
105 int Ncols = ptrList[0].size();
106 torch::Tensor out = torch::randn({Nrows, Ncols});
107
108 for (auto i = 0; i < ptrList.size(); i++)
109 {
110 out.slice(0, i, i + 1) = field2Torch(ptrList[i]);
111 }
112
113 return out;
114}
115
116template<class type_f>
117PtrList<Field<type_f>> torch2PtrList(torch::Tensor& tTensor)
118{
119 PtrList<Field<type_f>> out;
120
121 for (auto i = 0; i < tTensor.size(0); i++)
122 {
123 torch::Tensor t = tTensor.slice(0, i, i + 1);
124 out.append(tmp<Field<type_f>>(torch2Field<type_f>(t)));
125 }
126
127 return out;
128}
129
130template PtrList<Field<scalar>> torch2PtrList<scalar>(torch::Tensor& tTensor);
131template PtrList<Field<vector>> torch2PtrList<vector>(torch::Tensor& tTensor);
132
133
134
135
136}
137
138}
#define M_Assert(Expr, Msg)
torch::Tensor field2Torch(Field< vector > &field)
Definition torch2Foam.C:39
Field< vector > torch2Field(torch::Tensor &torchTensor)
Definition torch2Foam.C:59
torch::Tensor ptrList2Torch(PtrList< Field< vector > > &ptrList)
Definition torch2Foam.C:87
PtrList< Field< type_f > > torch2PtrList(torch::Tensor &tTensor)
Definition torch2Foam.C:117
template PtrList< Field< scalar > > torch2PtrList< scalar >(torch::Tensor &tTensor)
template PtrList< Field< vector > > torch2PtrList< vector >(torch::Tensor &tTensor)
label i
Definition pEqn.H:46
Header file of the torch2Foam namespace.