|
80 | 80 | from pytools import memoize_in |
81 | 81 |
|
82 | 82 | import grudge.dof_desc as dof_desc |
| 83 | +from grudge.array_context import MPIBasedArrayContext |
83 | 84 | from grudge.discretization import DiscretizationCollection |
84 | 85 |
|
85 | 86 |
|
@@ -128,16 +129,17 @@ def nodal_sum(dcoll: DiscretizationCollection, dd, vec) -> Scalar: |
128 | 129 | :class:`~arraycontext.ArrayContainer`. |
129 | 130 | :returns: a device scalar denoting the nodal sum. |
130 | 131 | """ |
131 | | - comm = dcoll.mpi_communicator |
132 | | - if comm is None: |
| 132 | + from arraycontext import get_container_context_recursively |
| 133 | + actx = get_container_context_recursively(vec) |
| 134 | + |
| 135 | + if not isinstance(actx, MPIBasedArrayContext): |
133 | 136 | return nodal_sum_loc(dcoll, dd, vec) |
134 | 137 |
|
| 138 | + comm = actx.mpi_communicator |
| 139 | + |
135 | 140 | # NOTE: Do not move, we do not want to import mpi4py in single-rank computations |
136 | 141 | from mpi4py import MPI |
137 | 142 |
|
138 | | - from arraycontext import get_container_context_recursively |
139 | | - actx = get_container_context_recursively(vec) |
140 | | - |
141 | 143 | return actx.from_numpy( |
142 | 144 | comm.allreduce(actx.to_numpy(nodal_sum_loc(dcoll, dd, vec)), op=MPI.SUM)) |
143 | 145 |
|
@@ -174,13 +176,16 @@ def nodal_min(dcoll: DiscretizationCollection, dd, vec, *, initial=None) -> Scal |
174 | 176 | :arg initial: an optional initial value. Defaults to `numpy.inf`. |
175 | 177 | :returns: a device scalar denoting the nodal minimum. |
176 | 178 | """ |
177 | | - comm = dcoll.mpi_communicator |
178 | | - if comm is None: |
| 179 | + from arraycontext import get_container_context_recursively |
| 180 | + actx = get_container_context_recursively(vec) |
| 181 | + |
| 182 | + if not isinstance(actx, MPIBasedArrayContext): |
179 | 183 | return nodal_min_loc(dcoll, dd, vec, initial=initial) |
180 | 184 |
|
| 185 | + comm = actx.mpi_communicator |
| 186 | + |
181 | 187 | # NOTE: Do not move, we do not want to import mpi4py in single-rank computations |
182 | 188 | from mpi4py import MPI |
183 | | - actx = vec.array_context |
184 | 189 |
|
185 | 190 | return actx.from_numpy( |
186 | 191 | comm.allreduce( |
@@ -231,13 +236,16 @@ def nodal_max(dcoll: DiscretizationCollection, dd, vec, *, initial=None) -> Scal |
231 | 236 | :arg initial: an optional initial value. Defaults to `-numpy.inf`. |
232 | 237 | :returns: a device scalar denoting the nodal maximum. |
233 | 238 | """ |
234 | | - comm = dcoll.mpi_communicator |
235 | | - if comm is None: |
| 239 | + from arraycontext import get_container_context_recursively |
| 240 | + actx = get_container_context_recursively(vec) |
| 241 | + |
| 242 | + if not isinstance(actx, MPIBasedArrayContext): |
236 | 243 | return nodal_max_loc(dcoll, dd, vec, initial=initial) |
237 | 244 |
|
| 245 | + comm = actx.mpi_communicator |
| 246 | + |
238 | 247 | # NOTE: Do not move, we do not want to import mpi4py in single-rank computations |
239 | 248 | from mpi4py import MPI |
240 | | - actx = vec.array_context |
241 | 249 |
|
242 | 250 | return actx.from_numpy( |
243 | 251 | comm.allreduce( |
|
0 commit comments