Loading...
Searching...
No Matches
torchUTILITIES.C
Go to the documentation of this file.
1/*---------------------------------------------------------------------------*\
2 ██╗████████╗██╗ ██╗ █████╗ ██████╗ █████╗ ███████╗██╗ ██╗
3 ██║╚══██╔══╝██║ ██║██╔══██╗██╔════╝██╔══██╗ ██╔════╝██║ ██║
4 ██║ ██║ ███████║███████║██║ ███████║█████╗█████╗ ██║ ██║
5 ██║ ██║ ██╔══██║██╔══██║██║ ██╔══██║╚════╝██╔══╝ ╚██╗ ██╔╝
6 ██║ ██║ ██║ ██║██║ ██║╚██████╗██║ ██║ ██║ ╚████╔╝
7 ╚═╝ ╚═╝ ╚═╝ ╚═╝╚═╝ ╚═╝ ╚═════╝╚═╝ ╚═╝ ╚═╝ ╚═══╝
8
9 * In real Time Highly Advanced Computational Applications for Finite Volumes
10 * Copyright (C) 2017 by the ITHACA-FV authors
11-------------------------------------------------------------------------------
12
13 License
14 This file is part of ITHACA-FV
15
16 ITHACA-FV is free software: you can redistribute it and/or modify
17 it under the terms of the GNU Lesser General Public License as published by
18 the Free Software Foundation, either version 3 of the License, or
19 (at your option) any later version.
20
21 ITHACA-FV is distributed in the hope that it will be useful,
22 but WITHOUT ANY WARRANTY; without even the implied warranty of
23 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
24 GNU Lesser General Public License for more details.
25
26 You should have received a copy of the GNU Lesser General Public License
27 along with ITHACA-FV. If not, see <http://www.gnu.org/licenses/>.
28
29\*---------------------------------------------------------------------------*/
30
31#include "torchUTILITIES.H"
32
33namespace ITHACAtorch
34{
35torch::Tensor removeConstValues(torch::Tensor input, std::vector<int>& indices,
36 std::vector<double>& constValues)
37{
38 torch::Tensor output;
39 return output;
40}
41
42bool isConst(torch::Tensor& tTensor)
43{
44 const float* ptr = (float*)tTensor.data_ptr();
45 bool all_equal = true;
46
47 for (std::size_t i = 1, s = tTensor.numel(); i < s && all_equal; i++)
48 {
49 all_equal = *ptr == *(ptr + i);
50 }
51
52 return all_equal;
53}
54
55void save(const torch::Tensor& torchTensor, const std::string fname)
56{
57 unsigned int dim = torchTensor.dim();
58 std::vector<size_t> shape = {dim};
59 // unsigned int shape[dim];
60 float* data_p = torchTensor.data_ptr<float>();
61
62 for (unsigned int i = 0; i < torchTensor.dim(); i++)
63 {
64 shape[i] = (unsigned int) torchTensor.size(i);
65 }
66
67 cnpy::npy_save(fname, data_p, shape);
68}
69
70torch::Tensor load(const std::string fname)
71{
72 cnpy::NpyArray arr = cnpy::npy_load(fname);
73 at::IntArrayRef shape[arr.shape.size()];
74 std::vector<int64_t> dims(arr.shape.size());
75
76 for (int i = 0; i < arr.shape.size(); i++)
77 {
78 dims[i] = (int64_t) arr.shape[i];
79 }
80
81 torch::Tensor tensor = torch::randn({2, 2, 3});
82 return torch::from_blob(arr.data<float>(), dims).clone();
83 //return tensor;
84}
85
86
87}
torch::Tensor load(const std::string fname)
torch::Tensor removeConstValues(torch::Tensor input, std::vector< int > &indices, std::vector< double > &constValues)
bool isConst(torch::Tensor &tTensor)
void save(const torch::Tensor &torchTensor, const std::string fname)
void npy_save(std::string fname, const T *data, const std::vector< size_t > shape, std::string mode="w")
Definition cnpy.H:198
NpyArray npy_load(std::string fname)
Definition cnpy.C:464
label i
Definition pEqn.H:46
T * data()
Definition cnpy.H:52
std::vector< size_t > shape
Definition cnpy.H:156
Header file of the torchUTILITIES file.