Skip to content

Commit 5850af2

Browse files
authored
Merge pull request #286 from Jammy2211/feature/jaxify_gnfw_conc
feature/jaxify_gnfw_conc
2 parents b9a2e50 + 9d1b9d8 commit 5850af2

1 file changed

Lines changed: 122 additions & 34 deletions

File tree

autogalaxy/profiles/mass/dark/gnfw_virial_mass_conc.py

Lines changed: 122 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -5,53 +5,141 @@
55
import numpy as np
66
from autogalaxy import cosmology as cosmo
77

8+
def is_jax(x):
9+
try:
10+
import jax
11+
from jax import Array
12+
from jax.core import Tracer
13+
return isinstance(x, (Array, Tracer))
14+
except Exception:
15+
return False
816

917
def kappa_s_and_scale_radius(
10-
cosmology, virial_mass, c_2, overdens, redshift_object, redshift_source, inner_slope
18+
cosmology,
19+
virial_mass,
20+
c_2,
21+
overdens,
22+
redshift_object,
23+
redshift_source,
24+
inner_slope,
1125
):
12-
from scipy.integrate import quad
13-
14-
concentration = (2.0 - inner_slope) * c_2 # gNFW concentration
15-
16-
critical_density = cosmology.critical_density(
17-
redshift_object, xp=np
18-
) # Msun / kpc^3
19-
20-
critical_surface_density = (
21-
cosmology.critical_surface_density_between_redshifts_solar_mass_per_kpc2_from(
22-
redshift_0=redshift_object,
23-
redshift_1=redshift_source,
24-
xp=np,
25-
)
26+
"""
27+
Compute the characteristic convergence and scale radius of a spherical gNFW halo
28+
parameterised by virial mass and concentration.
29+
30+
This routine converts a halo defined by its virial mass and concentration into
31+
the equivalent gNFW parameters (`kappa_s`, `scale_radius`) used in lensing
32+
calculations. The normalization is computed analytically using the closed-form
33+
hypergeometric expression for the enclosed mass integral, ensuring compatibility
34+
with both NumPy and JAX backends (e.g. within `jax.jit`).
35+
36+
The virial radius is defined via:
37+
38+
M_vir = (4/3) π Δ ρ_crit(z_lens) r_vir^3
39+
40+
where Δ is the overdensity with respect to the critical density. If `overdens`
41+
is set to zero, the Bryan & Norman (1998) redshift-dependent overdensity is used.
42+
43+
The gNFW normalization constant is computed as:
44+
45+
d_e = (Δ / 3) (3 − γ) c^γ /
46+
₂F₁(3 − γ, 3 − γ; 4 − γ; −c)
47+
48+
where γ is the inner slope and c is the gNFW concentration.
49+
50+
Parameters
51+
----------
52+
cosmology
53+
Cosmology object providing critical density, angular diameter distance
54+
conversions, and surface mass density calculations. Must support an `xp`
55+
argument for NumPy/JAX interoperability.
56+
virial_mass
57+
Virial mass of the halo in units of solar masses.
58+
c_2
59+
Concentration-like parameter, converted internally to the gNFW
60+
concentration via `(2 - inner_slope) * c_2`.
61+
overdens
62+
Overdensity with respect to the critical density. If zero, the
63+
Bryan & Norman (1998) redshift-dependent overdensity is used.
64+
redshift_object
65+
Redshift of the lens (halo).
66+
redshift_source
67+
Redshift of the background source.
68+
inner_slope
69+
Inner logarithmic density slope γ of the gNFW profile.
70+
xp
71+
Array backend module (`numpy` or `jax.numpy`). All array operations
72+
are dispatched through this module to ensure compatibility with
73+
both standard NumPy execution and JAX tracing / JIT compilation.
74+
75+
Returns
76+
-------
77+
kappa_s
78+
Dimensionless characteristic convergence of the gNFW profile.
79+
scale_radius
80+
Angular scale radius in arcseconds.
81+
virial_radius
82+
Virial radius in kiloparsecs.
83+
overdens
84+
Final overdensity value used in the calculation.
85+
86+
Notes
87+
-----
88+
- This implementation is fully JIT-compatible when `xp=jax.numpy`.
89+
- No Python-side branching depends on traced values; conditional logic
90+
is implemented via backend array operations.
91+
- The analytic normalization avoids numerical quadrature, improving
92+
both performance and differentiability.
93+
"""
94+
is_jax_bool = is_jax(virial_mass)
95+
96+
if not is_jax_bool:
97+
xp = np
98+
else:
99+
from jax import numpy as jnp
100+
xp = jnp
101+
102+
if xp is np:
103+
from scipy.special import hyp2f1
104+
else:
105+
try:
106+
from jax.scipy.special import hyp2f1 # noqa: F401
107+
except Exception as e:
108+
raise RuntimeError(
109+
"This feature requires jax.scipy.special.hyp2f1, which is available in "
110+
"JAX >= 0.6.1. Please upgrade `jax` and `jaxlib`."
111+
) from e
112+
113+
gamma = inner_slope
114+
concentration = (2.0 - gamma) * c_2 # gNFW concentration (your definition)
115+
116+
critical_density = cosmology.critical_density(redshift_object, xp=xp) # Msun / kpc^3
117+
118+
critical_surface_density = cosmology.critical_surface_density_between_redshifts_solar_mass_per_kpc2_from(
119+
redshift_0=redshift_object,
120+
redshift_1=redshift_source,
121+
xp=xp,
26122
) # Msun / kpc^2
27123

28-
kpc_per_arcsec = cosmology.kpc_per_arcsec_from(
29-
redshift=redshift_object, xp=np
30-
) # kpc / arcsec
124+
kpc_per_arcsec = cosmology.kpc_per_arcsec_from(redshift=redshift_object, xp=xp) # kpc / arcsec
31125

32-
if overdens == 0:
33-
x = cosmology.Om(redshift_object, xp=np) - 1.0
34-
overdens = 18.0 * np.pi**2 + 82.0 * x - 39.0 * x**2 # Bryan & Norman (1998)
126+
# Bryan & Norman (1998) overdensity if overdens == 0
127+
x = cosmology.Om(redshift_object, xp=xp) - 1.0
128+
overdens_bn98 = 18.0 * xp.pi**2 + 82.0 * x - 39.0 * x**2
129+
overdens = xp.where(overdens == 0, overdens_bn98, overdens)
35130

36131
# r_vir in kpc
37-
virial_radius = (
38-
virial_mass / (overdens * critical_density * (4.0 * np.pi / 3.0))
39-
) ** (1.0 / 3.0)
132+
virial_radius = (virial_mass / (overdens * critical_density * (4.0 * xp.pi / 3.0))) ** (1.0 / 3.0)
40133

41134
# scale radius in kpc
42135
scale_radius_kpc = virial_radius / concentration
43136

44-
# Normalization integral for gNFW
45-
def integrand(r):
46-
return (r**2 / r**inner_slope) * (1.0 + r / scale_radius_kpc) ** (
47-
inner_slope - 3.0
48-
)
137+
# c = rvir/rs is exactly "concentration" by definition
138+
c = concentration
49139

50-
de_c = (
51-
(overdens / 3.0)
52-
* (virial_radius**3 / scale_radius_kpc**inner_slope)
53-
/ quad(integrand, 0.0, virial_radius)[0]
54-
)
140+
# Analytic normalization
141+
a = 3.0 - gamma
142+
de_c = (overdens / 3.0) * a * (c**gamma) / hyp2f1(a, a, a + 1.0, -c)
55143

56144
rho_s = critical_density * de_c # Msun / kpc^3
57145
kappa_s = rho_s * scale_radius_kpc / critical_surface_density # dimensionless

0 commit comments

Comments
 (0)