Skip to content

Commit 61ccd37

Browse files
Add MPINumpyArrayContext
1 parent 9674514 commit 61ccd37

File tree

1 file changed

+23
-0
lines changed

1 file changed

+23
-0
lines changed

grudge/array_context.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
.. autoclass:: PytatoPyOpenCLArrayContext
44
.. autoclass:: MPIBasedArrayContext
55
.. autoclass:: MPIPyOpenCLArrayContext
6+
.. autoclass:: MPINumpyArrayContext
67
.. class:: MPIPytatoArrayContext
78
.. autofunction:: get_reasonable_array_context_class
89
"""
@@ -98,6 +99,8 @@
9899
from arraycontext.container import ArrayContainer
99100
from arraycontext.impl.pytato.compile import LazilyPyOpenCLCompilingFunctionCaller
100101

102+
from arraycontext import NumpyArrayContext
103+
101104
if TYPE_CHECKING:
102105
import pytato as pt
103106
from pytato.partition import PartId
@@ -459,6 +462,26 @@ def clone(self):
459462
# }}}
460463

461464

465+
# {{{ distributed + numpy
466+
467+
class MPINumpyArrayContext(NumpyArrayContext, MPIBasedArrayContext):
468+
"""An array context for using distributed computation with :mod:`numpy`
469+
eager evaluation.
470+
471+
.. autofunction:: __init__
472+
"""
473+
474+
def __init__(self, mpi_communicator) -> None:
475+
super().__init__()
476+
477+
self.mpi_communicator = mpi_communicator
478+
479+
def clone(self):
480+
return type(self)(self.mpi_communicator)
481+
482+
# }}}
483+
484+
462485
# {{{ distributed + pytato array context subclasses
463486

464487
class MPIBasePytatoPyOpenCLArrayContext(

0 commit comments

Comments
 (0)