Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
834 changes: 577 additions & 257 deletions jax_galsim/bounds.py

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions jax_galsim/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import jax.numpy as jnp
import numpy as np

STATIC_SCALAR_TYPES = (int, float, np.integer, np.floating)


def check_is_int_then_cast(val, msg):
"""Check if `val` is an integer, raise if not, otherwise cast to int."""
val = cast_to_float(val)

if isinstance(val, (int, float, np.integer, np.floating)):
if isinstance(val, STATIC_SCALAR_TYPES):
# for simple inputs, we can check direct in python
if val != int(val):
raise TypeError(msg)
Expand Down Expand Up @@ -43,9 +45,7 @@ def cast_numpy_array_to_native_byte_order(arr):


def _cast_to_type(x, typ, accept_strings=False):
if isinstance(x, (int, float, np.integer, np.floating)) or (
accept_strings and isinstance(x, str)
):
if isinstance(x, STATIC_SCALAR_TYPES) or (accept_strings and isinstance(x, str)):
return typ(x)
else:
return jnp.astype(x, typ)
Expand Down
106 changes: 102 additions & 4 deletions jax_galsim/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from jax_galsim.bounds import Bounds, BoundsD, BoundsI
from jax_galsim.core.utils import (
STATIC_SCALAR_TYPES,
cast_numpy_array_to_native_byte_order,
ensure_hashable,
implements,
Expand Down Expand Up @@ -269,6 +270,13 @@ def __init__(self, *args, **kwargs):
raise TypeError("wcs parameters must be a galsim.BaseWCS instance")
self.wcs = wcs

# raise an error if bounds doesn't have a fixed width
if not self._bounds.isStaticShape():
raise RuntimeError(
"JAX-GalSim `Image` objects must have a `BoundsI` instance with "
"a static shape (i.e., `image.bounds.isStaticShape() is True`)."
)

@staticmethod
def _get_xmin_ymin(array, kwargs, check_bounds=True):
"""A helper function for parsing xmin, ymin, bounds options with a given array"""
Expand All @@ -280,6 +288,14 @@ def _get_xmin_ymin(array, kwargs, check_bounds=True):
b = kwargs.pop("bounds")
if not isinstance(b, BoundsI):
raise TypeError("bounds must be a galsim.BoundsI instance")

# raise an error if bounds doesn't have a fixed width
if not b.isStaticShape():
raise RuntimeError(
"JAX-GalSim `Image` objects must have a `BoundsI` instance with "
"a static shape (i.e., `image.bounds.isStaticShape() is True`)."
)

if check_bounds and b.isDefined():
if b.deltax != array.shape[1]:
raise _galsim.GalSimIncompatibleValuesError(
Expand Down Expand Up @@ -571,6 +587,14 @@ def resize(self, bounds, wcs=None):
raise GalSimImmutableError("Cannot modify an immutable Image", self)
if not isinstance(bounds, BoundsI):
raise TypeError("bounds must be a galsim.BoundsI instance")

# raise an error if bounds doesn't have a fixed width
if not bounds.isStaticShape():
raise RuntimeError(
"JAX-GalSim `Image` objects must have a `BoundsI` instance with "
"a static shape (i.e., `image.bounds.isStaticShape() is True`)."
)

self._array = self._make_empty(shape=bounds.numpyShape(), dtype=self.dtype)
self._bounds = bounds
if wcs is not None:
Expand All @@ -580,6 +604,14 @@ def resize(self, bounds, wcs=None):
def subImage(self, bounds):
if not isinstance(bounds, BoundsI):
raise TypeError("bounds must be a galsim.BoundsI instance")

# raise an error if bounds doesn't have a fixed width
if not bounds.isStaticShape():
raise RuntimeError(
"JAX-GalSim `Image` objects must have a `BoundsI` instance with "
"a static shape (i.e., `image.bounds.isStaticShape() is True`)."
)

if not self.bounds.isDefined():
raise _galsim.GalSimUndefinedBoundsError(
"Attempt to access subImage of undefined image"
Expand All @@ -592,6 +624,13 @@ def subImage(self, bounds):
raise _galsim.GalSimBoundsError(
"Attempt to access subImage not (fully) in image", bounds, self.bounds
)
else:
inc_val = jnp.array(self.bounds.includes(bounds))
inc_val = equinox.error_if(
inc_val,
jnp.any(~inc_val),
"Attempt to access subImage not (fully) in image",
)

if self.bounds.isStatic() and bounds.isStatic():
i1 = bounds.ymin - self.ymin
Expand Down Expand Up @@ -619,6 +658,14 @@ def setSubImage(self, bounds, rhs):
raise GalSimImmutableError("Cannot modify an immutable Image", self)
if not isinstance(bounds, BoundsI):
raise TypeError("bounds must be a galsim.BoundsI instance")

# raise an error if bounds doesn't have a fixed width
if not bounds.isStaticShape():
raise RuntimeError(
"JAX-GalSim `Image` objects must have a `BoundsI` instance with "
"a static shape (i.e., `image.bounds.isStaticShape() is True`)."
)

if not self.bounds.isDefined():
raise _galsim.GalSimUndefinedBoundsError(
"Attempt to access values of an undefined image"
Expand All @@ -631,6 +678,14 @@ def setSubImage(self, bounds, rhs):
raise _galsim.GalSimBoundsError(
"Attempt to access subImage not (fully) in image", bounds, self.bounds
)
else:
inc_val = jnp.array(self.bounds.includes(bounds))
inc_val = equinox.error_if(
inc_val,
jnp.any(~inc_val),
"Attempt to access subImage not (fully) in image",
)

if not isinstance(rhs, Image):
raise TypeError("Trying to copyFrom a non-image")
if bounds.numpyShape() != rhs.bounds.numpyShape():
Expand Down Expand Up @@ -722,6 +777,13 @@ def wrap(self, bounds, hermitian=False):
if not isinstance(bounds, BoundsI):
raise TypeError("bounds must be a galsim.BoundsI instance")

# raise an error if bounds doesn't have a fixed width
if not bounds.isStaticShape():
raise RuntimeError(
"JAX-GalSim `Image` objects must have a `BoundsI` instance with "
"a static shape (i.e., `image.bounds.isStaticShape() is True`)."
)

def _raise_if_nonzero(bnds, x_or_y, msg):
if x_or_y == "x":
if bnds.isStatic():
Expand Down Expand Up @@ -902,12 +964,19 @@ def calculate_inverse_fft(self):
raise _galsim.GalSimError(
"calculate_inverse_fft requires that the image has a PixelScale wcs."
)
if not self.bounds.includes(0, 0):
if self.bounds.isStatic() and not self.bounds.includes(0, 0):
raise _galsim.GalSimBoundsError(
"calculate_inverse_fft requires that the image includes (0,0)",
PositionI(0, 0),
self.bounds,
)
else:
inc_val = jnp.array(self.bounds.includes(0, 0))
inc_val = equinox.error_if(
inc_val,
jnp.any(~inc_val),
"calculate_inverse_fft requires that the image includes (0,0)",
)

No2 = max(
max(self.bounds.xmax, -self.bounds.ymin),
Expand Down Expand Up @@ -1067,12 +1136,25 @@ def getValue(self, x, y):
raise _galsim.GalSimUndefinedBoundsError(
"Attempt to access values of an undefined image"
)
if not self.bounds.includes(x, y):
if (
self.bounds.isStatic()
and isinstance(x, STATIC_SCALAR_TYPES)
and isinstance(y, STATIC_SCALAR_TYPES)
and not self.bounds.includes(x, y)
):
raise _galsim.GalSimBoundsError(
"Attempt to access position not in bounds of image.",
PositionI(x, y),
self.bounds,
)
else:
inc_val = jnp.array(self.bounds.includes(x, y))
inc_val = equinox.error_if(
inc_val,
jnp.any(~inc_val),
"Attempt to access position not in bounds of image.",
)

return self._getValue(x, y)

@implements(_galsim.Image._getValue)
Expand All @@ -1090,10 +1172,18 @@ def setValue(self, *args, **kwargs):
pos, value = parse_pos_args(
args, kwargs, "x", "y", integer=True, others=["value"]
)
if not self.bounds.includes(pos):
if self.bounds.isStatic() and pos.isStatic() and not self.bounds.includes(pos):
raise _galsim.GalSimBoundsError(
"Attempt to set position not in bounds of image", pos, self.bounds
)
else:
inc_val = jnp.array(self.bounds.includes(pos))
inc_val = equinox.error_if(
inc_val,
jnp.any(~inc_val),
"Attempt to set position not in bounds of image",
)

self._setValue(pos.x, pos.y, value)

@implements(_galsim.Image._setValue)
Expand All @@ -1111,10 +1201,18 @@ def addValue(self, *args, **kwargs):
pos, value = parse_pos_args(
args, kwargs, "x", "y", integer=True, others=["value"]
)
if not self.bounds.includes(pos):
if self.bounds.isStatic() and pos.isStatic() and not self.bounds.includes(pos):
raise _galsim.GalSimBoundsError(
"Attempt to set position not in bounds of image", pos, self.bounds
)
else:
inc_val = jnp.array(self.bounds.includes(pos))
inc_val = equinox.error_if(
inc_val,
jnp.any(~inc_val),
"Attempt to set position not in bounds of image",
)

self._addValue(pos.x, pos.y, value)

@implements(_galsim.Image._addValue)
Expand Down
8 changes: 8 additions & 0 deletions jax_galsim/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.utils import (
STATIC_SCALAR_TYPES,
cast_to_float,
check_is_int_then_cast,
ensure_hashable,
Expand Down Expand Up @@ -182,6 +183,13 @@ def to_galsim(self):
cast(self.y),
)

def isStatic(self):
"""Returns ``True`` if the ``Position`` instance
``x`` and ``y`` values are not arrays"""
return isinstance(self.x, STATIC_SCALAR_TYPES) and isinstance(
self.y, STATIC_SCALAR_TYPES
)


@implements(_galsim.PositionD)
@register_pytree_node_class
Expand Down
4 changes: 2 additions & 2 deletions jax_galsim/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
from jax.tree_util import register_pytree_node_class

from jax_galsim.core.utils import implements
from jax_galsim.core.utils import STATIC_SCALAR_TYPES, implements

try:
from jax.extend.random import wrap_key_data
Expand Down Expand Up @@ -95,7 +95,7 @@ def generates_in_pairs(self):
def seed(self, seed=None):
if seed is None:
self._seed(seed=seed)
elif isinstance(seed, (int, float, np.integer, np.floating)):
elif isinstance(seed, STATIC_SCALAR_TYPES):
if seed == int(seed):
self._seed(seed=int(seed))
else:
Expand Down
3 changes: 2 additions & 1 deletion jax_galsim/wcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from jax_galsim.angle import AngleUnit, arcsec, radians
from jax_galsim.celestial import CelestialCoord
from jax_galsim.core.utils import (
STATIC_SCALAR_TYPES,
cast_to_float,
ensure_hashable,
implements,
Expand All @@ -22,7 +23,7 @@
# this kind of casting is only done for writing FITS headers
# and should never be done anywhere else in the code base
def _cast_to_static_numeric_scalar(x, msg=None):
if isinstance(x, (int, float, np.integer, np.floating)):
if isinstance(x, STATIC_SCALAR_TYPES):
return x

if isinstance(x, (np.ndarray, jax.Array, jnp.ndarray)):
Expand Down
2 changes: 1 addition & 1 deletion tests/GalSim
1 change: 1 addition & 0 deletions tests/jax/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ def _reg_fun(p):
"xmax",
"ymax",
"isStatic",
"isStaticShape",
]:
continue

Expand Down
47 changes: 47 additions & 0 deletions tests/jax/test_bounds_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import jax
import jax.numpy as jnp
import numpy as np

import jax_galsim


@jax.vmap
@jax.jit
def _make_bounds_int(xmin, ymin, xmax, ymax):
bds = jax_galsim.BoundsI(xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax)
return bds, bds.isDefined()


def test_bounds_jax_vmap_isdefined_int():
xmin = jnp.array([9, 10, 11, 12])
xmax = jnp.array([12, 11, 10, 9])
ymin = jnp.array([9, 11, 10, 12])
ymax = jnp.array([10, 10, 11, 10])
bds, isdef = _make_bounds_int(xmin, ymin, xmax, ymax)
np.testing.assert_array_equal(bds.isDefined(), isdef, strict=True)

# turn a bounds of arrays into a list of bounds
# see https://github.com/jax-ml/jax/discussions/35711
list_of_bnds = jax.tree.transpose(
jax.tree.structure(bds), None, jax.tree.map(list, bds)
)
assert list_of_bnds[0] != list_of_bnds[2]
assert list_of_bnds[1] == list_of_bnds[2]
assert list_of_bnds[2] == list_of_bnds[3]
assert all(not bnds.isStatic() for bnds in list_of_bnds)


@jax.vmap
@jax.jit
def _make_bounds_float(xmin, ymin, xmax, ymax):
bds = jax_galsim.BoundsD(xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax)
return bds, bds.isDefined()


def test_bounds_jax_vmap_isdefined_float():
xmin = jnp.array([9, 10, 11, 12])
xmax = jnp.array([12, 11, 10, 9])
ymin = jnp.array([9, 11, 10, 12])
ymax = jnp.array([10, 10, 10, 10])
bds, isdef = _make_bounds_float(xmin, ymin, xmax, ymax)
np.testing.assert_array_equal(bds.isDefined(), isdef, strict=True)
Loading