Skip to content

Commit 610d006

Browse files
committed
tests pick list fusion
1 parent 3e1b5a1 commit 610d006

File tree

1 file changed

+224
-0
lines changed

1 file changed

+224
-0
lines changed

test/test_pytato_transforms.py

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
import numpy as np # noqa: F401
2+
import pyopencl as cl
3+
from typing import Union
4+
from meshmode.mesh import BTAG_ALL
5+
from meshmode.mesh.generation import generate_regular_rect_mesh
6+
from arraycontext.metadata import NameHint
7+
from meshmode.array_context import (PytatoPyOpenCLArrayContext,
8+
PyOpenCLArrayContext)
9+
from pytato.transform import CombineMapper
10+
from pytato.array import (Placeholder, DataWrapper, SizeParam, IndexBase,
11+
Array, DictOfNamedArrays)
12+
from meshmode.discretization.connection import (FACE_RESTR_INTERIOR,
13+
FACE_RESTR_ALL)
14+
from pytools.obj_array import make_obj_array
15+
from pyopencl.tools import ( # noqa
16+
pytest_generate_tests_for_pyopencl as pytest_generate_tests)
17+
import grudge
18+
import grudge.op as op
19+
20+
21+
# {{{ utilities for test_push_indirections_*
22+
23+
class _IndexeeArraysMaterializedChecker(CombineMapper[bool]):
24+
def combine(self, *args: bool) -> bool:
25+
return all(args)
26+
27+
def map_placeholder(self, expr: Placeholder) -> bool:
28+
return True
29+
30+
def map_data_wrapper(self, expr: DataWrapper) -> bool:
31+
return True
32+
33+
def map_size_param(self, expr: SizeParam) -> bool:
34+
return True
35+
36+
def _map_index_base(self, expr: IndexBase) -> bool:
37+
from pytato.transform.indirections import _is_materialized
38+
return self.combine(
39+
_is_materialized(expr.array) or isinstance(expr.array, IndexBase),
40+
self.rec(expr.array)
41+
)
42+
43+
44+
def are_all_indexees_materialized_nodes(
45+
expr: Union[Array, DictOfNamedArrays]) -> bool:
46+
"""
47+
Returns *True* only if all indexee arrays are either materialized nodes,
48+
OR, other indexing nodes that have materialized indexees.
49+
"""
50+
return _IndexeeArraysMaterializedChecker()(expr)
51+
52+
53+
class _IndexerArrayDatawrapperChecker(CombineMapper[bool]):
54+
def combine(self, *args: bool) -> bool:
55+
return all(args)
56+
57+
def map_placeholder(self, expr: Placeholder) -> bool:
58+
return True
59+
60+
def map_data_wrapper(self, expr: DataWrapper) -> bool:
61+
return True
62+
63+
def map_size_param(self, expr: SizeParam) -> bool:
64+
return True
65+
66+
def _map_index_base(self, expr: IndexBase) -> bool:
67+
return self.combine(
68+
*[isinstance(idx, DataWrapper)
69+
for idx in expr.indices
70+
if isinstance(idx, Array)],
71+
super()._map_index_base(expr),
72+
)
73+
74+
75+
def are_all_indexer_arrays_datawrappers(
76+
expr: Union[Array, DictOfNamedArrays]) -> bool:
77+
"""
78+
Returns *True* only if all indexer arrays are instances of
79+
:class:`~pytato.array.DataWrapper`.
80+
"""
81+
return _IndexerArrayDatawrapperChecker()(expr)
82+
83+
# }}}
84+
85+
86+
def _evaluate_dict_of_named_arrays(actx, dict_of_named_arrays):
87+
container = make_obj_array([dict_of_named_arrays._data[name]
88+
for name in sorted(dict_of_named_arrays.keys())])
89+
90+
evaluated_container = actx.thaw(actx.freeze(container))
91+
92+
return {name: evaluated_container[i]
93+
for i, name in enumerate(sorted(dict_of_named_arrays.keys()))}
94+
95+
96+
class FluxOptimizerActx(PytatoPyOpenCLArrayContext):
97+
def __init__(self, *args, **kwargs):
98+
super().__init__(*args, **kwargs)
99+
self.check_completed = False
100+
101+
def transform_dag(self, dag):
102+
from grudge.pytato_transforms.pytato_indirection_transforms import (
103+
fuse_dof_pick_lists, fold_constant_indirections)
104+
from pytato.tags import PrefixNamed
105+
106+
if (
107+
len(dag) == 1
108+
and PrefixNamed("flux_container") in list(dag._data.values())[0].tags
109+
):
110+
assert not are_all_indexer_arrays_datawrappers(dag)
111+
self.check_completed = True
112+
113+
dag = fuse_dof_pick_lists(dag)
114+
dag = fold_constant_indirections(
115+
dag, lambda x: _evaluate_dict_of_named_arrays(self, x))
116+
117+
if (
118+
len(dag) == 1
119+
and PrefixNamed("flux_container") in list(dag._data.values())[0].tags
120+
):
121+
assert are_all_indexer_arrays_datawrappers(dag)
122+
self.check_completed = True
123+
124+
return dag
125+
126+
127+
# {{{ test_resampling_indirections_are_fused_0
128+
129+
def _compute_flux_0(dcoll, actx, u):
130+
u_interior_tpair, = op.interior_trace_pairs(dcoll, u)
131+
flux_on_interior_faces = u_interior_tpair.avg
132+
flux_on_all_faces = op.project(
133+
dcoll, FACE_RESTR_INTERIOR, FACE_RESTR_ALL, flux_on_interior_faces)
134+
135+
flux_on_all_faces = actx.tag(NameHint("flux_container"), flux_on_all_faces)
136+
return flux_on_all_faces
137+
138+
139+
def test_resampling_indirections_are_fused_0(ctx_factory):
140+
cl_ctx = ctx_factory()
141+
cq = cl.CommandQueue(cl_ctx)
142+
143+
ref_actx = PyOpenCLArrayContext(cq)
144+
actx = FluxOptimizerActx(cq)
145+
146+
dim = 3
147+
nel_1d = 4
148+
mesh = generate_regular_rect_mesh(
149+
a=(-0.5,)*dim,
150+
b=(0.5,)*dim,
151+
nelements_per_axis=(nel_1d,)*dim,
152+
boundary_tag_to_face={"bdry": ["-x", "+x",
153+
"-y", "+y",
154+
"-z", "+z"]}
155+
)
156+
dcoll = grudge.make_discretization_collection(ref_actx, mesh, order=2)
157+
158+
x, _, _ = dcoll.nodes()
159+
compiled_flux_0 = actx.compile(lambda ary: _compute_flux_0(dcoll, actx, ary))
160+
161+
ref_output = ref_actx.to_numpy(
162+
_compute_flux_0(dcoll, ref_actx, ref_actx.thaw(x)))
163+
output = actx.to_numpy(
164+
compiled_flux_0(actx.thaw(x)))
165+
166+
np.testing.assert_allclose(ref_output[0], output[0])
167+
assert actx.check_completed
168+
169+
# }}}
170+
171+
172+
# {{{ test_resampling_indirections_are_fused_1
173+
174+
def _compute_flux_1(dcoll, actx, u):
175+
u_interior_tpair, = op.interior_trace_pairs(dcoll, u)
176+
flux_on_interior_faces = u_interior_tpair.avg
177+
flux_on_bdry = op.project(dcoll, "vol", BTAG_ALL, u)
178+
flux_on_all_faces = (
179+
op.project(dcoll,
180+
FACE_RESTR_INTERIOR,
181+
FACE_RESTR_ALL,
182+
flux_on_interior_faces)
183+
+ op.project(dcoll, BTAG_ALL, FACE_RESTR_ALL, flux_on_bdry)
184+
)
185+
186+
result = op.inverse_mass(dcoll, op.face_mass(dcoll, flux_on_all_faces))
187+
188+
result = actx.tag(NameHint("flux_container"), result)
189+
return result
190+
191+
192+
def test_resampling_indirections_are_fused_1(ctx_factory):
193+
cl_ctx = ctx_factory()
194+
cq = cl.CommandQueue(cl_ctx)
195+
196+
ref_actx = PyOpenCLArrayContext(cq)
197+
actx = FluxOptimizerActx(cq)
198+
199+
dim = 3
200+
nel_1d = 4
201+
mesh = generate_regular_rect_mesh(
202+
a=(-0.5,)*dim,
203+
b=(0.5,)*dim,
204+
nelements_per_axis=(nel_1d,)*dim,
205+
boundary_tag_to_face={"bdry": ["-x", "+x",
206+
"-y", "+y",
207+
"-z", "+z"]}
208+
)
209+
dcoll = grudge.make_discretization_collection(ref_actx, mesh, order=2)
210+
211+
x, _, _ = dcoll.nodes()
212+
compiled_flux_1 = actx.compile(lambda ary: _compute_flux_1(dcoll, actx, ary))
213+
214+
ref_output = ref_actx.to_numpy(
215+
_compute_flux_1(dcoll, ref_actx, ref_actx.thaw(x)))
216+
output = actx.to_numpy(
217+
compiled_flux_1(actx.thaw(x)))
218+
219+
np.testing.assert_allclose(ref_output[0], output[0])
220+
assert actx.check_completed
221+
222+
# }}}
223+
224+
# vim: fdm=marker

0 commit comments

Comments
 (0)