diff --git a/notebooks/jax_likelihood_functions/interferometer/rectangular_sparse.ipynb b/notebooks/jax_likelihood_functions/interferometer/rectangular_sparse.ipynb index 91660f23..a4ccd793 100644 --- a/notebooks/jax_likelihood_functions/interferometer/rectangular_sparse.ipynb +++ b/notebooks/jax_likelihood_functions/interferometer/rectangular_sparse.ipynb @@ -333,7 +333,7 @@ "\n", "np.testing.assert_allclose(\n", " np.array(result),\n", - " -3152.03184792,\n", + " -3164.286252,\n", " rtol=1e-4,\n", " err_msg=\"interferometer/rectangular_sparse: JAX vmap likelihood mismatch\",\n", ")\n" diff --git a/scripts/jax_assertions/sparse_operators.py b/scripts/jax_assertions/sparse_operators.py index c30fb541..55b1271a 100644 --- a/scripts/jax_assertions/sparse_operators.py +++ b/scripts/jax_assertions/sparse_operators.py @@ -235,18 +235,29 @@ pix_weights_for_sub_slim_index = np.ones(shape=(9, 1)) +# Use the unmasked-extent index for the interferometer sparse operator (which +# lives on the extent grid, not the full native grid). For this fully-unmasked +# 3x3 grid the two indices are identical, but the extent form is the convention +# accepted by the operator after the Pmax > 1 fix. +rows_interferometer, cols_interferometer, vals_interferometer = aa.util.mapper.sparse_triplets_from( + pix_indexes_for_sub=pix_indexes_for_sub_slim_index, + pix_weights_for_sub=pix_weights_for_sub_slim_index, + slim_index_for_sub=np.arange(pix_indexes_for_sub_slim_index.shape[0], dtype=np.int32), + fft_index_for_masked_pixel=grid.mask.extent_index_for_masked_pixel, + sub_fraction_slim=np.ones(pix_indexes_for_sub_slim_index.shape[0], dtype=np.float64), + return_rows_slim=False, +) + sparse_operator = aa.InterferometerSparseOperator.from_nufft_precision_operator( nufft_precision_operator=nufft_precision_operator, dirty_image=None, ) -curvature_matrix_via_preload = ( - sparse_operator.curvature_matrix_via_sparse_operator_from( - pix_indexes_for_sub_slim_index=pix_indexes_for_sub_slim_index, - pix_weights_for_sub_slim_index=pix_weights_for_sub_slim_index, - fft_index_for_masked_pixel=grid.mask.fft_index_for_masked_pixel, - pix_pixels=3, - ) +curvature_matrix_via_preload = sparse_operator.curvature_matrix_diag_from( + rows=rows_interferometer, + cols=cols_interferometer, + vals=vals_interferometer, + S=3, ) npt.assert_allclose( diff --git a/scripts/jax_likelihood_functions/interferometer/rectangular_sparse.py b/scripts/jax_likelihood_functions/interferometer/rectangular_sparse.py index a670a4a4..57f6525f 100644 --- a/scripts/jax_likelihood_functions/interferometer/rectangular_sparse.py +++ b/scripts/jax_likelihood_functions/interferometer/rectangular_sparse.py @@ -213,7 +213,7 @@ class in **PyAutoFit**, which pairs the model with likelihood. np.testing.assert_allclose( np.array(result), - -3152.03184792, + -3164.286252, rtol=1e-4, err_msg="interferometer/rectangular_sparse: JAX vmap likelihood mismatch", ) @@ -263,12 +263,12 @@ class in **PyAutoFit**, which pairs the model with likelihood. The Path A run above uses TransformerDFT + `apply_sparse_operator(use_jax=True)` (the cached-precision-matrix accelerator for pixelization). This pass uses the same TransformerDFT but skips the sparse-operator optimization — the -plain direct-DFT pixelization path. The two paths give *different* -log-likelihoods at the ~0.4% level because the sparse-operator -precomputation is a numerical reformulation, not an exact reproduction; -this is expected. Path B's literal must therefore match -`scripts/jax_likelihood_functions/interferometer/rectangular.py` (which is -the same model + DFT-no-sparse path). +plain direct-DFT pixelization path. After the Pmax > 1 / extent-indexing fix +(issue #314), the two paths agree to numerical precision: the sparse-operator +precomputation is mathematically exact, not a "numerical reformulation". +Path B's literal therefore matches Path A and +`scripts/jax_likelihood_functions/interferometer/rectangular.py` (the same +model + DFT-no-sparse path). """ dataset_dft_nosparse = al.Interferometer.from_fits( data_path=path.join(dataset_path, "data.fits"),