Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pytato/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,9 @@ def map_einsum(self, expr: Einsum) -> Array:
args_as_pym_expr[0])

if redn_bounds:
from pytato.reductions import SumReductionOperation
inner_expr = Reduce(inner_expr,
"sum",
SumReductionOperation(),
redn_bounds)

return IndexLambda(expr=inner_expr,
Expand Down
162 changes: 147 additions & 15 deletions pytato/reductions.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@
THE SOFTWARE.
"""

from typing import Optional, Tuple, Union, Sequence, Dict, List
from typing import Any, Optional, Tuple, Union, Sequence, Dict, List
from abc import ABC, abstractmethod

import numpy as np

from pytato.array import ShapeType, Array, make_index_lambda
from pytato.scalar_expr import ScalarExpression, Reduce, INT_CLASSES
import pymbolic.primitives as prim
Expand All @@ -43,11 +47,97 @@
.. autofunction:: prod
.. autofunction:: all
.. autofunction:: any

.. currentmodule:: pytato.reductions

.. autoclass:: ReductionOperation
.. autoclass:: SumReductionOperation
.. autoclass:: ProductReductionOperation
.. autoclass:: MaxReductionOperation
.. autoclass:: MinReductionOperation
.. autoclass:: AllReductionOperation
.. autoclass:: AnyReductionOperation
"""

# }}}


class _NoValue:
pass


# {{{ reduction operations

class ReductionOperation(ABC):
"""
.. automethod:: neutral_element
.. automethod:: __hash__
.. automethod:: __eq__
"""

@abstractmethod
def neutral_element(self, dtype: np.dtype[Any]) -> Any:
pass

@abstractmethod
def __hash__(self) -> int:
pass

@abstractmethod
def __eq__(self, other: Any) -> bool:
pass


class _StatelessReductionOperation(ReductionOperation):
def __hash__(self) -> int:
return hash(type(self))

def __eq__(self, other: Any) -> bool:
return type(self) is type(other)


class SumReductionOperation(_StatelessReductionOperation):
def neutral_element(self, dtype: np.dtype[Any]) -> Any:
return 0


class ProductReductionOperation(_StatelessReductionOperation):
def neutral_element(self, dtype: np.dtype[Any]) -> Any:
return 1


class MaxReductionOperation(_StatelessReductionOperation):
def neutral_element(self, dtype: np.dtype[Any]) -> Any:
if dtype.kind == "f":
return dtype.type(float("-inf"))
elif dtype.kind == "i":
return np.iinfo(dtype).min
else:
raise TypeError(f"unknown neutral element for max and {dtype}")


class MinReductionOperation(_StatelessReductionOperation):
def neutral_element(self, dtype: np.dtype[Any]) -> Any:
if dtype.kind == "f":
return dtype.type(float("inf"))
elif dtype.kind == "i":
return np.iinfo(dtype).max
else:
raise TypeError(f"unknown neutral element for min and {dtype}")


class AllReductionOperation(_StatelessReductionOperation):
def neutral_element(self, dtype: np.dtype[Any]) -> Any:
return np.bool_(True)


class AnyReductionOperation(_StatelessReductionOperation):
def neutral_element(self, dtype: np.dtype[Any]) -> Any:
return np.bool_(False)

# }}}


# {{{ reductions

def _normalize_reduction_axes(
Expand Down Expand Up @@ -124,8 +214,9 @@ def _get_reduction_indices_bounds(shape: ShapeType,
return indices, pmap(redn_bounds) # type: ignore


def _make_reduction_lambda(op: str, a: Array,
axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
def _make_reduction_lambda(op: ReductionOperation, a: Array,
axis: Optional[Union[int, Tuple[int]]],
initial: Any) -> Array:
"""
Return a :class:`IndexLambda` that performs reduction over the *axis* axes
of *a* with the reduction op *op*.
Expand All @@ -137,9 +228,28 @@ def _make_reduction_lambda(op: str, a: Array,
:arg axis: The axes over which the reduction is to be performed. If axis is
*None*, perform reduction over all of *a*'s axes.
"""
new_shape, axes = _normalize_reduction_axes(a.shape, axis)
new_shape, reduction_axes = _normalize_reduction_axes(a.shape, axis)
del axis
indices, redn_bounds = _get_reduction_indices_bounds(a.shape, axes)
indices, redn_bounds = _get_reduction_indices_bounds(a.shape, reduction_axes)

if initial is _NoValue:
for iax in reduction_axes:
shape_iax = a.shape[iax]

from pytato.utils import are_shape_components_equal
if are_shape_components_equal(shape_iax, 0):
raise ValueError(
"zero-size reduction operation with no supplied "
"'initial' value")

if isinstance(iax, Array):
raise NotImplementedError(
"cannot statically determine emptiness of "
f"reduction axis {iax} (0-based)")

elif initial != op.neutral_element(a.dtype):
raise NotImplementedError("reduction with 'initial' not equal to the "
"neutral element")

return make_index_lambda(
Reduce(
Expand All @@ -151,52 +261,74 @@ def _make_reduction_lambda(op: str, a: Array,
a.dtype)


def sum(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
def sum(a: Array, axis: Optional[Union[int, Tuple[int]]] = None,
initial: Any = 0) -> Array:
"""
Sums array *a*'s elements along the *axis* axes.

:arg a: The :class:`pytato.Array` on which to perform the reduction.

:arg axis: The axes along which the elements are to be sum-reduced.
Defaults to all axes of the input array.
:arg initial: The value returned for an empty array, if supplied.
This value also serves as the base value onto which any additional
array entries are accumulated.
"""
return _make_reduction_lambda("sum", a, axis)
return _make_reduction_lambda(SumReductionOperation(), a, axis, initial)


def amax(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
def amax(a: Array, axis: Optional[Union[int, Tuple[int]]] = None, *,
initial: Any = _NoValue) -> Array:
"""
Returns the max of array *a*'s elements along the *axis* axes.

:arg a: The :class:`pytato.Array` on which to perform the reduction.

:arg axis: The axes along which the elements are to be max-reduced.
Defaults to all axes of the input array.
:arg initial: The value returned for an empty array, if supplied.
This value also serves as the base value onto which any additional
array entries are accumulated.
If not supplied, an :exc:`ValueError` will be raised
if the reduction is empty.
In that case, the reduction size must not be symbolic.
"""
return _make_reduction_lambda("max", a, axis)
return _make_reduction_lambda(MaxReductionOperation(), a, axis, initial)


def amin(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
def amin(a: Array, axis: Optional[Union[int, Tuple[int]]] = None,
initial: Any = _NoValue) -> Array:
"""
Returns the min of array *a*'s elements along the *axis* axes.

:arg a: The :class:`pytato.Array` on which to perform the reduction.

:arg axis: The axes along which the elements are to be min-reduced.
Defaults to all axes of the input array.
:arg initial: The value returned for an empty array, if supplied.
This value also serves as the base value onto which any additional
array entries are accumulated.
If not supplied, an :exc:`ValueError` will be raised
if the reduction is empty.
In that case, the reduction size must not be symbolic.
"""
return _make_reduction_lambda("min", a, axis)
return _make_reduction_lambda(MinReductionOperation(), a, axis, initial)


def prod(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
def prod(a: Array, axis: Optional[Union[int, Tuple[int]]] = None,
initial: Any = 1) -> Array:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be _NoValue as well? Numpy does it, see https://numpy.org/doc/stable/reference/generated/numpy.prod.html#numpy.prod.

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is OK. Numpy will return 1 for an empty prod. I think for an immutable object like 1 it does not matter whether that's done via _NoValue and then setting internally, or directly like this.

"""
Returns the product of array *a*'s elements along the *axis* axes.

:arg a: The :class:`pytato.Array` on which to perform the reduction.

:arg axis: The axes along which the elements are to be product-reduced.
Defaults to all axes of the input array.
:arg initial: The value returned for an empty array, if supplied.
This value also serves as the base value onto which any additional
array entries are accumulated.
"""
return _make_reduction_lambda("product", a, axis)
return _make_reduction_lambda(ProductReductionOperation(), a, axis, initial)


def all(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
Expand All @@ -208,7 +340,7 @@ def all(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
:arg axis: The axes along which the elements are to be product-reduced.
Defaults to all axes of the input array.
"""
return _make_reduction_lambda("all", a, axis)
return _make_reduction_lambda(AllReductionOperation(), a, axis, initial=True)


def any(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
Expand All @@ -220,7 +352,7 @@ def any(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array:
:arg axis: The axes along which the elements are to be product-reduced.
Defaults to all axes of the input array.
"""
return _make_reduction_lambda("any", a, axis)
return _make_reduction_lambda(AnyReductionOperation(), a, axis, initial=False)

# }}}

Expand Down
18 changes: 11 additions & 7 deletions pytato/scalar_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@
"""

from numbers import Number
from typing import Any, Union, Mapping, FrozenSet, Set, Tuple, Optional
from typing import (
Any, Union, Mapping, FrozenSet, Set, Tuple, Optional, TYPE_CHECKING)

from pymbolic.mapper import (WalkMapper as WalkMapperBase, IdentityMapper as
IdentityMapperBase)
Expand All @@ -44,6 +45,10 @@
import numpy as np
import re

if TYPE_CHECKING:
from pytato.reductions import ReductionOperation


__doc__ = """
.. currentmodule:: pytato.scalar_expr

Expand Down Expand Up @@ -232,21 +237,20 @@ class Reduce(ExpressionBase):

.. attribute:: op

One of ``"sum"``, ``"product"``, ``"max"``, ``"min"``,``"all"``, ``"any"``.
A :class:`pytato.reductions.ReductionOperation`.

.. attribute:: bounds

A mapping from reduction inames to tuples ``(lower_bound, upper_bound)``
identifying half-open bounds intervals. Must be hashable.
"""
inner_expr: ScalarExpression
op: str
op: ReductionOperation
bounds: Mapping[str, Tuple[ScalarExpression, ScalarExpression]]

def __init__(self, inner_expr: ScalarExpression, op: str, bounds: Any) -> None:
def __init__(self, inner_expr: ScalarExpression,
op: ReductionOperation, bounds: Any) -> None:
self.inner_expr = inner_expr
if op not in {"sum", "product", "max", "min", "all", "any"}:
raise ValueError(f"unsupported op: {op}")
self.op = op
self.bounds = bounds

Expand All @@ -256,7 +260,7 @@ def __hash__(self) -> int:
tuple(self.bounds.keys()),
tuple(self.bounds.values())))

def __getinitargs__(self) -> Tuple[ScalarExpression, str, Any]:
def __getinitargs__(self) -> Tuple[ScalarExpression, ReductionOperation, Any]:
return (self.inner_expr, self.op, self.bounds)

mapper_method = "map_reduce"
Expand Down
27 changes: 14 additions & 13 deletions pytato/target/loopy/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from pytato.loopy import LoopyCall
from pytato.tags import ImplStored, _BaseNameTag, Named, PrefixNamed
from pytools.tag import Tag
import pytato.reductions as red

# set in doc/conf.py
if getattr(sys, "PYTATO_BUILDING_SPHINX_DOCS", False):
Expand Down Expand Up @@ -537,13 +538,13 @@ def _get_sub_array_ref(array: Array, name: str) -> "lp.symbolic.SubArrayRef":
REDUCTION_INDEX_RE = re.compile("_r(0|([1-9][0-9]*))")

# Maps Pytato reduction types to the corresponding Loopy reduction types.
PYTATO_REDUCTION_TO_LOOPY_REDUCTION = {
"sum": "sum",
"product": "product",
"max": "max",
"min": "min",
"all": "all",
"any": "any",
PYTATO_REDUCTION_TO_LOOPY_REDUCTION: Mapping[Type[red.ReductionOperation], str] = {
red.SumReductionOperation: "sum",
red.ProductReductionOperation: "product",
red.MaxReductionOperation: "max",
red.MinReductionOperation: "min",
red.AllReductionOperation: "all",
red.AnyReductionOperation: "any",
}


Expand Down Expand Up @@ -620,8 +621,13 @@ def map_reduce(self, expr: scalar_expr.Reduce,
from loopy.symbolic import Reduction as LoopyReduction
state = prstnt_ctx.state

try:
loopy_redn = PYTATO_REDUCTION_TO_LOOPY_REDUCTION[type(expr.op)]
except KeyError:
raise NotImplementedError(expr.op)

unique_names_mapping = {
old_name: state.var_name_gen(f"_pt_{expr.op}" + old_name)
old_name: state.var_name_gen(f"_pt_{loopy_redn}" + old_name)
for old_name in expr.bounds}

inner_expr = loopy_substitute(expr.inner_expr,
Expand All @@ -633,11 +639,6 @@ def map_reduce(self, expr: scalar_expr.Reduce,
inner_expr = self.rec(inner_expr, prstnt_ctx,
local_ctx.copy(reduction_bounds=new_bounds))

try:
loopy_redn = PYTATO_REDUCTION_TO_LOOPY_REDUCTION[expr.op]
except KeyError:
raise NotImplementedError(expr.op)

inner_expr = LoopyReduction(loopy_redn,
tuple(unique_names_mapping.values()),
inner_expr)
Expand Down