Skip to content

Commit 730e0b4

Browse files
committed
forget about heterogeneous inter-volume trace pairs for now
1 parent 0d2a013 commit 730e0b4

File tree

1 file changed

+26
-26
lines changed

1 file changed

+26
-26
lines changed

grudge/trace_pair.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -697,47 +697,47 @@ def __init__(self,
697697
self.local_part_id = local_part_id
698698
self.remote_part_id = remote_part_id
699699

700-
from pytato import make_distributed_recv, staple_distributed_send
700+
from pytato import (
701+
make_distributed_recv,
702+
make_distributed_send,
703+
DistributedSendRefHolder)
704+
705+
# TODO: This currently assumes that local_bdry_data and
706+
# remote_bdry_data_template have the same structure. This is not true
707+
# in general. Find a way to staple the sends appropriately when the number
708+
# of recvs is not equal to the number of sends
709+
assert type(local_bdry_data) == type(remote_bdry_data_template)
710+
711+
sends = {}
701712

702-
# Staple the sends to a bunch of dummy arrays of zeros
703713
def send_single_array(key, local_subary):
704714
if isinstance(local_subary, Number):
705-
return 0
715+
return
706716
else:
707717
ary_tag = (comm_tag, key)
708-
return staple_distributed_send(
709-
local_subary, dest_rank=remote_rank, comm_tag=ary_tag,
710-
stapled_to=actx.zeros_like(local_subary))
718+
sends[key] = make_distributed_send(
719+
local_subary, dest_rank=remote_rank, comm_tag=ary_tag)
711720

712721
def recv_single_array(key, remote_subary_template):
713722
if isinstance(remote_subary_template, Number):
714723
# NOTE: Assumes that the same number is passed on every rank
715-
return remote_subary_template
724+
return Number
716725
else:
717726
ary_tag = (comm_tag, key)
718-
return make_distributed_recv(
719-
src_rank=remote_rank, comm_tag=ary_tag,
720-
shape=remote_subary_template.shape,
721-
dtype=remote_subary_template.dtype)
727+
return DistributedSendRefHolder(
728+
sends[key],
729+
make_distributed_recv(
730+
src_rank=remote_rank, comm_tag=ary_tag,
731+
shape=remote_subary_template.shape,
732+
dtype=remote_subary_template.dtype))
722733

723734
from arraycontext.container.traversal import rec_keyed_map_array_container
724-
zeros_like_local_bdry_data = rec_keyed_map_array_container(
725-
send_single_array, local_bdry_data)
726-
unswapped_remote_bdry_data = rec_keyed_map_array_container(
727-
recv_single_array, remote_bdry_data_template)
728735

729-
# Sum up the dummy zeros
730-
zero = actx.np.sum(zeros_like_local_bdry_data)
736+
rec_keyed_map_array_container(send_single_array, local_bdry_data)
737+
self.local_bdry_data = local_bdry_data
731738

732-
# Add the dummy zeros and hope that the caller proceeds to actually
733-
# use some of this data on every rank...
734-
from arraycontext import rec_map_array_container
735-
self.local_bdry_data = rec_map_array_container(
736-
lambda x: x + zero,
737-
local_bdry_data)
738-
self.unswapped_remote_bdry_data = rec_map_array_container(
739-
lambda x: x + zero,
740-
unswapped_remote_bdry_data)
739+
self.unswapped_remote_bdry_data = rec_keyed_map_array_container(
740+
recv_single_array, remote_bdry_data_template)
741741

742742
def finish(self):
743743
remote_to_local = self.dcoll._inter_partition_connections[

0 commit comments

Comments
 (0)