diff --git a/scripts/imaging/simulator_use_jax_parity.py b/scripts/imaging/simulator_use_jax_parity.py new file mode 100644 index 00000000..2c301246 --- /dev/null +++ b/scripts/imaging/simulator_use_jax_parity.py @@ -0,0 +1,120 @@ +""" +SimulatorImaging use_jax parity test +===================================== + +Runs ``al.SimulatorImaging(use_jax=False)`` and ``al.SimulatorImaging(use_jax=True)`` +against the same tracer + grid and asserts the noise-free simulated image agrees +to machine precision. Cross-xp numerical validation for ``SimulatorImaging.use_jax=True`` +— library unit tests stay NumPy-only. + +Noise is disabled on both paths because JAX and NumPy use different RNG algorithms; +the same seed produces different draws. The simulator's xp-threaded noise path is +tested separately (``test_autoarray/dataset/imaging/test_simulator_use_jax.py`` +covers the constructor wiring; this script covers numerical agreement of the +deterministic image-generation path). + +Also tests the @jax.jit roundtrip: the simulator under jit produces a dataset +whose data array agrees with the eager path. This exercises the +``register_tracer_classes(tracer)`` registration walker shipped in PR 1 +(PyAutoLens#538) — without prior registration the @jax.jit decoration would fail +to flatten the Tracer at the JIT boundary. +""" + +from autoconf import jax_wrapper # Sets JAX float64 before other imports + +import jax +import numpy as np + +import autolens as al + + +grid = al.Grid2D.uniform(shape_native=(100, 100), pixel_scales=0.1) + +lens_galaxy = al.Galaxy( + redshift=0.5, + light=al.lp.Sersic( + centre=(0.0, 0.0), + ell_comps=(0.01, 0.01), + intensity=1.0, + effective_radius=0.5, + sersic_index=4.0, + ), + mass=al.mp.Isothermal( + centre=(0.0, 0.0), + ell_comps=(0.01, 0.01), + einstein_radius=1.6, + ), +) +source_galaxy = al.Galaxy( + redshift=1.0, + light=al.lp.Sersic( + centre=(0.1, 0.1), + ell_comps=(0.01, 0.01), + intensity=0.5, + effective_radius=0.2, + sersic_index=1.0, + ), +) +tracer = al.Tracer(galaxies=[lens_galaxy, source_galaxy]) + +psf = al.Convolver.from_gaussian( + shape_native=(11, 11), sigma=0.05, pixel_scales=grid.pixel_scales +) + +# Noise-free configuration: both backends must produce identical deterministic images. +common_kwargs = dict( + exposure_time=300.0, + psf=psf, + background_sky_level=0.1, + add_poisson_noise_to_data=False, + include_poisson_noise_in_noise_map=False, +) + +# NumPy path. +simulator_np = al.SimulatorImaging(**common_kwargs) +dataset_np = simulator_np.via_tracer_from(tracer=tracer, grid=grid) +data_np = np.asarray(dataset_np.data.array) + +# JAX path (eager, not yet under @jax.jit). +simulator_jax = al.SimulatorImaging(use_jax=True, **common_kwargs) +dataset_jax = simulator_jax.via_tracer_from(tracer=tracer, grid=grid) +data_jax = np.asarray(dataset_jax.data.array) + +assert data_np.shape == data_jax.shape, ( + f"Shape mismatch: numpy={data_np.shape} vs jax={data_jax.shape}" +) +np.testing.assert_allclose( + data_np, + data_jax, + atol=1e-8, + err_msg="SimulatorImaging(use_jax=True) data differs from use_jax=False", +) +print( + f"PASS: SimulatorImaging(use_jax=True) data matches use_jax=False " + f"to atol=1e-8 (shape {data_np.shape})." +) + +# @jax.jit roundtrip: register tracer classes so JAX can flatten/unflatten +# the Tracer pytree at the JIT boundary, then run the simulator under jit +# and verify the output matches the eager JAX path. +al.util.register_tracer_classes(tracer) + + +@jax.jit +def simulate_jit(tracer): + dataset = simulator_jax.via_tracer_from(tracer=tracer, grid=grid) + return dataset.data._array + + +data_jit = np.asarray(simulate_jit(tracer)) + +np.testing.assert_allclose( + data_jax, + data_jit, + atol=1e-8, + err_msg="SimulatorImaging @jax.jit data differs from eager JAX path", +) +print( + f"PASS: SimulatorImaging @jax.jit roundtrip matches eager JAX " + f"to atol=1e-8 (shape {data_jit.shape})." +) diff --git a/scripts/interferometer/simulator_use_jax_parity.py b/scripts/interferometer/simulator_use_jax_parity.py new file mode 100644 index 00000000..1b211ef4 --- /dev/null +++ b/scripts/interferometer/simulator_use_jax_parity.py @@ -0,0 +1,111 @@ +""" +SimulatorInterferometer use_jax parity test +============================================ + +Runs ``al.SimulatorInterferometer(use_jax=False)`` and ``al.SimulatorInterferometer(use_jax=True)`` +against the same tracer + grid and asserts the noise-free simulated visibilities +agree to machine precision. Cross-xp numerical validation for +``SimulatorInterferometer.use_jax=True`` — library unit tests stay NumPy-only. + +Noise is disabled (``noise_sigma=None``) because JAX and NumPy use different RNG +algorithms; the same seed produces different draws. This test covers the +deterministic Fourier-transform path. + +Also tests the @jax.jit roundtrip for the interferometer simulator. +""" + +from autoconf import jax_wrapper # Sets JAX float64 before other imports + +import jax +import numpy as np + +import autolens as al + + +grid = al.Grid2D.uniform(shape_native=(64, 64), pixel_scales=0.1) + +lens_galaxy = al.Galaxy( + redshift=0.5, + light=al.lp.Sersic( + centre=(0.0, 0.0), + ell_comps=(0.01, 0.01), + intensity=1.0, + effective_radius=0.5, + sersic_index=4.0, + ), + mass=al.mp.Isothermal( + centre=(0.0, 0.0), + ell_comps=(0.01, 0.01), + einstein_radius=1.6, + ), +) +source_galaxy = al.Galaxy( + redshift=1.0, + light=al.lp.Sersic( + centre=(0.1, 0.1), + ell_comps=(0.01, 0.01), + intensity=0.5, + effective_radius=0.2, + sersic_index=1.0, + ), +) +tracer = al.Tracer(galaxies=[lens_galaxy, source_galaxy]) + +uv_wavelengths = np.random.RandomState(seed=0).uniform( + low=-1000.0, high=1000.0, size=(200, 2) +) + +# Noise-free configuration: both backends produce identical deterministic visibilities. +common_kwargs = dict( + uv_wavelengths=uv_wavelengths, + exposure_time=300.0, + noise_sigma=None, # disable noise for deterministic parity +) + +# NumPy path. +simulator_np = al.SimulatorInterferometer(**common_kwargs) +dataset_np = simulator_np.via_tracer_from(tracer=tracer, grid=grid) +data_np = np.asarray(dataset_np.data) + +# JAX path (eager). +simulator_jax = al.SimulatorInterferometer(use_jax=True, **common_kwargs) +dataset_jax = simulator_jax.via_tracer_from(tracer=tracer, grid=grid) +data_jax = np.asarray(dataset_jax.data) + +assert data_np.shape == data_jax.shape, ( + f"Shape mismatch: numpy={data_np.shape} vs jax={data_jax.shape}" +) +np.testing.assert_allclose( + data_np, + data_jax, + atol=1e-8, + err_msg="SimulatorInterferometer(use_jax=True) visibilities differ from use_jax=False", +) + +print( + f"PASS: SimulatorInterferometer(use_jax=True) visibilities match use_jax=False " + f"to atol=1e-8 ({data_np.shape[0]} visibilities)." +) + +# @jax.jit roundtrip. +al.util.register_tracer_classes(tracer) + + +@jax.jit +def simulate_jit(tracer): + dataset = simulator_jax.via_tracer_from(tracer=tracer, grid=grid) + return dataset.data._array + + +data_jit = np.asarray(simulate_jit(tracer)) + +np.testing.assert_allclose( + data_jax, + data_jit, + atol=1e-8, + err_msg="SimulatorInterferometer @jax.jit visibilities differ from eager JAX path", +) +print( + f"PASS: SimulatorInterferometer @jax.jit roundtrip matches eager JAX " + f"to atol=1e-8 ({data_jit.shape[0]} visibilities)." +)