Skip to content

Commit 2513cbe

Browse files
committed
refactor to distinguish between jax & scipy double fact
1 parent 63ae077 commit 2513cbe

File tree

3 files changed

+37
-36
lines changed

3 files changed

+37
-36
lines changed

src/isotropic/e2.py

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,28 +6,9 @@
66
import jax.random as random
77
from jax import Array
88
from jax.typing import ArrayLike
9-
from scipy.special import factorial2
109

1110
from isotropic.utils.bisection import get_theta
12-
13-
14-
def double_factorial_ratio(n: int, m: int) -> float:
15-
"""
16-
Compute the ratio of double factorials n!! / m!!.
17-
18-
Parameters
19-
----------
20-
n : int
21-
The numerator double factorial.
22-
m : int
23-
The denominator double factorial.
24-
25-
Returns
26-
-------
27-
float
28-
The ratio of the double factorials.
29-
"""
30-
return factorial2(n) / factorial2(m)
11+
from isotropic.utils.distribution import double_factorial_ratio_scipy
3112

3213

3314
def F_j(theta_j: float, j: int, d: int) -> Array:
@@ -49,10 +30,10 @@ def F_j(theta_j: float, j: int, d: int) -> Array:
4930
The value of the function F_j evaluated at theta_j.
5031
"""
5132
dj = d - j
52-
numoverden = double_factorial_ratio(dj - 2, dj - 1)
33+
numoverden = double_factorial_ratio_scipy(dj - 2, dj - 1)
5334

5435
def F_odd(_):
55-
C_j = (1 / 2) * double_factorial_ratio(dj - 1, dj - 2)
36+
C_j = (1 / 2) * double_factorial_ratio_scipy(dj - 1, dj - 2)
5637
prefactor = C_j * numoverden
5738
k_max = (dj - 2) // 2 # upper bound for k in range
5839
k_vals = jnp.arange(0, k_max + 1)
@@ -72,7 +53,7 @@ def product_term(k):
7253
return prefactor - C_j * jnp.cos(theta_j) * sum_terms
7354

7455
def F_even(_):
75-
C_j = (1 / jnp.pi) * double_factorial_ratio(dj - 1, dj - 2)
56+
C_j = (1 / jnp.pi) * double_factorial_ratio_scipy(dj - 1, dj - 2)
7657
prefactor = C_j * numoverden * theta_j
7758
k_max = (dj - 1) // 2
7859
k_vals = jnp.arange(1, k_max + 1)

src/isotropic/utils/distribution.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,29 @@
22

33
import jax.numpy as jnp
44
from jax import Array
5+
from scipy.special import factorial2
56

67

7-
def double_factorial(n: int) -> Array:
8+
def double_factorial_ratio_scipy(num: int, den: int) -> float:
9+
"""
10+
Compute the ratio of double factorials num!! / den!!.
11+
12+
Parameters
13+
----------
14+
num : int
15+
The numerator double factorial.
16+
den : int
17+
The denominator double factorial.
18+
19+
Returns
20+
-------
21+
float
22+
The ratio of the double factorials.
23+
"""
24+
return factorial2(num) / factorial2(den)
25+
26+
27+
def double_factorial_jax(n: int) -> Array:
828
"""
929
Helper function to compute double factorial:
1030
@@ -24,7 +44,7 @@ def double_factorial(n: int) -> Array:
2444
return jnp.where(n <= 0, 1, jnp.prod(jnp.arange(n, 0, -2, dtype=jnp.uint64)))
2545

2646

27-
def double_factorial_ratio(num: int, den: int) -> Array:
47+
def double_factorial_ratio_jax(num: int, den: int) -> Array:
2848
"""
2949
Computes the ratio of double factorials:
3050
@@ -101,8 +121,8 @@ def normal_integrand(theta: float, d: int, sigma: float) -> Array:
101121
# sigma = jnp.asarray(sigma)
102122

103123
# factorial components
104-
numerator_factorial = double_factorial(d - 1)
105-
denominator_factorial = double_factorial(d - 2)
124+
numerator_factorial = factorial2(d - 1)
125+
denominator_factorial = factorial2(d - 2)
106126

107127
# Numerator components
108128
one_minus_sigma_sq = 1.0 - sigma**2

tests/test_utils.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
from isotropic.utils.bisection import get_theta
66
from isotropic.utils.distribution import (
7-
double_factorial,
8-
double_factorial_ratio,
7+
double_factorial_jax,
8+
double_factorial_ratio_jax,
99
normal_integrand,
1010
)
1111
from isotropic.utils.linalg import jax_null_space
@@ -58,42 +58,42 @@ def F(theta):
5858
)
5959

6060

61-
def test_double_factorial():
61+
def test_double_factorial_jax():
6262
# Test even double factorial
6363
n_even = 6
64-
result_even = double_factorial(n_even)
64+
result_even = double_factorial_jax(n_even)
6565
expected_even = factorial2(n_even)
6666
assert jnp.isclose(result_even, expected_even), (
6767
f"Expected {expected_even}, got {result_even}"
6868
)
6969

7070
# Test odd double factorial
7171
n_odd = 5
72-
result_odd = double_factorial(n_odd)
72+
result_odd = double_factorial_jax(n_odd)
7373
expected_odd = factorial2(n_odd)
7474
assert jnp.isclose(result_odd, expected_odd), (
7575
f"Expected {expected_odd}, got {result_odd}"
7676
)
7777

7878
# Test zero double factorial
7979
n_zero = 0
80-
result_zero = double_factorial(n_zero)
80+
result_zero = double_factorial_jax(n_zero)
8181
expected_zero = factorial2(n_zero)
8282
assert jnp.isclose(result_zero, expected_zero), (
8383
f"Expected {expected_zero}, got {result_zero}"
8484
)
8585

8686

87-
def test_double_factorial_ratio():
87+
def test_double_factorial_ratio_jax():
8888
num, den = (2**8) - 1, (2**8) - 2
89-
ratio_received = double_factorial_ratio(num, den)
89+
ratio_received = double_factorial_ratio_jax(num, den)
9090
ratio_expected = factorial2(num) / factorial2(den)
9191
assert jnp.isclose(ratio_received, ratio_expected), (
9292
f"Expected {ratio_expected}, got {ratio_received}"
9393
)
9494

9595
num, den = (2**8) - 3, (2**8) - 1
96-
ratio_received = double_factorial_ratio(num, den)
96+
ratio_received = double_factorial_ratio_jax(num, den)
9797
ratio_expected = factorial2(num) / factorial2(den)
9898
assert jnp.isclose(ratio_received, ratio_expected), (
9999
f"Expected {ratio_expected}, got {ratio_received}"

0 commit comments

Comments
 (0)