@@ -697,47 +697,47 @@ def __init__(self,
697697 self .local_part_id = local_part_id
698698 self .remote_part_id = remote_part_id
699699
700- from pytato import make_distributed_recv , staple_distributed_send
700+ from pytato import (
701+ make_distributed_recv ,
702+ make_distributed_send ,
703+ DistributedSendRefHolder )
704+
705+ # TODO: This currently assumes that local_bdry_data and
706+ # remote_bdry_data_template have the same structure. This is not true
707+ # in general. Find a way to staple the sends appropriately when the number
708+ # of recvs is not equal to the number of sends
709+ assert type (local_bdry_data ) == type (remote_bdry_data_template )
710+
711+ sends = {}
701712
702- # Staple the sends to a bunch of dummy arrays of zeros
703713 def send_single_array (key , local_subary ):
704714 if isinstance (local_subary , Number ):
705- return 0
715+ return
706716 else :
707717 ary_tag = (comm_tag , key )
708- return staple_distributed_send (
709- local_subary , dest_rank = remote_rank , comm_tag = ary_tag ,
710- stapled_to = actx .zeros_like (local_subary ))
718+ sends [key ] = make_distributed_send (
719+ local_subary , dest_rank = remote_rank , comm_tag = ary_tag )
711720
712721 def recv_single_array (key , remote_subary_template ):
713722 if isinstance (remote_subary_template , Number ):
714723 # NOTE: Assumes that the same number is passed on every rank
715- return remote_subary_template
724+ return Number
716725 else :
717726 ary_tag = (comm_tag , key )
718- return make_distributed_recv (
719- src_rank = remote_rank , comm_tag = ary_tag ,
720- shape = remote_subary_template .shape ,
721- dtype = remote_subary_template .dtype )
727+ return DistributedSendRefHolder (
728+ sends [key ],
729+ make_distributed_recv (
730+ src_rank = remote_rank , comm_tag = ary_tag ,
731+ shape = remote_subary_template .shape ,
732+ dtype = remote_subary_template .dtype ))
722733
723734 from arraycontext .container .traversal import rec_keyed_map_array_container
724- zeros_like_local_bdry_data = rec_keyed_map_array_container (
725- send_single_array , local_bdry_data )
726- unswapped_remote_bdry_data = rec_keyed_map_array_container (
727- recv_single_array , remote_bdry_data_template )
728735
729- # Sum up the dummy zeros
730- zero = actx . np . sum ( zeros_like_local_bdry_data )
736+ rec_keyed_map_array_container ( send_single_array , local_bdry_data )
737+ self . local_bdry_data = local_bdry_data
731738
732- # Add the dummy zeros and hope that the caller proceeds to actually
733- # use some of this data on every rank...
734- from arraycontext import rec_map_array_container
735- self .local_bdry_data = rec_map_array_container (
736- lambda x : x + zero ,
737- local_bdry_data )
738- self .unswapped_remote_bdry_data = rec_map_array_container (
739- lambda x : x + zero ,
740- unswapped_remote_bdry_data )
739+ self .unswapped_remote_bdry_data = rec_keyed_map_array_container (
740+ recv_single_array , remote_bdry_data_template )
741741
742742 def finish (self ):
743743 remote_to_local = self .dcoll ._inter_partition_connections [
0 commit comments