Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -680,3 +680,35 @@ cpu_gpu_tpu_test(
"@jaxite_deps//numpy",
],
)

cpu_gpu_tpu_test(
name = "key_switching_test",
size = "small",
timeout = "long",
srcs = ["jaxite/jaxite_ckks/key_switching_test.py"],
main = "jaxite/jaxite_ckks/key_switching_test.py",
deps = [
":jaxite_ckks",
"@abseil-py//absl/testing:absltest",
"@abseil-py//absl/testing:parameterized",
"@jaxite_deps//jax",
"@jaxite_deps//jaxlib",
"@jaxite_deps//numpy",
],
)

cpu_gpu_tpu_test(
name = "blind_rotate_utils_test",
size = "small",
timeout = "long",
srcs = ["jaxite/jaxite_ckks/blind_rotate_utils_test.py"],
main = "jaxite/jaxite_ckks/blind_rotate_utils_test.py",
deps = [
":jaxite_ckks",
"@abseil-py//absl/testing:absltest",
"@abseil-py//absl/testing:parameterized",
"@jaxite_deps//jax",
"@jaxite_deps//jaxlib",
"@jaxite_deps//numpy",
],
)
141 changes: 139 additions & 2 deletions jaxite/jaxite_ckks/blind_rotate.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,149 @@
"""Blind rotation implementations for CKKS."""

import jax
import jax.numpy as jnp
from jaxite.jaxite_ckks import barrett
from jaxite.jaxite_ckks import basis_conversion
from jaxite.jaxite_ckks import blind_rotate_utils
from jaxite.jaxite_ckks import key_switching
from jaxite.jaxite_ckks import mul
from jaxite.jaxite_ckks import ntt
from jaxite.jaxite_ckks import rescale
from jaxite.jaxite_ckks import types


def hmuxrot(
ct: types.Ciphertext,
hmrkey: types.HMuxRotKey,
j: int,
bc_kernel: basis_conversion.BasisConversionBarrett,
control_index: int,
p_limbs: jax.Array,
mul_kernel: mul.MulPlaintextCiphertextBarrett,
rescale_kernel: rescale.Rescale,
ntt_q: ntt.NTTBarrett,
ntt_p: ntt.NTTBarrett,
) -> types.Ciphertext:
"""Evaluates HMuxRot^(j)(hmrkey_beta, ct).

Computes: P^-1 * [ a(X^{5^{-j}}) * hmrkey_0 + b(X^{5^{-j}}) * hmrkey_1 ] mod Q
Reference: https://eprint.iacr.org/2025/784 Algorithm 5 (with dnum = 1)

Args:
ct: The input ciphertext (a, b) under Q.
hmrkey: The HMuxRot key under PQ.
j: The rotation index.
bc_kernel: The basis conversion kernel.
control_index: The control index for basis conversion Q -> P.
p_limbs: The limbs of P.
mul_kernel: The multiplication kernel under PQ.
rescale_kernel: The rescaling kernel.
ntt_q: The NTT kernel for Q.
ntt_p: The NTT kernel for P.

Returns:
The resulting ciphertext under Q.
"""
# Algorithm 5, line 4 (Automorphism/rotation)
g = pow(5, -j, 2 * ct.data.shape[1])
alpha_rot = blind_rotate_utils.apply_automorphism_ntt(ct.data[1], g)
beta_rot = blind_rotate_utils.apply_automorphism_ntt(ct.data[0], g)

# Stack into a standard ciphertext of shape (2, degree, num_moduli)
ct_rot = types.Ciphertext(
data=jnp.stack([beta_rot, alpha_rot]), moduli=ct.moduli
)

switcher = key_switching.BATKeySwitcher()
return switcher.key_switch(
ct=ct_rot,
key_matrix_bat=hmrkey.key_matrix_bat,
p_limbs=p_limbs,
bc_kernel=bc_kernel,
control_index=control_index,
mul_kernel=mul_kernel,
rescale_kernel=rescale_kernel,
ntt_q=ntt_q,
ntt_p=ntt_p,
r=rescale_kernel.r,
c=rescale_kernel.c,
)


def brot_mux(
ct_in: types.Ciphertext,
mux_key: types.MuxRotationKey,
p_limbs: jax.Array,
bc_kernel: basis_conversion.BasisConversionBarrett,
control_index: int,
mul_kernel: mul.MulPlaintextCiphertextBarrett,
rescale_kernel: rescale.Rescale,
ntt_q: ntt.NTTBarrett,
ntt_p: ntt.NTTBarrett,
) -> types.Ciphertext:
"""Homomorphic Blind Rotation using the Mux Method (BRotMux).

Sequentially applies the MUX-based conditional rotation for each bit of the
rotation index, resulting in a right-rotation of ct_in by the secret index.
Reference: https://eprint.iacr.org/2025/784 Algorithm 3

Args:
ct_in: The input ciphertext under Q.
mux_key: The MuxRotationKey containing the keys for each bit.
p_limbs: The limbs of the auxiliary modulus P.
bc_kernel: The basis conversion kernel.
control_index: The control index for basis conversion Q -> P.
mul_kernel: The multiplication kernel under PQ.
rescale_kernel: The rescaling kernel.
ntt_q: The NTT kernel for Q.
ntt_p: The NTT kernel for P.

Returns:
A Ciphertext under Q representing the rotated ciphertext.
"""
# Algorithm 3, line 1: ctout = ct
ct_out = ct_in
# Algorithm 3, line 2: for k = 0 to n - 1
for k, (hmrkey_jk_0, hmrkey_not_jk_1) in enumerate(mux_key.keys):
# Algorithm 3, line 3: ct0 = HMuxRot(2^k)(hmrkey_jk, ctout)
ct0 = hmuxrot(
ct=ct_out,
hmrkey=hmrkey_jk_0,
j=2**k,
bc_kernel=bc_kernel,
control_index=control_index,
p_limbs=p_limbs,
mul_kernel=mul_kernel,
rescale_kernel=rescale_kernel,
ntt_q=ntt_q,
ntt_p=ntt_p,
)
# Algorithm 3, line 4: ct1 = HMuxRot(0)(hmrkey_1-jk, ctout)
ct1 = hmuxrot(
ct=ct_out,
hmrkey=hmrkey_not_jk_1,
j=0,
bc_kernel=bc_kernel,
control_index=control_index,
p_limbs=p_limbs,
mul_kernel=mul_kernel,
rescale_kernel=rescale_kernel,
ntt_q=ntt_q,
ntt_p=ntt_p,
)
# Algorithm 3, line 5: ctout = ct0 + ct1 mod Q
moduli_expanded = jnp.array(ct0.moduli, dtype=jnp.uint64).reshape(1, 1, -1)
sum_data = ct0.data.astype(jnp.uint64) + ct1.data.astype(jnp.uint64)
sum_reduced = jnp.where(
sum_data >= moduli_expanded, sum_data - moduli_expanded, sum_data
)
ct_out = types.Ciphertext(
data=sum_reduced.astype(jnp.uint32), moduli=ct0.moduli
)

return ct_out


def brot_cm(
cmkey_j: list[types.Ciphertext],
pt_rot_mu_all: list[types.Plaintext],
Expand Down Expand Up @@ -39,8 +176,8 @@ def brot_cm(
if len(cmkey_j) != len(pt_rot_mu_all):
raise ValueError("Lengths of cmkey_j and pt_rot_mu_all must match.")

if not jnp.array_equal(cmkey_j[0].moduli, pt_rot_mu_all[0].moduli):
raise ValueError("Moduli of cmkey_j and pt_rot_mu_all must match.")
if cmkey_j[0].moduli.shape != pt_rot_mu_all[0].moduli.shape:
raise ValueError("Moduli shapes of cmkey_j and pt_rot_mu_all must match.")

ct_data = jnp.stack([ct.data for ct in cmkey_j])

Expand Down
105 changes: 105 additions & 0 deletions jaxite/jaxite_ckks/blind_rotate_test.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,29 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for blind rotation kernels."""

import jax
import jax.numpy as jnp
from jaxite.jaxite_ckks import barrett
from jaxite.jaxite_ckks import basis_conversion
from jaxite.jaxite_ckks import blind_rotate
from jaxite.jaxite_ckks import encode
from jaxite.jaxite_ckks import encrypt
from jaxite.jaxite_ckks import key_gen
from jaxite.jaxite_ckks import mul
from jaxite.jaxite_ckks import ntt
from jaxite.jaxite_ckks import random
from jaxite.jaxite_ckks import rescale
from jaxite.jaxite_ckks import types
Expand Down Expand Up @@ -111,6 +128,94 @@ def test_blind_rotate_cm(self, num_slots, secret_idx):
self.assertAlmostEqual(e.real, d.real, delta=1.0)
self.assertAlmostEqual(e.imag, d.imag, delta=1.0)

@parameterized.named_parameters(
("secret_idx_0", 4, 0),
("secret_idx_1", 4, 1),
("secret_idx_2", 4, 2),
("secret_idx_3", 4, 3),
("dense_8_slots", 8, 5),
("dense_16_slots", 16, 9),
)
def test_brot_mux(self, num_slots, secret_idx):
degree = max(1024, 2 * num_slots)
r = 32
q_limbs = [1073184769]
p_limbs = [1073479681]
all_moduli = q_limbs + p_limbs
scale = 2**22

# 1. Generate keys
test_random_source = random.ZeroNoiseRandomSource()
pk_q, sk_q = key_gen.keygen(
degree, q_limbs, random_source=test_random_source
)
# 2. Generate Mux Rotation Key for secret index
# We choose the secret index bits corresponding to secret_idx
num_bits = int(np.log2(num_slots))
secret_bits = [int((secret_idx >> k) & 1) for k in range(num_bits)]
mux_key = key_gen.gen_mux_rotation_key(
sk=sk_q,
secret_bits=secret_bits,
q_limbs=q_limbs,
p_limbs=p_limbs,
random_source=test_random_source,
)

# 3. Setup input ciphertext ct_in
mu = np.array(
[complex(x % 4 + 1, x % 4 + 2) for x in range(num_slots)], dtype=complex
)

bc_kernel = basis_conversion.BasisConversionBarrett()
bc_kernel.precompute_constants(all_moduli, [([0], [1])])
mul_constants = barrett.precompute_barrett_constants(all_moduli)
mul_kernel = mul.MulPlaintextCiphertextBarrett(mul_constants)

rescale_kernel = rescale.Rescale()
rescale_kernel.precompute_constants(
all_moduli, num_rescales=1, r=r, c=degree // r
)

ntt_q = ntt.NTTBarrett()
ntt_q.precompute_constants(q_limbs, r, degree // r)
ntt_p = ntt.NTTBarrett()
ntt_p.precompute_constants(p_limbs, r, degree // r)

encoder_q = encode.Encode(degree, q_limbs, scale)
encryptor_q = encrypt.Encrypt(pk_q)

plain_mu = encoder_q.encode(mu.tolist())
ct_in = encryptor_q.encrypt(plain_mu, random_source=test_random_source)

# 4. Run homomorphic BRotMux
ct_res = blind_rotate.brot_mux(
ct_in=ct_in,
mux_key=mux_key,
p_limbs=jnp.array(p_limbs, dtype=jnp.uint32),
bc_kernel=bc_kernel,
control_index=0,
mul_kernel=mul_kernel,
rescale_kernel=rescale_kernel,
ntt_q=ntt_q,
ntt_p=ntt_p,
)

# 5. Decrypt and verify result
decryptor_q = encrypt.Decrypt(sk_q)
pt_dec = decryptor_q.decrypt(ct_res)
decoder = encode.Decode(scale, num_slots)
decoded = decoder.decode(pt_dec)

full_slots = degree // 2
mu_full = np.zeros(full_slots, dtype=complex)
mu_full[:num_slots] = mu
expected_full = _negacyclic_roll(mu_full, secret_idx)
expected = expected_full[:num_slots]

for e, d in zip(expected, decoded):
self.assertAlmostEqual(e.real, d.real, delta=1.5)
self.assertAlmostEqual(e.imag, d.imag, delta=1.5)


if __name__ == "__main__":
absltest.main()
Loading
Loading