Skip to content

Commit e6d5f17

Browse files
committed
Add test_random_search
1 parent cf23d5b commit e6d5f17

File tree

3 files changed

+19
-4
lines changed

3 files changed

+19
-4
lines changed

src/pso.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@ population number of particles in swarm
2424
2525
Returns
2626
-------
27+
f(X) objective corresponding to best particle seen so far
2728
X best particle seen so far
28-
Returns objective corresponding to best particle seen so far
29+
2930
3031
See also
3132
--------

src/random_search.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#include <npe.h>
22
#include <typedefs.h>
3-
3+
#include <pybind11/functional.h>
44
#include <pybind11/stl.h>
55

66

@@ -25,8 +25,8 @@ iters number of iterations
2525
2626
Returns
2727
-------
28+
f(X)
2829
X #X optimal parameter vector
29-
Returns f(X)
3030
3131
See also
3232
--------
@@ -57,4 +57,3 @@ npe_begin_code()
5757
return std::make_tuple(obj, npe::move(x));
5858

5959
npe_end_code()
60-

tests/test_basic.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2045,6 +2045,21 @@ def banana(x):
20452045
self.assertTrue(xopt.dtype == lb.dtype)
20462046
self.assertTrue(xopt.shape == (2, ))
20472047

2048+
def test_random_search(self):
2049+
def banana(x):
2050+
x1 = x[0]
2051+
x2 = x[1]
2052+
return x1**4 - 2*x2*x1**2 + x2**2 + x1**2 - 2*x1 + 5
2053+
2054+
lb = np.array([-3.0, -1.0])
2055+
ub = np.array([2.0, 6.0])
2056+
2057+
fopt, xopt = igl.random_search(banana, lb, ub, iters=10)
2058+
2059+
self.assertTrue(xopt.flags.c_contiguous)
2060+
self.assertTrue(xopt.dtype == lb.dtype)
2061+
self.assertTrue(xopt.shape == (2, ))
2062+
20482063

20492064
if __name__ == '__main__':
20502065
unittest.main()

0 commit comments

Comments
 (0)