Skip to content

Commit 27bab00

Browse files
committed
use scipy double factorial
1 parent 0be87e1 commit 27bab00

File tree

5 files changed

+51
-30
lines changed

5 files changed

+51
-30
lines changed

src/isotropic/e2.py

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,28 @@
66
import jax.random as random
77
from jax import Array
88
from jax.typing import ArrayLike
9+
from scipy.special import factorial2
910

1011
from isotropic.utils.bisection import get_theta
11-
from isotropic.utils.distribution import double_factorial
12+
13+
14+
def double_factorial_ratio(n: int, m: int) -> float:
15+
"""
16+
Compute the ratio of double factorials n!! / m!!.
17+
18+
Parameters
19+
----------
20+
n : int
21+
The numerator double factorial.
22+
m : int
23+
The denominator double factorial.
24+
25+
Returns
26+
-------
27+
float
28+
The ratio of the double factorials.
29+
"""
30+
return factorial2(n) / factorial2(m)
1231

1332

1433
def F_j(theta_j: float, j: int, d: int) -> Array:
@@ -30,12 +49,11 @@ def F_j(theta_j: float, j: int, d: int) -> Array:
3049
The value of the function F_j evaluated at theta_j.
3150
"""
3251
dj = d - j
33-
num = double_factorial(dj - 2)
34-
den = double_factorial(dj - 1)
52+
numoverden = double_factorial_ratio(dj - 2, dj - 1)
3553

3654
def F_odd(_):
37-
C_j = double_factorial(dj - 1) / (2 * double_factorial(dj - 2))
38-
prefactor = C_j * num / den
55+
C_j = (1 / 2) * double_factorial_ratio(dj - 1, dj - 2)
56+
prefactor = C_j * numoverden
3957
k_max = (dj - 2) // 2 # upper bound for k in range
4058
k_vals = jnp.arange(0, k_max + 1)
4159

@@ -54,8 +72,8 @@ def product_term(k):
5472
return prefactor - C_j * jnp.cos(theta_j) * sum_terms
5573

5674
def F_even(_):
57-
C_j = double_factorial(dj - 1) / (jnp.pi * double_factorial(dj - 2))
58-
prefactor = C_j * (num / den) * theta_j
75+
C_j = (1 / jnp.pi) * double_factorial_ratio(dj - 1, dj - 2)
76+
prefactor = C_j * numoverden * theta_j
5977
k_max = (dj - 1) // 2
6078
k_vals = jnp.arange(1, k_max + 1)
6179

src/isotropic/utils/distribution.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ def double_factorial_ratio(num: int, den: int) -> Array:
4646
-----
4747
For very large numbers, this is numerically stable only when |num - den| is ~5.
4848
"""
49+
if abs(num - den) > 4:
50+
raise ValueError("num and den should be close to each other")
4951
num_elems = jnp.arange(num, 0, -2, dtype=jnp.uint64)
5052
den_elems = jnp.arange(den, 0, -2, dtype=jnp.uint64)
5153

tests/test_e2.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import jax.numpy as jnp
22
from jax import Array, random
3+
from scipy.special import factorial2
34

45
from isotropic.e2 import F_j, get_e2_coeffs
5-
from isotropic.utils.distribution import double_factorial
66

77

88
def test_get_e2():
@@ -41,9 +41,9 @@ def test_F_j_even():
4141
result = F_j(theta_j, j, d)
4242

4343
# manually calculate expected result
44-
C_j = double_factorial(d - j - 1) / (jnp.pi * double_factorial(d - j - 2))
45-
num = double_factorial(d - j - 2)
46-
den = double_factorial(d - j - 1)
44+
C_j = factorial2(d - j - 1) / (jnp.pi * factorial2(d - j - 2))
45+
num = factorial2(d - j - 2)
46+
den = factorial2(d - j - 1)
4747
prefactor = C_j * (num / den) * theta_j
4848

4949
# k goes from 1 to (d - j - 1) // 2, i.e., 1 to 2
@@ -65,9 +65,9 @@ def test_F_j_odd():
6565
result = F_j(theta_j, j, d)
6666

6767
# manually calculate expected result
68-
C_j = double_factorial(d - j - 1) / (2 * double_factorial(d - j - 2))
69-
num = double_factorial(d - j - 2)
70-
den = double_factorial(d - j - 1)
68+
C_j = factorial2(d - j - 1) / (2 * factorial2(d - j - 2))
69+
num = factorial2(d - j - 2)
70+
den = factorial2(d - j - 1)
7171
prefactor = C_j * num / den
7272

7373
# k goes from 0 to (d - j - 2) // 2, i.e., 0 to 2

tests/test_error_generation.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,9 @@
1313

1414

1515
def test_add_isotropic_error():
16-
Phi_original = jnp.asarray([1 + 0j, 1 + 0j], dtype=complex) / jnp.sqrt(
17-
2
18-
) # n = 1, d = 3
19-
Phi_spherical = statevector_to_hypersphere(Phi_original) # d+1 = 4
16+
Phi = jnp.ones(4, dtype=complex)
17+
Phi = Phi / jnp.linalg.norm(Phi)
18+
Phi_spherical = statevector_to_hypersphere(Phi) # d+1 = 4
2019
basis = get_orthonormal_basis(
2120
Phi_spherical
2221
) # gives d vectors with d+1 elements each
@@ -27,8 +26,13 @@ def test_add_isotropic_error():
2726
)
2827
e2 = jnp.expand_dims(coeffs, axis=-1) * basis
2928

29+
# orthogonality check
30+
assert jnp.allclose(jnp.abs(jnp.dot(e2, Phi_spherical)), 0.0), (
31+
f"Expected 0.0, got {jnp.abs(jnp.dot(basis, Phi_spherical))}"
32+
)
33+
3034
def g(theta):
31-
return normal_integrand(theta, d=Phi_spherical.shape[0], sigma=0.9)
35+
return normal_integrand(theta, d=Phi_spherical.shape[0], sigma=0.6)
3236

3337
theta_zero = get_theta_zero(x=0.5, g=g)
3438
Psi_spherical = add_isotropic_error(Phi_spherical, e2=e2, theta_zero=theta_zero)
@@ -38,7 +42,3 @@ def g(theta):
3842
assert jnp.isclose(jnp.linalg.norm(Psi), 1.0), (
3943
f"Expected 1.0, got {jnp.linalg.norm(Psi)}"
4044
)
41-
42-
# fidelity check
43-
fidelity = jnp.abs(jnp.vdot(Phi_original, Psi)) ** 2
44-
assert 0.0 <= fidelity <= 1.0, f"Expected fidelity in [0, 1], got {fidelity}"

tests/test_utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import jax.numpy as jnp
22
from scipy.linalg import null_space
3+
from scipy.special import factorial2
34

45
from isotropic.utils.bisection import get_theta
56
from isotropic.utils.distribution import (
@@ -61,39 +62,39 @@ def test_double_factorial():
6162
# Test even double factorial
6263
n_even = 6
6364
result_even = double_factorial(n_even)
64-
expected_even = 48.0 # 6!! = 6 * 4 * 2 = 48
65+
expected_even = factorial2(n_even)
6566
assert jnp.isclose(result_even, expected_even), (
6667
f"Expected {expected_even}, got {result_even}"
6768
)
6869

6970
# Test odd double factorial
7071
n_odd = 5
7172
result_odd = double_factorial(n_odd)
72-
expected_odd = 15.0 # 5!! = 5 * 3 * 1 = 15
73+
expected_odd = factorial2(n_odd)
7374
assert jnp.isclose(result_odd, expected_odd), (
7475
f"Expected {expected_odd}, got {result_odd}"
7576
)
7677

7778
# Test zero double factorial
7879
n_zero = 0
7980
result_zero = double_factorial(n_zero)
80-
expected_zero = 1.0
81+
expected_zero = factorial2(n_zero)
8182
assert jnp.isclose(result_zero, expected_zero), (
8283
f"Expected {expected_zero}, got {result_zero}"
8384
)
8485

8586

8687
def test_double_factorial_ratio():
87-
num, den = (2**5) - 1, (2**5) - 2
88+
num, den = (2**8) - 1, (2**8) - 2
8889
ratio_received = double_factorial_ratio(num, den)
89-
ratio_expected = double_factorial(num) / double_factorial(den)
90+
ratio_expected = factorial2(num) / factorial2(den)
9091
assert jnp.isclose(ratio_received, ratio_expected), (
9192
f"Expected {ratio_expected}, got {ratio_received}"
9293
)
9394

94-
num, den = (2**5) - 3, (2**5) - 1
95+
num, den = (2**8) - 3, (2**8) - 1
9596
ratio_received = double_factorial_ratio(num, den)
96-
ratio_expected = double_factorial(num) / double_factorial(den)
97+
ratio_expected = factorial2(num) / factorial2(den)
9798
assert jnp.isclose(ratio_received, ratio_expected), (
9899
f"Expected {ratio_expected}, got {ratio_received}"
99100
)

0 commit comments

Comments
 (0)