Loading...
Searching...
No Matches
torch2Foam.C
Go to the documentation of this file.
1/*---------------------------------------------------------------------------*\
2 ██╗████████╗██╗ ██╗ █████╗ ██████╗ █████╗ ███████╗██╗ ██╗
3 ██║╚══██╔══╝██║ ██║██╔══██╗██╔════╝██╔══██╗ ██╔════╝██║ ██║
4 ██║ ██║ ███████║███████║██║ ███████║█████╗█████╗ ██║ ██║
5 ██║ ██║ ██╔══██║██╔══██║██║ ██╔══██║╚════╝██╔══╝ ╚██╗ ██╔╝
6 ██║ ██║ ██║ ██║██║ ██║╚██████╗██║ ██║ ██║ ╚████╔╝
7 ╚═╝ ╚═╝ ╚═╝ ╚═╝╚═╝ ╚═╝ ╚═════╝╚═╝ ╚═╝ ╚═╝ ╚═══╝
8
9 * In real Time Highly Advanced Computational Applications for Finite Volumes
10 * Copyright (C) 2017 by the ITHACA-FV authors
11-------------------------------------------------------------------------------
12
13 License
14 This file is part of ITHACA-FV
15
16 ITHACA-FV is free software: you can redistribute it and/or modify
17 it under the terms of the GNU Lesser General Public License as published by
18 the Free Software Foundation, either version 3 of the License, or
19 (at your option) any later version.
20
21 ITHACA-FV is distributed in the hope that it will be useful,
22 but WITHOUT ANY WARRANTY; without even the implied warranty of
23 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24 GNU Lesser General Public License for more details.
25
26 You should have received a copy of the GNU Lesser General Public License
27 along with ITHACA-FV. If not, see <http://www.gnu.org/licenses/>.
28
29\*---------------------------------------------------------------------------*/
30#include "torch2Foam.H"
31using namespace ITHACAtorch::torch2Foam;
32
33namespace ITHACAtorch
34{
35namespace torch2Foam
36{
37
38template<>
39torch::Tensor field2Torch(Field<vector>& field)
40{
41 int rows = 1;
42 int cols = field.size() * 3;
43 double* dataPtr = & field[0][0];
44 return torch::from_blob(dataPtr, {rows, cols}, {torch::kFloat64}).clone().to(
45 torch::kFloat32);
46}
47
48template<>
49torch::Tensor field2Torch(Field<scalar>& field)
50{
51 int rows = 1;
52 int cols = field.size();
53 double* dataPtr = & field[0];
54 return torch::from_blob(dataPtr, {rows, cols}, {torch::kFloat64}).clone().to(
55 torch::kFloat32);
56}
57
58template<>
59Field<vector> torch2Field(torch::Tensor& torchTensor)
60{
61 std::string error_message("The provided tensor has " + std::to_string(
62 torchTensor.dim()) +
63 " dimensions and with the current implementation only 1-D tensor can be casted in an OpenFOAM field.");
64 M_Assert(torchTensor.dim() <= 2, error_message.c_str());
65 M_Assert(torchTensor.dim() != 0, "The provided tensor has 0 dimension");
66 Field<vector> a(torchTensor.numel() / 3);
67 std::memcpy(& a[0][0], torchTensor.to(torch::kFloat64).data_ptr(),
68 sizeof (double) * torchTensor.numel());
69 return a;
70}
71
72template<>
73Field<scalar> torch2Field(torch::Tensor& torchTensor)
74{
75 std::string error_message("The provided tensor has " + std::to_string(
76 torchTensor.dim()) +
77 " dimensions and with the current implementation only 1-D tensor can be casted in an OpenFOAM field.");
78 M_Assert(torchTensor.dim() <= 2, error_message.c_str());
79 M_Assert(torchTensor.dim() != 0, "The provided tensor has 0 dimension");
80 Field<scalar> a(torchTensor.numel());
81 std::memcpy(& a[0], torchTensor.to(torch::kFloat64).data_ptr(),
82 sizeof (double) * torchTensor.numel());
83 return a;
84}
85
86template<>
87torch::Tensor ptrList2Torch(PtrList<Field<vector >> & ptrList)
88{
89 int Nrows = ptrList.size();
90 int Ncols = ptrList[0].size() * 3;
91 torch::Tensor out = torch::randn({Nrows, Ncols});
92 for (auto i = 0; i < ptrList.size(); i++)
93 {
94 out.slice(0, i, i + 1) = field2Torch(ptrList[i]);
95 }
96
97 return out;
98}
99
100template<>
101torch::Tensor ptrList2Torch(PtrList<Field<scalar >> & ptrList)
102{
103 int Nrows = ptrList.size();
104 int Ncols = ptrList[0].size();
105 torch::Tensor out = torch::randn({Nrows, Ncols});
106 for (auto i = 0; i < ptrList.size(); i++)
107 {
108 out.slice(0, i, i + 1) = field2Torch(ptrList[i]);
109 }
110
111 return out;
112}
113
114template<class type_f>
115PtrList<Field<type_f >> torch2PtrList(torch::Tensor& tTensor)
116{
117 PtrList<Field<type_f >> out;
118
119 for (auto i = 0; i < tTensor.size(0); i++)
120 {
121 torch::Tensor t = tTensor.slice(0, i, i + 1);
122 out.append(tmp<Field<type_f >> (torch2Field<type_f>(t)));
123 }
124 return out;
125}
126
127template PtrList<Field<scalar >> torch2PtrList<scalar>(torch::Tensor& tTensor);
128template PtrList<Field<vector >> torch2PtrList<vector>(torch::Tensor& tTensor);
129
130
131
132
133}
134
135}
#define M_Assert(Expr, Msg)
torch::Tensor field2Torch(Field< vector > &field)
Definition torch2Foam.C:39
Field< vector > torch2Field(torch::Tensor &torchTensor)
Definition torch2Foam.C:59
torch::Tensor ptrList2Torch(PtrList< Field< vector > > &ptrList)
Definition torch2Foam.C:87
PtrList< Field< type_f > > torch2PtrList(torch::Tensor &tTensor)
Definition torch2Foam.C:115
template PtrList< Field< scalar > > torch2PtrList< scalar >(torch::Tensor &tTensor)
template PtrList< Field< vector > > torch2PtrList< vector >(torch::Tensor &tTensor)
label i
Definition pEqn.H:46
Header file of the torch2Foam namespace.