35torch::Tensor removeConstValues(torch::Tensor input, std::vector<int>& indices,
36 std::vector<double>& constValues)
42bool isConst(torch::Tensor& tTensor)
44 const float* ptr = (
float*)tTensor.data_ptr();
45 bool all_equal =
true;
47 for (std::size_t i = 1, s = tTensor.numel(); i < s && all_equal; i++)
49 all_equal = * ptr == * (ptr + i);
55void save(
const torch::Tensor& torchTensor,
const std::string fname)
57 unsigned int dim = torchTensor.dim();
58 std::vector<size_t> shape = {dim};
60 float* data_p = torchTensor.data_ptr<
float>();
62 for (
unsigned int i = 0; i < torchTensor.dim(); i++)
64 shape[i] = (
unsigned int) torchTensor.size(i);
67 cnpy::npy_save(fname, data_p, shape);
70torch::Tensor load(
const std::string fname)
72 cnpy::NpyArray arr = cnpy::npy_load(fname);
73 at::IntArrayRef shape[arr.shape.size()];
74 std::vector<int64_t> dims(arr.shape.size());
76 for (
int i = 0; i < arr.shape.size(); i++)
78 dims[i] = (int64_t) arr.shape[i];
81 torch::Tensor tensor = torch::randn({2, 2, 3});
82 return torch::from_blob(arr.data<
float>(), dims).clone();
Header file of the torchUTILITIES file.