From 8dafed090c5e6796cfd3c6a2e6f1780d3f20ebea Mon Sep 17 00:00:00 2001 From: Akhil Samavedam Date: Tue, 23 Jun 2026 00:37:06 -0700 Subject: [PATCH] Add TPU optimizations for basis conversion and BAT PiperOrigin-RevId: 936482359 --- BUILD | 35 +++- jaxite/jaxite_ckks/basis_conversion.py | 81 +++------ jaxite/jaxite_ckks/bat_utils.py | 161 ++++++++++++++++++ jaxite/jaxite_ckks/bat_utils_test.py | 79 +++++++++ jaxite/jaxite_ckks/blind_rotate.py | 151 +++++++++++++++- jaxite/jaxite_ckks/blind_rotate_test.py | 98 +++++++++++ jaxite/jaxite_ckks/blind_rotate_utils.py | 63 +++++++ jaxite/jaxite_ckks/blind_rotate_utils_test.py | 87 ++++++++++ jaxite/jaxite_ckks/key_gen.py | 142 ++++++++++++++- jaxite/jaxite_ckks/ntt.py | 53 +----- jaxite/jaxite_ckks/types.py | 61 ++++++- 11 files changed, 900 insertions(+), 111 deletions(-) create mode 100644 jaxite/jaxite_ckks/bat_utils.py create mode 100644 jaxite/jaxite_ckks/bat_utils_test.py create mode 100644 jaxite/jaxite_ckks/blind_rotate_utils.py create mode 100644 jaxite/jaxite_ckks/blind_rotate_utils_test.py diff --git a/BUILD b/BUILD index 192295d..85e669e 100644 --- a/BUILD +++ b/BUILD @@ -161,9 +161,7 @@ tpu_test( size = "large", timeout = "moderate", srcs = ["jaxite_ec/finite_field_test.py"], - python_version = "PY3", shard_count = 3, - srcs_version = "PY3ONLY", deps = [ ":jaxite", # copybara: xprof_analysis_client # buildcleaner: keep @@ -200,9 +198,7 @@ tpu_test( "jaxite_ec/test_case/t8/zprize_msm_curve_377_res_dim_8_seed_0.csv", "jaxite_ec/test_case/t8/zprize_msm_curve_377_scalars_dim_8_seed_0.csv", ], - python_version = "PY3", shard_count = 3, - srcs_version = "PY3ONLY", deps = [ ":jaxite", # copybara: xprof_analysis_client # buildcleaner: keep @@ -221,9 +217,7 @@ tpu_test( size = "large", timeout = "long", srcs = ["jaxite_ec/elliptic_curve_test.py"], - python_version = "PY3", shard_count = 16, - srcs_version = "PY3ONLY", deps = [ ":jaxite", # copybara: xprof_analysis_client # buildcleaner: keep @@ -485,6 +479,20 @@ cpu_gpu_tpu_test( ], ) +cpu_gpu_tpu_test( + name = "bat_utils_test", + size = "small", + srcs = ["jaxite/jaxite_ckks/bat_utils_test.py"], + deps = [ + ":jaxite_ckks", + "@abseil-py//absl/testing:absltest", + "@abseil-py//absl/testing:parameterized", + "@jaxite_deps//jax", + "@jaxite_deps//jaxlib", + "@jaxite_deps//numpy", + ], +) + cpu_gpu_tpu_test( name = "barrett_test", size = "small", @@ -617,6 +625,20 @@ py_test( ], ) +py_test( + name = "blind_rotate_utils_test", + size = "small", + srcs = ["jaxite/jaxite_ckks/blind_rotate_utils_test.py"], + deps = [ + ":jaxite_ckks", + "@abseil-py//absl/testing:absltest", + "@abseil-py//absl/testing:parameterized", + "@jaxite_deps//jax", + "@jaxite_deps//jaxlib", + "@jaxite_deps//numpy", + ], +) + cpu_gpu_tpu_test( name = "blind_rotate_ckks_test", size = "small", @@ -628,6 +650,7 @@ cpu_gpu_tpu_test( ":jaxite_ckks", "@abseil-py//absl/testing:absltest", "@abseil-py//absl/testing:parameterized", + "@jaxite_deps//absl/logging", "@jaxite_deps//jax", "@jaxite_deps//jaxlib", "@jaxite_deps//numpy", diff --git a/jaxite/jaxite_ckks/basis_conversion.py b/jaxite/jaxite_ckks/basis_conversion.py index 9806337..6e0deb0 100644 --- a/jaxite/jaxite_ckks/basis_conversion.py +++ b/jaxite/jaxite_ckks/basis_conversion.py @@ -7,6 +7,7 @@ import jax import jax.numpy as jnp from jaxite.jaxite_ckks import barrett +from jaxite.jaxite_ckks import bat_utils from jaxite.jaxite_ckks import rns_utils # Enable 64-bit precision for large integer arithmetic @@ -85,48 +86,6 @@ def __hash__(self): return id(self) -def matmul_bat_einsum( - lhs: jax.Array, rhs: jax.Array, subscripts: str -) -> jax.Array: - """Basis Aligned Transformation (BAT) based matrix multiplication. - - Args: - lhs: input - rhs: twiddle factor matrix - subscripts: einsum subscripts - - Returns: - The matrix multiplication result. - """ - lhs_u8 = lhs.view(jnp.uint8) - shift_factors = jnp.array([0, 8, 16, 24], dtype=jnp.uint32) - i8_products = jnp.einsum( - subscripts, lhs_u8, rhs, preferred_element_type=jnp.uint32 - ) - return jnp.sum(i8_products.astype(jnp.uint64) << shift_factors, axis=(-1,)) - - -def _basis_aligned_transformation( - matrix: jnp.ndarray, moduli: list[int] -) -> jnp.ndarray: - """Prepares a matrix for Basis Aligned Transformation (BAT).""" - matrix_u64 = matrix.astype(jnp.uint64) - num_bytes = 4 - matrix_u64_byteshifted = jnp.array( - [matrix_u64 << (8 * byte_idx) for byte_idx in range(num_bytes)], - dtype=jnp.uint64, - ) - moduli_arr = jnp.array(moduli, dtype=jnp.uint64) - matrix_u64_byteshifted_mod_modulus = ( - matrix_u64_byteshifted % moduli_arr - ).astype(jnp.uint32) - # Output shape: (4, ..., moduli, 4) - matrix_u8 = jax.lax.bitcast_convert_type( - matrix_u64_byteshifted_mod_modulus, jnp.uint8 - ) - return matrix_u8 - - @jax.tree_util.register_pytree_node_class class BasisConversionBarrett(BasisConversion): """Kernel for Basis Conversion with Barrett reduction.""" @@ -167,13 +126,11 @@ def precompute_constants( dtype=jnp.uint64, ) - # BAT Preprocessing - q_hat_mod_p_bat_raw = _basis_aligned_transformation( + q_hat_mod_p_bat_raw = bat_utils.basis_aligned_transformation( q_hat_mod_p, target_moduli ) - q_hat_mod_p_bat = q_hat_mod_p_bat_raw.transpose(1, 0, 2, 3).reshape( - -1, q_hat_mod_p_bat_raw.shape[2], 4 - ) + # Shape: (4, num_Q, num_P, 4) -> (num_Q, 4, num_P, 4) + q_hat_mod_p_bat = q_hat_mod_p_bat_raw.transpose(1, 0, 2, 3) constants = BarrettBasisConversionConstants( q_hat_inv_mod_q=q_hat_inv_mod_q, @@ -191,21 +148,39 @@ def basis_change( constants = self.precomputed_constants[control_index] in_tower = jnp.asarray(in_tower, dtype=jnp.uint64) + degree = in_tower.shape[-2] + num_Q = in_tower.shape[-1] + + if degree >= 128: + block_size = 128 + num_blocks = degree // 128 + else: + block_size = degree + num_blocks = 1 + + # Reshape degree dimension to (num_blocks, block_size) to optimize TPU vectorization + in_tower_reshaped = in_tower.reshape( + *in_tower.shape[:-2], num_blocks, block_size, num_Q + ) + # Step 1: Compute c_unreduced = in_tower * QHatInvModq - # Ensure constants.q_hat_inv_mod_q broadcasts over leading dimensions. - # q_hat_inv_mod_q has shape (sizeQ,). in_tower has shape (..., sizeQ). - c_unreduced = in_tower * constants.q_hat_inv_mod_q + c_unreduced = in_tower_reshaped * constants.q_hat_inv_mod_q # Step 2: Modular Reduction c = barrett.modular_reduction(c_unreduced, constants.origin_barrett) # Step 3: BAT-based matrix multiplication - summed_terms = matmul_bat_einsum( - c, constants.q_hat_mod_p_bat, "...q,qpb->...pb" + summed_terms = bat_utils.matmul_bat_einsum( + c, constants.q_hat_mod_p_bat, "...mq,mqpb->...pb" + ) + + # Flatten the degree dimension back + summed_terms_flat = summed_terms.reshape( + *summed_terms.shape[:-3], degree, summed_terms.shape[-1] ) # Step 4: Final Modular Reduction out_tower = barrett.modular_reduction( - summed_terms, constants.target_barrett + summed_terms_flat, constants.target_barrett ) return out_tower diff --git a/jaxite/jaxite_ckks/bat_utils.py b/jaxite/jaxite_ckks/bat_utils.py new file mode 100644 index 0000000..786839a --- /dev/null +++ b/jaxite/jaxite_ckks/bat_utils.py @@ -0,0 +1,161 @@ +"""Basis Aligned Transformation (BAT) utilities for CKKS on TPU.""" + +import jax +import jax.numpy as jnp + +# Enable 64-bit precision for large integer arithmetic +jax.config.update("jax_enable_x64", True) + + +def matmul_bat_einsum( + lhs: jax.Array, rhs: jax.Array, subscripts: str +) -> jax.Array: + """Basis Aligned Transformation (BAT) based matrix multiplication. + + Args: + lhs: input + rhs: twiddle factor matrix + subscripts: einsum subscripts + + Returns: + The matrix multiplication result. + """ + lhs_u8 = jax.lax.bitcast_convert_type(lhs, jnp.uint8) + shift_factors = jnp.array([0, 8, 16, 24], dtype=jnp.uint32) + i8_products = jnp.einsum( + subscripts, lhs_u8, rhs, preferred_element_type=jnp.uint32 + ) + return jnp.sum(i8_products.astype(jnp.uint64) << shift_factors, axis=(-1,)) + + +def basis_aligned_transformation( + matrix: jnp.ndarray, moduli: list[int] +) -> jnp.ndarray: + """Prepares a matrix for Basis Aligned Transformation (BAT).""" + matrix_u64 = matrix.astype(jnp.uint64) + num_bytes = 4 + matrix_u64_byteshifted = jnp.array( + [matrix_u64 << (8 * byte_idx) for byte_idx in range(num_bytes)], + dtype=jnp.uint64, + ) + moduli_arr = jnp.array(moduli, dtype=jnp.uint64) + matrix_u64_byteshifted_mod_modulus = ( + matrix_u64_byteshifted % moduli_arr + ).astype(jnp.uint32) + # Output shape: (4, ..., moduli, 4) + matrix_u8 = jax.lax.bitcast_convert_type( + matrix_u64_byteshifted_mod_modulus, jnp.uint8 + ) + return matrix_u8 + + +def basis_aligned_transform_key( + key_matrix: jax.Array, moduli: jax.Array | list[int] +) -> jax.Array: + """Prepares the 4D key matrix of shape (degree, num_moduli, 2, dnum) for BAT. + + Args: + key_matrix: The key matrix of shape (degree, num_moduli, 2, dnum). + moduli: The moduli of the key matrix. + + Returns: + The transformed key matrix of shape (num_blocks, block_size, num_moduli, + 2, dnum, 4, 4) in uint8. + """ + matrix_u64 = key_matrix.astype(jnp.uint64) + num_bytes = 4 + matrix_u64_byteshifted = jnp.array( + [matrix_u64 << (8 * byte_idx) for byte_idx in range(num_bytes)], + dtype=jnp.uint64, + ) # Shape: (4, degree, num_moduli, 2, dnum) + + moduli_expanded = jnp.array(moduli, dtype=jnp.uint64).reshape(1, 1, -1, 1, 1) + + matrix_u64_byteshifted_mod_modulus = ( + matrix_u64_byteshifted % moduli_expanded + ).astype(jnp.uint32) + + # Bitcast to uint8: shape becomes (4, degree, num_moduli, 2, dnum, 4) + matrix_u8 = jax.lax.bitcast_convert_type( + matrix_u64_byteshifted_mod_modulus, jnp.uint8 + ) + + # Transpose to (degree, num_moduli, u, v, q, p) + # Axes mapping: + # 0: byte_idx (q, size 4) + # 1: degree + # 2: num_moduli + # 3: u (size 2) + # 4: v (size dnum) + # 5: b (p, size 4) + matrix_u8_transposed = jnp.transpose(matrix_u8, (1, 2, 3, 4, 0, 5)) + + # Reshape degree dimension to (num_blocks, block_size) to optimize TPU vectorization + degree = key_matrix.shape[0] + if degree >= 128: + block_size = 128 + num_blocks = degree // 128 + else: + block_size = degree + num_blocks = 1 + return matrix_u8_transposed.reshape( + num_blocks, block_size, *matrix_u8_transposed.shape[1:] + ) + + +def matmul_bat_key_vector( + vector_v: jax.Array, key_matrix_bat: jax.Array +) -> jax.Array: + """Computes BAT-based matrix-vector product for 2x2 or 2xdnum key multiplication. + + Args: + vector_v: The input vector of shape (..., degree, num_moduli, dnum). + key_matrix_bat: The pre-transformed key matrix of shape (num_blocks, + block_size, num_moduli, 2, dnum, 4, 4). + + Returns: + The matrix-vector product of shape (2, ..., degree, num_moduli) in uint64. + """ + degree = vector_v.shape[-3] + num_moduli = vector_v.shape[-2] + dnum = vector_v.shape[-1] + + if degree >= 128: + block_size = 128 + num_blocks = degree // 128 + else: + block_size = degree + num_blocks = 1 + + # Reshape degree dimension to (num_blocks, block_size) to optimize TPU vectorization + v_reshaped = vector_v.reshape( + *vector_v.shape[:-3], num_blocks, block_size, num_moduli, dnum + ) + + # View-cast vector_v to uint8 -> (..., num_blocks, block_size, num_moduli, dnum, 4) + v_u8 = jax.lax.bitcast_convert_type(v_reshaped, jnp.uint8) + + # einsum subscripts to compute the 2x2 (or 2xdnum) matrix-vector multiplication + # v_u8: ...ikjvq (where i is num_blocks, k is block_size, j is num_moduli, v is dnum, q is 4) + # key_matrix_bat: ikjuvqp (where i is num_blocks, k is block_size, j is num_moduli, u is 2, v is dnum, q is 4, p is 4) + # output: ...ikjup (where u is 2, p is 4) + i8_products = jnp.einsum( + "...ikjvq,ikjuvqp->...ikjup", + v_u8, + key_matrix_bat, + preferred_element_type=jnp.uint32, + ) + + shift_factors = jnp.array([0, 8, 16, 24], dtype=jnp.uint32) + # Shift and sum over the last dimension (p, size 4) + # to reconstruct uint64 values + # Shape after sum: (..., num_blocks, block_size, num_moduli, 2) + summed = jnp.sum(i8_products.astype(jnp.uint64) << shift_factors, axis=-1) + + # Reshape to flatten num_blocks and block_size back to degree + # Shape becomes: (..., degree, num_moduli, 2) + summed_flat = summed.reshape(*summed.shape[:-4], degree, num_moduli, 2) + + # Transpose to (2, ..., degree, num_moduli) + # where the components (size 2) is the first dimension + return jnp.moveaxis(summed_flat, -1, 0) diff --git a/jaxite/jaxite_ckks/bat_utils_test.py b/jaxite/jaxite_ckks/bat_utils_test.py new file mode 100644 index 0000000..9978696 --- /dev/null +++ b/jaxite/jaxite_ckks/bat_utils_test.py @@ -0,0 +1,79 @@ +"""Tests for bat_utils.""" + +import jax +import jax.numpy as jnp +from jaxite.jaxite_ckks import bat_utils +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized + +# Enable 64-bit precision for large integer arithmetic +jax.config.update("jax_enable_x64", True) + + +class BatUtilsTest(parameterized.TestCase): + + def test_bat_key_vector_matmul(self): + degree = 8 + num_moduli = 2 + moduli = jnp.array([1073184769, 1073479681], dtype=jnp.uint32) + + key = jax.random.key(0) + k0, k1, k2, k3 = jax.random.split(key, 4) + + # 1. Generate random key0 and key1 of shape (2, degree, num_moduli) + key0 = jax.random.randint( + k0, + shape=(2, degree, num_moduli), + minval=0, + maxval=2**30, + dtype=jnp.uint32, + ) + key1 = jax.random.randint( + k1, + shape=(2, degree, num_moduli), + minval=0, + maxval=2**30, + dtype=jnp.uint32, + ) + + # Pack into key_matrix (degree, num_moduli, 2, 2) + # key0 corresponds to column 0, key1 to column 1 + # row 0 has key0[0] and key1[0]; row 1 has key0[1] and key1[1] + stacked = jnp.stack( + [key0, key1], axis=1 + ) # Shape: (2, 2, degree, num_moduli) + key_matrix = jnp.transpose( + stacked, (2, 3, 0, 1) + ) # Shape: (degree, num_moduli, 2, 2) + + # 2. Generate random plaintexts a and b of shape (degree, num_moduli) + a = jax.random.randint( + k2, shape=(degree, num_moduli), minval=0, maxval=2**30, dtype=jnp.uint32 + ) + b = jax.random.randint( + k3, shape=(degree, num_moduli), minval=0, maxval=2**30, dtype=jnp.uint32 + ) + vector_v = jnp.stack([a, b], axis=-1) # Shape: (degree, num_moduli, 2) + + # 3. Compute expected product using exact modular arithmetic + # prod0 = key0 * a (element-wise over degree and moduli) + # prod1 = key1 * b + # expected = (prod0 + prod1) % moduli + moduli_expanded = moduli.reshape(1, 1, -1) + prod0 = (key0.astype(jnp.uint64) * a.astype(jnp.uint64)) % moduli_expanded + prod1 = (key1.astype(jnp.uint64) * b.astype(jnp.uint64)) % moduli_expanded + expected = (prod0 + prod1) % moduli_expanded + + # 4. Perform BAT pre-transformation on key_matrix + key_matrix_bat = bat_utils.basis_aligned_transform_key(key_matrix, moduli) + + # 5. Run BAT matrix-vector multiplication + actual_uint64 = bat_utils.matmul_bat_key_vector(vector_v, key_matrix_bat) + actual = (actual_uint64 % moduli_expanded).astype(jnp.uint32) + + np.testing.assert_array_equal(np.array(actual), np.array(expected)) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxite/jaxite_ckks/blind_rotate.py b/jaxite/jaxite_ckks/blind_rotate.py index 8a29657..7854738 100644 --- a/jaxite/jaxite_ckks/blind_rotate.py +++ b/jaxite/jaxite_ckks/blind_rotate.py @@ -1,10 +1,159 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Blind rotation implementations for CKKS.""" +import jax import jax.numpy as jnp from jaxite.jaxite_ckks import barrett +from jaxite.jaxite_ckks import basis_conversion +from jaxite.jaxite_ckks import bat_utils +from jaxite.jaxite_ckks import blind_rotate_utils from jaxite.jaxite_ckks import mul from jaxite.jaxite_ckks import rescale from jaxite.jaxite_ckks import types +import numpy as np + + +def hmuxrot( + ct: types.Ciphertext, + hmrkey: types.HMuxRotKey, + j: int, + bc_kernel: basis_conversion.BasisConversionBarrett, + control_index: int, + p_limbs: jax.Array, + mul_kernel: mul.MulPlaintextCiphertextBarrett, + rescale_kernel: rescale.Rescale, +) -> types.Ciphertext: + """Evaluates HMuxRot^(j)(hmrkey_beta, ct). + + Computes: P^-1 * [ a(X^{5^{-j}}) * hmrkey_0 + b(X^{5^{-j}}) * hmrkey_1 ] mod Q + + Args: + ct: The input ciphertext (a, b) under Q. + hmrkey: The HMuxRot key under PQ. + j: The rotation index. + bc_kernel: The basis conversion kernel. + control_index: The control index for basis conversion Q -> P. + p_limbs: The limbs of P. + mul_kernel: The multiplication kernel under PQ. + rescale_kernel: The rescaling kernel. + + Returns: + The resulting ciphertext under Q. + """ + g = pow(5, -j, 2 * ct.data.shape[1]) + a_rot = blind_rotate_utils.apply_automorphism_ntt(ct.data[1], g) + b_rot = blind_rotate_utils.apply_automorphism_ntt(ct.data[0], g) + + a_rot_ct = types.Ciphertext( + data=jnp.expand_dims(a_rot, axis=0), moduli=ct.moduli + ) + b_rot_ct = types.Ciphertext( + data=jnp.expand_dims(b_rot, axis=0), moduli=ct.moduli + ) + + a_lifted_ct = blind_rotate_utils.lift_ciphertext( + a_rot_ct, bc_kernel, control_index, p_limbs + ) + b_lifted_ct = blind_rotate_utils.lift_ciphertext( + b_rot_ct, bc_kernel, control_index, p_limbs + ) + + a_lifted_pt = types.Plaintext( + data=jnp.squeeze(a_lifted_ct.data, axis=0), moduli=a_lifted_ct.moduli + ) + b_lifted_pt = types.Plaintext( + data=jnp.squeeze(b_lifted_ct.data, axis=0), moduli=b_lifted_ct.moduli + ) + + # Stack a_lifted_pt and b_lifted_pt into vector_v of shape + # (degree, num_moduli, 2) + vector_v = jnp.stack([a_lifted_pt.data, b_lifted_pt.data], axis=-1) + + # Compute matrix multiplication using the BAT matrix-vector multiplication + # kernel + prod = bat_utils.matmul_bat_key_vector(vector_v, hmrkey.key_matrix_bat) + + # Perform modular reduction + reduced = barrett.modular_reduction(prod, mul_kernel.barrett_constants) + + ctout = types.Ciphertext(data=reduced, moduli=a_lifted_pt.moduli) + + rescale_kernel.rescale(ctout) + + return ctout + + +def brot_mux( + ct_in: types.Ciphertext, + mux_key: types.MuxRotationKey, + p_limbs: jax.Array, + bc_kernel: basis_conversion.BasisConversionBarrett, + control_index: int, + mul_kernel: mul.MulPlaintextCiphertextBarrett, + rescale_kernel: rescale.Rescale, +) -> types.Ciphertext: + """Homomorphic Blind Rotation using the Mux Method (BRotMux). + + Sequentially applies the MUX-based conditional rotation for each bit of the + rotation index, resulting in a right-rotation of ct_in by the secret index. + + Args: + ct_in: The input ciphertext under Q. + mux_key: The MuxRotationKey containing the keys for each bit. + p_limbs: The limbs of the auxiliary modulus P. + bc_kernel: The basis conversion kernel. + control_index: The control index for basis conversion Q -> P. + mul_kernel: The multiplication kernel under PQ. + rescale_kernel: The rescaling kernel. + + Returns: + A Ciphertext under Q representing the rotated ciphertext. + """ + ct_out = ct_in + for k, (hmrkey_jk_0, hmrkey_not_jk_1) in enumerate(mux_key.keys): + ct0 = hmuxrot( + ct=ct_out, + hmrkey=hmrkey_jk_0, + j=2**k, + bc_kernel=bc_kernel, + control_index=control_index, + p_limbs=p_limbs, + mul_kernel=mul_kernel, + rescale_kernel=rescale_kernel, + ) + ct1 = hmuxrot( + ct=ct_out, + hmrkey=hmrkey_not_jk_1, + j=0, + bc_kernel=bc_kernel, + control_index=control_index, + p_limbs=p_limbs, + mul_kernel=mul_kernel, + rescale_kernel=rescale_kernel, + ) + moduli_expanded = jnp.array(ct0.moduli, dtype=jnp.uint64).reshape(1, 1, -1) + sum_data = ct0.data.astype(jnp.uint64) + ct1.data.astype(jnp.uint64) + sum_reduced = jnp.where( + sum_data >= moduli_expanded, sum_data - moduli_expanded, sum_data + ) + ct_out = types.Ciphertext( + data=sum_reduced.astype(jnp.uint32), moduli=ct0.moduli + ) + + return ct_out def brot_cm( @@ -39,7 +188,7 @@ def brot_cm( if len(cmkey_j) != len(pt_rot_mu_all): raise ValueError("Lengths of cmkey_j and pt_rot_mu_all must match.") - if not jnp.array_equal(cmkey_j[0].moduli, pt_rot_mu_all[0].moduli): + if not np.array_equal(cmkey_j[0].moduli, pt_rot_mu_all[0].moduli): raise ValueError("Moduli of cmkey_j and pt_rot_mu_all must match.") ct_data = jnp.stack([ct.data for ct in cmkey_j]) diff --git a/jaxite/jaxite_ckks/blind_rotate_test.py b/jaxite/jaxite_ckks/blind_rotate_test.py index 8f98a5b..a37515f 100644 --- a/jaxite/jaxite_ckks/blind_rotate_test.py +++ b/jaxite/jaxite_ckks/blind_rotate_test.py @@ -1,7 +1,23 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + """Tests for blind rotation kernels.""" import jax +import jax.numpy as jnp from jaxite.jaxite_ckks import barrett +from jaxite.jaxite_ckks import basis_conversion from jaxite.jaxite_ckks import blind_rotate from jaxite.jaxite_ckks import encode from jaxite.jaxite_ckks import encrypt @@ -111,6 +127,88 @@ def test_blind_rotate_cm(self, num_slots, secret_idx): self.assertAlmostEqual(e.real, d.real, delta=1.0) self.assertAlmostEqual(e.imag, d.imag, delta=1.0) + @parameterized.named_parameters( + ("secret_idx_0", 4, 0), + ("secret_idx_1", 4, 1), + ("secret_idx_2", 4, 2), + ("secret_idx_3", 4, 3), + ("dense_8_slots", 8, 5), + ("dense_16_slots", 16, 9), + ) + def test_brot_mux(self, num_slots, secret_idx): + degree = max(1024, 2 * num_slots) + r = 32 + q_limbs = [1073184769] + p_limbs = [1073479681] + all_moduli = q_limbs + p_limbs + scale = 2**22 + + # 1. Generate keys + test_random_source = random.ZeroNoiseRandomSource() + pk_q, sk_q = key_gen.keygen( + degree, q_limbs, random_source=test_random_source + ) + # 2. Generate Mux Rotation Key for secret index + # We choose the secret index bits corresponding to secret_idx + num_bits = int(np.log2(num_slots)) + secret_bits = [int((secret_idx >> k) & 1) for k in range(num_bits)] + mux_key = key_gen.gen_mux_rotation_key( + sk=sk_q, + secret_bits=secret_bits, + q_limbs=q_limbs, + p_limbs=p_limbs, + random_source=test_random_source, + ) + + # 3. Setup input ciphertext ct_in + mu = np.array( + [complex(x % 4 + 1, x % 4 + 2) for x in range(num_slots)], dtype=complex + ) + + bc_kernel = basis_conversion.BasisConversionBarrett() + bc_kernel.precompute_constants(all_moduli, [([0], [1])]) + mul_constants = barrett.precompute_barrett_constants(all_moduli) + mul_kernel = mul.MulPlaintextCiphertextBarrett(mul_constants) + + rescale_kernel = rescale.Rescale() + rescale_kernel.precompute_constants( + all_moduli, num_rescales=1, r=r, c=degree // r + ) + + encoder_q = encode.Encode(degree, q_limbs, scale) + encryptor_q = encrypt.Encrypt(pk_q) + + plain_mu = encoder_q.encode(mu.tolist()) + ct_in = encryptor_q.encrypt(plain_mu, random_source=test_random_source) + + # 4. Run homomorphic BRotMux + ct_res = blind_rotate.brot_mux( + ct_in=ct_in, + mux_key=mux_key, + p_limbs=jnp.array(p_limbs, dtype=jnp.uint32), + bc_kernel=bc_kernel, + control_index=0, + mul_kernel=mul_kernel, + rescale_kernel=rescale_kernel, + ) + + # 5. Decrypt and verify result + decryptor_q = encrypt.Decrypt(sk_q) + pt_dec = decryptor_q.decrypt(ct_res) + + decoder = encode.Decode(scale, num_slots) + decoded = decoder.decode(pt_dec) + + full_slots = degree // 2 + mu_full = np.zeros(full_slots, dtype=complex) + mu_full[:num_slots] = mu + expected_full = _negacyclic_roll(mu_full, secret_idx) + expected = expected_full[:num_slots] + + for e, d in zip(expected, decoded): + self.assertAlmostEqual(e.real, d.real, delta=1.5) + self.assertAlmostEqual(e.imag, d.imag, delta=1.5) + if __name__ == "__main__": absltest.main() diff --git a/jaxite/jaxite_ckks/blind_rotate_utils.py b/jaxite/jaxite_ckks/blind_rotate_utils.py new file mode 100644 index 0000000..09dd81b --- /dev/null +++ b/jaxite/jaxite_ckks/blind_rotate_utils.py @@ -0,0 +1,63 @@ +"""Utility functions for homomorphic blind rotation.""" + +import math +import jax +import jax.numpy as jnp +from jaxite.jaxite_ckks import basis_conversion +from jaxite.jaxite_ckks import types + + +def apply_automorphism_ntt(data: jax.Array, g: int) -> jax.Array: + """Applies the automorphism X -> X^g to a polynomial in the NTT domain. + + Handles the bit-reversed layout of jaxite's NTT representation. + + Args: + data: The polynomial in NTT domain. Shape (..., degree, num_moduli). + g: The automorphism generator (must be odd). + + Returns: + The permuted polynomial in NTT domain with the same shape. + """ + degree = data.shape[-2] + bits = int(math.log2(degree)) + indices = jnp.arange(degree, dtype=jnp.uint32) + + # Bit-reverse indices to map the bit-reversed layout of jaxite's NTT + def bit_reverse(x): + rev = jnp.zeros_like(x) + temp = x + for _ in range(bits): + rev = (rev << 1) | (temp & 1) + temp >>= 1 + return rev + + br_indices = bit_reverse(indices) + g_u32 = jnp.array(g, dtype=jnp.uint32) + target_roots = (((2 * br_indices + 1) * g_u32 - 1) // 2) % degree + target_indices = bit_reverse(target_roots) + + return jnp.take(data, target_indices, axis=-2) + + +def lift_ciphertext( + ct: types.Ciphertext, + bc_kernel: basis_conversion.BasisConversionBarrett, + control_index: int, + p_limbs: jax.Array, +) -> types.Ciphertext: + """Lifts a ciphertext from Q to PQ using basis conversion. + + Args: + ct: The input ciphertext under Q. Shape (num_elements, degree, num_Q). + bc_kernel: The precomputed basis conversion kernel from Q to P. + control_index: The control index specifying the Q -> P conversion. + p_limbs: The limbs of modulus P. + + Returns: + A lifted Ciphertext under PQ. Shape (num_elements, degree, num_Q + num_P). + """ + data_p = bc_kernel.basis_change(ct.data, control_index=control_index) + data_pq = jnp.concatenate([ct.data, data_p], axis=-1) + moduli_pq = jnp.concatenate([ct.moduli, p_limbs]).astype(jnp.uint32) + return types.Ciphertext(data=data_pq, moduli=moduli_pq) diff --git a/jaxite/jaxite_ckks/blind_rotate_utils_test.py b/jaxite/jaxite_ckks/blind_rotate_utils_test.py new file mode 100644 index 0000000..5d27a75 --- /dev/null +++ b/jaxite/jaxite_ckks/blind_rotate_utils_test.py @@ -0,0 +1,87 @@ +"""Tests for blind rotation utilities.""" + +import jax +import jax.numpy as jnp +from jaxite.jaxite_ckks import basis_conversion +from jaxite.jaxite_ckks import blind_rotate_utils +from jaxite.jaxite_ckks import ntt_cpu +from jaxite.jaxite_ckks import types +import numpy as np +from absl.testing import absltest +from absl.testing import parameterized + +jax.config.update("jax_enable_x64", True) + + +class BlindRotateUtilsTest(parameterized.TestCase): + + def test_apply_automorphism_ntt(self): + # Test with a simple identity automorphism (g = 1) + degree = 8 + data = jnp.arange(degree, dtype=jnp.float64).reshape(degree, 1) + res = blind_rotate_utils.apply_automorphism_ntt(data, 1) + np.testing.assert_array_equal(res, data) + + def test_lift_ciphertext(self): + degree = 8 + q_limbs = [1073184769] + p_limbs = [1073479681] + all_moduli = q_limbs + p_limbs + + # ntt of zero is zero, e is zero + # b = - a * sk + a_slots = jnp.zeros((degree, len(q_limbs)), dtype=jnp.uint32) + b_slots = jnp.zeros((degree, len(q_limbs)), dtype=jnp.uint32) + ct = types.Ciphertext( + data=jnp.stack([b_slots, a_slots]), + moduli=jnp.array(q_limbs, dtype=jnp.uint32), + ) + + bc_kernel = basis_conversion.BasisConversionBarrett() + bc_kernel.precompute_constants(all_moduli, [([0], [1])]) + + lifted_ct = blind_rotate_utils.lift_ciphertext( + ct, + bc_kernel, + control_index=0, + p_limbs=jnp.array(p_limbs, dtype=jnp.uint32), + ) + + # Output dimensions should match PQ towers + self.assertEqual(lifted_ct.data.shape, (2, degree, 2)) + np.testing.assert_array_equal( + lifted_ct.moduli, np.array(all_moduli, dtype=np.uint32) + ) + + @parameterized.parameters(3, 5, 7) + def test_apply_automorphism_ntt_non_trivial(self, g): + degree = 8 + q = 1073184769 + # Create a non-trivial polynomial in coefficient domain + poly = np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=np.uint64).reshape( + degree, 1 + ) + + # Compute automorphism in coefficient domain + poly_rot = np.zeros_like(poly) + for i in range(degree): + target_pow = (i * g) % (2 * degree) + val = poly[i, 0] + if target_pow >= degree: + target_pow -= degree + poly_rot[target_pow, 0] = (q - val) % q + else: + poly_rot[target_pow, 0] = val + + # Convert to NTT domain + ntt_poly = ntt_cpu.ntt_negacyclic_poly(poly, [q]) + expected_ntt_poly_rot = ntt_cpu.ntt_negacyclic_poly(poly_rot, [q]) + + # Apply automorphism in NTT domain + res = blind_rotate_utils.apply_automorphism_ntt(jnp.array(ntt_poly), g) + + np.testing.assert_array_equal(np.array(res), expected_ntt_poly_rot) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxite/jaxite_ckks/key_gen.py b/jaxite/jaxite_ckks/key_gen.py index 8e8408f..9f72fdc 100644 --- a/jaxite/jaxite_ckks/key_gen.py +++ b/jaxite/jaxite_ckks/key_gen.py @@ -1,8 +1,8 @@ """Key generation utilities for CKKS.""" import math - import jax.numpy as jnp +from jaxite.jaxite_ckks import blind_rotate_utils from jaxite.jaxite_ckks import encode from jaxite.jaxite_ckks import encrypt from jaxite.jaxite_ckks import ntt_cpu @@ -43,6 +43,8 @@ def keygen( return PublicKey(pk_data, np.array(moduli, dtype=np.uint64)), SecretKey( s_ntt, np.array(moduli, dtype=np.uint64) ) + + def extend_secret_key( secret_key: SecretKey, target_moduli: list[int], @@ -113,6 +115,7 @@ def gen_key_switching_key( p_limbs: list[int], dnum: int, random_source: random.RandomSource | None = None, + dest_key_ext: SecretKey | None = None, ) -> types.EvaluationKeys: """Generate key switching keys to switch from source_key to dest_key.""" random_source = random_source or random.SecureRandomSource() @@ -122,7 +125,8 @@ def gen_key_switching_key( all_moduli = q_limbs + p_limbs all_moduli_u64 = np.array(all_moduli, dtype=np.uint64).reshape(1, -1) - dest_key_ext = extend_secret_key(dest_key, all_moduli) + if dest_key_ext is None: + dest_key_ext = extend_secret_key(dest_key, all_moduli) s_dst_ext_slots = dest_key_ext.data p_val = math.prod(p_limbs) @@ -214,3 +218,137 @@ def gen_cm_keys( cm_keys[i, j] = encryptor.encrypt(plain_0) return cm_keys + + +def gen_hmuxrot_key( + sk: types.SecretKey, + beta: int, + j: int, + q_limbs: list[int], + p_limbs: list[int], + random_source: random.RandomSource | None = None, + sk_ext: types.SecretKey | None = None, +) -> types.HMuxRotKey: + """Generates an HMuxRotKey symmetrically. + + The key consists of: + - key0: Symmetrically encrypts target0 = P * beta * sk(X^{5^{-j}}) under sk. + Constructed as a key-switching key from beta * sk(X^{5^{-j}}) to sk. + - key1: Symmetrically encrypts target1 = P * beta under sk. + + Args: + sk: The secret key under Q. + beta: The selector bit (0 or 1). + j: The rotation index. + q_limbs: The limbs of the ciphertext modulus Q. + p_limbs: The limbs of the auxiliary modulus P. + random_source: Optional random source. Defaults to SecureRandomSource. + sk_ext: Optional pre-extended secret key. If not provided, it is computed. + + Returns: + The generated HMuxRotKey. + """ + random_source = random_source or random.SecureRandomSource() + degree = sk.data.shape[0] + all_moduli = q_limbs + p_limbs + all_moduli_u64 = np.array(all_moduli, dtype=np.uint64).reshape(1, -1) + q_limbs_u64 = np.array(q_limbs, dtype=np.uint64).reshape(1, -1) + + if sk_ext is None: + dest_sk_ext = extend_secret_key(sk, all_moduli) + else: + dest_sk_ext = sk_ext + + p_val = math.prod(p_limbs) + scaled_val0 = (p_val * beta) % all_moduli_u64 + + # key0 is a key-switching key from beta * sk(X^{5^{-j}}) to sk. + g = pow(5, -j, 2 * degree) + sk_rot_data = blind_rotate_utils.apply_automorphism_ntt(jnp.array(sk.data), g) + # Scale by beta (modulo Q) + sk_rot_beta_data = (sk_rot_data * beta) % q_limbs_u64 + sk_rot_beta = types.SecretKey(np.array(sk_rot_beta_data), sk.moduli) + + ksk = gen_key_switching_key( + source_key=sk_rot_beta, + dest_key=sk, + q_limbs=q_limbs, + p_limbs=p_limbs, + dnum=1, + random_source=random_source, + dest_key_ext=dest_sk_ext, + ) + key0 = types.Ciphertext( + data=jnp.stack([ksk.b[0], ksk.a[0]]), + moduli=ksk.moduli, + ) + + # key1 encrypts P * beta under sk. Since this is a constant polynomial and not + # a secret key, we perform direct symmetric encryption inline. + target1 = np.ones((degree, len(all_moduli)), dtype=np.uint64) * scaled_val0 + target1 = target1 % all_moduli_u64 + + a_coeffs = random_source.gen_uniform_poly(degree, all_moduli) + e_coeffs = random_source.gen_gaussian_poly(degree, all_moduli) + a_slots = ntt_cpu.ntt_negacyclic_poly(a_coeffs, all_moduli) + e_slots = ntt_cpu.ntt_negacyclic_poly(e_coeffs, all_moduli) + + prod = (a_slots * dest_sk_ext.data) % all_moduli_u64 + b_slots = (e_slots + target1 + all_moduli_u64 - prod) % all_moduli_u64 + + key1 = types.Ciphertext( + data=jnp.array(np.stack([b_slots, a_slots]), dtype=jnp.uint32), + moduli=jnp.array(all_moduli, dtype=jnp.uint32), + ) + + return types.HMuxRotKey(key0, key1) + + +def gen_mux_rotation_key( + sk: types.SecretKey, + secret_bits: list[int], + q_limbs: list[int], + p_limbs: list[int], + random_source: random.RandomSource | None = None, +) -> types.MuxRotationKey: + """Generates a MuxRotationKey for the bits of the rotation index. + + For each bit k from 0 to len(secret_bits) - 1, generates: + - hmrkey_jk_0: HMuxRotKey for beta = secret_bits[k], rotation amount = 2^k. + - hmrkey_not_jk_1: HMuxRotKey for beta = 1 - secret_bits[k], rotation = 0. + + Args: + sk: The secret key under Q. + secret_bits: The list of bits representing the secret rotation index. + q_limbs: The limbs of the ciphertext modulus Q. + p_limbs: The limbs of the auxiliary modulus P. + random_source: Optional random source. + + Returns: + The generated MuxRotationKey. + """ + all_moduli = q_limbs + p_limbs + sk_ext = extend_secret_key(sk, all_moduli) + + keys = [] + for k, bit in enumerate(secret_bits): + hmrkey_jk_0 = gen_hmuxrot_key( + sk=sk, + beta=bit, + j=2**k, + q_limbs=q_limbs, + p_limbs=p_limbs, + random_source=random_source, + sk_ext=sk_ext, + ) + hmrkey_not_jk_1 = gen_hmuxrot_key( + sk=sk, + beta=1 - bit, + j=0, + q_limbs=q_limbs, + p_limbs=p_limbs, + random_source=random_source, + sk_ext=sk_ext, + ) + keys.append((hmrkey_jk_0, hmrkey_not_jk_1)) + return types.MuxRotationKey(keys) diff --git a/jaxite/jaxite_ckks/ntt.py b/jaxite/jaxite_ckks/ntt.py index 95f9ab7..2226a56 100644 --- a/jaxite/jaxite_ckks/ntt.py +++ b/jaxite/jaxite_ckks/ntt.py @@ -7,6 +7,7 @@ import jax import jax.numpy as jnp from jaxite.jaxite_ckks import barrett +from jaxite.jaxite_ckks import bat_utils from jaxite.jaxite_ckks import math as ckks_math # Enable 64-bit precision for large integer arithmetic @@ -101,48 +102,6 @@ def slice_moduli(self, slice_obj): ) -def _matmul_bat_einsum( - lhs: jax.Array, rhs: jax.Array, subscripts: str -) -> jax.Array: - """Basis Aligned Transformation (BAT) based matrix multiplication. - - Args: - lhs: input - rhs: twiddle factor matrix - subscripts: einsum subscripts - - Returns: - The matrix multiplication result. - """ - lhs_u8 = jax.lax.bitcast_convert_type(lhs, jnp.uint8) - shift_factors = jnp.array([0, 8, 16, 24], dtype=jnp.uint32) - i8_products = jnp.einsum( - subscripts, lhs_u8, rhs, preferred_element_type=jnp.uint32 - ) - return jnp.sum(i8_products.astype(jnp.uint64) << shift_factors, axis=(-1,)) - - -def _basis_aligned_transformation( - matrix: jnp.ndarray, moduli: list[int] -) -> jnp.ndarray: - """Prepares a matrix for Basis Aligned Transformation (BAT).""" - matrix_u64 = matrix.astype(jnp.uint64) - num_bytes = 4 - matrix_u64_byteshifted = jnp.array( - [matrix_u64 << (8 * byte_idx) for byte_idx in range(num_bytes)], - dtype=jnp.uint64, - ) - moduli_arr = jnp.array(moduli, dtype=jnp.uint64) - matrix_u64_byteshifted_mod_modulus = ( - matrix_u64_byteshifted % moduli_arr - ).astype(jnp.uint32) - # Output shape: (4, ..., moduli, 4) - matrix_u8 = jax.lax.bitcast_convert_type( - matrix_u64_byteshifted_mod_modulus, jnp.uint8 - ) - return matrix_u8 - - @jax.tree_util.register_pytree_node_class class NTTBarrett(NTTBase): """Kernel for NTT with Barrett reduction.""" @@ -261,7 +220,7 @@ def precompute_constants( def to_bat(tf, moduli): # tf: (R, R, M) - raw_bat = _basis_aligned_transformation(tf, moduli) + raw_bat = bat_utils.basis_aligned_transformation(tf, moduli) # raw_bat shape: (4_byte_shift, rows, cols, moduli, 4_u8_bytes) # We want (rows, 4_byte_shift, cols, 4_u8_bytes, moduli) # matching subscripts q=shift, p=u8 @@ -286,7 +245,7 @@ def ntt(self, v: jnp.ndarray) -> jnp.ndarray: # q=u8 (lhs axis 4), q=shift (rhs axis 1). Summed. # z=target row (axis 0), r=source row (axis 2). Sum over r. # p=u8 target (axis 3). Becomes axis 4 of result. - res1 = _matmul_bat_einsum( + res1 = bat_utils.matmul_bat_einsum( v, self.constants.ntt_bat_tf_step1, "...rcmq,zqrpm->...zcmp" ) res1 = barrett.modular_reduction(res1, self.constants.barrett_constants) @@ -300,7 +259,7 @@ def ntt(self, v: jnp.ndarray) -> jnp.ndarray: # For ntt_tf_step3, axis 0 is source, axis 1 is target. # So to_bat axis 0 is source, axis 2 is target. # Subscripts: "cqnpm" -> c=0 (source), n=2 (target). - res3 = _matmul_bat_einsum( + res3 = bat_utils.matmul_bat_einsum( res2, self.constants.ntt_bat_tf_step3, "...rcmq,cqnpm->...rnmp" ) return barrett.modular_reduction(res3, self.constants.barrett_constants) @@ -311,7 +270,7 @@ def intt(self, v: jnp.ndarray) -> jnp.ndarray: # itf1 axis 0 is source, axis 1 is target. # to_bat axis 0 is source, axis 2 is target. # Subscripts: "cqlpm" -> c=0 (source), l=2 (target). - res1 = _matmul_bat_einsum( + res1 = bat_utils.matmul_bat_einsum( v, self.constants.intt_bat_tf_step1, "...rcmq,cqlpm->...rlmp" ) res1 = barrett.modular_reduction(res1, self.constants.barrett_constants) @@ -323,7 +282,7 @@ def intt(self, v: jnp.ndarray) -> jnp.ndarray: # itf3 axis 0 is target, axis 1 is source. # to_bat axis 0 is target, axis 2 is source. # Subscripts: "lqrpm" -> l=0 (target), r=2 (source). - res3 = _matmul_bat_einsum( + res3 = bat_utils.matmul_bat_einsum( res2, self.constants.intt_bat_tf_step3, "...rcmq,lqrpm->...lcmp" ) return barrett.modular_reduction(res3, self.constants.barrett_constants) diff --git a/jaxite/jaxite_ckks/types.py b/jaxite/jaxite_ckks/types.py index 9035f0b..1712ed4 100644 --- a/jaxite/jaxite_ckks/types.py +++ b/jaxite/jaxite_ckks/types.py @@ -2,6 +2,8 @@ import dataclasses import jax +import jax.numpy as jnp +from jaxite.jaxite_ckks import bat_utils import numpy as np @@ -18,7 +20,7 @@ def tree_flatten(self): @classmethod def tree_unflatten(cls, _, children): - return cls(*children) + return cls(children[0], children[1]) @jax.tree_util.register_pytree_node_class @@ -34,7 +36,7 @@ def tree_flatten(self): @classmethod def tree_unflatten(cls, _, children): - return cls(*children) + return cls(children[0], children[1]) @dataclasses.dataclass(frozen=True) @@ -60,3 +62,58 @@ class EvaluationKeys: a: jax.Array b: jax.Array moduli: jax.Array + + +@jax.tree_util.register_pytree_node_class +@dataclasses.dataclass(frozen=True) +class HMuxRotKey: + """A key used in a single HMuxRot step. + + Consists of two ciphertexts symmetrically encrypted under the destination key + sk modulo PQ: + - key0: encrypts P * beta * sk(X^{5^{-j}}) + - key1: encrypts P * beta + """ + + key0: Ciphertext + key1: Ciphertext + key_matrix_bat: jax.Array = dataclasses.field(init=False) + + def __post_init__(self): + # Stack key0.data and key1.data along axis 1 to get + # (2, 2, degree, num_moduli). + # Then transposing to (degree, num_moduli, 2, 2) + stacked = jnp.stack([self.key0.data, self.key1.data], axis=1) + key_matrix = jnp.transpose(stacked, (2, 3, 0, 1)) + + # Precompute BAT representation using bat_utils helper + key_matrix_bat = bat_utils.basis_aligned_transform_key( + key_matrix, self.key0.moduli + ) + object.__setattr__(self, "key_matrix_bat", key_matrix_bat) + + def tree_flatten(self): + return (self.key0, self.key1), None + + @classmethod + def tree_unflatten(cls, _, children): + return cls(children[0], children[1]) + + +@jax.tree_util.register_pytree_node_class +@dataclasses.dataclass(frozen=True) +class MuxRotationKey: + """A set of HMuxRot keys for all bits of a secret rotation index. + + Contains a list of pairs of keys (hmrkey_jk_0, hmrkey_not_jk_1) for each bit k + from 0 to log2(num_slots) - 1. + """ + + keys: list[tuple[HMuxRotKey, HMuxRotKey]] + + def tree_flatten(self): + return (self.keys,), None + + @classmethod + def tree_unflatten(cls, _, children): + return cls(children[0])