diff --git a/BUILD b/BUILD index bfb25b4..6fd6af8 100644 --- a/BUILD +++ b/BUILD @@ -680,3 +680,35 @@ cpu_gpu_tpu_test( "@jaxite_deps//numpy", ], ) + +cpu_gpu_tpu_test( + name = "key_switching_test", + size = "small", + timeout = "long", + srcs = ["jaxite/jaxite_ckks/key_switching_test.py"], + main = "jaxite/jaxite_ckks/key_switching_test.py", + deps = [ + ":jaxite_ckks", + "@abseil-py//absl/testing:absltest", + "@abseil-py//absl/testing:parameterized", + "@jaxite_deps//jax", + "@jaxite_deps//jaxlib", + "@jaxite_deps//numpy", + ], +) + +cpu_gpu_tpu_test( + name = "blind_rotate_utils_test", + size = "small", + timeout = "long", + srcs = ["jaxite/jaxite_ckks/blind_rotate_utils_test.py"], + main = "jaxite/jaxite_ckks/blind_rotate_utils_test.py", + deps = [ + ":jaxite_ckks", + "@abseil-py//absl/testing:absltest", + "@abseil-py//absl/testing:parameterized", + "@jaxite_deps//jax", + "@jaxite_deps//jaxlib", + "@jaxite_deps//numpy", + ], +) diff --git a/jaxite/jaxite_ckks/blind_rotate.py b/jaxite/jaxite_ckks/blind_rotate.py index 8a29657..e606c78 100644 --- a/jaxite/jaxite_ckks/blind_rotate.py +++ b/jaxite/jaxite_ckks/blind_rotate.py @@ -1,12 +1,149 @@ """Blind rotation implementations for CKKS.""" +import jax import jax.numpy as jnp from jaxite.jaxite_ckks import barrett +from jaxite.jaxite_ckks import basis_conversion +from jaxite.jaxite_ckks import blind_rotate_utils +from jaxite.jaxite_ckks import key_switching from jaxite.jaxite_ckks import mul +from jaxite.jaxite_ckks import ntt from jaxite.jaxite_ckks import rescale from jaxite.jaxite_ckks import types +def hmuxrot( + ct: types.Ciphertext, + hmrkey: types.HMuxRotKey, + j: int, + bc_kernel: basis_conversion.BasisConversionBarrett, + control_index: int, + p_limbs: jax.Array, + mul_kernel: mul.MulPlaintextCiphertextBarrett, + rescale_kernel: rescale.Rescale, + ntt_q: ntt.NTTBarrett, + ntt_p: ntt.NTTBarrett, +) -> types.Ciphertext: + """Evaluates HMuxRot^(j)(hmrkey_beta, ct). + + Computes: P^-1 * [ a(X^{5^{-j}}) * hmrkey_0 + b(X^{5^{-j}}) * hmrkey_1 ] mod Q + Reference: https://eprint.iacr.org/2025/784 Algorithm 5 (with dnum = 1) + + Args: + ct: The input ciphertext (a, b) under Q. + hmrkey: The HMuxRot key under PQ. + j: The rotation index. + bc_kernel: The basis conversion kernel. + control_index: The control index for basis conversion Q -> P. + p_limbs: The limbs of P. + mul_kernel: The multiplication kernel under PQ. + rescale_kernel: The rescaling kernel. + ntt_q: The NTT kernel for Q. + ntt_p: The NTT kernel for P. + + Returns: + The resulting ciphertext under Q. + """ + # Algorithm 5, line 4 (Automorphism/rotation) + g = pow(5, -j, 2 * ct.data.shape[1]) + alpha_rot = blind_rotate_utils.apply_automorphism_ntt(ct.data[1], g) + beta_rot = blind_rotate_utils.apply_automorphism_ntt(ct.data[0], g) + + # Stack into a standard ciphertext of shape (2, degree, num_moduli) + ct_rot = types.Ciphertext( + data=jnp.stack([beta_rot, alpha_rot]), moduli=ct.moduli + ) + + switcher = key_switching.BATKeySwitcher() + return switcher.key_switch( + ct=ct_rot, + key_matrix_bat=hmrkey.key_matrix_bat, + p_limbs=p_limbs, + bc_kernel=bc_kernel, + control_index=control_index, + mul_kernel=mul_kernel, + rescale_kernel=rescale_kernel, + ntt_q=ntt_q, + ntt_p=ntt_p, + r=rescale_kernel.r, + c=rescale_kernel.c, + ) + + +def brot_mux( + ct_in: types.Ciphertext, + mux_key: types.MuxRotationKey, + p_limbs: jax.Array, + bc_kernel: basis_conversion.BasisConversionBarrett, + control_index: int, + mul_kernel: mul.MulPlaintextCiphertextBarrett, + rescale_kernel: rescale.Rescale, + ntt_q: ntt.NTTBarrett, + ntt_p: ntt.NTTBarrett, +) -> types.Ciphertext: + """Homomorphic Blind Rotation using the Mux Method (BRotMux). + + Sequentially applies the MUX-based conditional rotation for each bit of the + rotation index, resulting in a right-rotation of ct_in by the secret index. + Reference: https://eprint.iacr.org/2025/784 Algorithm 3 + + Args: + ct_in: The input ciphertext under Q. + mux_key: The MuxRotationKey containing the keys for each bit. + p_limbs: The limbs of the auxiliary modulus P. + bc_kernel: The basis conversion kernel. + control_index: The control index for basis conversion Q -> P. + mul_kernel: The multiplication kernel under PQ. + rescale_kernel: The rescaling kernel. + ntt_q: The NTT kernel for Q. + ntt_p: The NTT kernel for P. + + Returns: + A Ciphertext under Q representing the rotated ciphertext. + """ + # Algorithm 3, line 1: ctout = ct + ct_out = ct_in + # Algorithm 3, line 2: for k = 0 to n - 1 + for k, (hmrkey_jk_0, hmrkey_not_jk_1) in enumerate(mux_key.keys): + # Algorithm 3, line 3: ct0 = HMuxRot(2^k)(hmrkey_jk, ctout) + ct0 = hmuxrot( + ct=ct_out, + hmrkey=hmrkey_jk_0, + j=2**k, + bc_kernel=bc_kernel, + control_index=control_index, + p_limbs=p_limbs, + mul_kernel=mul_kernel, + rescale_kernel=rescale_kernel, + ntt_q=ntt_q, + ntt_p=ntt_p, + ) + # Algorithm 3, line 4: ct1 = HMuxRot(0)(hmrkey_1-jk, ctout) + ct1 = hmuxrot( + ct=ct_out, + hmrkey=hmrkey_not_jk_1, + j=0, + bc_kernel=bc_kernel, + control_index=control_index, + p_limbs=p_limbs, + mul_kernel=mul_kernel, + rescale_kernel=rescale_kernel, + ntt_q=ntt_q, + ntt_p=ntt_p, + ) + # Algorithm 3, line 5: ctout = ct0 + ct1 mod Q + moduli_expanded = jnp.array(ct0.moduli, dtype=jnp.uint64).reshape(1, 1, -1) + sum_data = ct0.data.astype(jnp.uint64) + ct1.data.astype(jnp.uint64) + sum_reduced = jnp.where( + sum_data >= moduli_expanded, sum_data - moduli_expanded, sum_data + ) + ct_out = types.Ciphertext( + data=sum_reduced.astype(jnp.uint32), moduli=ct0.moduli + ) + + return ct_out + + def brot_cm( cmkey_j: list[types.Ciphertext], pt_rot_mu_all: list[types.Plaintext], @@ -39,8 +176,8 @@ def brot_cm( if len(cmkey_j) != len(pt_rot_mu_all): raise ValueError("Lengths of cmkey_j and pt_rot_mu_all must match.") - if not jnp.array_equal(cmkey_j[0].moduli, pt_rot_mu_all[0].moduli): - raise ValueError("Moduli of cmkey_j and pt_rot_mu_all must match.") + if cmkey_j[0].moduli.shape != pt_rot_mu_all[0].moduli.shape: + raise ValueError("Moduli shapes of cmkey_j and pt_rot_mu_all must match.") ct_data = jnp.stack([ct.data for ct in cmkey_j]) diff --git a/jaxite/jaxite_ckks/blind_rotate_test.py b/jaxite/jaxite_ckks/blind_rotate_test.py index 8f98a5b..96051b6 100644 --- a/jaxite/jaxite_ckks/blind_rotate_test.py +++ b/jaxite/jaxite_ckks/blind_rotate_test.py @@ -1,12 +1,29 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Tests for blind rotation kernels.""" import jax +import jax.numpy as jnp from jaxite.jaxite_ckks import barrett +from jaxite.jaxite_ckks import basis_conversion from jaxite.jaxite_ckks import blind_rotate from jaxite.jaxite_ckks import encode from jaxite.jaxite_ckks import encrypt from jaxite.jaxite_ckks import key_gen from jaxite.jaxite_ckks import mul +from jaxite.jaxite_ckks import ntt from jaxite.jaxite_ckks import random from jaxite.jaxite_ckks import rescale from jaxite.jaxite_ckks import types @@ -111,6 +128,94 @@ def test_blind_rotate_cm(self, num_slots, secret_idx): self.assertAlmostEqual(e.real, d.real, delta=1.0) self.assertAlmostEqual(e.imag, d.imag, delta=1.0) + @parameterized.named_parameters( + ("secret_idx_0", 4, 0), + ("secret_idx_1", 4, 1), + ("secret_idx_2", 4, 2), + ("secret_idx_3", 4, 3), + ("dense_8_slots", 8, 5), + ("dense_16_slots", 16, 9), + ) + def test_brot_mux(self, num_slots, secret_idx): + degree = max(1024, 2 * num_slots) + r = 32 + q_limbs = [1073184769] + p_limbs = [1073479681] + all_moduli = q_limbs + p_limbs + scale = 2**22 + + # 1. Generate keys + test_random_source = random.ZeroNoiseRandomSource() + pk_q, sk_q = key_gen.keygen( + degree, q_limbs, random_source=test_random_source + ) + # 2. Generate Mux Rotation Key for secret index + # We choose the secret index bits corresponding to secret_idx + num_bits = int(np.log2(num_slots)) + secret_bits = [int((secret_idx >> k) & 1) for k in range(num_bits)] + mux_key = key_gen.gen_mux_rotation_key( + sk=sk_q, + secret_bits=secret_bits, + q_limbs=q_limbs, + p_limbs=p_limbs, + random_source=test_random_source, + ) + + # 3. Setup input ciphertext ct_in + mu = np.array( + [complex(x % 4 + 1, x % 4 + 2) for x in range(num_slots)], dtype=complex + ) + + bc_kernel = basis_conversion.BasisConversionBarrett() + bc_kernel.precompute_constants(all_moduli, [([0], [1])]) + mul_constants = barrett.precompute_barrett_constants(all_moduli) + mul_kernel = mul.MulPlaintextCiphertextBarrett(mul_constants) + + rescale_kernel = rescale.Rescale() + rescale_kernel.precompute_constants( + all_moduli, num_rescales=1, r=r, c=degree // r + ) + + ntt_q = ntt.NTTBarrett() + ntt_q.precompute_constants(q_limbs, r, degree // r) + ntt_p = ntt.NTTBarrett() + ntt_p.precompute_constants(p_limbs, r, degree // r) + + encoder_q = encode.Encode(degree, q_limbs, scale) + encryptor_q = encrypt.Encrypt(pk_q) + + plain_mu = encoder_q.encode(mu.tolist()) + ct_in = encryptor_q.encrypt(plain_mu, random_source=test_random_source) + + # 4. Run homomorphic BRotMux + ct_res = blind_rotate.brot_mux( + ct_in=ct_in, + mux_key=mux_key, + p_limbs=jnp.array(p_limbs, dtype=jnp.uint32), + bc_kernel=bc_kernel, + control_index=0, + mul_kernel=mul_kernel, + rescale_kernel=rescale_kernel, + ntt_q=ntt_q, + ntt_p=ntt_p, + ) + + # 5. Decrypt and verify result + decryptor_q = encrypt.Decrypt(sk_q) + pt_dec = decryptor_q.decrypt(ct_res) + decoder = encode.Decode(scale, num_slots) + decoded = decoder.decode(pt_dec) + + full_slots = degree // 2 + mu_full = np.zeros(full_slots, dtype=complex) + mu_full[:num_slots] = mu + expected_full = _negacyclic_roll(mu_full, secret_idx) + expected = expected_full[:num_slots] + + for e, d in zip(expected, decoded): + self.assertAlmostEqual(e.real, d.real, delta=1.5) + self.assertAlmostEqual(e.imag, d.imag, delta=1.5) + if __name__ == "__main__": absltest.main() diff --git a/jaxite/jaxite_ckks/blind_rotate_utils.py b/jaxite/jaxite_ckks/blind_rotate_utils.py new file mode 100644 index 0000000..aac0dbf --- /dev/null +++ b/jaxite/jaxite_ckks/blind_rotate_utils.py @@ -0,0 +1,97 @@ +"""Utility functions for homomorphic blind rotation.""" + +import math +import jax +import jax.numpy as jnp +from jaxite.jaxite_ckks import basis_conversion +from jaxite.jaxite_ckks import ntt +from jaxite.jaxite_ckks import types + + +def apply_automorphism_ntt(data: jax.Array, g: int) -> jax.Array: + """Applies the automorphism X -> X^g to a polynomial in the NTT domain. + + Handles the bit-reversed layout of jaxite's NTT representation. + + Args: + data: The polynomial in NTT domain. Shape (..., degree, num_moduli). + g: The automorphism generator (must be odd). + + Returns: + The permuted polynomial in NTT domain with the same shape. + """ + degree = data.shape[-2] + bits = int(math.log2(degree)) + indices = jnp.arange(degree, dtype=jnp.uint32) + + # Bit-reverse indices to map the bit-reversed layout of jaxite's NTT + def bit_reverse(x): + rev = jnp.zeros_like(x) + temp = x + for _ in range(bits): + rev = (rev << 1) | (temp & 1) + temp >>= 1 + return rev + + br_indices = bit_reverse(indices) + g_u32 = jnp.array(g, dtype=jnp.uint32) + target_roots = (((2 * br_indices + 1) * g_u32 - 1) // 2) % degree + target_indices = bit_reverse(target_roots) + + return jnp.take(data, target_indices, axis=-2) + + +def lift_ciphertext( + ct: types.Ciphertext, + bc_kernel: basis_conversion.BasisConversionBarrett, + control_index: int, + p_limbs: jax.Array, + ntt_q: ntt.NTTBarrett, + ntt_p: ntt.NTTBarrett, + r: int, + c: int, +) -> types.Ciphertext: + """Lifts a ciphertext from Q to PQ using basis conversion. + + Args: + ct: The input ciphertext under Q. Shape (num_elements, degree, num_Q). + bc_kernel: The precomputed basis conversion kernel from Q to P. + control_index: The control index specifying the Q -> P conversion. + p_limbs: The limbs of modulus P. + ntt_q: The NTT kernel for modulus Q. + ntt_p: The NTT kernel for modulus P. + r: The row dimension of the NTT layout. + c: The column dimension of the NTT layout. + + Returns: + A lifted Ciphertext under PQ. Shape (num_elements, degree, num_Q + num_P). + """ + # 1. Reshape ct.data to (num_elements, r, c, num_q) for INTT + num_elements, degree, num_q = ct.data.shape + ct_data_reshaped = ct.data.reshape(num_elements, r, c, num_q) + + # 2. Convert ct.data to coefficient domain modulo Q + ct_coef_q = ntt_q.intt(ct_data_reshaped) + # Reshape back to (num_elements, degree, num_q) + ct_coef_q_flat = ct_coef_q.reshape(num_elements, degree, num_q) + + # 3. Do basis conversion in coefficient domain: Q -> P + data_p_coef = bc_kernel.basis_change( + ct_coef_q_flat, control_index=control_index + ) + + # 4. Reshape data_p_coef to (num_elements, r, c, num_p) for NTT + num_p = len(p_limbs) + data_p_coef_reshaped = data_p_coef.reshape(num_elements, r, c, num_p) + + # 5. Convert data_p_coef to NTT domain modulo P + data_p_ntt = ntt_p.ntt(data_p_coef_reshaped) + # Reshape back to (num_elements, degree, num_p) + data_p_ntt_flat = data_p_ntt.reshape(num_elements, degree, num_p) + + # 6. Concatenate Q (NTT) and P (NTT) + data_pq = jnp.concatenate([ct.data, data_p_ntt_flat], axis=-1) + moduli_pq = jnp.concatenate([ct.moduli, jnp.asarray(p_limbs)]).astype( + jnp.uint32 + ) + return types.Ciphertext(data=data_pq, moduli=moduli_pq) diff --git a/jaxite/jaxite_ckks/blind_rotate_utils_test.py b/jaxite/jaxite_ckks/blind_rotate_utils_test.py new file mode 100644 index 0000000..6938903 --- /dev/null +++ b/jaxite/jaxite_ckks/blind_rotate_utils_test.py @@ -0,0 +1,99 @@ +"""Tests for blind rotation utilities.""" + +import jax +import jax.numpy as jnp +from jaxite.jaxite_ckks import basis_conversion +from jaxite.jaxite_ckks import blind_rotate_utils +from jaxite.jaxite_ckks import ntt +from jaxite.jaxite_ckks import ntt_cpu +from jaxite.jaxite_ckks import types +import numpy as np + +from absl.testing import absltest +from absl.testing import parameterized + +jax.config.update("jax_enable_x64", True) + + +class BlindRotateUtilsTest(parameterized.TestCase): + + def test_apply_automorphism_ntt(self): + # Test with a simple identity automorphism (g = 1) + degree = 8 + data = jnp.arange(degree, dtype=jnp.float64).reshape(degree, 1) + res = blind_rotate_utils.apply_automorphism_ntt(data, 1) + np.testing.assert_array_equal(res, data) + + def test_lift_ciphertext(self): + degree = 8 + q_limbs = [1073184769] + p_limbs = [1073479681] + all_moduli = q_limbs + p_limbs + + # ntt of zero is zero, e is zero + # b = - a * sk + a_slots = jnp.zeros((degree, len(q_limbs)), dtype=jnp.uint32) + b_slots = jnp.zeros((degree, len(q_limbs)), dtype=jnp.uint32) + ct = types.Ciphertext( + data=jnp.stack([b_slots, a_slots]), + moduli=jnp.array(q_limbs, dtype=jnp.uint32), + ) + + bc_kernel = basis_conversion.BasisConversionBarrett() + bc_kernel.precompute_constants(all_moduli, [([0], [1])]) + + r, c = 4, 2 + ntt_q = ntt.NTTBarrett() + ntt_q.precompute_constants(q_limbs, r=r, c=c) + ntt_p = ntt.NTTBarrett() + ntt_p.precompute_constants(p_limbs, r=r, c=c) + + lifted_ct = blind_rotate_utils.lift_ciphertext( + ct, + bc_kernel, + control_index=0, + p_limbs=jnp.array(p_limbs, dtype=jnp.uint32), + ntt_q=ntt_q, + ntt_p=ntt_p, + r=r, + c=c, + ) + + # Output dimensions should match PQ towers + self.assertEqual(lifted_ct.data.shape, (2, degree, 2)) + np.testing.assert_array_equal( + lifted_ct.moduli, np.array(all_moduli, dtype=np.uint32) + ) + + @parameterized.parameters(3, 5, 7) + def test_apply_automorphism_ntt_non_trivial(self, g): + degree = 8 + q = 1073184769 + # Create a non-trivial polynomial in coefficient domain + poly = np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=np.uint64).reshape( + degree, 1 + ) + + # Compute automorphism in coefficient domain + poly_rot = np.zeros_like(poly) + for i in range(degree): + target_pow = (i * g) % (2 * degree) + val = poly[i, 0] + if target_pow >= degree: + target_pow -= degree + poly_rot[target_pow, 0] = (q - val) % q + else: + poly_rot[target_pow, 0] = val + + # Convert to NTT domain + ntt_poly = ntt_cpu.ntt_negacyclic_poly(poly, [q]) + expected_ntt_poly_rot = ntt_cpu.ntt_negacyclic_poly(poly_rot, [q]) + + # Apply automorphism in NTT domain + res = blind_rotate_utils.apply_automorphism_ntt(jnp.array(ntt_poly), g) + + np.testing.assert_array_equal(np.array(res), expected_ntt_poly_rot) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxite/jaxite_ckks/key_gen.py b/jaxite/jaxite_ckks/key_gen.py index 39666f2..6b447c4 100644 --- a/jaxite/jaxite_ckks/key_gen.py +++ b/jaxite/jaxite_ckks/key_gen.py @@ -1,8 +1,9 @@ """Key generation utilities for CKKS.""" import math - import jax.numpy as jnp +from jaxite.jaxite_ckks import bat_utils +from jaxite.jaxite_ckks import blind_rotate_utils from jaxite.jaxite_ckks import encode from jaxite.jaxite_ckks import encrypt from jaxite.jaxite_ckks import ntt_cpu @@ -43,6 +44,8 @@ def keygen( return PublicKey(pk_data, np.array(moduli, dtype=np.uint64)), SecretKey( s_ntt, np.array(moduli, dtype=np.uint64) ) + + def extend_secret_key( secret_key: SecretKey, target_moduli: list[int], @@ -113,6 +116,7 @@ def gen_key_switching_key( p_limbs: list[int], dnum: int, random_source: random.RandomSource | None = None, + dest_key_ext: SecretKey | None = None, ) -> types.EvaluationKeys: """Generate key switching keys to switch from source_key to dest_key.""" random_source = random_source or random.SecureRandomSource() @@ -122,7 +126,8 @@ def gen_key_switching_key( all_moduli = q_limbs + p_limbs all_moduli_u64 = np.array(all_moduli, dtype=np.uint64).reshape(1, -1) - dest_key_ext = extend_secret_key(dest_key, all_moduli) + if dest_key_ext is None: + dest_key_ext = extend_secret_key(dest_key, all_moduli) s_dst_ext_slots = dest_key_ext.data p_val = math.prod(p_limbs) @@ -214,8 +219,6 @@ def gen_cm_keys( cm_keys[i, j] = encryptor.encrypt(plain_0) return cm_keys - - def gen_conjugate_key( sk: types.SecretKey, q_limbs: list[int], @@ -236,3 +239,151 @@ def gen_conjugate_key( dnum=dnum, random_source=random_source, ) + + +def gen_hmuxrot_key( + sk: types.SecretKey, + beta: int, + j: int, + q_limbs: list[int], + p_limbs: list[int], + random_source: random.RandomSource | None = None, + sk_ext: types.SecretKey | None = None, +) -> types.HMuxRotKey: + """Generates an HMuxRotKey symmetrically. + + The key consists of: + - key0: Symmetrically encrypts target0 = P * beta * sk(X^{5^{-j}}) under sk. + Constructed as a key-switching key from beta * sk(X^{5^{-j}}) to sk. + - key1: Symmetrically encrypts target1 = P * beta under sk. + + Args: + sk: The secret key under Q. + beta: The selector bit (0 or 1). + j: The rotation index. + q_limbs: The limbs of the ciphertext modulus Q. + p_limbs: The limbs of the auxiliary modulus P. + random_source: Optional random source. Defaults to SecureRandomSource. + sk_ext: Optional pre-extended secret key. If not provided, it is computed. + + Returns: + The generated HMuxRotKey. + """ + random_source = random_source or random.SecureRandomSource() + degree = sk.data.shape[0] + all_moduli = q_limbs + p_limbs + all_moduli_u64 = np.array(all_moduli, dtype=np.uint64).reshape(1, -1) + q_limbs_u64 = np.array(q_limbs, dtype=np.uint64).reshape(1, -1) + + if sk_ext is None: + dest_sk_ext = extend_secret_key(sk, all_moduli) + else: + dest_sk_ext = sk_ext + + p_val = math.prod(p_limbs) + scaled_val0 = (p_val * beta) % all_moduli_u64 + + # key0 is a key-switching key from beta * sk(X^{5^{-j}}) to sk. + g = pow(5, -j, 2 * degree) + sk_rot_data = blind_rotate_utils.apply_automorphism_ntt(jnp.array(sk.data), g) + # Scale by beta (modulo Q) + sk_rot_beta_data = (sk_rot_data * beta) % q_limbs_u64 + sk_rot_beta = types.SecretKey(np.array(sk_rot_beta_data), sk.moduli) + + ksk = gen_key_switching_key( + source_key=sk_rot_beta, + dest_key=sk, + q_limbs=q_limbs, + p_limbs=p_limbs, + dnum=1, + random_source=random_source, + dest_key_ext=dest_sk_ext, + ) + key0 = types.Ciphertext( + data=jnp.stack([ksk.b[0], ksk.a[0]]), + moduli=ksk.moduli, + ) + + # key1 encrypts P * beta under sk. Since this is a constant polynomial and not + # a secret key, we perform direct symmetric encryption inline. + target1 = np.ones((degree, len(all_moduli)), dtype=np.uint64) * scaled_val0 + target1 = target1 % all_moduli_u64 + + a_coeffs = random_source.gen_uniform_poly(degree, all_moduli) + e_coeffs = random_source.gen_gaussian_poly(degree, all_moduli) + a_slots = ntt_cpu.ntt_negacyclic_poly(a_coeffs, all_moduli) + e_slots = ntt_cpu.ntt_negacyclic_poly(e_coeffs, all_moduli) + + prod = (a_slots * dest_sk_ext.data) % all_moduli_u64 + b_slots = (e_slots + target1 + all_moduli_u64 - prod) % all_moduli_u64 + + key1 = types.Ciphertext( + data=jnp.array(np.stack([b_slots, a_slots]), dtype=jnp.uint32), + moduli=jnp.array(all_moduli, dtype=jnp.uint32), + ) + + return types.HMuxRotKey(key0, key1) + + +def gen_mux_rotation_key( + sk: types.SecretKey, + secret_bits: list[int], + q_limbs: list[int], + p_limbs: list[int], + random_source: random.RandomSource | None = None, +) -> types.MuxRotationKey: + """Generates a MuxRotationKey for the bits of the rotation index. + + For each bit k from 0 to len(secret_bits) - 1, generates: + - hmrkey_jk_0: HMuxRotKey for beta = secret_bits[k], rotation amount = 2^k. + - hmrkey_not_jk_1: HMuxRotKey for beta = 1 - secret_bits[k], rotation = 0. + + Args: + sk: The secret key under Q. + secret_bits: The list of bits representing the secret rotation index. + q_limbs: The limbs of the ciphertext modulus Q. + p_limbs: The limbs of the auxiliary modulus P. + random_source: Optional random source. + + Returns: + The generated MuxRotationKey. + """ + all_moduli = q_limbs + p_limbs + sk_ext = extend_secret_key(sk, all_moduli) + + keys = [] + for k, bit in enumerate(secret_bits): + hmrkey_jk_0 = gen_hmuxrot_key( + sk=sk, + beta=bit, + j=2**k, + q_limbs=q_limbs, + p_limbs=p_limbs, + random_source=random_source, + sk_ext=sk_ext, + ) + hmrkey_not_jk_1 = gen_hmuxrot_key( + sk=sk, + beta=1 - bit, + j=0, + q_limbs=q_limbs, + p_limbs=p_limbs, + random_source=random_source, + sk_ext=sk_ext, + ) + keys.append((hmrkey_jk_0, hmrkey_not_jk_1)) + return types.MuxRotationKey(keys) + + +def basis_aligned_transform_key( + key_matrix: jnp.ndarray, moduli: jnp.ndarray | list[int] +) -> jnp.ndarray: + """Transforms a key matrix of shape (degree, num_moduli, 2, 2) for BAT.""" + # Move moduli axis to the end: (degree, 2, 2, num_moduli) + key_matrix_transposed = jnp.transpose(key_matrix, (0, 2, 3, 1)) + key_matrix_bat_raw = bat_utils.basis_aligned_transformation( + key_matrix_transposed, moduli + ) + # Transpose back to shape (degree, num_moduli, 2, 2, 4, 4) from raw + # (4, degree, 2, 2, num_moduli, 4). + return jnp.transpose(key_matrix_bat_raw, (1, 4, 2, 3, 0, 5)) diff --git a/jaxite/jaxite_ckks/key_switching.py b/jaxite/jaxite_ckks/key_switching.py index ecaeadc..4164ef2 100644 --- a/jaxite/jaxite_ckks/key_switching.py +++ b/jaxite/jaxite_ckks/key_switching.py @@ -17,9 +17,13 @@ import math import jax import jax.numpy as jnp +from jaxite.jaxite_ckks import barrett from jaxite.jaxite_ckks import basis_conversion +from jaxite.jaxite_ckks import blind_rotate_utils +from jaxite.jaxite_ckks import math as ckks_math from jaxite.jaxite_ckks import mul from jaxite.jaxite_ckks import ntt +from jaxite.jaxite_ckks import rescale from jaxite.jaxite_ckks import types @@ -39,6 +43,7 @@ def precompute_constants( r: int, c: int, ): + """Precomputes the NTT constants for all modulus partitions.""" limbs_per_part = math.ceil(len(q_limbs) / dnum) all_moduli = q_limbs + p_limbs @@ -71,6 +76,7 @@ def tree_flatten(self): @classmethod def tree_unflatten(cls, aux_data, children): + del aux_data obj = cls() obj.ntt_kernels_q = children[0] obj.ntt_kernels_out = children[1] @@ -80,7 +86,7 @@ def key_switch( self, ct: types.Ciphertext, ksk: types.EvaluationKeys, - p_limbs: list[int], + p_limbs: jax.Array | list[int], bc_kernel: basis_conversion.BasisConversionBarrett, mul_kernel: mul.MulPlaintextCiphertextBarrett, start_control_index: int, @@ -88,25 +94,26 @@ def key_switch( """Switch ciphertext from source key to destination key modulo QP.""" c0 = ct.data[0] c1 = ct.data[1] - q_limbs = ct.moduli.tolist() + q_limbs = jnp.asarray(ct.moduli, dtype=jnp.uint32) + p_limbs = jnp.asarray(p_limbs, dtype=jnp.uint32) degree = c1.shape[0] dnum = len(self.ntt_kernels_q) - limbs_per_part = math.ceil(len(q_limbs) / dnum) - all_moduli = q_limbs + p_limbs - all_moduli_arr = jnp.array(all_moduli, dtype=jnp.uint32) + limbs_per_part = math.ceil(q_limbs.shape[0] / dnum) + all_moduli = jnp.concatenate([q_limbs, p_limbs]) - c0_ks = jnp.zeros((degree, len(all_moduli)), dtype=jnp.uint64) - c1_ks = jnp.zeros((degree, len(all_moduli)), dtype=jnp.uint64) - all_moduli_u64 = jnp.array(all_moduli, dtype=jnp.uint64).reshape(1, -1) + c0_ks = jnp.zeros((degree, all_moduli.shape[0]), dtype=jnp.uint64) + c1_ks = jnp.zeros((degree, all_moduli.shape[0]), dtype=jnp.uint64) + all_moduli_u64 = all_moduli.astype(jnp.uint64).reshape(1, -1) for i in range(dnum): start_idx = i * limbs_per_part - end_idx = min(start_idx + limbs_per_part, len(q_limbs)) + end_idx = min(start_idx + limbs_per_part, q_limbs.shape[0]) q_part = q_limbs[start_idx:end_idx] in_indices = list(range(start_idx, end_idx)) - out_indices = [j for j in range(len(all_moduli)) if j not in in_indices] - out_moduli = [all_moduli[j] for j in out_indices] + out_indices = [ + j for j in range(all_moduli.shape[0]) if j not in in_indices + ] # Extract partition and convert to coefficient domain c1_part = c1[:, start_idx:end_idx] @@ -114,12 +121,12 @@ def key_switch( 1, self.ntt_kernels_q[i].constants.r, self.ntt_kernels_q[i].constants.c, - len(q_part), + q_part.shape[0], ) c1_part_coeffs = self.ntt_kernels_q[i].intt( c1_part_reshaped.astype(jnp.uint32) ) - c1_part_coeffs = c1_part_coeffs.reshape(degree, len(q_part)) + c1_part_coeffs = c1_part_coeffs.reshape(degree, q_part.shape[0]) # Basis change to out_moduli control_index = start_control_index + i @@ -132,15 +139,15 @@ def key_switch( 1, self.ntt_kernels_out[i].constants.r, self.ntt_kernels_out[i].constants.c, - len(out_moduli), + len(out_indices), ) c1_part_out_ntt = self.ntt_kernels_out[i].ntt( c1_part_out_coeffs_reshaped.astype(jnp.uint32) ) - c1_part_out = c1_part_out_ntt.reshape(degree, len(out_moduli)) + c1_part_out = c1_part_out_ntt.reshape(degree, len(out_indices)) # Merge into full all_moduli representation - c1_part_qp = jnp.zeros((degree, len(all_moduli)), dtype=jnp.uint32) + c1_part_qp = jnp.zeros((degree, all_moduli.shape[0]), dtype=jnp.uint32) c1_part_qp = c1_part_qp.at[:, in_indices].set(c1_part) c1_part_qp = c1_part_qp.at[:, out_indices].set(c1_part_out) @@ -149,12 +156,12 @@ def key_switch( ksk_a_part = ksk.a[i] c0_ks_part = mul_kernel.mul( - types.Plaintext(data=ksk_b_part, moduli=all_moduli_arr), - types.Plaintext(data=c1_part_qp, moduli=all_moduli_arr), + types.Plaintext(data=ksk_b_part, moduli=all_moduli), + types.Plaintext(data=c1_part_qp, moduli=all_moduli), ) c1_ks_part = mul_kernel.mul( - types.Plaintext(data=ksk_a_part, moduli=all_moduli_arr), - types.Plaintext(data=c1_part_qp, moduli=all_moduli_arr), + types.Plaintext(data=ksk_a_part, moduli=all_moduli), + types.Plaintext(data=c1_part_qp, moduli=all_moduli), ) # Sum modulo all_moduli @@ -162,12 +169,15 @@ def key_switch( c1_ks = (c1_ks + c1_ks_part.data.astype(jnp.uint64)) % all_moduli_u64 # Scale c0 by P - p_val = math.prod(p_limbs) - p_mod_q = jnp.array([p_val % q for q in q_limbs], dtype=jnp.uint64) - c0_scaled_q = (c0.astype(jnp.uint64) * p_mod_q.reshape(1, -1)) % jnp.array( - q_limbs, dtype=jnp.uint64 - ).reshape(1, -1) - c0_scaled_p = jnp.zeros((degree, len(p_limbs)), dtype=jnp.uint32) + p_limbs_u64 = p_limbs.astype(jnp.uint64) + q_limbs_u64 = q_limbs.astype(jnp.uint64) + p_mod_q = p_limbs_u64[0] % q_limbs_u64 + for j in range(1, p_limbs.shape[0]): + p_mod_q = (p_mod_q * (p_limbs_u64[j] % q_limbs_u64)) % q_limbs_u64 + c0_scaled_q = ( + c0.astype(jnp.uint64) * p_mod_q.astype(jnp.uint64).reshape(1, -1) + ) % q_limbs_u64.reshape(1, -1) + c0_scaled_p = jnp.zeros((degree, p_limbs.shape[0]), dtype=jnp.uint32) c0_scaled_qp = jnp.concatenate( [c0_scaled_q.astype(jnp.uint32), c0_scaled_p], axis=-1 ) @@ -180,5 +190,121 @@ def key_switch( data=jnp.stack( [c0_prime.astype(jnp.uint32), c1_prime.astype(jnp.uint32)] ), - moduli=jnp.array(all_moduli, dtype=jnp.uint32), + moduli=all_moduli, + ) + + +@jax.tree_util.register_pytree_node_class +class BATKeySwitcher: + """Kernel for BAT-based key switching on TPU.""" + + def __init__(self): + pass + + @staticmethod + def transform_key( + key_matrix: jax.Array, moduli: jax.Array | list[int] + ) -> jax.Array: + """Prepares the 4D key matrix of shape (degree, num_moduli, 2, 2) for BAT.""" + matrix_u64 = key_matrix.astype(jnp.uint64) + num_bytes = 4 + matrix_u64_byteshifted = jnp.array( + [matrix_u64 << (8 * byte_idx) for byte_idx in range(num_bytes)], + dtype=jnp.uint64, + ) + + moduli_expanded = jnp.array(moduli, dtype=jnp.uint64).reshape( + 1, 1, -1, 1, 1 + ) + + matrix_u64_byteshifted_mod_modulus = ( + matrix_u64_byteshifted % moduli_expanded + ).astype(jnp.uint32) + + matrix_u8 = jax.lax.bitcast_convert_type( + matrix_u64_byteshifted_mod_modulus, jnp.uint8 + ) + + matrix_u8_transposed = jnp.transpose(matrix_u8, (1, 2, 3, 4, 0, 5)) + + degree = key_matrix.shape[0] + block_size, num_blocks = ckks_math.compute_tpu_block_sizes(degree) + return matrix_u8_transposed.reshape( + num_blocks, block_size, *matrix_u8_transposed.shape[1:] + ) + + def key_switch( + self, + ct: types.Ciphertext, + key_matrix_bat: jax.Array, + p_limbs: jax.Array, + bc_kernel: basis_conversion.BasisConversionBarrett, + control_index: int, + mul_kernel: mul.MulPlaintextCiphertextBarrett, + rescale_kernel: rescale.Rescale, + ntt_q: ntt.NTTBarrett, + ntt_p: ntt.NTTBarrett, + r: int, + c: int, + ) -> types.Ciphertext: + """Switch ciphertext from source key to destination key using BAT.""" + c0 = ct.data[0] + c1 = ct.data[1] + + c0_ct = types.Ciphertext(data=jnp.expand_dims(c0, axis=0), moduli=ct.moduli) + c1_ct = types.Ciphertext(data=jnp.expand_dims(c1, axis=0), moduli=ct.moduli) + + c0_lifted = blind_rotate_utils.lift_ciphertext( + c0_ct, bc_kernel, control_index, p_limbs, ntt_q, ntt_p, r, c ) + c1_lifted = blind_rotate_utils.lift_ciphertext( + c1_ct, bc_kernel, control_index, p_limbs, ntt_q, ntt_p, r, c + ) + + c0_lifted_data = jnp.squeeze(c0_lifted.data, axis=0) + c1_lifted_data = jnp.squeeze(c1_lifted.data, axis=0) + + vector_v = jnp.stack([c1_lifted_data, c0_lifted_data], axis=-1) + + prod = self.mul(vector_v, key_matrix_bat) + + reduced = barrett.modular_reduction(prod, mul_kernel.barrett_constants) + + ct_out = types.Ciphertext(data=reduced, moduli=c0_lifted.moduli) + + rescale_kernel.rescale(ct_out) + return ct_out + + def mul(self, vector_v: jax.Array, key_matrix_bat: jax.Array) -> jax.Array: + """Computes BAT-based matrix-vector product for 2x2 key multiplication.""" + degree = vector_v.shape[-3] + num_moduli = vector_v.shape[-2] + + block_size, num_blocks = ckks_math.compute_tpu_block_sizes(degree) + + v_reshaped = vector_v.reshape( + *vector_v.shape[:-3], num_blocks, block_size, num_moduli, 2 + ) + + v_u8 = jax.lax.bitcast_convert_type(v_reshaped, jnp.uint8) + + i8_products = jnp.einsum( + "...ikjvq,ikjuvqp->...ikjup", + v_u8, + key_matrix_bat, + preferred_element_type=jnp.uint32, + ) + + shift_factors = jnp.array([0, 8, 16, 24], dtype=jnp.uint32) + summed = jnp.sum(i8_products.astype(jnp.uint64) << shift_factors, axis=-1) + summed_flat = summed.reshape(*summed.shape[:-4], degree, num_moduli, 2) + return jnp.moveaxis(summed_flat, -1, 0) + + def tree_flatten(self): + return (), None + + @classmethod + def tree_unflatten(cls, aux_data, children): + del aux_data + del children + return cls() diff --git a/jaxite/jaxite_ckks/key_switching_test.py b/jaxite/jaxite_ckks/key_switching_test.py new file mode 100644 index 0000000..a46dca8 --- /dev/null +++ b/jaxite/jaxite_ckks/key_switching_test.py @@ -0,0 +1,238 @@ +"""Tests for key switching kernels.""" + +import math +import jax +import jax.numpy as jnp +from jaxite.jaxite_ckks import barrett +from jaxite.jaxite_ckks import basis_conversion +from jaxite.jaxite_ckks import encode +from jaxite.jaxite_ckks import encrypt +from jaxite.jaxite_ckks import key_gen +from jaxite.jaxite_ckks import key_switching +from jaxite.jaxite_ckks import mul +from jaxite.jaxite_ckks import ntt +from jaxite.jaxite_ckks import ntt_cpu +from jaxite.jaxite_ckks import random +from jaxite.jaxite_ckks import rescale +from jaxite.jaxite_ckks import types +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized + +jax.config.update("jax_enable_x64", True) + +TEST_PRIMES = ( + 1_073_692_673, + 1_073_643_521, + 1_073_479_681, + 1_073_430_529, +) + + +class KeySwitchingTest(parameterized.TestCase): + + @parameterized.named_parameters( + ("dnum_1", 1), + ("dnum_2", 2), + ) + def test_key_switcher_rns(self, dnum): + degree = 16 + q_limbs = [TEST_PRIMES[0], TEST_PRIMES[1]] + p_limbs = [TEST_PRIMES[2]] + all_moduli = q_limbs + p_limbs + scale = 2**20 + + test_random_source = random.ZeroNoiseRandomSource() + pk_src, sk_src = key_gen.keygen( + degree, q_limbs, random_source=test_random_source + ) + _, sk_dst = key_gen.keygen( + degree, q_limbs, random_source=test_random_source + ) + + ksk = key_gen.gen_key_switching_key( + source_key=sk_src, + dest_key=sk_dst, + q_limbs=q_limbs, + p_limbs=p_limbs, + dnum=dnum, + random_source=test_random_source, + ) + + bc_kernel = basis_conversion.BasisConversionBarrett() + ks_control_indices = mul.Mul.compute_control_indices(q_limbs, p_limbs, dnum) + print("DEBUG: ks_control_indices =", ks_control_indices) + bc_kernel.precompute_constants(all_moduli, ks_control_indices) + for idx, consts in enumerate(bc_kernel.precomputed_constants): + print(f"DEBUG: constants[{idx}]:") + print(f"DEBUG: q_hat_inv_mod_q shape: {consts.q_hat_inv_mod_q.shape}") + print(f"DEBUG: q_hat_mod_p_bat shape: {consts.q_hat_mod_p_bat.shape}") + + barrett_constants_pq = barrett.precompute_barrett_constants(all_moduli) + mul_kernel = mul.MulPlaintextCiphertextBarrett(barrett_constants_pq) + + # 1. Encrypt message under sk_src + mu = np.array( + [complex(x % 4 + 1, x % 4 + 2) for x in range(degree // 2)], + dtype=complex, + ) + encoder = encode.Encode(degree, q_limbs, scale) + encryptor_src = encrypt.Encrypt(pk_src) + ct_in = encryptor_src.encrypt( + encoder.encode(mu.tolist()), random_source=test_random_source + ) + + # 2. Key switch to sk_dst + switcher = key_switching.KeySwitcher() + switcher.precompute_constants(q_limbs, p_limbs, dnum, r=4, c=4) + + ct_switched_qp = switcher.key_switch( + ct=ct_in, + ksk=ksk, + p_limbs=jnp.array(p_limbs, dtype=jnp.uint32), + bc_kernel=bc_kernel, + mul_kernel=mul_kernel, + start_control_index=1, + ) + + # 3. Rescale ct_switched_qp to ct_switched_q + rescale_kernel = rescale.Rescale() + rescale_kernel.precompute_constants( + all_moduli, num_rescales=len(p_limbs), r=4, c=4 + ) + rescale_kernel.rescale(ct_switched_qp) + + # 4. Decrypt under sk_dst and decode + decryptor_dst = encrypt.Decrypt(sk_dst) + pt_dec = decryptor_dst.decrypt(ct_switched_qp) + + decoder = encode.Decode(scale, degree // 2) + decoded = decoder.decode(pt_dec) + + for e, d in zip(mu, decoded): + self.assertAlmostEqual(e.real, d.real, delta=1e-1) + self.assertAlmostEqual(e.imag, d.imag, delta=1e-1) + + def test_key_switcher_bat(self): + degree = 16 + q_limbs = [TEST_PRIMES[0], TEST_PRIMES[1]] + p_limbs = [TEST_PRIMES[2]] + all_moduli = q_limbs + p_limbs + scale = 2**20 + + test_random_source = random.ZeroNoiseRandomSource() + pk_src, sk_src = key_gen.keygen( + degree, q_limbs, random_source=test_random_source + ) + _, sk_dst = key_gen.keygen( + degree, q_limbs, random_source=test_random_source + ) + + # Generate key switching key (source to dest) + ksk = key_gen.gen_key_switching_key( + source_key=sk_src, + dest_key=sk_dst, + q_limbs=q_limbs, + p_limbs=p_limbs, + dnum=1, + random_source=test_random_source, + ) + + # key0 is ksk ciphertext + key0 = types.Ciphertext( + data=jnp.stack([ksk.b[0], ksk.a[0]]), + moduli=ksk.moduli, + ) + + # key1 encrypts P under sk_dst + p_val = math.prod(p_limbs) + scaled_val = p_val % np.array(all_moduli, dtype=np.uint64).reshape(1, -1) + target1 = np.ones((degree, len(all_moduli)), dtype=np.uint64) * scaled_val + + a_coeffs = test_random_source.gen_uniform_poly(degree, all_moduli) + e_coeffs = test_random_source.gen_gaussian_poly(degree, all_moduli) + a_slots = ntt_cpu.ntt_negacyclic_poly(a_coeffs, all_moduli) + e_slots = ntt_cpu.ntt_negacyclic_poly(e_coeffs, all_moduli) + + sk_dst_ext = key_gen.extend_secret_key(sk_dst, all_moduli) + prod = (a_slots * sk_dst_ext.data) % np.array( + all_moduli, dtype=np.uint64 + ).reshape(1, -1) + b_slots = ( + e_slots + + target1 + + np.array(all_moduli, dtype=np.uint64).reshape(1, -1) + - prod + ) % np.array(all_moduli, dtype=np.uint64).reshape(1, -1) + + key1 = types.Ciphertext( + data=jnp.array(np.stack([b_slots, a_slots]), dtype=jnp.uint32), + moduli=jnp.array(all_moduli, dtype=jnp.uint32), + ) + + # Stack key0 and key1 to get key_matrix of shape (degree, num_moduli, 2, 2) + stacked = jnp.stack([key0.data, key1.data], axis=1) + key_matrix = jnp.transpose(stacked, (2, 3, 0, 1)) + + # Precompute BAT representation + key_matrix_bat = key_switching.BATKeySwitcher.transform_key( + key_matrix, key0.moduli + ) + + bc_kernel = basis_conversion.BasisConversionBarrett() + # BC from Q (2 limbs) to P (1 limb) + bc_kernel.precompute_constants(all_moduli, [([0, 1], [2])]) + + barrett_constants_pq = barrett.precompute_barrett_constants(all_moduli) + mul_kernel = mul.MulPlaintextCiphertextBarrett(barrett_constants_pq) + + rescale_kernel = rescale.Rescale() + rescale_kernel.precompute_constants(all_moduli, num_rescales=1, r=4, c=4) + + # 1. Encrypt message under sk_src + mu = np.array( + [complex(x % 4 + 1, x % 4 + 2) for x in range(degree // 2)], + dtype=complex, + ) + encoder = encode.Encode(degree, q_limbs, scale) + encryptor_src = encrypt.Encrypt(pk_src) + ct_in = encryptor_src.encrypt( + encoder.encode(mu.tolist()), random_source=test_random_source + ) + + # 2. Key switch using BATKeySwitcher + r, c = 4, 4 + ntt_q = ntt.NTTBarrett() + ntt_q.precompute_constants(q_limbs, r=r, c=c) + ntt_p = ntt.NTTBarrett() + ntt_p.precompute_constants(p_limbs, r=r, c=c) + + switcher = key_switching.BATKeySwitcher() + ct_switched_q = switcher.key_switch( + ct=ct_in, + key_matrix_bat=key_matrix_bat, + p_limbs=jnp.array(p_limbs, dtype=jnp.uint32), + bc_kernel=bc_kernel, + control_index=0, + mul_kernel=mul_kernel, + rescale_kernel=rescale_kernel, + ntt_q=ntt_q, + ntt_p=ntt_p, + r=r, + c=c, + ) + + # 3. Decrypt under sk_dst and decode + decryptor_dst = encrypt.Decrypt(sk_dst) + pt_dec = decryptor_dst.decrypt(ct_switched_q) + + decoder = encode.Decode(scale, degree // 2) + decoded = decoder.decode(pt_dec) + + for e, d in zip(mu, decoded): + self.assertAlmostEqual(e.real, d.real, delta=1e-1) + self.assertAlmostEqual(e.imag, d.imag, delta=1e-1) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxite/jaxite_ckks/math.py b/jaxite/jaxite_ckks/math.py index d86c48b..4450daa 100644 --- a/jaxite/jaxite_ckks/math.py +++ b/jaxite/jaxite_ckks/math.py @@ -104,3 +104,17 @@ def get_bit_reverse_perm(n: int) -> list[int]: temp >>= 1 perm[i] = r return perm + + +def compute_tpu_block_sizes(degree: int) -> tuple[int, int]: + """Computes block_size and num_blocks for TPU vectorization optimization. + + Args: + degree: The polynomial degree. + + Returns: + A tuple of (block_size, num_blocks). + """ + if degree >= 128: + return 128, degree // 128 + return degree, 1 diff --git a/jaxite/jaxite_ckks/types.py b/jaxite/jaxite_ckks/types.py index 9035f0b..fe91446 100644 --- a/jaxite/jaxite_ckks/types.py +++ b/jaxite/jaxite_ckks/types.py @@ -2,6 +2,7 @@ import dataclasses import jax +import jax.numpy as jnp import numpy as np @@ -18,7 +19,7 @@ def tree_flatten(self): @classmethod def tree_unflatten(cls, _, children): - return cls(*children) + return cls(children[0], children[1]) @jax.tree_util.register_pytree_node_class @@ -34,7 +35,7 @@ def tree_flatten(self): @classmethod def tree_unflatten(cls, _, children): - return cls(*children) + return cls(children[0], children[1]) @dataclasses.dataclass(frozen=True) @@ -60,3 +61,61 @@ class EvaluationKeys: a: jax.Array b: jax.Array moduli: jax.Array + + +@jax.tree_util.register_pytree_node_class +@dataclasses.dataclass(frozen=True) +class HMuxRotKey: + """A key used in a single HMuxRot step. + + Consists of two ciphertexts symmetrically encrypted under the destination key + sk modulo PQ: + - key0: encrypts P * beta * sk(X^{5^{-j}}) + - key1: encrypts P * beta + """ + + key0: Ciphertext + key1: Ciphertext + key_matrix_bat: jax.Array = dataclasses.field(init=False) + + def __post_init__(self): + # Stack key0.data and key1.data along axis 1 to get + # (2, 2, degree, num_moduli). + # Then transposing to (degree, num_moduli, 2, 2) + stacked = jnp.stack([self.key0.data, self.key1.data], axis=1) + key_matrix = jnp.transpose(stacked, (2, 3, 0, 1)) + + # Precompute BAT representation using blind_rotate.HMuxRot helper + # pylint: disable=g-import-not-at-top + from jaxite.jaxite_ckks import key_switching + + key_matrix_bat = key_switching.BATKeySwitcher.transform_key( + key_matrix, self.key0.moduli + ) + object.__setattr__(self, "key_matrix_bat", key_matrix_bat) + + def tree_flatten(self): + return (self.key0, self.key1), None + + @classmethod + def tree_unflatten(cls, _, children): + return cls(children[0], children[1]) + + +@jax.tree_util.register_pytree_node_class +@dataclasses.dataclass(frozen=True) +class MuxRotationKey: + """A set of HMuxRot keys for all bits of a secret rotation index. + + Contains a list of pairs of keys (hmrkey_jk_0, hmrkey_not_jk_1) for each bit k + from 0 to log2(num_slots) - 1. + """ + + keys: list[tuple[HMuxRotKey, HMuxRotKey]] + + def tree_flatten(self): + return (self.keys,), None + + @classmethod + def tree_unflatten(cls, _, children): + return cls(children[0])