Skip to content

Commit 8c86c98

Browse files
committed
handle all-Number cases
1 parent ab3f43a commit 8c86c98

File tree

1 file changed

+28
-9
lines changed

1 file changed

+28
-9
lines changed

grudge/trace_pair.py

Lines changed: 28 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,14 @@ def cross_rank_trace_pairs(
840840
{dcoll._part_id_helper.get_mpi_rank(part_id) for part_id in remote_part_ids})
841841

842842
actx = get_container_context_recursively(ary)
843-
assert actx is not None
843+
844+
if actx is None:
845+
# NOTE: Assumes that the same number is passed on every rank
846+
return [
847+
TracePair(
848+
volume_dd.trace(BTAG_PARTITION(remote_part_id)),
849+
interior=ary, exterior=ary)
850+
for remote_part_id in remote_part_ids]
844851

845852
from grudge.array_context import MPIPytatoArrayContextBase
846853

@@ -923,14 +930,6 @@ def cross_rank_inter_volume_trace_pairs(
923930
break
924931
if actx is not None:
925932
break
926-
assert actx is not None
927-
928-
from grudge.array_context import MPIPytatoArrayContextBase
929-
930-
if isinstance(actx, MPIPytatoArrayContextBase):
931-
rbc_class = _RankBoundaryCommunicationLazy
932-
else:
933-
rbc_class = _RankBoundaryCommunicationEager
934933

935934
def get_remote_connected_partitions(local_vol_dd, remote_vol_dd):
936935
connected_part_ids = _connected_partitions(
@@ -941,6 +940,26 @@ def get_remote_connected_partitions(local_vol_dd, remote_vol_dd):
941940
for part_id in connected_part_ids
942941
if dcoll._part_id_helper.get_mpi_rank(part_id) != rank]
943942

943+
if actx is None:
944+
# NOTE: Assumes that the same number is passed on every rank for a
945+
# given volume
946+
return {
947+
(remote_vol_dd, local_vol_dd): [
948+
TracePair(
949+
local_vol_dd.trace(BTAG_PARTITION(remote_part_id)),
950+
interior=local_vol_ary, exterior=remote_vol_ary)
951+
for remote_part_id in get_remote_connected_partitions(
952+
local_vol_dd, remote_vol_dd)]
953+
for (remote_vol_dd, local_vol_dd), (remote_vol_ary, local_vol_ary)
954+
in pairwise_volume_data.items()}
955+
956+
from grudge.array_context import MPIPytatoArrayContextBase
957+
958+
if isinstance(actx, MPIPytatoArrayContextBase):
959+
rbc_class = _RankBoundaryCommunicationLazy
960+
else:
961+
rbc_class = _RankBoundaryCommunicationEager
962+
944963
rank_bdry_communicators = {}
945964

946965
for vol_dd_pair, vol_data_pair in pairwise_volume_data.items():

0 commit comments

Comments
 (0)