31using namespace ITHACAtorch::torch2Foam;
39torch::Tensor field2Torch(Field<vector>& field)
42 int cols = field.size() * 3;
43 double* dataPtr = & field[0][0];
44 return torch::from_blob(dataPtr, {rows, cols}, {torch::kFloat64}).clone().to(
49torch::Tensor field2Torch(Field<scalar>& field)
52 int cols = field.size();
53 double* dataPtr = & field[0];
54 return torch::from_blob(dataPtr, {rows, cols}, {torch::kFloat64}).clone().to(
59Field<vector> torch2Field(torch::Tensor& torchTensor)
61 std::string error_message(
"The provided tensor has " + std::to_string(
63 " dimensions and with the current implementation only 1-D tensor can be casted in an OpenFOAM field.");
64 M_Assert(torchTensor.dim() <= 2, error_message.c_str());
65 M_Assert(torchTensor.dim() != 0,
"The provided tensor has 0 dimension");
66 Field<vector> a(torchTensor.numel() / 3);
67 std::memcpy(& a[0][0], torchTensor.to(torch::kFloat64).data_ptr(),
68 sizeof (
double) * torchTensor.numel());
73Field<scalar> torch2Field(torch::Tensor& torchTensor)
75 std::string error_message(
"The provided tensor has " + std::to_string(
77 " dimensions and with the current implementation only 1-D tensor can be casted in an OpenFOAM field.");
78 M_Assert(torchTensor.dim() <= 2, error_message.c_str());
79 M_Assert(torchTensor.dim() != 0,
"The provided tensor has 0 dimension");
80 Field<scalar> a(torchTensor.numel());
81 std::memcpy(& a[0], torchTensor.to(torch::kFloat64).data_ptr(),
82 sizeof (
double) * torchTensor.numel());
87torch::Tensor ptrList2Torch(PtrList<Field<vector >>& ptrList)
89 int Nrows = ptrList.size();
90 int Ncols = ptrList[0].size() * 3;
91 torch::Tensor out = torch::randn({Nrows, Ncols});
93 for (
auto i = 0; i < ptrList.size(); i++)
95 out.slice(0, i, i + 1) = field2Torch(ptrList[i]);
102torch::Tensor ptrList2Torch(PtrList<Field<scalar >>& ptrList)
104 int Nrows = ptrList.size();
105 int Ncols = ptrList[0].size();
106 torch::Tensor out = torch::randn({Nrows, Ncols});
108 for (
auto i = 0; i < ptrList.size(); i++)
110 out.slice(0, i, i + 1) = field2Torch(ptrList[i]);
116template<
class type_f>
117PtrList<Field<type_f >> torch2PtrList(torch::Tensor& tTensor)
119 PtrList<Field<type_f >> out;
121 for (
auto i = 0; i < tTensor.size(0); i++)
123 torch::Tensor t = tTensor.slice(0, i, i + 1);
124 out.append(tmp<Field<type_f >> (torch2Field<type_f>(t)));
130template PtrList<Field<scalar >> torch2PtrList<scalar>(torch::Tensor& tTensor);
131template PtrList<Field<vector >> torch2PtrList<vector>(torch::Tensor& tTensor);
Header file of the torch2Foam namespace.