Skip to content

Commit 5cb55f4

Browse files
committed
port examples to arraycontext
1 parent 5559945 commit 5cb55f4

File tree

3 files changed

+85
-51
lines changed

3 files changed

+85
-51
lines changed

examples/curve-pot.py

Lines changed: 49 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
1-
import pyopencl as cl
21
import numpy as np
32
import numpy.linalg as la
43

4+
import pyopencl as cl
5+
56
try:
67
import matplotlib.pyplot as plt
7-
except ModuleNotFoundError:
8-
plt = None
8+
USE_MATPLOTLIB = True
9+
except ImportError:
10+
USE_MATPLOTLIB = False
911

1012
try:
1113
from mayavi import mlab
12-
except ModuleNotFoundError:
13-
mlab = None
14+
USE_MAYAVI = True
15+
except ImportError:
16+
USE_MAYAVI = False
17+
18+
import logging
19+
logging.basicConfig(level=logging.INFO)
1420

1521

1622
def process_kernel(knl, what_operator):
@@ -45,17 +51,16 @@ def draw_pot_figure(aspect_ratio,
4551
ovsmp_center_exp=0.66,
4652
force_center_side=None):
4753

48-
import logging
49-
logging.basicConfig(level=logging.INFO)
50-
5154
if novsmp is None:
5255
novsmp = 4*nsrc
5356

5457
if what_operator_lpot is None:
5558
what_operator_lpot = what_operator
5659

60+
from sumpy.array_context import PyOpenCLArrayContext
5761
ctx = cl.create_some_context()
5862
queue = cl.CommandQueue(ctx)
63+
actx = PyOpenCLArrayContext(queue, force_device_scalars=True)
5964

6065
# {{{ make plot targets
6166

@@ -86,16 +91,18 @@ def draw_pot_figure(aspect_ratio,
8691
knl_kwargs = {}
8792

8893
vol_source_knl, vol_target_knl = process_kernel(knl, what_operator)
89-
p2p = P2P(ctx, source_kernels=(vol_source_knl,),
94+
p2p = P2P(actx.context, source_kernels=(vol_source_knl,),
9095
target_kernels=(vol_target_knl,),
9196
exclude_self=False,
9297
value_dtypes=np.complex128)
9398

9499
lpot_source_knl, lpot_target_knl = process_kernel(knl, what_operator_lpot)
95100

96101
from sumpy.qbx import LayerPotential
97-
lpot = LayerPotential(ctx, expansion=expn_class(knl, order=order),
98-
source_kernels=(lpot_source_knl,), target_kernels=(lpot_target_knl,),
102+
lpot = LayerPotential(actx.context,
103+
expansion=expn_class(knl, order=order),
104+
source_kernels=(lpot_source_knl,),
105+
target_kernels=(lpot_target_knl,),
99106
value_dtypes=np.complex128)
100107

101108
# }}}
@@ -142,8 +149,9 @@ def map_to_curve(t):
142149
+ center_side[:, np.newaxis]
143150
* center_dist*native_curve.normal)
144151

145-
#native_curve.plot()
146-
#plt.show()
152+
if 0:
153+
native_curve.plot()
154+
plt.show()
147155

148156
volpot_kwargs = knl_kwargs.copy()
149157
lpot_kwargs = knl_kwargs.copy()
@@ -169,7 +177,9 @@ def map_to_curve(t):
169177

170178
def apply_lpot(x):
171179
xovsmp = np.dot(fim, x)
172-
evt, (y,) = lpot(queue, native_curve.pos, ovsmp_curve.pos,
180+
evt, (y,) = lpot(actx.queue,
181+
native_curve.pos,
182+
ovsmp_curve.pos,
173183
centers,
174184
[xovsmp * ovsmp_curve.speed * ovsmp_weights],
175185
expansion_radii=np.ones(centers.shape[1]),
@@ -191,18 +201,22 @@ def apply_lpot(x):
191201
mode_nr = 0
192202
density = np.cos(mode_nr*2*np.pi*native_t).astype(np.complex128)
193203
ovsmp_density = np.cos(mode_nr*2*np.pi*ovsmp_t).astype(np.complex128)
194-
evt, (vol_pot,) = p2p(queue, fp.points, native_curve.pos,
204+
evt, (vol_pot,) = p2p(actx.queue,
205+
fp.points,
206+
native_curve.pos,
195207
[native_curve.speed*native_weights*density], **volpot_kwargs)
196208

197-
evt, (curve_pot,) = lpot(queue, native_curve.pos, ovsmp_curve.pos,
209+
evt, (curve_pot,) = lpot(actx.queue,
210+
native_curve.pos,
211+
ovsmp_curve.pos,
198212
centers,
199213
[ovsmp_density * ovsmp_curve.speed * ovsmp_weights],
200214
expansion_radii=np.ones(centers.shape[1]),
201215
**lpot_kwargs)
202216

203217
# }}}
204218

205-
if 0:
219+
if USE_MATPLOTLIB:
206220
# {{{ plot on-surface potential in 2D
207221

208222
plt.plot(curve_pot, label="pot")
@@ -216,7 +230,7 @@ def apply_lpot(x):
216230
("potential", vol_pot.real)
217231
])
218232

219-
if 0:
233+
if USE_MATPLOTLIB:
220234
# {{{ 2D false-color plot
221235

222236
plt.clf()
@@ -230,12 +244,8 @@ def apply_lpot(x):
230244
# close the curve
231245
plt.plot(src[-1::-len(src)+1, 0], src[-1::-len(src)+1, 1], "o-k")
232246

233-
#plt.gca().set_aspect("equal", "datalim")
234247
cb = plt.colorbar(shrink=0.9)
235248
cb.set_label(r"$\log_{10}(\mathdefault{Error})$")
236-
#from matplotlib.ticker import NullFormatter
237-
#plt.gca().xaxis.set_major_formatter(NullFormatter())
238-
#plt.gca().yaxis.set_major_formatter(NullFormatter())
239249
fp.set_matplotlib_limits()
240250

241251
# }}}
@@ -261,7 +271,7 @@ def apply_lpot(x):
261271
plotval_vol[outlier_flag] = sum(
262272
nb[outlier_flag] for nb in neighbors)/len(neighbors)
263273

264-
if mlab is not None:
274+
if USE_MAYAVI:
265275
fp.show_scalar_in_mayavi(scale*plotval_vol, max_val=1)
266276
mlab.colorbar()
267277
if 1:
@@ -275,17 +285,23 @@ def apply_lpot(x):
275285

276286

277287
if __name__ == "__main__":
278-
draw_pot_figure(aspect_ratio=1, nsrc=100, novsmp=100, helmholtz_k=(35+4j)*0.3,
288+
draw_pot_figure(
289+
aspect_ratio=1, nsrc=100, novsmp=100, helmholtz_k=(35+4j)*0.3,
279290
what_operator="D", what_operator_lpot="D", force_center_side=1)
291+
if USE_MATPLOTLIB:
292+
plt.savefig("eigvals-ext-nsrc100-novsmp100.pdf")
293+
plt.clf()
280294

281-
# plt.savefig("eigvals-ext-nsrc100-novsmp100.pdf")
282-
#plt.clf()
283-
#draw_pot_figure(aspect_ratio=1, nsrc=100, novsmp=100, helmholtz_k=0,
284-
# what_operator="D", what_operator_lpot="D", force_center_side=-1)
285-
#plt.savefig("eigvals-int-nsrc100-novsmp100.pdf")
286-
#plt.clf()
287-
#draw_pot_figure(aspect_ratio=1, nsrc=100, novsmp=200, helmholtz_k=0,
288-
# what_operator="D", what_operator_lpot="D", force_center_side=-1)
289-
#plt.savefig("eigvals-int-nsrc100-novsmp200.pdf")
295+
# draw_pot_figure(
296+
# aspect_ratio=1, nsrc=100, novsmp=100, helmholtz_k=0,
297+
# what_operator="D", what_operator_lpot="D", force_center_side=-1)
298+
# plt.savefig("eigvals-int-nsrc100-novsmp100.pdf")
299+
# plt.clf()
300+
301+
# draw_pot_figure(
302+
# aspect_ratio=1, nsrc=100, novsmp=200, helmholtz_k=0,
303+
# what_operator="D", what_operator_lpot="D", force_center_side=-1)
304+
# plt.savefig("eigvals-int-nsrc100-novsmp200.pdf")
305+
# plt.clf()
290306

291307
# vim: fdm=marker

examples/expansion-toys.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,32 @@
1+
import numpy as np
2+
13
import pyopencl as cl
4+
25
import sumpy.toys as t
3-
import numpy as np
46
from sumpy.visualization import FieldPlotter
7+
from sumpy.kernel import ( # noqa: F401
8+
YukawaKernel,
9+
HelmholtzKernel,
10+
LaplaceKernel)
11+
512
try:
613
import matplotlib.pyplot as plt
7-
except ModuleNotFoundError:
8-
plt = None
14+
USE_MATPLOTLIB = True
15+
except ImportError:
16+
USE_MATPLOTLIB = False
917

1018

1119
def main():
12-
from sumpy.kernel import ( # noqa: F401
13-
YukawaKernel, HelmholtzKernel, LaplaceKernel)
20+
from sumpy.array_context import PyOpenCLArrayContext
21+
ctx = cl.create_some_context()
22+
queue = cl.CommandQueue(ctx)
23+
actx = PyOpenCLArrayContext(queue, force_device_scalars=True)
24+
1425
tctx = t.ToyContext(
15-
cl.create_some_context(),
16-
#LaplaceKernel(2),
26+
actx.context,
27+
# LaplaceKernel(2),
1728
YukawaKernel(2), extra_kernel_kwargs={"lam": 5},
18-
#HelmholtzKernel(2), extra_kernel_kwargs={"k": 0.3},
29+
# HelmholtzKernel(2), extra_kernel_kwargs={"k": 0.3},
1930
)
2031

2132
pt_src = t.PointSources(
@@ -25,7 +36,7 @@ def main():
2536

2637
fp = FieldPlotter([3, 0], extent=8)
2738

28-
if 0 and plt is not None:
39+
if USE_MATPLOTLIB:
2940
t.logplot(fp, pt_src, cmap="jet")
3041
plt.colorbar()
3142
plt.show()
@@ -35,12 +46,12 @@ def main():
3546
lexp = t.local_expand(mexp, [3, 0])
3647
lexp2 = t.local_expand(lexp, [3, 1], 3)
3748

38-
#diff = mexp - pt_src
39-
#diff = mexp2 - pt_src
49+
# diff = mexp - pt_src
50+
# diff = mexp2 - pt_src
4051
diff = lexp2 - pt_src
4152

4253
print(t.l_inf(diff, 1.2, center=lexp2.center))
43-
if 1 and plt is not None:
54+
if USE_MATPLOTLIB:
4455
t.logplot(fp, diff, cmap="jet", vmin=-3, vmax=0)
4556
plt.colorbar()
4657
plt.show()

examples/sym-exp-complexity.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import numpy as np
2-
import pyopencl as cl
32
import loopy as lp
3+
4+
import pyopencl as cl
5+
46
from sumpy.kernel import LaplaceKernel, HelmholtzKernel
57
from sumpy.expansion.local import (
68
LinearPDEConformingVolumeTaylorLocalExpansion,
@@ -9,14 +11,19 @@
911
LinearPDEConformingVolumeTaylorMultipoleExpansion,
1012
)
1113
from sumpy.e2e import E2EFromCSR
14+
1215
try:
1316
import matplotlib.pyplot as plt
14-
except ModuleNotFoundError:
15-
plt = None
17+
USE_MATPLOTLIB = True
18+
except ImportError:
19+
USE_MATPLOTLIB = False
1620

1721

1822
def find_flops():
23+
from sumpy.array_context import PyOpenCLArrayContext
1924
ctx = cl.create_some_context()
25+
queue = cl.CommandQueue(ctx)
26+
actx = PyOpenCLArrayContext(queue, force_device_scalars=True)
2027

2128
if 0:
2229
knl = LaplaceKernel(2)
@@ -35,7 +42,7 @@ def find_flops():
3542
print(order)
3643
m_expn = m_expn_cls(knl, order)
3744
l_expn = l_expn_cls(knl, order)
38-
m2l = E2EFromCSR(ctx, m_expn, l_expn)
45+
m2l = E2EFromCSR(actx.context, m_expn, l_expn)
3946

4047
loopy_knl = m2l.get_kernel()
4148
loopy_knl = lp.add_and_infer_dtypes(
@@ -74,7 +81,7 @@ def plot_flops():
7481
flops = [45, 194, 474, 931, 1650, 2632, 3925, 5591, 7706, 10272]
7582
filename = "helmholtz-m2l-complexity-2d.pdf"
7683

77-
if plt is not None:
84+
if USE_MATPLOTLIB:
7885
plt.rc("font", size=16)
7986
plt.title(case)
8087
plt.ylabel("Flop count")
@@ -86,5 +93,5 @@ def plot_flops():
8693

8794

8895
if __name__ == "__main__":
89-
#find_flops()
96+
# find_flops()
9097
plot_flops()

0 commit comments

Comments
 (0)