diff --git a/jax_profiling/imaging/pixelization_sparse_cpu.py b/jax_profiling/imaging/pixelization_sparse_cpu.py index 2778bb8..8425684 100644 --- a/jax_profiling/imaging/pixelization_sparse_cpu.py +++ b/jax_profiling/imaging/pixelization_sparse_cpu.py @@ -14,8 +14,10 @@ ``pixelization.py`` to within ``rtol=1e-4`` (asserted at the bottom). * ``fit_sparse`` — ``FitImaging(xp=np)`` on a dataset with an attached CPU sparse operator (``apply_sparse_operator_cpu()``). Its ``log_evidence`` - is reported next to the gold reference but *not* asserted — the script's - job is to surface any divergence, not hide it. + is also asserted to match ``fit_no_sparse`` (within ``rtol=1e-4``) so this + script doubles as a CPU-sparse regression test. (PyAutoArray PR #296 fixed + an out-of-bounds read in ``psf_precision_value_from`` that previously made + the sparse path return ``NaN`` for masks touching the noise-map boundary.) Per-step timings are eager-only: numba functions are compiled on first call and reused on subsequent calls, so each section reports a first-call cost @@ -668,10 +670,11 @@ def _log_evidence(): # point — they share the same mathematical formulation and only differ in # their array backend. Asserting parity catches drift on either side. # -# The sparse path is intentionally NOT asserted: this script is a diagnostic -# for surfacing any divergence between sparse and non-sparse, so failing on -# divergence would defeat its purpose. The numbers above + the JSON output -# are the evidence the user needs to debug the sparse path. +# Both paths are asserted post-fix (PyAutoArray PR #296 fixed the OOB read in +# psf_precision_value_from). Keeping the sparse-path assertion here means the +# script doubles as a regression test: if the sparse-CPU bug recurs, this +# script will fail loudly instead of silently writing a "false" flag into the +# JSON. np.testing.assert_allclose( log_evidence_no_sparse, @@ -688,16 +691,20 @@ def _log_evidence(): f"log_evidence matches JAX reference {EXPECTED_LOG_EVIDENCE_HST:.6f}" ) -if sparse_matches_non_sparse: - print( - f" Sparse path also agrees with non-sparse to rtol=1e-4 " - f"(observed rtol = {sparse_vs_non_sparse_rtol:.3e})." - ) -else: - print( - f" *** SPARSE PATH DIVERGES from non-sparse: observed rtol = " - f"{sparse_vs_non_sparse_rtol:.3e} (>= 1e-4). " - f"This script's purpose is to surface this divergence; investigate " - f"the sparse-CPU path in PyAutoArray/autoarray/inversion/inversion/" - f"imaging_numba/." - ) +np.testing.assert_allclose( + log_evidence_sparse, + log_evidence_no_sparse, + rtol=1e-4, + err_msg=( + f"imaging/pixelization_sparse_cpu[{instrument}]: regression — CPU " + f"sparse-operator log_evidence diverges from non-sparse " + f"(sparse={log_evidence_sparse}, non_sparse={log_evidence_no_sparse}, " + f"observed rtol={sparse_vs_non_sparse_rtol:.3e}). Likely regression " + f"in PyAutoArray/autoarray/inversion/inversion/imaging_numba/." + ), +) +print( + f" Sparse regression assertion PASSED: " + f"sparse log_evidence matches non-sparse (rtol = " + f"{sparse_vs_non_sparse_rtol:.3e})." +) diff --git a/jax_profiling/imaging/results/pixelization_sparse_cpu_likelihood_summary_hst_v2026.5.1.4.json b/jax_profiling/imaging/results/pixelization_sparse_cpu_likelihood_summary_hst_v2026.5.1.4.json index a37f352..684b0ce 100644 --- a/jax_profiling/imaging/results/pixelization_sparse_cpu_likelihood_summary_hst_v2026.5.1.4.json +++ b/jax_profiling/imaging/results/pixelization_sparse_cpu_likelihood_summary_hst_v2026.5.1.4.json @@ -13,23 +13,23 @@ "source_pixels": 784 }, "steps": { - "Ray-trace grids": 0.007880670000304235, - "Lens light images (pre-PSF)": 0.004302589999861084, - "Blurred image (PSF convolution)": 0.01615301999991061, - "Profile-subtracted image": 1.0900000052060933e-05, - "Inversion setup (steps 4-7 combined)": 1.9999999494757505e-06, - "Data vector (D)": 0.05178085000006831, - "Curvature matrix (F)": 0.3827769700001227, - "Regularization matrix (H)": 0.04594268000000738, - "Regularized reconstruction": 0.456709669999691, - "Mapped recon + log evidence": 0.422671039999841 + "Ray-trace grids": 0.010212379999575204, + "Lens light images (pre-PSF)": 0.005444229999557138, + "Blurred image (PSF convolution)": 0.024003369999991264, + "Profile-subtracted image": 1.8070000078296288e-05, + "Inversion setup (steps 4-7 combined)": 2.540000423323363e-06, + "Data vector (D)": 0.06494419000009657, + "Curvature matrix (F)": 0.37986498000027497, + "Regularization matrix (H)": 0.049308830000518354, + "Regularized reconstruction": 0.43148055000056046, + "Mapped recon + log evidence": 0.5270483899999817 }, - "total_step_by_step": 1.3882303899998079, - "fit_no_sparse_full_likelihood_per_call": 3.598550679999607, + "total_step_by_step": 1.4923275300010572, + "fit_no_sparse_full_likelihood_per_call": 3.6985298600004173, "log_evidence_no_sparse": 26232.068573757562, - "log_evidence_sparse": NaN, + "log_evidence_sparse": 26232.06854278734, "log_evidence_jax_expected": 26232.068573757562, - "delta_sparse_minus_no_sparse": NaN, - "delta_sparse_minus_jax_expected": NaN, - "sparse_matches_non_sparse_rtol_1em4": false + "delta_sparse_minus_no_sparse": -3.0970222724135965e-05, + "delta_sparse_minus_jax_expected": -3.0970222724135965e-05, + "sparse_matches_non_sparse_rtol_1em4": true } \ No newline at end of file diff --git a/jax_profiling/imaging/results/pixelization_sparse_cpu_likelihood_summary_hst_v2026.5.1.4.png b/jax_profiling/imaging/results/pixelization_sparse_cpu_likelihood_summary_hst_v2026.5.1.4.png index 734cfd4..a1e308e 100644 Binary files a/jax_profiling/imaging/results/pixelization_sparse_cpu_likelihood_summary_hst_v2026.5.1.4.png and b/jax_profiling/imaging/results/pixelization_sparse_cpu_likelihood_summary_hst_v2026.5.1.4.png differ