-
Notifications
You must be signed in to change notification settings - Fork 16
Introduce ReductionOperation class, accept 'initial' in reductions
#238
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,7 +27,11 @@ | |
| THE SOFTWARE. | ||
| """ | ||
|
|
||
| from typing import Optional, Tuple, Union, Sequence, Dict, List | ||
| from typing import Any, Optional, Tuple, Union, Sequence, Dict, List | ||
| from abc import ABC, abstractmethod | ||
|
|
||
| import numpy as np | ||
|
|
||
| from pytato.array import ShapeType, Array, make_index_lambda | ||
| from pytato.scalar_expr import ScalarExpression, Reduce, INT_CLASSES | ||
| import pymbolic.primitives as prim | ||
|
|
@@ -43,11 +47,97 @@ | |
| .. autofunction:: prod | ||
| .. autofunction:: all | ||
| .. autofunction:: any | ||
|
|
||
| .. currentmodule:: pytato.reductions | ||
|
|
||
| .. autoclass:: ReductionOperation | ||
| .. autoclass:: SumReductionOperation | ||
| .. autoclass:: ProductReductionOperation | ||
| .. autoclass:: MaxReductionOperation | ||
| .. autoclass:: MinReductionOperation | ||
| .. autoclass:: AllReductionOperation | ||
| .. autoclass:: AnyReductionOperation | ||
| """ | ||
|
|
||
| # }}} | ||
|
|
||
|
|
||
| class _NoValue: | ||
| pass | ||
|
|
||
|
|
||
| # {{{ reduction operations | ||
|
|
||
| class ReductionOperation(ABC): | ||
| """ | ||
| .. automethod:: neutral_element | ||
| .. automethod:: __hash__ | ||
| .. automethod:: __eq__ | ||
| """ | ||
|
|
||
| @abstractmethod | ||
| def neutral_element(self, dtype: np.dtype[Any]) -> Any: | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def __hash__(self) -> int: | ||
| pass | ||
|
|
||
| @abstractmethod | ||
| def __eq__(self, other: Any) -> bool: | ||
| pass | ||
|
|
||
|
|
||
| class _StatelessReductionOperation(ReductionOperation): | ||
| def __hash__(self) -> int: | ||
| return hash(type(self)) | ||
|
|
||
| def __eq__(self, other: Any) -> bool: | ||
| return type(self) is type(other) | ||
|
|
||
|
|
||
| class SumReductionOperation(_StatelessReductionOperation): | ||
| def neutral_element(self, dtype: np.dtype[Any]) -> Any: | ||
| return 0 | ||
|
|
||
|
|
||
| class ProductReductionOperation(_StatelessReductionOperation): | ||
| def neutral_element(self, dtype: np.dtype[Any]) -> Any: | ||
| return 1 | ||
|
|
||
|
|
||
| class MaxReductionOperation(_StatelessReductionOperation): | ||
| def neutral_element(self, dtype: np.dtype[Any]) -> Any: | ||
| if dtype.kind == "f": | ||
| return dtype.type(float("-inf")) | ||
| elif dtype.kind == "i": | ||
| return np.iinfo(dtype).min | ||
| else: | ||
| raise TypeError(f"unknown neutral element for max and {dtype}") | ||
|
|
||
|
|
||
| class MinReductionOperation(_StatelessReductionOperation): | ||
| def neutral_element(self, dtype: np.dtype[Any]) -> Any: | ||
| if dtype.kind == "f": | ||
| return dtype.type(float("inf")) | ||
| elif dtype.kind == "i": | ||
| return np.iinfo(dtype).max | ||
| else: | ||
| raise TypeError(f"unknown neutral element for min and {dtype}") | ||
|
|
||
|
|
||
| class AllReductionOperation(_StatelessReductionOperation): | ||
| def neutral_element(self, dtype: np.dtype[Any]) -> Any: | ||
| return np.bool_(True) | ||
|
|
||
|
|
||
| class AnyReductionOperation(_StatelessReductionOperation): | ||
| def neutral_element(self, dtype: np.dtype[Any]) -> Any: | ||
| return np.bool_(False) | ||
|
|
||
| # }}} | ||
|
|
||
|
|
||
| # {{{ reductions | ||
|
|
||
| def _normalize_reduction_axes( | ||
|
|
@@ -124,8 +214,9 @@ def _get_reduction_indices_bounds(shape: ShapeType, | |
| return indices, pmap(redn_bounds) # type: ignore | ||
|
|
||
|
|
||
| def _make_reduction_lambda(op: str, a: Array, | ||
| axis: Optional[Union[int, Tuple[int]]] = None) -> Array: | ||
| def _make_reduction_lambda(op: ReductionOperation, a: Array, | ||
| axis: Optional[Union[int, Tuple[int]]], | ||
| initial: Any) -> Array: | ||
| """ | ||
| Return a :class:`IndexLambda` that performs reduction over the *axis* axes | ||
| of *a* with the reduction op *op*. | ||
|
|
@@ -137,9 +228,28 @@ def _make_reduction_lambda(op: str, a: Array, | |
| :arg axis: The axes over which the reduction is to be performed. If axis is | ||
| *None*, perform reduction over all of *a*'s axes. | ||
| """ | ||
| new_shape, axes = _normalize_reduction_axes(a.shape, axis) | ||
| new_shape, reduction_axes = _normalize_reduction_axes(a.shape, axis) | ||
| del axis | ||
| indices, redn_bounds = _get_reduction_indices_bounds(a.shape, axes) | ||
| indices, redn_bounds = _get_reduction_indices_bounds(a.shape, reduction_axes) | ||
|
|
||
| if initial is _NoValue: | ||
| for iax in reduction_axes: | ||
| shape_iax = a.shape[iax] | ||
|
|
||
| from pytato.utils import are_shape_components_equal | ||
| if are_shape_components_equal(shape_iax, 0): | ||
| raise ValueError( | ||
| "zero-size reduction operation with no supplied " | ||
| "'initial' value") | ||
|
|
||
| if isinstance(iax, Array): | ||
| raise NotImplementedError( | ||
| "cannot statically determine emptiness of " | ||
| f"reduction axis {iax} (0-based)") | ||
|
|
||
| elif initial != op.neutral_element(a.dtype): | ||
| raise NotImplementedError("reduction with 'initial' not equal to the " | ||
| "neutral element") | ||
|
|
||
| return make_index_lambda( | ||
| Reduce( | ||
|
|
@@ -151,52 +261,74 @@ def _make_reduction_lambda(op: str, a: Array, | |
| a.dtype) | ||
|
|
||
|
|
||
| def sum(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: | ||
| def sum(a: Array, axis: Optional[Union[int, Tuple[int]]] = None, | ||
| initial: Any = 0) -> Array: | ||
| """ | ||
| Sums array *a*'s elements along the *axis* axes. | ||
|
|
||
| :arg a: The :class:`pytato.Array` on which to perform the reduction. | ||
|
|
||
| :arg axis: The axes along which the elements are to be sum-reduced. | ||
| Defaults to all axes of the input array. | ||
| :arg initial: The value returned for an empty array, if supplied. | ||
| This value also serves as the base value onto which any additional | ||
| array entries are accumulated. | ||
| """ | ||
| return _make_reduction_lambda("sum", a, axis) | ||
| return _make_reduction_lambda(SumReductionOperation(), a, axis, initial) | ||
|
|
||
|
|
||
| def amax(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: | ||
| def amax(a: Array, axis: Optional[Union[int, Tuple[int]]] = None, *, | ||
| initial: Any = _NoValue) -> Array: | ||
| """ | ||
| Returns the max of array *a*'s elements along the *axis* axes. | ||
|
|
||
| :arg a: The :class:`pytato.Array` on which to perform the reduction. | ||
|
|
||
| :arg axis: The axes along which the elements are to be max-reduced. | ||
| Defaults to all axes of the input array. | ||
| :arg initial: The value returned for an empty array, if supplied. | ||
| This value also serves as the base value onto which any additional | ||
| array entries are accumulated. | ||
| If not supplied, an :exc:`ValueError` will be raised | ||
| if the reduction is empty. | ||
| In that case, the reduction size must not be symbolic. | ||
| """ | ||
| return _make_reduction_lambda("max", a, axis) | ||
| return _make_reduction_lambda(MaxReductionOperation(), a, axis, initial) | ||
|
|
||
|
|
||
| def amin(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: | ||
| def amin(a: Array, axis: Optional[Union[int, Tuple[int]]] = None, | ||
| initial: Any = _NoValue) -> Array: | ||
| """ | ||
| Returns the min of array *a*'s elements along the *axis* axes. | ||
|
|
||
| :arg a: The :class:`pytato.Array` on which to perform the reduction. | ||
|
|
||
| :arg axis: The axes along which the elements are to be min-reduced. | ||
| Defaults to all axes of the input array. | ||
| :arg initial: The value returned for an empty array, if supplied. | ||
| This value also serves as the base value onto which any additional | ||
| array entries are accumulated. | ||
| If not supplied, an :exc:`ValueError` will be raised | ||
| if the reduction is empty. | ||
| In that case, the reduction size must not be symbolic. | ||
| """ | ||
| return _make_reduction_lambda("min", a, axis) | ||
| return _make_reduction_lambda(MinReductionOperation(), a, axis, initial) | ||
|
|
||
|
|
||
| def prod(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: | ||
| def prod(a: Array, axis: Optional[Union[int, Tuple[int]]] = None, | ||
| initial: Any = 1) -> Array: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be
Owner
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is OK. Numpy will return 1 for an empty |
||
| """ | ||
| Returns the product of array *a*'s elements along the *axis* axes. | ||
|
|
||
| :arg a: The :class:`pytato.Array` on which to perform the reduction. | ||
|
|
||
| :arg axis: The axes along which the elements are to be product-reduced. | ||
| Defaults to all axes of the input array. | ||
| :arg initial: The value returned for an empty array, if supplied. | ||
| This value also serves as the base value onto which any additional | ||
| array entries are accumulated. | ||
| """ | ||
| return _make_reduction_lambda("product", a, axis) | ||
| return _make_reduction_lambda(ProductReductionOperation(), a, axis, initial) | ||
|
|
||
|
|
||
| def all(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: | ||
|
|
@@ -208,7 +340,7 @@ def all(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: | |
| :arg axis: The axes along which the elements are to be product-reduced. | ||
| Defaults to all axes of the input array. | ||
| """ | ||
| return _make_reduction_lambda("all", a, axis) | ||
| return _make_reduction_lambda(AllReductionOperation(), a, axis, initial=True) | ||
|
|
||
|
|
||
| def any(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: | ||
|
|
@@ -220,7 +352,7 @@ def any(a: Array, axis: Optional[Union[int, Tuple[int]]] = None) -> Array: | |
| :arg axis: The axes along which the elements are to be product-reduced. | ||
| Defaults to all axes of the input array. | ||
| """ | ||
| return _make_reduction_lambda("any", a, axis) | ||
| return _make_reduction_lambda(AnyReductionOperation(), a, axis, initial=False) | ||
|
|
||
| # }}} | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.