Skip to content

Commit cf23d5b

Browse files
committed
Fix PSO, add test
1 parent f3534f8 commit cf23d5b

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

src/pso.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <npe.h>
22
#include <common.h>
3+
#include <pybind11/functional.h>
34
#include <typedefs.h>
45
#include <pybind11/stl.h>
56
#include <igl/pso.h>
@@ -113,5 +114,3 @@ npe_begin_code()
113114
return std::make_tuple(obj, npe::move(x));
114115

115116
npe_end_code()
116-
117-

tests/test_basic.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1081,7 +1081,7 @@ def test_signed_distance(self):
10811081
max_v = np.max(self.v1, axis=0)
10821082
n = 16
10831083
g = np.mgrid[min_v[0]:max_v[0]:complex(n), min_v[1]:max_v[1]:complex(n), min_v[2]:max_v[2]:complex(n)]
1084-
p = np.vstack(map(np.ravel, g)).T
1084+
p = np.vstack(list(map(np.ravel, g))).T
10851085
s, i, c = igl.signed_distance(p, self.v1, self.f1)
10861086

10871087
self.assertEqual(s.shape[0], p.shape[0])
@@ -2029,6 +2029,21 @@ def test_connected_components(self):
20292029
self.assertTrue(k.dtype == self.f1.dtype)
20302030

20312031
self.assertTrue(c.shape[0] == a.shape[0])
2032+
2033+
def test_pso(self):
2034+
def banana(x):
2035+
x1 = x[0]
2036+
x2 = x[1]
2037+
return x1**4 - 2*x2*x1**2 + x2**2 + x1**2 - 2*x1 + 5
2038+
2039+
lb = np.array([-3.0, -1.0])
2040+
ub = np.array([2.0, 6.0])
2041+
2042+
fopt, xopt = igl.pso(banana, lb, ub, max_iters=10, population=10)
2043+
2044+
self.assertTrue(xopt.flags.c_contiguous)
2045+
self.assertTrue(xopt.dtype == lb.dtype)
2046+
self.assertTrue(xopt.shape == (2, ))
20322047

20332048

20342049
if __name__ == '__main__':

0 commit comments

Comments
 (0)