3030import loopy as lp
3131from meshmode .transform_metadata import (
3232 ConcurrentElementInameTag , ConcurrentDOFInameTag ,
33- DiscretizationElementAxisTag , DiscretizationDOFAxisTag )
33+ DiscretizationElementAxisTag , DiscretizationDOFAxisTag ,
34+ DiscretizationDOFPickListAxisTag )
3435from pytools import memoize_in , keyed_memoize_method
3536from arraycontext import (
3637 ArrayContext , ArrayT , ArrayOrContainerT , NotAnArrayContainerError ,
@@ -547,17 +548,22 @@ def _per_target_group_pick_info(
547548 _FromGroupPickData (
548549 from_group_index = source_group_index ,
549550 dof_pick_lists = actx .freeze (
550- actx .tag (NameHint ("dof_pick_lists" ),
551- actx .from_numpy (dof_pick_lists ))),
551+ actx .tag_axis (0 , DiscretizationDOFPickListAxisTag (),
552+ actx .tag (NameHint ("dof_pick_lists" ),
553+ actx .from_numpy (dof_pick_lists )))),
552554 dof_pick_list_indices = actx .freeze (
553- actx .tag (NameHint ("dof_pick_list_indices" ),
554- actx .from_numpy (dof_pick_list_indices ))),
555+ actx .tag_axis (0 , DiscretizationElementAxisTag (),
556+ actx .tag (NameHint ("dof_pick_list_indices" ),
557+ actx .from_numpy (dof_pick_list_indices )))),
555558 from_el_present = actx .freeze (
556- actx .tag (NameHint ("from_el_present" ),
557- actx .from_numpy (from_el_present .astype (np .int8 )))),
559+ actx .tag_axis (0 , DiscretizationElementAxisTag (),
560+ actx .tag (NameHint ("from_el_present" ),
561+ actx .from_numpy (
562+ from_el_present .astype (np .int8 ))))),
558563 from_element_indices = actx .freeze (
559- actx .tag (NameHint ("from_el_indices" ),
560- actx .from_numpy (from_el_indices ))),
564+ actx .tag_axis (0 , DiscretizationElementAxisTag (),
565+ actx .tag (NameHint ("from_el_indices" ),
566+ actx .from_numpy (from_el_indices )))),
561567 is_surjective = from_el_present .all ()
562568 ))
563569
@@ -726,25 +732,27 @@ def group_pick_knl(is_surjective: bool):
726732 group_pick_info = None
727733
728734 if group_pick_info is not None :
729- group_array_contributions = []
730-
731735 if actx .permits_advanced_indexing and not _force_use_loopy :
732736 for fgpd in group_pick_info :
733737 from_element_indices = actx .thaw (fgpd .from_element_indices )
734738
735739 if ary [fgpd .from_group_index ].size :
736740 grp_ary_contrib = ary [fgpd .from_group_index ][
741+ tag_axes (actx , {
742+ 1 : DiscretizationDOFAxisTag ()},
737743 _reshape_and_preserve_tags (
738- actx , from_element_indices , (- 1 , 1 )),
739- actx .thaw (fgpd .dof_pick_lists )[
740- actx .thaw (fgpd .dof_pick_list_indices )]
741- ]
744+ actx , from_element_indices , (- 1 , 1 ))) ,
745+ actx .thaw (fgpd .dof_pick_lists )[
746+ actx .thaw (fgpd .dof_pick_list_indices )]
747+ ]
742748
743749 if not fgpd .is_surjective :
744750 from_el_present = actx .thaw (fgpd .from_el_present )
745751 grp_ary_contrib = actx .np .where (
746- _reshape_and_preserve_tags (
747- actx , from_el_present , (- 1 , 1 )),
752+ tag_axes (actx , {
753+ 1 : DiscretizationDOFAxisTag ()},
754+ _reshape_and_preserve_tags (
755+ actx , from_el_present , (- 1 , 1 ))),
748756 grp_ary_contrib ,
749757 0 )
750758
@@ -794,8 +802,10 @@ def group_pick_knl(is_surjective: bool):
794802 mat = self ._resample_matrix (actx , i_tgrp , i_batch )
795803 if actx .permits_advanced_indexing and not _force_use_loopy :
796804 batch_result = actx .np .where (
797- _reshape_and_preserve_tags (
798- actx , from_el_present , (- 1 , 1 )),
805+ tag_axes (actx , {
806+ 1 : DiscretizationDOFAxisTag ()},
807+ _reshape_and_preserve_tags (
808+ actx , from_el_present , (- 1 , 1 ))),
799809 actx .einsum ("ij,ej->ei" ,
800810 mat , grp_ary [from_element_indices ]),
801811 0 )
@@ -816,11 +826,15 @@ def group_pick_knl(is_surjective: bool):
816826
817827 if actx .permits_advanced_indexing and not _force_use_loopy :
818828 batch_result = actx .np .where (
819- _reshape_and_preserve_tags (
820- actx , from_el_present , (- 1 , 1 )),
821- from_vec [
829+ tag_axes (actx , {
830+ 1 : DiscretizationDOFAxisTag ()},
822831 _reshape_and_preserve_tags (
823- actx , from_element_indices , (- 1 , 1 )),
832+ actx , from_el_present , (- 1 , 1 ))),
833+ from_vec [
834+ tag_axes (actx , {
835+ 1 : DiscretizationDOFAxisTag ()},
836+ _reshape_and_preserve_tags (
837+ actx , from_element_indices , (- 1 , 1 ))),
824838 pick_list ],
825839 0 )
826840 else :
@@ -847,10 +861,13 @@ def group_pick_knl(is_surjective: bool):
847861 else :
848862 # If no batched data at all, return zeros for this
849863 # particular group array
850- group_array = actx .zeros (
864+ group_array = tag_axes (actx , {
865+ 0 : DiscretizationElementAxisTag (),
866+ 1 : DiscretizationDOFAxisTag ()},
867+ actx .zeros (
851868 shape = (self .to_discr .groups [i_tgrp ].nelements ,
852869 self .to_discr .groups [i_tgrp ].nunit_dofs ),
853- dtype = ary .entry_dtype )
870+ dtype = ary .entry_dtype ))
854871
855872 group_arrays .append (group_array )
856873
0 commit comments