Skip to content

Commit b07bbb9

Browse files
committed
utility to generate and add error to state
1 parent 656704a commit b07bbb9

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

src/isotropic/utils/state_transforms.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,14 @@
33
from math import log
44

55
import jax.numpy as jnp
6+
import jax.random as random
67
from jax import Array
8+
from jax.typing import ArrayLike
9+
10+
from isotropic.e2 import F_j, get_e2_coeffs
11+
from isotropic.orthonormal import get_orthonormal_basis
12+
from isotropic.thetazero import get_theta_zero
13+
from isotropic.utils.distribution import normal_integrand
714

815

916
def statevector_to_hypersphere(Phi: Array) -> Array:
@@ -69,3 +76,44 @@ def add_isotropic_error(Phi_sp: Array, e2: Array, theta_zero: float) -> Array:
6976
(jnp.sum(e2, axis=0)) * jnp.sin(theta_zero)
7077
)
7178
return Psi_sp
79+
80+
81+
def generate_and_add_isotropic_error(
82+
Phi: ArrayLike, sigma: float = 0.9, key: ArrayLike = random.PRNGKey(0)
83+
) -> Array:
84+
"""
85+
Generate and add isotropic error to a given statevector.
86+
87+
Parameters
88+
----------
89+
Phi : ArrayLike
90+
The input statevector as a complex JAX array of dimension 2^n, for n-qubits.
91+
sigma : float, optional
92+
The standard deviation for the isotropic error, by default 0.9.
93+
key : ArrayLike, optional
94+
Random key for reproducibility, by default random.PRNGKey(0).
95+
96+
Returns
97+
-------
98+
Array
99+
The perturbed statevector after adding isotropic error.
100+
"""
101+
Phi_spherical = statevector_to_hypersphere(Phi)
102+
basis = get_orthonormal_basis(
103+
Phi_spherical
104+
) # gives d vectors with d+1 elements each
105+
_, coeffs = get_e2_coeffs(
106+
d=basis.shape[0], # gives d coefficients for the d vectors above
107+
F_j=F_j,
108+
key=key,
109+
)
110+
e2 = jnp.expand_dims(coeffs, axis=-1) * basis
111+
112+
def g(theta):
113+
return normal_integrand(theta, d=Phi_spherical.shape[0], sigma=sigma)
114+
115+
x = random.uniform(key, shape=(), minval=0, maxval=1)
116+
theta_zero = get_theta_zero(x=x, g=g)
117+
Psi_spherical = add_isotropic_error(Phi_spherical, e2=e2, theta_zero=theta_zero)
118+
Psi = hypersphere_to_statevector(Psi_spherical)
119+
return Psi

0 commit comments

Comments
 (0)