Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 31 additions & 40 deletions grudge/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,12 @@
"""


from functools import reduce, partial
from functools import partial

from arraycontext import (
make_loopy_program,
map_array_container,
serialize_container,
get_container_context_recursively,
DeviceScalar
)
from arraycontext.container import ArrayOrContainerT
Expand Down Expand Up @@ -94,7 +94,6 @@ def norm(dcoll: DiscretizationCollection, vec, p, dd=None) -> "DeviceScalar":
if dd is None:
dd = dof_desc.DD_VOLUME

from arraycontext import get_container_context_recursively
actx = get_container_context_recursively(vec)

dd = dof_desc.as_dofdesc(dd)
Expand Down Expand Up @@ -128,7 +127,7 @@ def nodal_sum(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":

# NOTE: Don't move this
from mpi4py import MPI
actx = vec.array_context
actx = get_container_context_recursively(vec)

return actx.from_numpy(
comm.allreduce(actx.to_numpy(nodal_sum_loc(dcoll, dd, vec)), op=MPI.SUM))
Expand All @@ -143,15 +142,13 @@ def nodal_sum_loc(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
:class:`~arraycontext.container.ArrayContainer` of them.
:returns: a scalar denoting the rank-local nodal sum.
"""
if not isinstance(vec, DOFArray):
return sum(
nodal_sum_loc(dcoll, dd, comp)
for _, comp in serialize_container(vec)
)

actx = vec.array_context

return sum([actx.np.sum(grp_ary) for grp_ary in vec])
actx = get_container_context_recursively(vec)
result = actx.np.sum(vec)
# Fix actx._force_device_scalars == False case
if np.isscalar(result):
return actx.from_numpy(result)
else:
return result


def nodal_min(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
Expand All @@ -169,7 +166,7 @@ def nodal_min(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":

# NOTE: Don't move this
from mpi4py import MPI
actx = vec.array_context
actx = get_container_context_recursively(vec)

return actx.from_numpy(
comm.allreduce(actx.to_numpy(nodal_min_loc(dcoll, dd, vec)), op=MPI.MIN))
Expand All @@ -185,17 +182,13 @@ def nodal_min_loc(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
:class:`~arraycontext.container.ArrayContainer` of them.
:returns: a scalar denoting the rank-local nodal minimum.
"""
if not isinstance(vec, DOFArray):
return min(
nodal_min_loc(dcoll, dd, comp)
for _, comp in serialize_container(vec)
)

actx = vec.array_context

return reduce(
lambda acc, grp_ary: actx.np.minimum(acc, actx.np.min(grp_ary)),
vec, actx.from_numpy(np.array(np.inf)))
actx = get_container_context_recursively(vec)
result = actx.np.min(vec)
# Fix actx._force_device_scalars == False case
if np.isscalar(result):
return actx.from_numpy(result)
else:
return result


def nodal_max(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
Expand All @@ -213,7 +206,7 @@ def nodal_max(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":

# NOTE: Don't move this
from mpi4py import MPI
actx = vec.array_context
actx = get_container_context_recursively(vec)

return actx.from_numpy(
comm.allreduce(actx.to_numpy(nodal_max_loc(dcoll, dd, vec)), op=MPI.MAX))
Expand All @@ -229,17 +222,13 @@ def nodal_max_loc(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
:class:`~arraycontext.container.ArrayContainer`.
:returns: a scalar denoting the rank-local nodal maximum.
"""
if not isinstance(vec, DOFArray):
return max(
nodal_max_loc(dcoll, dd, comp)
for _, comp in serialize_container(vec)
)

actx = vec.array_context

return reduce(
lambda acc, grp_ary: actx.np.maximum(acc, actx.np.max(grp_ary)),
vec, actx.from_numpy(np.array(-np.inf)))
actx = get_container_context_recursively(vec)
result = actx.np.max(vec)
# Fix actx._force_device_scalars == False case
if np.isscalar(result):
return actx.from_numpy(result)
else:
return result


def integral(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
Expand All @@ -253,9 +242,10 @@ def integral(dcoll: DiscretizationCollection, dd, vec) -> "DeviceScalar":
"""
from grudge.op import _apply_mass_operator

actx = get_container_context_recursively(vec)
dd = dof_desc.as_dofdesc(dd)

ones = dcoll.discr_from_dd(dd).zeros(vec.array_context) + 1.0
ones = dcoll.discr_from_dd(dd).zeros(actx) + 1.0
return nodal_sum(
dcoll, dd, vec * _apply_mass_operator(dcoll, dd, dd, ones)
)
Expand Down Expand Up @@ -295,7 +285,7 @@ def _apply_elementwise_reduction(
partial(_apply_elementwise_reduction, op_name, dcoll, dd), vec
)

actx = vec.array_context
actx = get_container_context_recursively(vec)

if actx.supports_nonscalar_broadcasting:
return DOFArray(
Expand Down Expand Up @@ -456,11 +446,12 @@ def elementwise_integral(
else:
raise TypeError("invalid number of arguments")

actx = get_container_context_recursively(vec)
dd = dof_desc.as_dofdesc(dd)

from grudge.op import _apply_mass_operator

ones = dcoll.discr_from_dd(dd).zeros(vec.array_context) + 1.0
ones = dcoll.discr_from_dd(dd).zeros(actx) + 1.0
return elementwise_sum(
dcoll, dd, vec * _apply_mass_operator(dcoll, dd, dd, ones)
)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ git+https://github.com/inducer/dagrt.git#egg=dagrt
git+https://github.com/inducer/leap.git#egg=leap
git+https://github.com/inducer/meshpy.git#egg=meshpy
git+https://github.com/inducer/modepy.git#egg=modepy
git+https://github.com/inducer/arraycontext.git#egg=arraycontext
git+https://github.com/majosm/arraycontext.git@empty-subcontainers#egg=arraycontext
git+https://github.com/inducer/meshmode.git#egg=meshmode
git+https://github.com/inducer/pyvisfile.git#egg=pyvisfile
git+https://github.com/inducer/pymetis.git#egg=pymetis
Expand Down