Skip to content

Commit 44394e3

Browse files
committed
use lazy eval array context for setup actx
1 parent 1aa585f commit 44394e3

File tree

1 file changed

+12
-9
lines changed

1 file changed

+12
-9
lines changed

test/test_pytato_transforms.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def _map_index_base(self, expr: IndexBase) -> bool:
3737
from grudge.pytato_transforms.pytato_indirection_transforms import (
3838
_is_materialized)
3939
return self.combine(
40-
_is_materialized(expr.array) or isinstance(expr.array, BasicIndex),
40+
_is_materialized(expr.array) or isinstance(expr, BasicIndex),
4141
self.rec(expr.array)
4242
)
4343

@@ -227,6 +227,7 @@ def _compute_flux_2(dcoll, actx, u):
227227
normal_on_bdry_faces = actx.thaw(dcoll.normal(BTAG_ALL))
228228
flux_on_interior_faces = u_interior_tpair.avg * normal_on_interior_faces
229229
flux_on_bdry = op.project(dcoll, "vol", BTAG_ALL, u) * normal_on_bdry_faces
230+
230231
flux_on_all_faces = (
231232
op.project(dcoll,
232233
FACE_RESTR_INTERIOR,
@@ -245,21 +246,23 @@ def test_resampling_indirections_are_fused_2(ctx_factory):
245246
cl_ctx = ctx_factory()
246247
cq = cl.CommandQueue(cl_ctx)
247248

248-
ref_actx = PyOpenCLArrayContext(cq)
249+
from grudge.array_context import get_reasonable_array_context_class
250+
251+
ref_actx = get_reasonable_array_context_class(lazy=True, distributed=False)(cq)
249252
actx = FluxOptimizerActx(cq)
250253

251-
dim = 2
252-
nel_1d = 4
254+
dim = 3
255+
nel_1d = 16
256+
order = 4
253257
mesh = generate_regular_rect_mesh(
254258
a=(-0.5,)*dim,
255259
b=(0.5,)*dim,
256260
nelements_per_axis=(nel_1d,)*dim,
257-
boundary_tag_to_face={"bdry": ["-x", "+x",
258-
"-y", "+y"]}
259261
)
260-
dcoll = grudge.make_discretization_collection(ref_actx, mesh, order=2)
261-
262-
x, _ = dcoll.nodes()
262+
dcoll = grudge.make_discretization_collection(
263+
ref_actx, mesh,
264+
order=order)
265+
x, _, _ = dcoll.nodes()
263266
compiled_flux_2 = actx.compile(lambda ary: _compute_flux_2(dcoll, actx, ary))
264267

265268
ref_output = ref_actx.to_numpy(

0 commit comments

Comments
 (0)