|
3 | 3 | from math import log |
4 | 4 |
|
5 | 5 | import jax.numpy as jnp |
| 6 | +import jax.random as random |
6 | 7 | from jax import Array |
| 8 | +from jax.typing import ArrayLike |
| 9 | + |
| 10 | +from isotropic.e2 import F_j, get_e2_coeffs |
| 11 | +from isotropic.orthonormal import get_orthonormal_basis |
| 12 | +from isotropic.thetazero import get_theta_zero |
| 13 | +from isotropic.utils.distribution import normal_integrand |
7 | 14 |
|
8 | 15 |
|
9 | 16 | def statevector_to_hypersphere(Phi: Array) -> Array: |
@@ -69,3 +76,44 @@ def add_isotropic_error(Phi_sp: Array, e2: Array, theta_zero: float) -> Array: |
69 | 76 | (jnp.sum(e2, axis=0)) * jnp.sin(theta_zero) |
70 | 77 | ) |
71 | 78 | return Psi_sp |
| 79 | + |
| 80 | + |
| 81 | +def generate_and_add_isotropic_error( |
| 82 | + Phi: ArrayLike, sigma: float = 0.9, key: ArrayLike = random.PRNGKey(0) |
| 83 | +) -> Array: |
| 84 | + """ |
| 85 | + Generate and add isotropic error to a given statevector. |
| 86 | +
|
| 87 | + Parameters |
| 88 | + ---------- |
| 89 | + Phi : ArrayLike |
| 90 | + The input statevector as a complex JAX array of dimension 2^n, for n-qubits. |
| 91 | + sigma : float, optional |
| 92 | + The standard deviation for the isotropic error, by default 0.9. |
| 93 | + key : ArrayLike, optional |
| 94 | + Random key for reproducibility, by default random.PRNGKey(0). |
| 95 | +
|
| 96 | + Returns |
| 97 | + ------- |
| 98 | + Array |
| 99 | + The perturbed statevector after adding isotropic error. |
| 100 | + """ |
| 101 | + Phi_spherical = statevector_to_hypersphere(Phi) |
| 102 | + basis = get_orthonormal_basis( |
| 103 | + Phi_spherical |
| 104 | + ) # gives d vectors with d+1 elements each |
| 105 | + _, coeffs = get_e2_coeffs( |
| 106 | + d=basis.shape[0], # gives d coefficients for the d vectors above |
| 107 | + F_j=F_j, |
| 108 | + key=key, |
| 109 | + ) |
| 110 | + e2 = jnp.expand_dims(coeffs, axis=-1) * basis |
| 111 | + |
| 112 | + def g(theta): |
| 113 | + return normal_integrand(theta, d=Phi_spherical.shape[0], sigma=sigma) |
| 114 | + |
| 115 | + x = random.uniform(key, shape=(), minval=0, maxval=1) |
| 116 | + theta_zero = get_theta_zero(x=x, g=g) |
| 117 | + Psi_spherical = add_isotropic_error(Phi_spherical, e2=e2, theta_zero=theta_zero) |
| 118 | + Psi = hypersphere_to_statevector(Psi_spherical) |
| 119 | + return Psi |
0 commit comments