|
1 | 1 | import jax |
2 | 2 | import jax.numpy as jnp |
3 | 3 |
|
4 | | -from isotropic.e2 import F_j, get_e2_coeffs |
5 | | -from isotropic.orthonormal import get_orthonormal_basis |
6 | | -from isotropic.thetazero import get_theta_zero |
7 | | -from isotropic.utils.distribution import normal_integrand |
8 | | -from isotropic.utils.state_transforms import ( |
9 | | - add_isotropic_error, |
10 | | - hypersphere_to_statevector, |
11 | | - statevector_to_hypersphere, |
12 | | -) |
| 4 | +from isotropic.utils.state_transforms import generate_and_add_isotropic_error |
13 | 5 |
|
14 | 6 |
|
15 | 7 | def test_add_isotropic_error(): |
16 | 8 | Phi = jnp.ones(4, dtype=complex) |
17 | 9 | Phi = Phi / jnp.linalg.norm(Phi) |
18 | | - Phi_spherical = statevector_to_hypersphere(Phi) |
19 | | - basis = get_orthonormal_basis( |
20 | | - Phi_spherical |
21 | | - ) # gives d vectors with d+1 elements each |
22 | | - _, coeffs = get_e2_coeffs( |
23 | | - d=basis.shape[0], # gives d coefficients for the d vectors above |
24 | | - F_j=F_j, |
25 | | - key=jax.random.PRNGKey(0), |
26 | | - ) |
27 | | - e2 = jnp.expand_dims(coeffs, axis=-1) * basis |
28 | 10 |
|
29 | | - # orthogonality check |
30 | | - assert jnp.allclose(jnp.abs(jnp.dot(e2, Phi_spherical)), 0.0), ( |
31 | | - f"Expected 0.0, got {jnp.abs(jnp.dot(basis, Phi_spherical))}" |
| 11 | + Psi = generate_and_add_isotropic_error( |
| 12 | + Phi=Phi, |
| 13 | + sigma=0.9, |
| 14 | + key=jax.random.PRNGKey(0), |
32 | 15 | ) |
33 | 16 |
|
34 | | - def g(theta): |
35 | | - return normal_integrand(theta, d=Phi_spherical.shape[0], sigma=0.96) |
36 | | - |
37 | | - theta_zero = get_theta_zero(x=0.5, g=g) |
38 | | - Psi_spherical = add_isotropic_error(Phi_spherical, e2=e2, theta_zero=theta_zero) |
39 | | - Psi = hypersphere_to_statevector(Psi_spherical) |
40 | | - |
41 | 17 | # normalization check |
42 | 18 | assert jnp.isclose(jnp.linalg.norm(Psi), 1.0), ( |
43 | 19 | f"Expected 1.0, got {jnp.linalg.norm(Psi)}" |
|
0 commit comments