Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 0 additions & 25 deletions jaxite/jaxite_ckks/conjugate.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,35 +17,10 @@
import jax
import jax.numpy as jnp
from jaxite.jaxite_ckks import basis_conversion
from jaxite.jaxite_ckks import key_gen
from jaxite.jaxite_ckks import key_switching
from jaxite.jaxite_ckks import mul
from jaxite.jaxite_ckks import random
from jaxite.jaxite_ckks import rescale
from jaxite.jaxite_ckks import types
import numpy as np


def gen_conjugate_key(
sk: types.SecretKey,
q_limbs: list[int],
p_limbs: list[int],
dnum: int,
random_source: random.RandomSource | None = None,
) -> types.EvaluationKeys:
"""Generates the key switching key for the conjugate automorphism."""
# Conjugate of secret key: s(X^-1) is represented by reversing the index
# of the secret key in NTT domain.
s_conj_data = np.flip(sk.data, axis=0)
s_conj = types.SecretKey(s_conj_data, sk.moduli)
return key_gen.gen_key_switching_key(
source_key=s_conj,
dest_key=sk,
q_limbs=q_limbs,
p_limbs=p_limbs,
dnum=dnum,
random_source=random_source,
)


@jax.tree_util.register_pytree_node_class
Expand Down
105 changes: 102 additions & 3 deletions jaxite/jaxite_ckks/conjugate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,16 @@

import hypothesis
from hypothesis import strategies as st
import jax
from jaxite.jaxite_ckks import barrett
from jaxite.jaxite_ckks import basis_conversion
from jaxite.jaxite_ckks import conjugate
from jaxite.jaxite_ckks import encode
from jaxite.jaxite_ckks import encrypt
from jaxite.jaxite_ckks import key_gen
from jaxite.jaxite_ckks import mul
from jaxite.jaxite_ckks import ntt_cpu
from jaxite.jaxite_ckks import random
from jaxite.jaxite_ckks import rescale
from jaxite.jaxite_ckks import types
import numpy as np

from absl.testing import absltest
Expand All @@ -38,7 +37,7 @@ def _run_conjugate_test(
degree, q_limbs, random_source=test_random_source
)

conj_key = conjugate.gen_conjugate_key(
conj_key = key_gen.gen_conjugate_key(
sk=sk_q,
q_limbs=q_limbs,
p_limbs=p_limbs,
Expand Down Expand Up @@ -184,6 +183,106 @@ def test_conjugate_hypothesis(self, slots):
mu=np.array(slots, dtype=complex),
)

def test_conjugate_of_conjugate(self):
degree = 16
num_slots = 8
q_limbs = [1073184769]
p_limbs = [1073479681, 1073741953]

all_moduli = q_limbs + p_limbs
scale = 2**20

test_random_source = random.ZeroNoiseRandomSource()

pk_q, sk_q = key_gen.keygen(
degree, q_limbs, random_source=test_random_source
)

conj_key = key_gen.gen_conjugate_key(
sk=sk_q,
q_limbs=q_limbs,
p_limbs=p_limbs,
dnum=1,
random_source=test_random_source,
)

bc_kernel = basis_conversion.BasisConversionBarrett()
bc_kernel.precompute_constants(all_moduli, [([0], [1, 2])])

barrett_constants_pq = barrett.precompute_barrett_constants(all_moduli)
mul_kernel = mul.MulPlaintextCiphertextBarrett(barrett_constants_pq)
rescale_kernel = rescale.Rescale()
rescale_kernel.precompute_constants(
all_moduli, num_rescales=len(p_limbs), r=4, c=4
)

conjugate_kernel = conjugate.Conjugation()
conjugate_kernel.precompute_constants(
q_limbs=q_limbs, p_limbs=p_limbs, dnum=1, r=4, c=4
)

mu = np.array(
[complex(x % 4 + 1.5, x % 4 - 2.5) for x in range(num_slots)],
dtype=complex,
)
encoder_q = encode.Encode(degree, q_limbs, scale)
encryptor_q = encrypt.Encrypt(pk_q)

plain_mu = encoder_q.encode(mu.tolist())
ct_in = encryptor_q.encrypt(plain_mu, random_source=test_random_source)

ct_conj = conjugate_kernel.conjugate(
ct=ct_in,
conj_key=conj_key,
p_limbs=p_limbs,
bc_kernel=bc_kernel,
mul_kernel=mul_kernel,
rescale_kernel=rescale_kernel,
start_control_index=0,
)

ct_conj_conj = conjugate_kernel.conjugate(
ct=ct_conj,
conj_key=conj_key,
p_limbs=p_limbs,
bc_kernel=bc_kernel,
mul_kernel=mul_kernel,
rescale_kernel=rescale_kernel,
start_control_index=0,
)

decryptor_q = encrypt.Decrypt(sk_q)
pt_dec = decryptor_q.decrypt(ct_conj_conj)

decoder = encode.Decode(scale, num_slots)
decoded = decoder.decode(pt_dec)

for e, d in zip(mu, decoded):
self.assertAlmostEqual(e.real, d.real, delta=0.5)
self.assertAlmostEqual(e.imag, d.imag, delta=0.5)

def test_conjugation_ntt_preservation(self):
degree = 8
q = 1073184769

np.random.seed(42)
poly = np.random.randint(0, q, size=degree).astype(np.uint64)

poly_rot = np.zeros_like(poly)
poly_rot[0] = poly[0]
for i in range(1, degree):
val = poly[i]
target_idx = degree - i
poly_rot[target_idx] = (q - val) % q

ntt_poly = ntt_cpu.ntt_negacyclic_poly(poly.reshape(degree, 1), [q])
ntt_poly_rot = ntt_cpu.ntt_negacyclic_poly(poly_rot.reshape(degree, 1), [q])

ntt_poly_flipped = np.flip(ntt_poly, axis=0)

np.testing.assert_array_equal(ntt_poly_flipped, ntt_poly_rot)


if __name__ == "__main__":

absltest.main()
22 changes: 22 additions & 0 deletions jaxite/jaxite_ckks/key_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,25 @@ def gen_cm_keys(
cm_keys[i, j] = encryptor.encrypt(plain_0)

return cm_keys


def gen_conjugate_key(
sk: types.SecretKey,
q_limbs: list[int],
p_limbs: list[int],
dnum: int,
random_source: random.RandomSource | None = None,
) -> types.EvaluationKeys:
"""Generates the key switching key for the conjugate automorphism."""
# Conjugate of secret key: s(X^-1) is represented by reversing the index
# of the secret key in NTT domain.
s_conj_data = np.flip(sk.data, axis=0)
s_conj = types.SecretKey(s_conj_data, sk.moduli)
return gen_key_switching_key(
source_key=s_conj,
dest_key=sk,
q_limbs=q_limbs,
p_limbs=p_limbs,
dnum=dnum,
random_source=random_source,
)
Loading