66.. currentmodule:: arraycontext
77
88.. autofunction:: with_container_arithmetic
9+
10+ .. autoclass:: BcastUntilActxArray
911"""
1012
1113
3436"""
3537
3638import enum
39+ import operator
3740from collections .abc import Callable
41+ from dataclasses import dataclass , field
42+ from functools import partialmethod
43+ from numbers import Number
3844from typing import Any , TypeVar
3945from warnings import warn
4046
4147import numpy as np
4248
49+ from arraycontext .container import (
50+ NotAnArrayContainerError ,
51+ deserialize_container ,
52+ serialize_container ,
53+ )
54+ from arraycontext .context import ArrayContext , ArrayOrContainer
55+
4356
4457# {{{ with_container_arithmetic
4558
@@ -142,8 +155,9 @@ def __instancecheck__(cls, instance: Any) -> bool:
142155 warn (
143156 "Broadcasting container against non-object numpy array. "
144157 "This was never documented to work and will now stop working in "
145- "2025. Convert the array to an object array to preserve the "
146- "current semantics." , DeprecationWarning , stacklevel = 3 )
158+ "2025. Convert the array to an object array or use "
159+ "arraycontext.BcastUntilActxArray (or similar) to obtain the desired "
160+ "broadcasting semantics." , DeprecationWarning , stacklevel = 3 )
147161 return True
148162 else :
149163 return False
@@ -207,6 +221,14 @@ class has an ``array_context`` attribute. If so, and if :data:`__debug__`
207221
208222 Each operator class also includes the "reverse" operators if applicable.
209223
224+ .. note::
225+
226+ For the generated binary arithmetic operators, if certain types
227+ should be broadcast over the container (with the container as the
228+ 'outer' structure) but are not handled in this way by their types,
229+ you may wrap them in :class:`BcastUntilActxArray` to achieve
230+ the desired semantics.
231+
210232 .. note::
211233
212234 To generate the code implementing the operators, this function relies on
@@ -402,8 +424,9 @@ def wrap(cls: Any) -> Any:
402424 warn (
403425 f"Broadcasting array context array types across { cls } "
404426 "has been explicitly "
405- "enabled. As of 2025, this will stop working. "
406- "There is no replacement as of right now. "
427+ "enabled. As of 2026, this will stop working. "
428+ "Use arraycontext.Bcast* object wrappers for "
429+ "roughly equivalent functionality. "
407430 "See the discussion in "
408431 "https://github.com/inducer/arraycontext/pull/190. "
409432 "To opt out now (and avoid this warning), "
@@ -413,8 +436,9 @@ def wrap(cls: Any) -> Any:
413436 warn (
414437 f"Broadcasting array context array types across { cls } "
415438 "has been implicitly "
416- "enabled. As of 2025, this will no longer work. "
417- "There is no replacement as of right now. "
439+ "enabled. As of 2026, this will no longer work. "
440+ "Use arraycontext.Bcast* object wrappers for "
441+ "roughly equivalent functionality. "
418442 "See the discussion in "
419443 "https://github.com/inducer/arraycontext/pull/190. "
420444 "To opt out now (and avoid this warning), "
@@ -603,8 +627,9 @@ def {fname}(arg1):
603627 if isinstance(arg2, { tup_str (bcast_actx_ary_types )} ):
604628 warn("Broadcasting { cls } over array "
605629 f"context array type {{type(arg2)}} is deprecated "
606- "and will no longer work in 2025. "
607- "There is no replacement as of right now. "
630+ "and will no longer work in 2026. "
631+ "Use arraycontext.Bcast* object wrappers for "
632+ "roughly equivalent functionality. "
608633 "See the discussion in "
609634 "https://github.com/inducer/arraycontext/"
610635 "pull/190. ",
@@ -654,8 +679,10 @@ def {fname}(arg2, arg1):
654679 warn("Broadcasting { cls } over array "
655680 f"context array type {{type(arg1)}} "
656681 "is deprecated "
657- "and will no longer work in 2025."
658- "There is no replacement as of right now. "
682+ "and will no longer work in 2026."
683+ "Use arraycontext.Bcast* object "
684+ "wrappers for roughly equivalent "
685+ "functionality. "
659686 "See the discussion in "
660687 "https://github.com/inducer/arraycontext/"
661688 "pull/190. ",
@@ -687,4 +714,111 @@ def {fname}(arg2, arg1):
687714# }}}
688715
689716
717+ # {{{ Bcast object-ified broadcast rules
718+
719+ # Possible advantages of the "Bcast" broadcast-rule-as-object design:
720+ #
721+ # - If one rule does not fit the user's need, they can straightforwardly use
722+ # another.
723+ #
724+ # - It's straightforward to find where certain broadcast rules are used.
725+ #
726+ # - The broadcast rule can contain more state. For example, it's now easy
727+ # for the rule to know what array context should be used to determine
728+ # actx array types.
729+ #
730+ # Possible downsides of the "Bcast" broadcast-rule-as-object design:
731+ #
732+ # - User code is a bit more wordy.
733+
734+ @dataclass (frozen = True )
735+ class BcastUntilActxArray :
736+ """
737+ An operator-overloading wrapper around an object (*broadcastee*) that should be
738+ broadcast across array containers until the 'opposite' operand is one of the
739+ :attr:`~arraycontext.ArrayContext.array_types`
740+ of *actx* or a :class:`~numbers.Number`.
741+
742+ Suggested usage pattern::
743+
744+ bcast = functools.partial(BcastUntilActxArray, actx)
745+
746+ container + bcast(actx_array)
747+
748+ .. automethod:: __init__
749+ """
750+
751+ array_context : ArrayContext
752+ broadcastee : ArrayOrContainer
753+
754+ _stop_types : tuple [type , ...] = field (init = False )
755+
756+ def __post_init__ (self ) -> None :
757+ object .__setattr__ (
758+ self , "_stop_types" , (* self .array_context .array_types , Number ))
759+
760+ def _binary_op (self ,
761+ op : Callable [
762+ [ArrayOrContainer , ArrayOrContainer ],
763+ ArrayOrContainer
764+ ],
765+ right : ArrayOrContainer
766+ ) -> ArrayOrContainer :
767+ try :
768+ serialized = serialize_container (right )
769+ except NotAnArrayContainerError :
770+ return op (self .broadcastee , right )
771+
772+ return deserialize_container (right , [
773+ (k , op (self .broadcastee , right_v )
774+ if isinstance (right_v , self ._stop_types ) else
775+ self ._binary_op (op , right_v )
776+ )
777+ for k , right_v in serialized ])
778+
779+ def _rev_binary_op (self ,
780+ op : Callable [
781+ [ArrayOrContainer , ArrayOrContainer ],
782+ ArrayOrContainer
783+ ],
784+ left : ArrayOrContainer
785+ ) -> ArrayOrContainer :
786+ try :
787+ serialized = serialize_container (left )
788+ except NotAnArrayContainerError :
789+ return op (left , self .broadcastee )
790+
791+ return deserialize_container (left , [
792+ (k , op (left_v , self .broadcastee )
793+ if isinstance (left_v , self ._stop_types ) else
794+ self ._rev_binary_op (op , left_v )
795+ )
796+ for k , left_v in serialized ])
797+
798+ __add__ = partialmethod (_binary_op , operator .add )
799+ __radd__ = partialmethod (_rev_binary_op , operator .add )
800+ __sub__ = partialmethod (_binary_op , operator .sub )
801+ __rsub__ = partialmethod (_rev_binary_op , operator .sub )
802+ __mul__ = partialmethod (_binary_op , operator .mul )
803+ __rmul__ = partialmethod (_rev_binary_op , operator .mul )
804+ __truediv__ = partialmethod (_binary_op , operator .truediv )
805+ __rtruediv__ = partialmethod (_rev_binary_op , operator .truediv )
806+ __floordiv__ = partialmethod (_binary_op , operator .floordiv )
807+ __rfloordiv__ = partialmethod (_rev_binary_op , operator .floordiv )
808+ __mod__ = partialmethod (_binary_op , operator .mod )
809+ __rmod__ = partialmethod (_rev_binary_op , operator .mod )
810+ __pow__ = partialmethod (_binary_op , operator .pow )
811+ __rpow__ = partialmethod (_rev_binary_op , operator .pow )
812+
813+ __lshift__ = partialmethod (_binary_op , operator .lshift )
814+ __rlshift__ = partialmethod (_rev_binary_op , operator .lshift )
815+ __rshift__ = partialmethod (_binary_op , operator .rshift )
816+ __rrshift__ = partialmethod (_rev_binary_op , operator .rshift )
817+ __and__ = partialmethod (_binary_op , operator .and_ )
818+ __rand__ = partialmethod (_rev_binary_op , operator .and_ )
819+ __or__ = partialmethod (_binary_op , operator .or_ )
820+ __ror__ = partialmethod (_rev_binary_op , operator .or_ )
821+
822+ # }}}
823+
690824# vim: foldmethod=marker
0 commit comments