31using namespace ITHACAtorch::torch2Eigen;
39torch::Tensor eigenMatrix2torchTensor(
40 Eigen::Matrix<type, Eigen::Dynamic, Eigen::Dynamic> _eigenMatrix)
42 Eigen::MatrixXf eigenMatrix = _eigenMatrix.template cast <float> ();
43 int rows = eigenMatrix.rows();
44 int cols = eigenMatrix.cols();
46 if (!eigenMatrix.IsRowMajor)
48 eigenMatrix.transposeInPlace();
51 return torch::from_blob(eigenMatrix.data(), {rows, cols}).clone();
55Eigen::Matrix<type, Eigen::Dynamic, Eigen::Dynamic> torchTensor2eigenMatrix(
56 torch::Tensor& torchTensor)
58 std::string error_message(
"The provided tensor has " + std::to_string(
59 torchTensor.dim()) +
" dimensions and cannot be casted in a Matrix.");
60 M_Assert(torchTensor.dim() <= 2, error_message.c_str());
61 M_Assert(torchTensor.dim() != 0,
"The provided tensor has 0 dimension");
64 int nElem = torchTensor.size(0);
66 for (
int i = 1; i < torchTensor.dim(); i++)
68 nElem *= torchTensor.size(i);
71 if (torchTensor.dim() == 1)
73 rows = torchTensor.size(0);
78 rows = torchTensor.size(0);
79 cols = torchTensor.size(1);
82 typedef Eigen::Matrix<type, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
84 type* data_p = torchTensor.data_ptr<type>();
85 std::vector<type> raw(nElem);
87 for (
int i = 0; i < nElem; i++)
89 type d(* (data_p + i));
90 type a =
static_cast <type
>(d);
94 Eigen::Map<MatrixXf_rm> eigenMatrix(& raw[0], rows,
99template Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic>
100torchTensor2eigenMatrix<int>(torch::Tensor& torchTensor);
102template Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic>
103torchTensor2eigenMatrix<double>(torch::Tensor& torchTensor);
105template Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic>
106torchTensor2eigenMatrix<float>(torch::Tensor& torchTensor);
108template torch::Tensor eigenMatrix2torchTensor<float>(
109 Eigen::Matrix<float, Eigen::Dynamic, Eigen::Dynamic> eigenMatrix);
111template torch::Tensor eigenMatrix2torchTensor<double>(
112 Eigen::Matrix<double, Eigen::Dynamic, Eigen::Dynamic> eigenMatrix);
114template torch::Tensor eigenMatrix2torchTensor<int>(
115 Eigen::Matrix<int, Eigen::Dynamic, Eigen::Dynamic> eigenMatrix);
Header file of the torch2Eigen namespace.