Skip to content

Commit d22b588

Browse files
committed
add axes for reshapes, etc. in direction connection
1 parent e4685ed commit d22b588

File tree

2 files changed

+52
-25
lines changed

2 files changed

+52
-25
lines changed

meshmode/discretization/connection/direct.py

Lines changed: 42 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@
3030
import loopy as lp
3131
from meshmode.transform_metadata import (
3232
ConcurrentElementInameTag, ConcurrentDOFInameTag,
33-
DiscretizationElementAxisTag, DiscretizationDOFAxisTag)
33+
DiscretizationElementAxisTag, DiscretizationDOFAxisTag,
34+
DiscretizationDOFPickListAxisTag)
3435
from pytools import memoize_in, keyed_memoize_method
3536
from arraycontext import (
3637
ArrayContext, ArrayT, ArrayOrContainerT, NotAnArrayContainerError,
@@ -547,17 +548,22 @@ def _per_target_group_pick_info(
547548
_FromGroupPickData(
548549
from_group_index=source_group_index,
549550
dof_pick_lists=actx.freeze(
550-
actx.tag(NameHint("dof_pick_lists"),
551-
actx.from_numpy(dof_pick_lists))),
551+
actx.tag_axis(0, DiscretizationDOFPickListAxisTag(),
552+
actx.tag(NameHint("dof_pick_lists"),
553+
actx.from_numpy(dof_pick_lists)))),
552554
dof_pick_list_indices=actx.freeze(
553-
actx.tag(NameHint("dof_pick_list_indices"),
554-
actx.from_numpy(dof_pick_list_indices))),
555+
actx.tag_axis(0, DiscretizationElementAxisTag(),
556+
actx.tag(NameHint("dof_pick_list_indices"),
557+
actx.from_numpy(dof_pick_list_indices)))),
555558
from_el_present=actx.freeze(
556-
actx.tag(NameHint("from_el_present"),
557-
actx.from_numpy(from_el_present.astype(np.int8)))),
559+
actx.tag_axis(0, DiscretizationElementAxisTag(),
560+
actx.tag(NameHint("from_el_present"),
561+
actx.from_numpy(
562+
from_el_present.astype(np.int8))))),
558563
from_element_indices=actx.freeze(
559-
actx.tag(NameHint("from_el_indices"),
560-
actx.from_numpy(from_el_indices))),
564+
actx.tag_axis(0, DiscretizationElementAxisTag(),
565+
actx.tag(NameHint("from_el_indices"),
566+
actx.from_numpy(from_el_indices)))),
561567
is_surjective=from_el_present.all()
562568
))
563569

@@ -726,25 +732,27 @@ def group_pick_knl(is_surjective: bool):
726732
group_pick_info = None
727733

728734
if group_pick_info is not None:
729-
group_array_contributions = []
730-
731735
if actx.permits_advanced_indexing and not _force_use_loopy:
732736
for fgpd in group_pick_info:
733737
from_element_indices = actx.thaw(fgpd.from_element_indices)
734738

735739
if ary[fgpd.from_group_index].size:
736740
grp_ary_contrib = ary[fgpd.from_group_index][
741+
tag_axes(actx, {
742+
1: DiscretizationDOFAxisTag()},
737743
_reshape_and_preserve_tags(
738-
actx, from_element_indices, (-1, 1)),
739-
actx.thaw(fgpd.dof_pick_lists)[
740-
actx.thaw(fgpd.dof_pick_list_indices)]
741-
]
744+
actx, from_element_indices, (-1, 1))),
745+
actx.thaw(fgpd.dof_pick_lists)[
746+
actx.thaw(fgpd.dof_pick_list_indices)]
747+
]
742748

743749
if not fgpd.is_surjective:
744750
from_el_present = actx.thaw(fgpd.from_el_present)
745751
grp_ary_contrib = actx.np.where(
746-
_reshape_and_preserve_tags(
747-
actx, from_el_present, (-1, 1)),
752+
tag_axes(actx, {
753+
1: DiscretizationDOFAxisTag()},
754+
_reshape_and_preserve_tags(
755+
actx, from_el_present, (-1, 1))),
748756
grp_ary_contrib,
749757
0)
750758

@@ -794,8 +802,10 @@ def group_pick_knl(is_surjective: bool):
794802
mat = self._resample_matrix(actx, i_tgrp, i_batch)
795803
if actx.permits_advanced_indexing and not _force_use_loopy:
796804
batch_result = actx.np.where(
797-
_reshape_and_preserve_tags(
798-
actx, from_el_present, (-1, 1)),
805+
tag_axes(actx, {
806+
1: DiscretizationDOFAxisTag()},
807+
_reshape_and_preserve_tags(
808+
actx, from_el_present, (-1, 1))),
799809
actx.einsum("ij,ej->ei",
800810
mat, grp_ary[from_element_indices]),
801811
0)
@@ -816,11 +826,15 @@ def group_pick_knl(is_surjective: bool):
816826

817827
if actx.permits_advanced_indexing and not _force_use_loopy:
818828
batch_result = actx.np.where(
819-
_reshape_and_preserve_tags(
820-
actx, from_el_present, (-1, 1)),
821-
from_vec[
829+
tag_axes(actx, {
830+
1: DiscretizationDOFAxisTag()},
822831
_reshape_and_preserve_tags(
823-
actx, from_element_indices, (-1, 1)),
832+
actx, from_el_present, (-1, 1))),
833+
from_vec[
834+
tag_axes(actx, {
835+
1: DiscretizationDOFAxisTag()},
836+
_reshape_and_preserve_tags(
837+
actx, from_element_indices, (-1, 1))),
824838
pick_list],
825839
0)
826840
else:
@@ -847,10 +861,13 @@ def group_pick_knl(is_surjective: bool):
847861
else:
848862
# If no batched data at all, return zeros for this
849863
# particular group array
850-
group_array = actx.zeros(
864+
group_array = tag_axes(actx, {
865+
0: DiscretizationElementAxisTag(),
866+
1: DiscretizationDOFAxisTag()},
867+
actx.zeros(
851868
shape=(self.to_discr.groups[i_tgrp].nelements,
852869
self.to_discr.groups[i_tgrp].nunit_dofs),
853-
dtype=ary.entry_dtype)
870+
dtype=ary.entry_dtype))
854871

855872
group_arrays.append(group_array)
856873

meshmode/transform_metadata.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
.. autoclass:: DiscretizationDOFAxisTag
99
.. autoclass:: DiscretizationAmbientDimAxisTag
1010
.. autoclass:: DiscretizationTopologicalDimAxisTag
11+
.. autoclass:: DiscretizationDOFPickListAxisTag
1112
"""
1213

1314
__copyright__ = """
@@ -121,3 +122,12 @@ class DiscretizationTopologicalDimAxisTag(DiscretizationDimAxisTag):
121122
Array dimensions tagged with this tag type describe an axis indexing over
122123
the discretization's physical coordinate dimensions.
123124
"""
125+
126+
127+
@tag_dataclass
128+
class DiscretizationDOFPickListAxisTag(DiscretizationEntityAxisTag):
129+
"""
130+
Array dimensions tagged with this tag type describe an axis indexing over
131+
DOF pick lists. See :mod:`meshmode.discretization.connection.direct` for
132+
details.
133+
"""

0 commit comments

Comments
 (0)