|
5 | 5 | import numpy as np |
6 | 6 | from autogalaxy import cosmology as cosmo |
7 | 7 |
|
| 8 | +def is_jax(x): |
| 9 | + try: |
| 10 | + import jax |
| 11 | + from jax import Array |
| 12 | + from jax.core import Tracer |
| 13 | + return isinstance(x, (Array, Tracer)) |
| 14 | + except Exception: |
| 15 | + return False |
8 | 16 |
|
9 | 17 | def kappa_s_and_scale_radius( |
10 | | - cosmology, virial_mass, c_2, overdens, redshift_object, redshift_source, inner_slope |
| 18 | + cosmology, |
| 19 | + virial_mass, |
| 20 | + c_2, |
| 21 | + overdens, |
| 22 | + redshift_object, |
| 23 | + redshift_source, |
| 24 | + inner_slope, |
11 | 25 | ): |
12 | | - from scipy.integrate import quad |
13 | | - |
14 | | - concentration = (2.0 - inner_slope) * c_2 # gNFW concentration |
15 | | - |
16 | | - critical_density = cosmology.critical_density( |
17 | | - redshift_object, xp=np |
18 | | - ) # Msun / kpc^3 |
19 | | - |
20 | | - critical_surface_density = ( |
21 | | - cosmology.critical_surface_density_between_redshifts_solar_mass_per_kpc2_from( |
22 | | - redshift_0=redshift_object, |
23 | | - redshift_1=redshift_source, |
24 | | - xp=np, |
25 | | - ) |
| 26 | + """ |
| 27 | + Compute the characteristic convergence and scale radius of a spherical gNFW halo |
| 28 | + parameterised by virial mass and concentration. |
| 29 | +
|
| 30 | + This routine converts a halo defined by its virial mass and concentration into |
| 31 | + the equivalent gNFW parameters (`kappa_s`, `scale_radius`) used in lensing |
| 32 | + calculations. The normalization is computed analytically using the closed-form |
| 33 | + hypergeometric expression for the enclosed mass integral, ensuring compatibility |
| 34 | + with both NumPy and JAX backends (e.g. within `jax.jit`). |
| 35 | +
|
| 36 | + The virial radius is defined via: |
| 37 | +
|
| 38 | + M_vir = (4/3) π Δ ρ_crit(z_lens) r_vir^3 |
| 39 | +
|
| 40 | + where Δ is the overdensity with respect to the critical density. If `overdens` |
| 41 | + is set to zero, the Bryan & Norman (1998) redshift-dependent overdensity is used. |
| 42 | +
|
| 43 | + The gNFW normalization constant is computed as: |
| 44 | +
|
| 45 | + d_e = (Δ / 3) (3 − γ) c^γ / |
| 46 | + ₂F₁(3 − γ, 3 − γ; 4 − γ; −c) |
| 47 | +
|
| 48 | + where γ is the inner slope and c is the gNFW concentration. |
| 49 | +
|
| 50 | + Parameters |
| 51 | + ---------- |
| 52 | + cosmology |
| 53 | + Cosmology object providing critical density, angular diameter distance |
| 54 | + conversions, and surface mass density calculations. Must support an `xp` |
| 55 | + argument for NumPy/JAX interoperability. |
| 56 | + virial_mass |
| 57 | + Virial mass of the halo in units of solar masses. |
| 58 | + c_2 |
| 59 | + Concentration-like parameter, converted internally to the gNFW |
| 60 | + concentration via `(2 - inner_slope) * c_2`. |
| 61 | + overdens |
| 62 | + Overdensity with respect to the critical density. If zero, the |
| 63 | + Bryan & Norman (1998) redshift-dependent overdensity is used. |
| 64 | + redshift_object |
| 65 | + Redshift of the lens (halo). |
| 66 | + redshift_source |
| 67 | + Redshift of the background source. |
| 68 | + inner_slope |
| 69 | + Inner logarithmic density slope γ of the gNFW profile. |
| 70 | + xp |
| 71 | + Array backend module (`numpy` or `jax.numpy`). All array operations |
| 72 | + are dispatched through this module to ensure compatibility with |
| 73 | + both standard NumPy execution and JAX tracing / JIT compilation. |
| 74 | +
|
| 75 | + Returns |
| 76 | + ------- |
| 77 | + kappa_s |
| 78 | + Dimensionless characteristic convergence of the gNFW profile. |
| 79 | + scale_radius |
| 80 | + Angular scale radius in arcseconds. |
| 81 | + virial_radius |
| 82 | + Virial radius in kiloparsecs. |
| 83 | + overdens |
| 84 | + Final overdensity value used in the calculation. |
| 85 | +
|
| 86 | + Notes |
| 87 | + ----- |
| 88 | + - This implementation is fully JIT-compatible when `xp=jax.numpy`. |
| 89 | + - No Python-side branching depends on traced values; conditional logic |
| 90 | + is implemented via backend array operations. |
| 91 | + - The analytic normalization avoids numerical quadrature, improving |
| 92 | + both performance and differentiability. |
| 93 | + """ |
| 94 | + is_jax_bool = is_jax(virial_mass) |
| 95 | + |
| 96 | + if not is_jax_bool: |
| 97 | + xp = np |
| 98 | + else: |
| 99 | + from jax import numpy as jnp |
| 100 | + xp = jnp |
| 101 | + |
| 102 | + if xp is np: |
| 103 | + from scipy.special import hyp2f1 |
| 104 | + else: |
| 105 | + try: |
| 106 | + from jax.scipy.special import hyp2f1 # noqa: F401 |
| 107 | + except Exception as e: |
| 108 | + raise RuntimeError( |
| 109 | + "This feature requires jax.scipy.special.hyp2f1, which is available in " |
| 110 | + "JAX >= 0.6.1. Please upgrade `jax` and `jaxlib`." |
| 111 | + ) from e |
| 112 | + |
| 113 | + gamma = inner_slope |
| 114 | + concentration = (2.0 - gamma) * c_2 # gNFW concentration (your definition) |
| 115 | + |
| 116 | + critical_density = cosmology.critical_density(redshift_object, xp=xp) # Msun / kpc^3 |
| 117 | + |
| 118 | + critical_surface_density = cosmology.critical_surface_density_between_redshifts_solar_mass_per_kpc2_from( |
| 119 | + redshift_0=redshift_object, |
| 120 | + redshift_1=redshift_source, |
| 121 | + xp=xp, |
26 | 122 | ) # Msun / kpc^2 |
27 | 123 |
|
28 | | - kpc_per_arcsec = cosmology.kpc_per_arcsec_from( |
29 | | - redshift=redshift_object, xp=np |
30 | | - ) # kpc / arcsec |
| 124 | + kpc_per_arcsec = cosmology.kpc_per_arcsec_from(redshift=redshift_object, xp=xp) # kpc / arcsec |
31 | 125 |
|
32 | | - if overdens == 0: |
33 | | - x = cosmology.Om(redshift_object, xp=np) - 1.0 |
34 | | - overdens = 18.0 * np.pi**2 + 82.0 * x - 39.0 * x**2 # Bryan & Norman (1998) |
| 126 | + # Bryan & Norman (1998) overdensity if overdens == 0 |
| 127 | + x = cosmology.Om(redshift_object, xp=xp) - 1.0 |
| 128 | + overdens_bn98 = 18.0 * xp.pi**2 + 82.0 * x - 39.0 * x**2 |
| 129 | + overdens = xp.where(overdens == 0, overdens_bn98, overdens) |
35 | 130 |
|
36 | 131 | # r_vir in kpc |
37 | | - virial_radius = ( |
38 | | - virial_mass / (overdens * critical_density * (4.0 * np.pi / 3.0)) |
39 | | - ) ** (1.0 / 3.0) |
| 132 | + virial_radius = (virial_mass / (overdens * critical_density * (4.0 * xp.pi / 3.0))) ** (1.0 / 3.0) |
40 | 133 |
|
41 | 134 | # scale radius in kpc |
42 | 135 | scale_radius_kpc = virial_radius / concentration |
43 | 136 |
|
44 | | - # Normalization integral for gNFW |
45 | | - def integrand(r): |
46 | | - return (r**2 / r**inner_slope) * (1.0 + r / scale_radius_kpc) ** ( |
47 | | - inner_slope - 3.0 |
48 | | - ) |
| 137 | + # c = rvir/rs is exactly "concentration" by definition |
| 138 | + c = concentration |
49 | 139 |
|
50 | | - de_c = ( |
51 | | - (overdens / 3.0) |
52 | | - * (virial_radius**3 / scale_radius_kpc**inner_slope) |
53 | | - / quad(integrand, 0.0, virial_radius)[0] |
54 | | - ) |
| 140 | + # Analytic normalization |
| 141 | + a = 3.0 - gamma |
| 142 | + de_c = (overdens / 3.0) * a * (c**gamma) / hyp2f1(a, a, a + 1.0, -c) |
55 | 143 |
|
56 | 144 | rho_s = critical_density * de_c # Msun / kpc^3 |
57 | 145 | kappa_s = rho_s * scale_radius_kpc / critical_surface_density # dimensionless |
|
0 commit comments