diff --git a/jax_galsim/bounds.py b/jax_galsim/bounds.py index 69ab8746..c8ba4954 100644 --- a/jax_galsim/bounds.py +++ b/jax_galsim/bounds.py @@ -1,12 +1,12 @@ +import equinox import galsim as _galsim import jax import jax.numpy as jnp -import numpy as np from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import ( + STATIC_SCALAR_TYPES, cast_to_float, - cast_to_int, check_is_int_then_cast, ensure_hashable, implements, @@ -14,19 +14,15 @@ from jax_galsim.position import Position, PositionD, PositionI BOUNDS_LAX_DESCR = """\ -The JAX implementation - -- will not always test whether the bounds are valid - -Further, the JAX implementation adds a new method, ``isStatic`` to the -``BoundsI`` class. If JAX-GalSim detects that the ``BoundsI`` instance +The JAX implementation adds a new method, ``isStatic`` to the +``Bounds`` class. If JAX-GalSim detects that a ``BoundsI`` instance has been instantiated with static, known values, ``isStatic()`` will -return ``True``. +return ``True``, otherwise it is ``False``. For ``BoundsD``, ``isStatic()`` +always returns ``False``. -``BoundsI`` objects in JAX-Galsim support an additional initialization -call ``BoundsI(xmin=..., deltax=..., ymin=..., deltay=...)``. In this case, -the values for ``deltax/y`` indicate the width of the bounds and must be -static constants. +``BoundsI`` objects in JAX-Galsim must have a fixed width. To help support +this requirement, JAX-Galsim supports an additional initialization call +``BoundsI(xmin=..., deltax=..., ymin=..., deltay=...)``. When calling ``jax.vmap`` over ``BoundsI`` objects, only ``x/ymin`` are vectorized over. This restriction allows for code that renders @@ -46,11 +42,13 @@ def __init__(self): ) def _parse_args(self, *args, **kwargs): + do_isdefined = True + if len(kwargs) == 0: if len(args) == 4: - self._isdefined = True self.xmin, self.xmax, self.ymin, self.ymax = args elif len(args) == 0: + do_isdefined = False self._isdefined = False self.xmin = 0 self.ymin = 0 @@ -70,7 +68,6 @@ def _parse_args(self, *args, **kwargs): self.ymin = args[0].ymin self.deltay = args[0].deltay + offset elif isinstance(args[0], Position): - self._isdefined = True self.xmin = self.xmax = args[0].x self.ymin = self.ymax = args[0].y else: @@ -78,10 +75,8 @@ def _parse_args(self, *args, **kwargs): "Single argument to %s must be either a Bounds or a Position" % (self.__class__.__name__) ) - self._isdefined = True elif len(args) == 2: if isinstance(args[0], Position) and isinstance(args[1], Position): - self._isdefined = True self.xmin = min(args[0].x, args[1].x) self.xmax = max(args[0].x, args[1].x) self.ymin = min(args[0].y, args[1].y) @@ -103,7 +98,6 @@ def _parse_args(self, *args, **kwargs): ) else: try: - self._isdefined = True self.xmin = kwargs.pop("xmin") self.ymin = kwargs.pop("ymin") except KeyError: @@ -128,17 +122,7 @@ def _parse_args(self, *args, **kwargs): if kwargs: raise TypeError("Got unexpected keyword arguments %s" % kwargs.keys()) - # for simple inputs, we can check if the bounds are valid - if isinstance(self, BoundsD): - max_delta = 0 - else: - max_delta = 1 - if ( - isinstance(self.deltax, (int, float, np.integer, np.floating)) - and isinstance(self.deltay, (int, float, np.integer, np.floating)) - and (self.deltax < max_delta or self.deltay < max_delta) - ): - self._isdefined = False + return do_isdefined @implements(_galsim.Bounds.area) def area(self): @@ -166,58 +150,40 @@ def origin(self): @property @implements(_galsim.Bounds.center) def center(self): - if not self.isDefined(): - raise _galsim.GalSimUndefinedBoundsError( - "center is invalid for an undefined Bounds" + if not isinstance(self._isdefined, jnp.ndarray): + if not self.isDefined(): + raise _galsim.GalSimUndefinedBoundsError( + "center is invalid for an undefined Bounds" + ) + else: + self._isdefined = equinox.error_if( + self._isdefined, + jnp.any(~self._isdefined), + "center is invalid for an undefined Bounds", ) return self._center @property @implements(_galsim.Bounds.true_center) def true_center(self): - if not self.isDefined(): - raise _galsim.GalSimUndefinedBoundsError( - "true_center is invalid for an undefined Bounds" + if not isinstance(self._isdefined, jnp.ndarray): + if not self.isDefined(): + raise _galsim.GalSimUndefinedBoundsError( + "true_center is invalid for an undefined Bounds" + ) + else: + self._isdefined = equinox.error_if( + self._isdefined, + jnp.any(~self._isdefined), + "true_center is invalid for an undefined Bounds", ) return PositionD((self.xmax + self.xmin) / 2.0, (self.ymax + self.ymin) / 2.0) @implements(_galsim.Bounds.includes) def includes(self, *args): - if len(args) == 1: - if isinstance(args[0], Bounds): - b = args[0] - return ( - self.isDefined() - & b.isDefined() - & (self.xmin <= b.xmin) - & (self.xmax >= b.xmax) - & (self.ymin <= b.ymin) - & (self.ymax >= b.ymax) - ) - elif isinstance(args[0], Position): - p = args[0] - return ( - self.isDefined() - & (self.xmin <= p.x) - & (self.ymin <= p.y) - & (p.x <= self.xmax) - & (p.y <= self.ymax) - ) - else: - raise TypeError("Invalid argument %s" % args[0]) - elif len(args) == 2: - x, y = args - return ( - self.isDefined() - & (self.xmin <= float(x)) - & (self.ymin <= float(y)) - & (float(x) <= self.xmax) - & (float(y) <= self.ymax) - ) - elif len(args) == 0: - raise TypeError("include takes at least 1 argument (0 given)") - else: - raise TypeError("include takes at most 2 arguments (%d given)" % len(args)) + raise NotImplementedError( + "Subclasses of `Bounds` must implement the `includes` method!" + ) @implements(_galsim.Bounds.expand) def expand(self, factor_x, factor_y=None): @@ -262,61 +228,24 @@ def shift(self, delta): ) def __and__(self, other): - if not isinstance(other, self.__class__): - raise TypeError("other must be a %s instance" % self.__class__.__name__) - if not self.isDefined() or not other.isDefined(): - return self.__class__() - else: - xmin = jnp.maximum(self.xmin, other.xmin) - xmax = jnp.minimum(self.xmax, other.xmax) - ymin = jnp.maximum(self.ymin, other.ymin) - ymax = jnp.minimum(self.ymax, other.ymax) - if xmin > xmax or ymin > ymax: - return self.__class__() - else: - return self.__class__(xmin, xmax, ymin, ymax) + raise NotImplementedError( + "Subclasses of `Bounds` must implement the `__and__` method!" + ) def __add__(self, other): - if isinstance(other, self.__class__): - if not other.isDefined(): - return self - elif self.isDefined(): - xmin = jnp.minimum(self.xmin, other.xmin) - xmax = jnp.maximum(self.xmax, other.xmax) - ymin = jnp.minimum(self.ymin, other.ymin) - ymax = jnp.maximum(self.ymax, other.ymax) - return self.__class__(xmin, xmax, ymin, ymax) - else: - return other - elif isinstance(other, self._pos_class): - if self.isDefined(): - xmin = jnp.minimum(self.xmin, other.x) - xmax = jnp.maximum(self.xmax, other.x) - ymin = jnp.minimum(self.ymin, other.y) - ymax = jnp.maximum(self.ymax, other.y) - return self.__class__(xmin, xmax, ymin, ymax) - else: - return self.__class__(other) - else: - raise TypeError( - "other must be either a %s or a %s" - % (self.__class__.__name__, self._pos_class.__name__) - ) - - def _getinitargs(self): - if self.isDefined(): - return (self.xmin, self.xmax, self.ymin, self.ymax) - else: - return () + raise NotImplementedError( + "Subclasses of `Bounds` must implement the `__add__` method!" + ) def __eq__(self, other): - return self is other or ( - isinstance(other, self.__class__) - and self._getinitargs() == other._getinitargs() + raise NotImplementedError( + "Subclasses of `Bounds` must implement the `__eq__` method!" ) def __ne__(self, other): - return not self.__eq__(other) + raise NotImplementedError( + "Subclasses of `Bounds` must implement the `__ne__` method!" + ) def __hash__(self): return hash( @@ -333,10 +262,7 @@ def tree_flatten(self): """This function flattens the Bounds into a list of children nodes that will be traced by JAX and auxiliary static data.""" # Define the children nodes of the PyTree that need tracing - if self.isDefined(): - children = (self.xmin, self.deltax, self.ymin, self.deltay) - else: - children = tuple() + children = (self.xmin, self.deltax, self.ymin, self.deltay, self._isdefined) # Define auxiliary static data that doesn’t need to be traced aux_data = None return (children, aux_data) @@ -344,15 +270,16 @@ def tree_flatten(self): @classmethod def tree_unflatten(cls, aux_data, children): """Recreates an instance of the class from flatten representation""" - if children: - return cls( - xmin=children[0], - deltax=children[1], - ymin=children[2], - deltay=children[3], - ) - else: - return cls() + ret = cls.__new__(cls) + ret.xmin = children[0] + ret.deltax = children[1] + ret.ymin = children[2] + ret.deltay = children[3] + ret._isdefined = children[4] + ret._isstatic = False + ret._isstaticshape = False + + return ret @classmethod def from_galsim(cls, galsim_bounds): @@ -401,6 +328,163 @@ def isStatic(self): ``False`` for ``BoundsD``.""" return self._isstatic + def isStaticShape(self): + """Returns ``True`` if the ``BoundsI`` instance + has static, known dimensions. Always returns + ``False`` for ``BoundsD``.""" + return self._isstaticshape + + +def _bounds_and_op_static(self, other): + if not self.isDefined() or not other.isDefined(): + return self.__class__() + else: + xmin = max(self.xmin, other.xmin) + xmax = min(self.xmax, other.xmax) + ymin = max(self.ymin, other.ymin) + ymax = min(self.ymax, other.ymax) + if xmin > xmax or ymin > ymax: + return self.__class__() + else: + return self.__class__(xmin, xmax, ymin, ymax) + + +def _bounds_and_op_dynamic(self, other): + xmin = jnp.maximum(self.xmin, other.xmin) + xmax = jnp.minimum(self.xmax, other.xmax) + ymin = jnp.maximum(self.ymin, other.ymin) + ymax = jnp.minimum(self.ymax, other.ymax) + + is_defined = self.isDefined() & other.isDefined() & (ymin <= ymax) & (xmin <= xmax) + xmin = jnp.where( + is_defined, + xmin, + 0.0, + ) + xmax = jnp.where( + is_defined, + xmax, + 0.0, + ) + ymin = jnp.where( + is_defined, + ymin, + 0.0, + ) + ymax = jnp.where( + is_defined, + ymax, + 0.0, + ) + + cls = self.__class__ + ret = cls.__new__(cls) + ret.xmin = xmin + ret.deltax = xmax - xmin + ret.ymin = ymin + ret.deltay = ymax - ymin + ret._isdefined = is_defined + ret._isstatic = False + ret._isstaticshape = False + + return ret + + +def _bounds_bounds_add_op_static(self, other): + if not other.isDefined(): + return self + elif self.isDefined(): + xmin = min(self.xmin, other.xmin) + xmax = max(self.xmax, other.xmax) + ymin = min(self.ymin, other.ymin) + ymax = max(self.ymax, other.ymax) + return self.__class__(xmin, xmax, ymin, ymax) + else: + return other + + +def _bounds_bounds_add_op_dynamic(self, other, min_delta): + def _ret_correct_attr(self_isdef, self_attr, other_isdef, other_attr, op): + return jnp.where( + ~other_isdef, + self_attr, + jnp.where(self_isdef, op(self_attr, other_attr), other_attr), + ) + + xmin = _ret_correct_attr( + self._isdefined, self.xmin, other._isdefined, other.xmin, jnp.minimum + ) + xmax = _ret_correct_attr( + self._isdefined, self.xmax, other._isdefined, other.xmax, jnp.maximum + ) + ymin = _ret_correct_attr( + self._isdefined, self.ymin, other._isdefined, other.ymin, jnp.minimum + ) + ymax = _ret_correct_attr( + self._isdefined, self.ymax, other._isdefined, other.ymax, jnp.maximum + ) + + cls = self.__class__ + ret = cls.__new__(cls) + + ret.xmin = xmin + ret.deltax = xmax - xmin + min_delta + ret.ymin = ymin + ret.deltay = ymax - ymin + min_delta + ret._isdefined = jnp.where( + ~other._isdefined, + self._isdefined, + jnp.where( + self._isdefined, + (ret.deltax >= min_delta) & (ret.deltay >= min_delta), + other._isdefined, + ), + ) + ret._isstatic = False + ret._isstaticshape = False + + return ret + + +def _bounds_pos_add_op_dynamic(self, other, min_delta): + xmin = jnp.where( + self._isdefined, + jnp.minimum(self.xmin, other.x), + other.x, + ) + xmax = jnp.where( + self._isdefined, + jnp.maximum(self.xmax, other.x), + other.x, + ) + ymin = jnp.where( + self._isdefined, + jnp.minimum(self.ymin, other.y), + other.y, + ) + ymax = jnp.where( + self._isdefined, + jnp.maximum(self.ymax, other.y), + other.y, + ) + + cls = self.__class__ + ret = cls.__new__(cls) + + ret.xmin = xmin + ret.deltax = xmax - xmin + min_delta + ret.ymin = ymin + ret.deltay = ymax - ymin + min_delta + ret._isdefined = jnp.where( + self._isdefined, + (ret.deltax >= min_delta) & (ret.deltay >= min_delta), + jnp.array(True), + ) + ret._isstatic = False + ret._isstaticshape = False + + return ret + @implements(_galsim.BoundsD, lax_description=BOUNDS_LAX_DESCR) @register_pytree_node_class @@ -409,18 +493,22 @@ class BoundsD(Bounds): def __init__(self, *args, **kwargs): self._isstatic = False - self._parse_args(*args, **kwargs) + self._isstaticshape = False + do_isdefined = self._parse_args(*args, **kwargs) self.xmin = cast_to_float(self.xmin) self.deltax = cast_to_float(self.deltax) self.ymin = cast_to_float(self.ymin) self.deltay = cast_to_float(self.deltay) + if do_isdefined: + self._isdefined = (self.deltax >= 0) & (self.deltay >= 0) + self._isdefined = jnp.array(self._isdefined) def _check_scalar(self, x, name): try: if ( isinstance(x, jax.Array) and x.shape == () - and x.dtype.name in ["float32", "float64", "float"] + and jnp.issubdtype(x.dtype, jnp.floating) ): return elif x == float(x): @@ -452,8 +540,56 @@ def _area(self): def _center(self): return PositionD((self.xmax + self.xmin) / 2.0, (self.ymax + self.ymin) / 2.0) + @implements(_galsim.Bounds.includes) + def includes(self, *args): + if len(args) == 1: + if isinstance(args[0], Bounds): + b = args[0] + return ( + self.isDefined() + & b.isDefined() + & (self.xmin <= b.xmin) + & (self.xmax >= b.xmax) + & (self.ymin <= b.ymin) + & (self.ymax >= b.ymax) + ) + elif isinstance(args[0], Position): + p = args[0] + return ( + self.isDefined() + & (self.xmin <= p.x) + & (self.ymin <= p.y) + & (p.x <= self.xmax) + & (p.y <= self.ymax) + ) + else: + raise TypeError("Invalid argument %s" % args[0]) + elif len(args) == 2: + x, y = args + return ( + self.isDefined() + & (self.xmin <= cast_to_float(x)) + & (self.ymin <= cast_to_float(y)) + & (cast_to_float(x) <= self.xmax) + & (cast_to_float(y) <= self.ymax) + ) + elif len(args) == 0: + raise TypeError("include takes at least 1 argument (0 given)") + else: + raise TypeError("include takes at most 2 arguments (%d given)" % len(args)) + def __repr__(self): - if self.isDefined(): + # sometimes we will encounter a tracer here + # and so we suppress any boolean conversion errors + try: + if jnp.any(self.isDefined()): + print_full = True + else: + print_full = False + except Exception: + print_full = True + + if print_full: return "galsim.%s(%r, %r, %r, %r)" % ( self.__class__.__name__, ensure_hashable(self.xmin), @@ -465,7 +601,17 @@ def __repr__(self): return "galsim.%s()" % (self.__class__.__name__) def __str__(self): - if self.isDefined(): + # sometimes we will encounter a tracer here + # and so we suppress any boolean conversion errors + try: + if jnp.any(self.isDefined()): + print_full = True + else: + print_full = False + except Exception: + print_full = True + + if print_full: return "galsim.%s(%s,%s,%s,%s)" % ( self.__class__.__name__, ensure_hashable(self.xmin), @@ -487,6 +633,45 @@ def __hash__(self): ) ) + def _getinitargs(self): + # defined only for galsim test suite + return (self.xmin, self.xmax, self.ymin, self.ymax) + + def __eq__(self, other): + if self is other: + return jnp.array(True) + elif isinstance(other, self.__class__): + return ( + self.isDefined() + & other.isDefined() + & (self.xmin == other.xmin) + & (self.ymin == other.ymin) + & (self.xmax == other.xmax) + & (self.ymax == other.ymax) + ) | ((~self.isDefined()) & (~other.isDefined())) + else: + return jnp.array(False) + + def __ne__(self, other): + return ~self.__eq__(other) + + def __and__(self, other): + if not isinstance(other, self.__class__): + raise TypeError("other must be a %s instance" % self.__class__.__name__) + + return _bounds_and_op_dynamic(self, other) + + def __add__(self, other): + if isinstance(other, self.__class__): + return _bounds_bounds_add_op_dynamic(self, other, 0) + elif isinstance(other, self._pos_class): + return _bounds_pos_add_op_dynamic(self, other, 0) + else: + raise TypeError( + "other must be either a %s or a %s" + % (self.__class__.__name__, self._pos_class.__name__) + ) + @implements(_galsim.BoundsI, lax_description=BOUNDS_LAX_DESCR) @register_pytree_node_class @@ -494,47 +679,48 @@ class BoundsI(Bounds): _pos_class = PositionI def __init__(self, *args, **kwargs): - # initial setting to let stuff pass through freely - self._isstatic = True - self._parse_args(*args, **kwargs) - self.deltax = cast_to_float(self.deltax) - self.deltay = cast_to_float(self.deltay) - if (self.deltax != int(self.deltax)) or (self.deltay != int(self.deltay)): - raise TypeError("BoundsI must be initialized with integer values") - self.deltax = cast_to_int(self.deltax) - self.deltay = cast_to_int(self.deltay) - - if not ( - isinstance( - self._xmin, - (int, float, np.floating, np.integer), - ) - and isinstance( - self._ymin, - (int, float, np.floating, np.integer), - ) - ): - self._isstatic = False - # validate inputs are ints - self._xmin = check_is_int_then_cast( - self._xmin, "BoundsI must be initialized with integer values" + self.deltax = check_is_int_then_cast( + self.deltax, "BoundsI must be initialized with integer values" + ) + self.deltay = check_is_int_then_cast( + self.deltay, "BoundsI must be initialized with integer values" + ) + self.xmin = check_is_int_then_cast( + self.xmin, "BoundsI must be initialized with integer values" ) - self._ymin = check_is_int_then_cast( - self._ymin, "BoundsI must be initialized with integer values" + self.ymin = check_is_int_then_cast( + self.ymin, "BoundsI must be initialized with integer values" ) - if self.deltax < 1 and self.deltay < 1: - self._isdefined = False + if isinstance(self.deltax, int) and isinstance(self.deltay, int): + self._isstaticshape = True + else: + self._isstaticshape = False + + if ( + isinstance(self.xmin, int) + and isinstance(self.ymin, int) + and isinstance(self.deltax, int) + and isinstance(self.deltay, int) + ): + self._isstatic = True + else: + self._isstatic = False + + if self.isStaticShape(): + self._isdefined = self.deltax >= 1 and self.deltay >= 1 + else: + self._isdefined = (self.deltax >= 1) & (self.deltay >= 1) def _check_scalar(self, x, name): try: if ( isinstance(x, jax.Array) and x.shape == () - and x.dtype.name in ["int32", "int64", "int"] + and jnp.issubdtype(x.dtype, jnp.integer) ): return elif x == int(x): @@ -545,24 +731,17 @@ def _check_scalar(self, x, name): def numpyShape(self): "A simple utility function to get the numpy shape that corresponds to this `Bounds` object." - if self.isDefined(): - return self.deltay, self.deltax - else: - return 0, 0 - - @property - def xmin(self): - if self._isstatic: - return self._xmin - else: - return jnp.astype(self._xmin, jnp.int_) - - @xmin.setter - def xmin(self, value): - if self._isstatic: - self._xmin = value + if self._isstaticshape: + if self._isdefined: + return self.deltay, self.deltax + else: + return 0, 0 else: - self._xmin = jnp.astype(value, jnp.float_) + return jax.lax.cond( + self._isdefined, + lambda: (self.deltay, self.deltax), + lambda: (jnp.zeros_like(self.deltay), jnp.zeros_like(self.deltax)), + ) @property def xmax(self): @@ -572,20 +751,6 @@ def xmax(self): def xmax(self, value): self.deltax = value - self.xmin + 1 - @property - def ymin(self): - if self._isstatic: - return self._ymin - else: - return jnp.astype(self._ymin, jnp.int_) - - @ymin.setter - def ymin(self, value): - if self._isstatic: - self._ymin = value - else: - self._ymin = jnp.astype(value, jnp.float_) - @property def ymax(self): return self.ymin + self.deltay - 1 @@ -596,10 +761,17 @@ def ymax(self, value): def _area(self): # Remember the + 1 this time to include the pixels on both edges of the bounds. - if not self.isDefined(): - return 0 + if self._isstaticshape: + if self._isdefined: + return self.deltax * self.deltay + else: + return 0 else: - return self.deltax * self.deltay + return jax.lax.cond( + self._isdefined, + lambda: self.deltax * self.deltay, + lambda: 0.0, + ) @property def _center(self): @@ -613,58 +785,91 @@ def _center(self): self.ymin + self.deltay // 2, ) - def tree_flatten(self): - """This function flattens the Bounds into a list of children - nodes that will be traced by JAX and auxiliary static data.""" - # Define the children nodes of the PyTree that need tracing - if self.isDefined(): - if self._isstatic: - # Define the children nodes of the PyTree that need tracing - children = tuple() - - # Define auxiliary static data that doesn’t need to be traced - aux_data = { - "xmin": self._xmin, - "ymin": self._ymin, - "deltax": self.deltax, - "deltay": self.deltay, - } - else: - children = (self._xmin, self._ymin) - # Define auxiliary static data that doesn’t need to be traced - aux_data = {"deltax": self.deltax, "deltay": self.deltay} - else: - children = tuple() - aux_data = None - - return (children, aux_data) - - @classmethod - def tree_unflatten(cls, aux_data, children): - """Recreates an instance of the class from flatten representation""" - if aux_data is not None: - ret = cls.__new__(cls) - if "xmin" in aux_data and "ymin" in aux_data: - ret._isstatic = True - ret._xmin = aux_data["xmin"] - ret._ymin = aux_data["ymin"] + @implements(_galsim.Bounds.includes) + def includes(self, *args): + if len(args) == 1: + if isinstance(args[0], Bounds): + b = args[0] + if self.isStatic() and b.isStatic(): + return ( + self.isDefined() + and b.isDefined() + and (self.xmin <= b.xmin) + and (self.xmax >= b.xmax) + and (self.ymin <= b.ymin) + and (self.ymax >= b.ymax) + ) + else: + return ( + jnp.array(self.isDefined()) + & jnp.array(b.isDefined()) + & jnp.array(self.xmin <= b.xmin) + & jnp.array(self.xmax >= b.xmax) + & jnp.array(self.ymin <= b.ymin) + & jnp.array(self.ymax >= b.ymax) + ) + elif isinstance(args[0], Position): + p = args[0] + ok_types = STATIC_SCALAR_TYPES + if ( + self._isstatic + and isinstance(p.x, ok_types) + and isinstance(p.y, ok_types) + ): + return ( + self.isDefined() + and (self.xmin <= p.x) + and (self.ymin <= p.y) + and (p.x <= self.xmax) + and (p.y <= self.ymax) + ) + else: + return ( + jnp.array(self.isDefined()) + & jnp.array(self.xmin <= p.x) + & jnp.array(self.ymin <= p.y) + & jnp.array(p.x <= self.xmax) + & jnp.array(p.y <= self.ymax) + ) else: - ret._isstatic = False - ret._xmin = children[0] - ret._ymin = children[1] - ret.deltax = aux_data["deltax"] - ret.deltay = aux_data["deltay"] - if ret.deltax < 1 and ret.deltay < 1: - ret._isdefined = False + raise TypeError("Invalid argument %s" % args[0]) + elif len(args) == 2: + x, y = args + x = cast_to_float(x) + y = cast_to_float(y) + if self._isstatic and isinstance(x, float) and isinstance(y, float): + return ( + self.isDefined() + and (self.xmin <= x) + and (self.ymin <= y) + and (x <= self.xmax) + and (y <= self.ymax) + ) else: - ret._isdefined = True + return ( + jnp.array(self.isDefined()) + & jnp.array(self.xmin <= x) + & jnp.array(self.ymin <= y) + & jnp.array(x <= self.xmax) + & jnp.array(y <= self.ymax) + ) + elif len(args) == 0: + raise TypeError("include takes at least 1 argument (0 given)") else: - ret = cls() - - return ret + raise TypeError("include takes at most 2 arguments (%d given)" % len(args)) def __repr__(self): - if self.isDefined(): + # sometimes we will encounter a tracer here + # and so we suppress any boolean conversion errors + try: + if jnp.any(self.isDefined()): + print_full = True + else: + print_full = False + except Exception: + print_full = True + + if print_full: return "galsim.%s(xmin=%r, deltax=%r, ymin=%r, deltay=%r)" % ( self.__class__.__name__, ensure_hashable(self.xmin), @@ -676,7 +881,17 @@ def __repr__(self): return "galsim.%s()" % (self.__class__.__name__) def __str__(self): - if self.isDefined(): + # sometimes we will encounter a tracer here + # and so we suppress any boolean conversion errors + try: + if jnp.any(self.isDefined()): + print_full = True + else: + print_full = False + except Exception: + print_full = True + + if print_full: return "galsim.%s(xmin=%s, deltax=%s, ymin=%s, deltay=%s)" % ( self.__class__.__name__, ensure_hashable(self.xmin), @@ -687,17 +902,6 @@ def __str__(self): else: return "galsim.%s()" % (self.__class__.__name__) - def _getinitargs(self): - if self.isDefined(): - return (self.xmin, self.deltax, self.ymin, self.deltay) - else: - return () - - def __eq__(self, other): - return self is other or ( - isinstance(other, BoundsI) and self._getinitargs() == other._getinitargs() - ) - def __hash__(self): return hash( ( @@ -708,3 +912,119 @@ def __hash__(self): ensure_hashable(self.deltay), ) ) + + def __eq__(self, other): + if self is other: + if self.isStatic() and other.isStatic(): + return True + else: + return jnp.array(True) + elif isinstance(other, self.__class__): + if self.isStatic() and other.isStatic(): + min_eq = (self.xmin == other.xmin) and (self.ymin == other.ymin) + self_isdef = self.isDefined() + other_isdef = other.isDefined() + shape_eq = (self.deltax == other.deltax) and ( + self.deltay == other.deltay + ) + return (self_isdef and other_isdef and shape_eq and min_eq) or ( + (not self_isdef) and (not other_isdef) + ) + else: + min_eq = jnp.array(self.xmin == other.xmin) & jnp.array( + self.ymin == other.ymin + ) + self_isdef = jnp.array(self.isDefined()) + other_isdef = jnp.array(other.isDefined()) + shape_eq = jnp.array(self.deltax == other.deltax) & jnp.array( + self.deltay == other.deltay + ) + return (self_isdef & other_isdef & shape_eq & min_eq) | ( + (~self_isdef) & (~other_isdef) + ) + else: + return False + + def __ne__(self, other): + if not isinstance(other, self.__class__): + return True + + if self.isStatic() and other.isStatic(): + return not self.__eq__(other) + else: + return ~self.__eq__(other) + + def __and__(self, other): + if not isinstance(other, self.__class__): + raise TypeError("other must be a %s instance" % self.__class__.__name__) + + if self.isStatic() and other.isStatic(): + return _bounds_and_op_static(self, other) + else: + return _bounds_and_op_dynamic(self, other) + + def __add__(self, other): + if isinstance(other, self.__class__): + if self.isStatic() and other.isStatic(): + return _bounds_bounds_add_op_static(self, other) + else: + return _bounds_bounds_add_op_dynamic(self, other, 1) + elif isinstance(other, self._pos_class): + return _bounds_pos_add_op_dynamic(self, other, 1) + else: + raise TypeError( + "other must be either a %s or a %s" + % (self.__class__.__name__, self._pos_class.__name__) + ) + + def tree_flatten(self): + """This function flattens the Bounds into a list of children + nodes that will be traced by JAX and auxiliary static data.""" + # Define the children nodes of the PyTree that need tracing + aux_data = {"isstatic": self._isstatic, "isstaticshape": self._isstaticshape} + + if self._isstatic: + aux_data["xmin"] = self.xmin + aux_data["ymin"] = self.ymin + + if self._isstaticshape: + aux_data["deltax"] = self.deltax + aux_data["deltay"] = self.deltay + aux_data["isdefined"] = self._isdefined + + if self._isstatic: + children = tuple() + elif self._isstaticshape: + children = (self.xmin, self.ymin) + else: + children = (self.xmin, self.deltax, self.ymin, self.deltay, self._isdefined) + + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Recreates an instance of the class from flatten representation""" + ret = cls.__new__(cls) + ret._isstatic = aux_data["isstatic"] + ret._isstaticshape = aux_data["isstaticshape"] + + if ret._isstatic: + ret.xmin = aux_data["xmin"] + ret.ymin = aux_data["ymin"] + ret.deltax = aux_data["deltax"] + ret.deltay = aux_data["deltay"] + ret._isdefined = aux_data["isdefined"] + elif ret._isstaticshape: + ret.xmin = children[0] + ret.ymin = children[1] + ret.deltax = aux_data["deltax"] + ret.deltay = aux_data["deltay"] + ret._isdefined = aux_data["isdefined"] + else: + ret.xmin = children[0] + ret.deltax = children[1] + ret.ymin = children[2] + ret.deltay = children[3] + ret._isdefined = children[4] + + return ret diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index 839d0658..9a06975c 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -8,12 +8,14 @@ import jax.numpy as jnp import numpy as np +STATIC_SCALAR_TYPES = (int, float, np.integer, np.floating) + def check_is_int_then_cast(val, msg): """Check if `val` is an integer, raise if not, otherwise cast to int.""" val = cast_to_float(val) - if isinstance(val, (int, float, np.integer, np.floating)): + if isinstance(val, STATIC_SCALAR_TYPES): # for simple inputs, we can check direct in python if val != int(val): raise TypeError(msg) @@ -43,9 +45,7 @@ def cast_numpy_array_to_native_byte_order(arr): def _cast_to_type(x, typ, accept_strings=False): - if isinstance(x, (int, float, np.integer, np.floating)) or ( - accept_strings and isinstance(x, str) - ): + if isinstance(x, STATIC_SCALAR_TYPES) or (accept_strings and isinstance(x, str)): return typ(x) else: return jnp.astype(x, typ) diff --git a/jax_galsim/image.py b/jax_galsim/image.py index 7c78a178..c4a60ed3 100644 --- a/jax_galsim/image.py +++ b/jax_galsim/image.py @@ -7,6 +7,7 @@ from jax_galsim.bounds import Bounds, BoundsD, BoundsI from jax_galsim.core.utils import ( + STATIC_SCALAR_TYPES, cast_numpy_array_to_native_byte_order, ensure_hashable, implements, @@ -269,6 +270,13 @@ def __init__(self, *args, **kwargs): raise TypeError("wcs parameters must be a galsim.BaseWCS instance") self.wcs = wcs + # raise an error if bounds doesn't have a fixed width + if not self._bounds.isStaticShape(): + raise RuntimeError( + "JAX-GalSim `Image` objects must have a `BoundsI` instance with " + "a static shape (i.e., `image.bounds.isStaticShape() is True`)." + ) + @staticmethod def _get_xmin_ymin(array, kwargs, check_bounds=True): """A helper function for parsing xmin, ymin, bounds options with a given array""" @@ -280,6 +288,14 @@ def _get_xmin_ymin(array, kwargs, check_bounds=True): b = kwargs.pop("bounds") if not isinstance(b, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") + + # raise an error if bounds doesn't have a fixed width + if not b.isStaticShape(): + raise RuntimeError( + "JAX-GalSim `Image` objects must have a `BoundsI` instance with " + "a static shape (i.e., `image.bounds.isStaticShape() is True`)." + ) + if check_bounds and b.isDefined(): if b.deltax != array.shape[1]: raise _galsim.GalSimIncompatibleValuesError( @@ -571,6 +587,14 @@ def resize(self, bounds, wcs=None): raise GalSimImmutableError("Cannot modify an immutable Image", self) if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") + + # raise an error if bounds doesn't have a fixed width + if not bounds.isStaticShape(): + raise RuntimeError( + "JAX-GalSim `Image` objects must have a `BoundsI` instance with " + "a static shape (i.e., `image.bounds.isStaticShape() is True`)." + ) + self._array = self._make_empty(shape=bounds.numpyShape(), dtype=self.dtype) self._bounds = bounds if wcs is not None: @@ -580,6 +604,14 @@ def resize(self, bounds, wcs=None): def subImage(self, bounds): if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") + + # raise an error if bounds doesn't have a fixed width + if not bounds.isStaticShape(): + raise RuntimeError( + "JAX-GalSim `Image` objects must have a `BoundsI` instance with " + "a static shape (i.e., `image.bounds.isStaticShape() is True`)." + ) + if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "Attempt to access subImage of undefined image" @@ -592,6 +624,13 @@ def subImage(self, bounds): raise _galsim.GalSimBoundsError( "Attempt to access subImage not (fully) in image", bounds, self.bounds ) + else: + inc_val = jnp.array(self.bounds.includes(bounds)) + inc_val = equinox.error_if( + inc_val, + jnp.any(~inc_val), + "Attempt to access subImage not (fully) in image", + ) if self.bounds.isStatic() and bounds.isStatic(): i1 = bounds.ymin - self.ymin @@ -619,6 +658,14 @@ def setSubImage(self, bounds, rhs): raise GalSimImmutableError("Cannot modify an immutable Image", self) if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") + + # raise an error if bounds doesn't have a fixed width + if not bounds.isStaticShape(): + raise RuntimeError( + "JAX-GalSim `Image` objects must have a `BoundsI` instance with " + "a static shape (i.e., `image.bounds.isStaticShape() is True`)." + ) + if not self.bounds.isDefined(): raise _galsim.GalSimUndefinedBoundsError( "Attempt to access values of an undefined image" @@ -631,6 +678,14 @@ def setSubImage(self, bounds, rhs): raise _galsim.GalSimBoundsError( "Attempt to access subImage not (fully) in image", bounds, self.bounds ) + else: + inc_val = jnp.array(self.bounds.includes(bounds)) + inc_val = equinox.error_if( + inc_val, + jnp.any(~inc_val), + "Attempt to access subImage not (fully) in image", + ) + if not isinstance(rhs, Image): raise TypeError("Trying to copyFrom a non-image") if bounds.numpyShape() != rhs.bounds.numpyShape(): @@ -722,6 +777,13 @@ def wrap(self, bounds, hermitian=False): if not isinstance(bounds, BoundsI): raise TypeError("bounds must be a galsim.BoundsI instance") + # raise an error if bounds doesn't have a fixed width + if not bounds.isStaticShape(): + raise RuntimeError( + "JAX-GalSim `Image` objects must have a `BoundsI` instance with " + "a static shape (i.e., `image.bounds.isStaticShape() is True`)." + ) + def _raise_if_nonzero(bnds, x_or_y, msg): if x_or_y == "x": if bnds.isStatic(): @@ -902,12 +964,19 @@ def calculate_inverse_fft(self): raise _galsim.GalSimError( "calculate_inverse_fft requires that the image has a PixelScale wcs." ) - if not self.bounds.includes(0, 0): + if self.bounds.isStatic() and not self.bounds.includes(0, 0): raise _galsim.GalSimBoundsError( "calculate_inverse_fft requires that the image includes (0,0)", PositionI(0, 0), self.bounds, ) + else: + inc_val = jnp.array(self.bounds.includes(0, 0)) + inc_val = equinox.error_if( + inc_val, + jnp.any(~inc_val), + "calculate_inverse_fft requires that the image includes (0,0)", + ) No2 = max( max(self.bounds.xmax, -self.bounds.ymin), @@ -1067,12 +1136,25 @@ def getValue(self, x, y): raise _galsim.GalSimUndefinedBoundsError( "Attempt to access values of an undefined image" ) - if not self.bounds.includes(x, y): + if ( + self.bounds.isStatic() + and isinstance(x, STATIC_SCALAR_TYPES) + and isinstance(y, STATIC_SCALAR_TYPES) + and not self.bounds.includes(x, y) + ): raise _galsim.GalSimBoundsError( "Attempt to access position not in bounds of image.", PositionI(x, y), self.bounds, ) + else: + inc_val = jnp.array(self.bounds.includes(x, y)) + inc_val = equinox.error_if( + inc_val, + jnp.any(~inc_val), + "Attempt to access position not in bounds of image.", + ) + return self._getValue(x, y) @implements(_galsim.Image._getValue) @@ -1090,10 +1172,18 @@ def setValue(self, *args, **kwargs): pos, value = parse_pos_args( args, kwargs, "x", "y", integer=True, others=["value"] ) - if not self.bounds.includes(pos): + if self.bounds.isStatic() and pos.isStatic() and not self.bounds.includes(pos): raise _galsim.GalSimBoundsError( "Attempt to set position not in bounds of image", pos, self.bounds ) + else: + inc_val = jnp.array(self.bounds.includes(pos)) + inc_val = equinox.error_if( + inc_val, + jnp.any(~inc_val), + "Attempt to set position not in bounds of image", + ) + self._setValue(pos.x, pos.y, value) @implements(_galsim.Image._setValue) @@ -1111,10 +1201,18 @@ def addValue(self, *args, **kwargs): pos, value = parse_pos_args( args, kwargs, "x", "y", integer=True, others=["value"] ) - if not self.bounds.includes(pos): + if self.bounds.isStatic() and pos.isStatic() and not self.bounds.includes(pos): raise _galsim.GalSimBoundsError( "Attempt to set position not in bounds of image", pos, self.bounds ) + else: + inc_val = jnp.array(self.bounds.includes(pos)) + inc_val = equinox.error_if( + inc_val, + jnp.any(~inc_val), + "Attempt to set position not in bounds of image", + ) + self._addValue(pos.x, pos.y, value) @implements(_galsim.Image._addValue) diff --git a/jax_galsim/position.py b/jax_galsim/position.py index cf36dba8..6b5ffc0d 100644 --- a/jax_galsim/position.py +++ b/jax_galsim/position.py @@ -4,6 +4,7 @@ from jax.tree_util import register_pytree_node_class from jax_galsim.core.utils import ( + STATIC_SCALAR_TYPES, cast_to_float, check_is_int_then_cast, ensure_hashable, @@ -182,6 +183,13 @@ def to_galsim(self): cast(self.y), ) + def isStatic(self): + """Returns ``True`` if the ``Position`` instance + ``x`` and ``y`` values are not arrays""" + return isinstance(self.x, STATIC_SCALAR_TYPES) and isinstance( + self.y, STATIC_SCALAR_TYPES + ) + @implements(_galsim.PositionD) @register_pytree_node_class diff --git a/jax_galsim/random.py b/jax_galsim/random.py index 2cf2db27..b5e730f3 100644 --- a/jax_galsim/random.py +++ b/jax_galsim/random.py @@ -9,7 +9,7 @@ import numpy as np from jax.tree_util import register_pytree_node_class -from jax_galsim.core.utils import implements +from jax_galsim.core.utils import STATIC_SCALAR_TYPES, implements try: from jax.extend.random import wrap_key_data @@ -95,7 +95,7 @@ def generates_in_pairs(self): def seed(self, seed=None): if seed is None: self._seed(seed=seed) - elif isinstance(seed, (int, float, np.integer, np.floating)): + elif isinstance(seed, STATIC_SCALAR_TYPES): if seed == int(seed): self._seed(seed=int(seed)) else: diff --git a/jax_galsim/wcs.py b/jax_galsim/wcs.py index 6dcec13b..ec1320be 100644 --- a/jax_galsim/wcs.py +++ b/jax_galsim/wcs.py @@ -7,6 +7,7 @@ from jax_galsim.angle import AngleUnit, arcsec, radians from jax_galsim.celestial import CelestialCoord from jax_galsim.core.utils import ( + STATIC_SCALAR_TYPES, cast_to_float, ensure_hashable, implements, @@ -22,7 +23,7 @@ # this kind of casting is only done for writing FITS headers # and should never be done anywhere else in the code base def _cast_to_static_numeric_scalar(x, msg=None): - if isinstance(x, (int, float, np.integer, np.floating)): + if isinstance(x, STATIC_SCALAR_TYPES): return x if isinstance(x, (np.ndarray, jax.Array, jnp.ndarray)): diff --git a/tests/GalSim b/tests/GalSim index 549616e8..f3d81a1d 160000 --- a/tests/GalSim +++ b/tests/GalSim @@ -1 +1 @@ -Subproject commit 549616e8ca4bb84142fae6cdb0a006669f92454b +Subproject commit f3d81a1d18a30651d8769818731d4c4ac3541478 diff --git a/tests/jax/test_api.py b/tests/jax/test_api.py index e76b081c..c3afed10 100644 --- a/tests/jax/test_api.py +++ b/tests/jax/test_api.py @@ -360,6 +360,7 @@ def _reg_fun(p): "xmax", "ymax", "isStatic", + "isStaticShape", ]: continue diff --git a/tests/jax/test_bounds_jax.py b/tests/jax/test_bounds_jax.py new file mode 100644 index 00000000..f9481487 --- /dev/null +++ b/tests/jax/test_bounds_jax.py @@ -0,0 +1,47 @@ +import jax +import jax.numpy as jnp +import numpy as np + +import jax_galsim + + +@jax.vmap +@jax.jit +def _make_bounds_int(xmin, ymin, xmax, ymax): + bds = jax_galsim.BoundsI(xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax) + return bds, bds.isDefined() + + +def test_bounds_jax_vmap_isdefined_int(): + xmin = jnp.array([9, 10, 11, 12]) + xmax = jnp.array([12, 11, 10, 9]) + ymin = jnp.array([9, 11, 10, 12]) + ymax = jnp.array([10, 10, 11, 10]) + bds, isdef = _make_bounds_int(xmin, ymin, xmax, ymax) + np.testing.assert_array_equal(bds.isDefined(), isdef, strict=True) + + # turn a bounds of arrays into a list of bounds + # see https://github.com/jax-ml/jax/discussions/35711 + list_of_bnds = jax.tree.transpose( + jax.tree.structure(bds), None, jax.tree.map(list, bds) + ) + assert list_of_bnds[0] != list_of_bnds[2] + assert list_of_bnds[1] == list_of_bnds[2] + assert list_of_bnds[2] == list_of_bnds[3] + assert all(not bnds.isStatic() for bnds in list_of_bnds) + + +@jax.vmap +@jax.jit +def _make_bounds_float(xmin, ymin, xmax, ymax): + bds = jax_galsim.BoundsD(xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax) + return bds, bds.isDefined() + + +def test_bounds_jax_vmap_isdefined_float(): + xmin = jnp.array([9, 10, 11, 12]) + xmax = jnp.array([12, 11, 10, 9]) + ymin = jnp.array([9, 11, 10, 12]) + ymax = jnp.array([10, 10, 10, 10]) + bds, isdef = _make_bounds_float(xmin, ymin, xmax, ymax) + np.testing.assert_array_equal(bds.isDefined(), isdef, strict=True)