Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions scripts/jax_likelihood_functions/datacube/delaunay.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,10 +235,11 @@
print("JAX Time Taken per Likelihood:", (time.time() - start) / batch_size)

"""
Cube log-likelihood ≈ N × single-channel log-likelihood (-3165.42388511) for
identical channels. Pinned empirically below.
Cube log-likelihood ≈ N × single-channel log-likelihood for identical
channels. Pinned empirically below.
"""
EXPECTED_VMAP_LOG_LIKELIHOOD = n_channels * -3165.42388511
EXPECTED_SINGLE_CHANNEL_LOG_LIKELIHOOD = -3165.42388511
EXPECTED_VMAP_LOG_LIKELIHOOD = n_channels * EXPECTED_SINGLE_CHANNEL_LOG_LIKELIHOOD

np.testing.assert_allclose(
np.array(result),
Expand Down Expand Up @@ -278,9 +279,14 @@ def log_l_jit_fn(parameters):
"""
__Path B: TransformerNUFFT cross-check__

Re-run the same cube vmap with ``TransformerNUFFT`` and confirm the result
matches the ``TransformerDFT`` value.
Re-run a single-channel vmap with ``TransformerNUFFT`` and confirm the result
matches the single-channel ``TransformerDFT`` value. The 4-channel DFT path above
already validates datacube factor-graph summation and JIT routing; duplicating the
same Delaunay NUFFT graph across four identical channels can exceed the release
runner memory budget.
"""
nufft_channels = 1

dataset_list_nufft = [
al.Interferometer.from_fits(
data_path=path.join(dataset_path, "data.fits"),
Expand All @@ -289,7 +295,7 @@ def log_l_jit_fn(parameters):
real_space_mask=real_space_mask,
transformer_class=al.TransformerNUFFT,
)
for _ in range(n_channels)
for _ in range(nufft_channels)
]

analysis_list_nufft = [
Expand Down Expand Up @@ -329,7 +335,7 @@ def log_l_jit_fn(parameters):

np.testing.assert_allclose(
np.array(result_nufft),
EXPECTED_VMAP_LOG_LIKELIHOOD,
nufft_channels * EXPECTED_SINGLE_CHANNEL_LOG_LIKELIHOOD,
rtol=1e-4,
err_msg="datacube/delaunay: TransformerNUFFT cube vmap disagrees with TransformerDFT",
)
Expand Down
Loading