Skip to content

Commit 9277241

Browse files
Merge pull request #113 from EXP-code/WrapperFix
`orthoCheck` Python wrapper is missing its argument. This is fixed in this PR. Plus: bug fixes for velocity field calculations.
2 parents 03caf5e + be1514b commit 9277241

7 files changed

Lines changed: 73 additions & 25 deletions

File tree

expui/BasisFactory.H

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,10 @@ namespace BasisClasses
259259
if (Naccel > 0) pseudo = currentAccel(time);
260260
}
261261

262+
//! Get the field label vector
263+
std::vector<std::string> getFieldLabels(void)
264+
{ return getFieldLabels(coordinates); }
265+
262266
};
263267

264268
using BasisPtr = std::shared_ptr<Basis>;

expui/Coefficients.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2792,7 +2792,7 @@ namespace CoefClasses
27922792

27932793
for (int t=0; t<ntim; t++) {
27942794
auto & cof = *(coefs[roundTime(times[t])]->coefs);
2795-
for (int i=0; i<4; i++) {
2795+
for (int i=0; i<Nfld; i++) {
27962796
for (int l=0; l<(Lmax+2)*(Lmax+1)/2; l++) {
27972797
for (int n=0; n<Nmax; n++) {
27982798
ret(i, l, n, t) = cof(i, l, n);

expui/FieldBasis.H

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,13 @@ namespace BasisClasses
9898
virtual std::vector<double>
9999
crt_eval(double x, double y, double z);
100100

101+
//! Get the field labels
102+
std::vector<std::string> getFieldLabels(const Coord ctype)
103+
{ return fieldLabels; }
104+
101105
public:
102106

103-
//! Constructor from YAML node
107+
//! Constructor from YAML node
104108
FieldBasis(const YAML::Node& conf,
105109
const std::string& name="FieldBasis") : Basis(conf, name)
106110
{ configure(); }
@@ -158,10 +162,6 @@ namespace BasisClasses
158162
{
159163
}
160164

161-
//! Get the field labels
162-
std::vector<std::string> getFieldLabels(const Coord ctype)
163-
{ return fieldLabels; }
164-
165165
//! Return current maximum harmonic order in expansion
166166
int getLmax() { return lmax; }
167167

expui/FieldBasis.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -731,8 +731,8 @@ namespace BasisClasses
731731
double r = sqrt(R*R + z*z);
732732

733733
double vr = (u*x + v*y + w*z)/r;
734-
double vt = (u*z*x + v*z*y - w*R)/R/r;
735-
double vp = (u*y - v*x)/R;
734+
double vt = (u*z*x + v*z*y - w*R*R)/R/r;
735+
double vp = (v*x - u*y)/R;
736736

737737
return {vr, vt, vp, vr*vr, vt*vt, vp*vp};
738738
}

exputil/EmpCylSL.cc

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include <thread>
1717
#include <exp_thread.h>
1818
#include <EXPException.H>
19+
#include <exputils.H>
1920

2021
#include <Eigen/Eigenvalues>
2122

@@ -391,6 +392,8 @@ void EmpCylSL::reset(int numr, int lmax, int mmax, int nord,
391392
ortho = std::make_shared<SLGridSph>(make_sl(), LMAX, NMAX, NUMR,
392393
RMIN, RMAX*0.99, false, 1, 1.0);
393394

395+
orthoTest(ortho->orthoCheck(std::max<int>(NMAX*50, 200)), "EmpCylSL[SLGridSph]", "l");
396+
394397
// Resize (should not be necessary) but just in case some future
395398
// feature changes mulitstep on the fly
396399
//
@@ -868,6 +871,8 @@ int EmpCylSL::read_eof_file(const string& eof_file)
868871
ortho = std::make_shared<SLGridSph>(make_sl(), LMAX, NMAX, NUMR,
869872
RMIN, RMAX*0.99, false, 1, 1.0);
870873

874+
orthoTest(ortho->orthoCheck(std::max<int>(NMAX*50, 200)), "EmpCylSL[SLGridSph]", "l");
875+
871876
setup_eof();
872877
setup_accumulation();
873878

@@ -1434,10 +1439,13 @@ void EmpCylSL::compute_eof_grid(int request_id, int m)
14341439
{
14351440
// Check for existence of ortho and create if necessary
14361441
//
1437-
if (not ortho)
1442+
if (not ortho) {
14381443
ortho = std::make_shared<SLGridSph>(make_sl(), LMAX, NMAX, NUMR,
14391444
RMIN, RMAX*0.99, false, 1, 1.0);
14401445

1446+
orthoTest(ortho->orthoCheck(std::max<int>(NMAX*50, 200)), "EmpCylSL[SLGridSph]", "l");
1447+
}
1448+
14411449

14421450
// Read in coefficient matrix or
14431451
// make grid if needed
@@ -1613,11 +1621,14 @@ void EmpCylSL::compute_even_odd(int request_id, int m)
16131621
{
16141622
// check for ortho
16151623
//
1616-
if (not ortho)
1624+
if (not ortho) {
16171625
ortho = std::make_shared<SLGridSph>(make_sl(),
16181626
LMAX, NMAX, NUMR, RMIN, RMAX*0.99,
16191627
false, 1, 1.0);
16201628

1629+
orthoTest(ortho->orthoCheck(std::max<int>(NMAX*50, 200)), "EmpCylSL[SLGridSph]", "l");
1630+
}
1631+
16211632
double dens, potl, potr, pott;
16221633

16231634
int icnt, off;
@@ -2293,9 +2304,14 @@ void EmpCylSL::generate_eof(int numr, int nump, int numt,
22932304

22942305
// Create spherical orthogonal basis if necessary
22952306
//
2296-
if (not ortho)
2307+
if (not ortho) {
22972308
ortho = std::make_shared<SLGridSph>(make_sl(), LMAX, NMAX, NUMR,
22982309
RMIN, RMAX*0.99, false, 1, 1.0);
2310+
2311+
orthoTest(ortho->orthoCheck(std::max<int>(NMAX*50, 200)), "EmpCylSL[SLGridSph]", "l");
2312+
}
2313+
2314+
22992315
// Initialize fixed variables and storage
23002316
//
23012317
setup_eof();
@@ -2585,9 +2601,13 @@ void EmpCylSL::generate_eof(int numr, int nump, int numt,
25852601
void EmpCylSL::accumulate_eof(double r, double z, double phi, double mass,
25862602
int id, int mlevel)
25872603
{
2588-
if (not ortho)
2604+
if (not ortho) {
25892605
ortho = std::make_shared<SLGridSph>
25902606
(make_sl(), LMAX, NMAX, NUMR, RMIN, RMAX*0.99, false, 1, 1.0);
2607+
2608+
orthoTest(ortho->orthoCheck(std::max<int>(NMAX*50, 200)), "EmpCylSL[SLGridSph]", "l");
2609+
}
2610+
25912611
if (eof_made) {
25922612
if (VFLAG & 2)
25932613
cerr << "accumulate_eof: Process " << setw(4) << myid << ", Thread "

pyEXP/BasisWrappers.cc

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -928,8 +928,25 @@ void BasisFactoryClasses(py::module &m)
928928
Returns
929929
-------
930930
None
931-
)");
931+
)")
932+
.def("getFieldLabels",
933+
[](BasisClasses::Basis& A)
934+
{
935+
return A.getFieldLabels();
936+
},
937+
R"(
938+
Provide the field labels for the basis functions
932939
940+
Parameters
941+
----------
942+
None
943+
944+
Returns
945+
-------
946+
list: str
947+
list of basis function labels
948+
)"
949+
);
933950

934951
py::class_<BasisClasses::BiorthBasis, std::shared_ptr<BasisClasses::BiorthBasis>, PyBiorthBasis, BasisClasses::Basis>
935952
(m, "BiorthBasis")
@@ -1279,9 +1296,9 @@ void BasisFactoryClasses(py::module &m)
12791296
// orthoCheck is not in the base class and needs to have different
12801297
// parameters depending on the basis type. Here, the quadrature
12811298
// is determined by the scale of the meridional grid.
1282-
.def("orthoCheck", [](BasisClasses::Cylindrical& A)
1299+
.def("orthoCheck", [](BasisClasses::Cylindrical& A, int knots)
12831300
{
1284-
return A.orthoCheck();
1301+
return A.orthoCheck(knots);
12851302
},
12861303
R"(
12871304
Check orthgonality of basis functions by quadrature
@@ -1298,7 +1315,7 @@ void BasisFactoryClasses(py::module &m)
12981315
-------
12991316
list(numpy.ndarray)
13001317
list of numpy.ndarrays from [0, ... , Mmax]
1301-
)")
1318+
)", py::arg("knots")=400)
13021319
.def_static("cacheInfo", [](std::string cachefile)
13031320
{
13041321
return BasisClasses::Cylindrical::cacheInfo(cachefile);
@@ -2044,11 +2061,11 @@ void BasisFactoryClasses(py::module &m)
20442061
// orthoCheck is not in the base class and needs to have
20452062
// different parameters depending on the basis type. Here the
20462063
// user can and will often need to specify a quadrature value.
2047-
.def("orthoCheck", [](BasisClasses::FieldBasis& A)
2064+
.def("orthoCheck", [](BasisClasses::FieldBasis& A)
20482065
{
20492066
return A.orthoCheck();
20502067
},
2051-
R"(
2068+
R"(
20522069
Check orthgonality of basis functions by quadrature
20532070
20542071
Inner-product matrix of orthogonal functions
@@ -2062,10 +2079,10 @@ void BasisFactoryClasses(py::module &m)
20622079
numpy.ndarray
20632080
orthogonality matrix
20642081
)"
2065-
);
2082+
);
20662083

20672084
py::class_<BasisClasses::VelocityBasis, std::shared_ptr<BasisClasses::VelocityBasis>, BasisClasses::FieldBasis>(m, "VelocityBasis")
2068-
.def(py::init<const std::string&>(),
2085+
.def(py::init<const std::string&>(),
20692086
R"(
20702087
Create a orthogonal velocity-field basis
20712088

pyEXP/TensorToArray.H

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ py::array_t<T> make_ndarray3(Eigen::Tensor<T, 3>& mat)
1111
// Check rank
1212
if (dims.size() != 3) {
1313
std::ostringstream sout;
14-
sout << "make_ndarray: tensor rank must be 3, found " << dims.size();
14+
sout << "make_ndarray3: tensor rank must be 3, found " << dims.size();
1515
throw std::runtime_error(sout.str());
1616
}
1717

@@ -37,10 +37,17 @@ py::array_t<T> make_ndarray4(Eigen::Tensor<T, 4>& mat)
3737
// Check rank
3838
if (dims.size() != 4) {
3939
std::ostringstream sout;
40-
sout << "make_ndarray: tensor rank must be 4, found " << dims.size();
40+
sout << "make_ndarray4: tensor rank must be 4, found " << dims.size();
4141
throw std::runtime_error(sout.str());
4242
}
4343

44+
// Sanity check
45+
for (int i=0; i<mat.size(); i++) {
46+
if (isnan(std::abs(mat.data()[i]))) {
47+
throw std::runtime_error("make_ndarray4: NaN encountered");
48+
}
49+
}
50+
4451
// Make the memory mapping
4552
return py::array_t<T>
4653
(
@@ -107,11 +114,11 @@ Eigen::Tensor<T, 4> make_tensor4(py::array_t<T> array)
107114

108115
// Build result tensor with col-major ordering
109116
Eigen::Tensor<T, 4> tensor(shape[0], shape[1], shape[2], shape[3]);
110-
for (int i=0, l=0; i < shape[0]; i++) {
117+
for (int i=0, c=0; i < shape[0]; i++) {
111118
for (int j=0; j < shape[1]; j++) {
112119
for (int k=0; k < shape[2]; k++) {
113-
for (int l=0; l < shape[3]; k++) {
114-
tensor(i, j, k, l) = data[l++];
120+
for (int l=0; l < shape[3]; l++, c++) {
121+
tensor(i, j, k, l) = data[c];
115122
}
116123
}
117124
}

0 commit comments

Comments
 (0)