34template<
class Type,
template<
class>
class PatchField,
class GeoMesh>
36 PtrList<GeometricField<Type, PatchField, GeoMesh >>& snapshots):
37 _snapshots(snapshots),
38 mesh(snapshots[0].mesh()),
39 convDict(autoPtr<IOdictionary>
54 flt = autoPtr<Filter>(Filter::New(word(convDict().lookup(
"Filter")),
56 domainDivision = Vector<label>(convDict().lookup(
"domainDivision"));
57 filterSize = Vector<scalar>(convDict().lookup(
"filterSize"));
58 domainSize = mesh.bounds().max() - mesh.bounds().min();
59 setDomainDivision(domainDivision[0], domainDivision[1], domainDivision[2]);
60 setFilterSize(filterSize[0], filterSize[1], filterSize[2]);
61 weights = flt->apply(cellsInBoxes, convPoints, mesh);
64template<
class Type,
template<
class>
class PatchField,
class GeoMesh>
65void ConvLayer<Type, PatchField, GeoMesh>::setDomainDivision(label Nx, label Ny,
68 M_Assert(((Nx != 1 && mesh.solutionD()[0] != -1 ) || (Nx == 1 &&
69 mesh.solutionD()[0] == -1)),
70 "The mesh has valid components only along the y and z directions, set Nx = 1");
71 M_Assert(((Ny != 1 && mesh.solutionD()[1] != -1 ) || (Ny == 1 &&
72 mesh.solutionD()[1] == -1)),
73 "The mesh has valid components only along the x and z directions, set Ny = 1");
74 M_Assert(((Nz != 1 && mesh.solutionD()[2] != -1 ) || (Nz == 1 &&
75 mesh.solutionD()[2] == -1)),
76 "The mesh has valid components only along the x and y directions, set Nz = 1");
78 for (label i = 0; i < ds.size(); i++)
80 if (mesh.solutionD()[i] != -1 && domainDivision[i] != 1)
82 ds[i] = domainSize[i] / (domainDivision[i] - 1);
86 ds[i] = domainSize[i] / 2;
90 convPoints = List<point>(Nx * Ny * Nz);
93 for (label i = 0; i < Nx; i++)
95 for (label j = 0; j < Ny; j++)
97 for (label k = 0; k < Nz; k++)
99 if (i == 0 && Nx == 1)
104 if (j == 0 && Ny == 1)
109 if (k == 0 && Nz == 1)
114 convPoints[index] = mesh.bounds().min() + cmptMultiply((ds * i), vector(1, 0,
115 0)) + cmptMultiply((ds * j), vector(0, 1, 0)) + cmptMultiply((ds * k), vector(0,
122 isDomainDivisionSet =
true;
125template<
class Type,
template<
class>
class PatchField,
class GeoMesh>
126void ConvLayer<Type, PatchField, GeoMesh>::setFilterSize(
double dx,
double dy,
129 M_Assert(isDomainDivisionSet,
"You need to set the division before.");
133 cellsInBoxes.resize(convPoints.size());
135 for (label i = 0; i < convPoints.size(); i++)
137 cellSet a(mesh,
"set", 0);
138 point mini = convPoints[i] - filterSize / 2;
139 point maxi = convPoints[i] + filterSize / 2;
140 treeBoundBox boxi(mini, maxi);
141 List<treeBoundBox> l;
143 boxToCell finding(mesh, l);
145 finding.verbose(
false);
147 finding.applyToSet(topoSetSource::ADD, a);
148 cellsInBoxes[i] = a.toc();
151 isFilterSizeSet =
true;
165torch::Tensor ConvLayer<scalar, fvPatchField, volMesh>::filter()
167 M_Assert(isDomainDivisionSet &&
169 "You need to set domainDivision and filterSize before calling the filter funtion.");
170 label Nx = domainDivision[0];
171 label Ny = domainDivision[1];
172 label Nz = domainDivision[2];
173 torch::Tensor output = torch::zeros({_snapshots.size(), 1, Nx, Ny, Nz});
174 auto foo_a = output.accessor<float, 5>();
176 for (label i = 0; i < _snapshots.size(); i++)
180 for (label j = 0; j < Nx; j++)
182 for (label k = 0; k < Ny; k++)
184 for (label l = 0; l < Nz; l++)
186 for (label p = 0; p < cellsInBoxes[index].size(); p++)
188 foo_a[i][0][j][k][l] += _snapshots[i][cellsInBoxes[index][p]] *
203torch::Tensor ConvLayer<vector, fvPatchField, volMesh>::filter()
205 M_Assert(isDomainDivisionSet &&
207 "You need to set domainDivision and filterSize before calling the filter funtion.");
208 label Nx = domainDivision[0];
209 label Ny = domainDivision[1];
210 label Nz = domainDivision[2];
211 torch::Tensor output = torch::zeros({_snapshots.size(), 3, Nx, Ny, Nz});
212 auto foo_a = output.accessor<float, 5>();
214 for (label i = 0; i < _snapshots.size(); i++)
218 for (label j = 0; j < Nx; j++)
220 for (label k = 0; k < Ny; k++)
222 for (label l = 0; l < Nz; l++)
224 for (label p = 0; p < cellsInBoxes[index].size(); p++)
226 foo_a[i][0][j][k][l] += _snapshots[i][cellsInBoxes[index][p]][0] *
228 foo_a[i][1][j][k][l] += _snapshots[i][cellsInBoxes[index][p]][1] *
230 foo_a[i][2][j][k][l] += _snapshots[i][cellsInBoxes[index][p]][2] *
Header file of the ConvLayer class.
ConvLayer(PtrList< GeometricField< Type, PatchField, GeoMesh > > &snapshots)
Construct using Time as functionObject.