diff --git a/jaxite/jaxite_ckks/conjugate.py b/jaxite/jaxite_ckks/conjugate.py index 8048083..92f1166 100644 --- a/jaxite/jaxite_ckks/conjugate.py +++ b/jaxite/jaxite_ckks/conjugate.py @@ -17,35 +17,10 @@ import jax import jax.numpy as jnp from jaxite.jaxite_ckks import basis_conversion -from jaxite.jaxite_ckks import key_gen from jaxite.jaxite_ckks import key_switching from jaxite.jaxite_ckks import mul -from jaxite.jaxite_ckks import random from jaxite.jaxite_ckks import rescale from jaxite.jaxite_ckks import types -import numpy as np - - -def gen_conjugate_key( - sk: types.SecretKey, - q_limbs: list[int], - p_limbs: list[int], - dnum: int, - random_source: random.RandomSource | None = None, -) -> types.EvaluationKeys: - """Generates the key switching key for the conjugate automorphism.""" - # Conjugate of secret key: s(X^-1) is represented by reversing the index - # of the secret key in NTT domain. - s_conj_data = np.flip(sk.data, axis=0) - s_conj = types.SecretKey(s_conj_data, sk.moduli) - return key_gen.gen_key_switching_key( - source_key=s_conj, - dest_key=sk, - q_limbs=q_limbs, - p_limbs=p_limbs, - dnum=dnum, - random_source=random_source, - ) @jax.tree_util.register_pytree_node_class diff --git a/jaxite/jaxite_ckks/conjugate_test.py b/jaxite/jaxite_ckks/conjugate_test.py index ca89725..43fbca7 100644 --- a/jaxite/jaxite_ckks/conjugate_test.py +++ b/jaxite/jaxite_ckks/conjugate_test.py @@ -4,7 +4,6 @@ import hypothesis from hypothesis import strategies as st -import jax from jaxite.jaxite_ckks import barrett from jaxite.jaxite_ckks import basis_conversion from jaxite.jaxite_ckks import conjugate @@ -12,9 +11,9 @@ from jaxite.jaxite_ckks import encrypt from jaxite.jaxite_ckks import key_gen from jaxite.jaxite_ckks import mul +from jaxite.jaxite_ckks import ntt_cpu from jaxite.jaxite_ckks import random from jaxite.jaxite_ckks import rescale -from jaxite.jaxite_ckks import types import numpy as np from absl.testing import absltest @@ -38,7 +37,7 @@ def _run_conjugate_test( degree, q_limbs, random_source=test_random_source ) - conj_key = conjugate.gen_conjugate_key( + conj_key = key_gen.gen_conjugate_key( sk=sk_q, q_limbs=q_limbs, p_limbs=p_limbs, @@ -184,6 +183,106 @@ def test_conjugate_hypothesis(self, slots): mu=np.array(slots, dtype=complex), ) + def test_conjugate_of_conjugate(self): + degree = 16 + num_slots = 8 + q_limbs = [1073184769] + p_limbs = [1073479681, 1073741953] + + all_moduli = q_limbs + p_limbs + scale = 2**20 + + test_random_source = random.ZeroNoiseRandomSource() + + pk_q, sk_q = key_gen.keygen( + degree, q_limbs, random_source=test_random_source + ) + + conj_key = key_gen.gen_conjugate_key( + sk=sk_q, + q_limbs=q_limbs, + p_limbs=p_limbs, + dnum=1, + random_source=test_random_source, + ) + + bc_kernel = basis_conversion.BasisConversionBarrett() + bc_kernel.precompute_constants(all_moduli, [([0], [1, 2])]) + + barrett_constants_pq = barrett.precompute_barrett_constants(all_moduli) + mul_kernel = mul.MulPlaintextCiphertextBarrett(barrett_constants_pq) + rescale_kernel = rescale.Rescale() + rescale_kernel.precompute_constants( + all_moduli, num_rescales=len(p_limbs), r=4, c=4 + ) + + conjugate_kernel = conjugate.Conjugation() + conjugate_kernel.precompute_constants( + q_limbs=q_limbs, p_limbs=p_limbs, dnum=1, r=4, c=4 + ) + + mu = np.array( + [complex(x % 4 + 1.5, x % 4 - 2.5) for x in range(num_slots)], + dtype=complex, + ) + encoder_q = encode.Encode(degree, q_limbs, scale) + encryptor_q = encrypt.Encrypt(pk_q) + + plain_mu = encoder_q.encode(mu.tolist()) + ct_in = encryptor_q.encrypt(plain_mu, random_source=test_random_source) + + ct_conj = conjugate_kernel.conjugate( + ct=ct_in, + conj_key=conj_key, + p_limbs=p_limbs, + bc_kernel=bc_kernel, + mul_kernel=mul_kernel, + rescale_kernel=rescale_kernel, + start_control_index=0, + ) + + ct_conj_conj = conjugate_kernel.conjugate( + ct=ct_conj, + conj_key=conj_key, + p_limbs=p_limbs, + bc_kernel=bc_kernel, + mul_kernel=mul_kernel, + rescale_kernel=rescale_kernel, + start_control_index=0, + ) + + decryptor_q = encrypt.Decrypt(sk_q) + pt_dec = decryptor_q.decrypt(ct_conj_conj) + + decoder = encode.Decode(scale, num_slots) + decoded = decoder.decode(pt_dec) + + for e, d in zip(mu, decoded): + self.assertAlmostEqual(e.real, d.real, delta=0.5) + self.assertAlmostEqual(e.imag, d.imag, delta=0.5) + + def test_conjugation_ntt_preservation(self): + degree = 8 + q = 1073184769 + + np.random.seed(42) + poly = np.random.randint(0, q, size=degree).astype(np.uint64) + + poly_rot = np.zeros_like(poly) + poly_rot[0] = poly[0] + for i in range(1, degree): + val = poly[i] + target_idx = degree - i + poly_rot[target_idx] = (q - val) % q + + ntt_poly = ntt_cpu.ntt_negacyclic_poly(poly.reshape(degree, 1), [q]) + ntt_poly_rot = ntt_cpu.ntt_negacyclic_poly(poly_rot.reshape(degree, 1), [q]) + + ntt_poly_flipped = np.flip(ntt_poly, axis=0) + + np.testing.assert_array_equal(ntt_poly_flipped, ntt_poly_rot) + if __name__ == "__main__": + absltest.main() diff --git a/jaxite/jaxite_ckks/key_gen.py b/jaxite/jaxite_ckks/key_gen.py index 8e8408f..39666f2 100644 --- a/jaxite/jaxite_ckks/key_gen.py +++ b/jaxite/jaxite_ckks/key_gen.py @@ -214,3 +214,25 @@ def gen_cm_keys( cm_keys[i, j] = encryptor.encrypt(plain_0) return cm_keys + + +def gen_conjugate_key( + sk: types.SecretKey, + q_limbs: list[int], + p_limbs: list[int], + dnum: int, + random_source: random.RandomSource | None = None, +) -> types.EvaluationKeys: + """Generates the key switching key for the conjugate automorphism.""" + # Conjugate of secret key: s(X^-1) is represented by reversing the index + # of the secret key in NTT domain. + s_conj_data = np.flip(sk.data, axis=0) + s_conj = types.SecretKey(s_conj_data, sk.moduli) + return gen_key_switching_key( + source_key=s_conj, + dest_key=sk, + q_limbs=q_limbs, + p_limbs=p_limbs, + dnum=dnum, + random_source=random_source, + )