3636
3737from arraycontext import (
3838 ArrayContext ,
39- map_array_container ,
40- thaw
39+ map_array_container
4140)
41+ from arraycontext .container import ArrayOrContainerT
42+
4243from functools import partial
44+
4345from meshmode .transform_metadata import FirstAxisIsElementsTag
4446
4547from grudge .discretization import DiscretizationCollection
48+ from grudge .dof_desc import DOFDesc
4649
4750from meshmode .dof_array import DOFArray
4851
@@ -66,8 +69,6 @@ def interp(dcoll: DiscretizationCollection, src, tgt, vec):
6669
6770def volume_quadrature_interpolation_matrix (
6871 actx : ArrayContext , base_element_group , vol_quad_element_group ):
69- """todo.
70- """
7172 @keyed_memoize_in (
7273 actx , volume_quadrature_interpolation_matrix ,
7374 lambda base_grp , vol_quad_grp : (base_grp .discretization_key (),
@@ -85,10 +86,7 @@ def get_volume_vand(base_grp, vol_quad_grp):
8586
8687
8788def surface_quadrature_interpolation_matrix (
88- actx : ArrayContext ,
89- base_element_group , face_quad_element_group , dtype ):
90- """todo.
91- """
89+ actx : ArrayContext , base_element_group , face_quad_element_group ):
9290 @keyed_memoize_in (
9391 actx , surface_quadrature_interpolation_matrix ,
9492 lambda base_grp , face_quad_grp : (base_grp .discretization_key (),
@@ -121,11 +119,7 @@ def get_surface_vand(base_grp, face_quad_grp):
121119
122120def volume_and_surface_interpolation_matrix (
123121 actx : ArrayContext ,
124- base_element_group ,
125- vol_quad_element_group ,
126- face_quad_element_group , dtype ):
127- """todo.
128- """
122+ base_element_group , vol_quad_element_group , face_quad_element_group ):
129123 @keyed_memoize_in (
130124 actx , volume_and_surface_interpolation_matrix ,
131125 lambda base_grp , vol_quad_grp , face_quad_grp : (
@@ -134,21 +128,15 @@ def volume_and_surface_interpolation_matrix(
134128 face_quad_grp .discretization_key ()))
135129 def get_vol_surf_interpolation_matrix (base_grp , vol_quad_grp , face_quad_grp ):
136130 vq_mat = actx .to_numpy (
137- thaw (
138- volume_quadrature_interpolation_matrix (
139- actx , base_grp , vol_quad_grp
140- ),
141- actx
142- )
143- )
131+ volume_quadrature_interpolation_matrix (
132+ actx ,
133+ base_element_group = base_grp ,
134+ vol_quad_element_group = vol_quad_grp ))
144135 vf_mat = actx .to_numpy (
145- thaw (
146- surface_quadrature_interpolation_matrix (
147- actx , base_grp , face_quad_grp , dtype
148- ),
149- actx
150- )
151- )
136+ surface_quadrature_interpolation_matrix (
137+ actx ,
138+ base_element_group = base_grp ,
139+ face_quad_element_group = face_quad_grp ))
152140 return actx .freeze (actx .from_numpy (np .block ([[vq_mat ], [vf_mat ]])))
153141
154142 return get_vol_surf_interpolation_matrix (
@@ -158,49 +146,23 @@ def get_vol_surf_interpolation_matrix(base_grp, vol_quad_grp, face_quad_grp):
158146# }}}
159147
160148
161- def volume_quadrature_interpolation (dcoll : DiscretizationCollection , dq , vec ):
162- if not isinstance (vec , DOFArray ):
163- return map_array_container (
164- partial (volume_quadrature_interpolation , dcoll , dq ), vec
165- )
166-
167- actx = vec .array_context
168- discr = dcoll .discr_from_dd ("vol" )
169- quad_discr = dcoll .discr_from_dd (dq )
170-
171- return DOFArray (
172- actx ,
173- data = tuple (
174- actx .einsum ("ij,ej->ei" ,
175- volume_quadrature_interpolation_matrix (
176- actx ,
177- base_element_group = bgrp ,
178- vol_quad_element_group = qgrp
179- ),
180- vec_i ,
181- arg_names = ("Vq_mat" , "vec" ),
182- tagged = (FirstAxisIsElementsTag (),))
183-
184- for bgrp , qgrp , vec_i in zip (discr .groups , quad_discr .groups , vec )
185- )
186- )
187-
188-
189149def volume_and_surface_quadrature_interpolation (
190- dcoll : DiscretizationCollection , dq , df , vec ):
150+ dcoll : DiscretizationCollection ,
151+ dd_quad : DOFDesc ,
152+ dd_face_quad : DOFDesc ,
153+ vec : ArrayOrContainerT ) -> ArrayOrContainerT :
191154 """todo.
192155 """
193156 if not isinstance (vec , DOFArray ):
194157 return map_array_container (
195158 partial (volume_and_surface_quadrature_interpolation ,
196- dcoll , dq , df ), vec
159+ dcoll , dd_quad , dd_face_quad ), vec
197160 )
198161
199162 actx = vec .array_context
200- dtype = vec .entry_dtype
201163 discr = dcoll .discr_from_dd ("vol" )
202- quad_volm_discr = dcoll .discr_from_dd (dq )
203- quad_face_discr = dcoll .discr_from_dd (df )
164+ quad_volm_discr = dcoll .discr_from_dd (dd_quad )
165+ quad_face_discr = dcoll .discr_from_dd (dd_face_quad )
204166
205167 return DOFArray (
206168 actx ,
@@ -210,8 +172,7 @@ def volume_and_surface_quadrature_interpolation(
210172 actx ,
211173 base_element_group = bgrp ,
212174 vol_quad_element_group = qvgrp ,
213- face_quad_element_group = qfgrp ,
214- dtype = dtype
175+ face_quad_element_group = qfgrp
215176 ),
216177 vec_i ,
217178 arg_names = ("Vh_mat" , "vec" ),
0 commit comments