Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 29 additions & 6 deletions BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,7 @@ tpu_test(
size = "large",
timeout = "moderate",
srcs = ["jaxite_ec/finite_field_test.py"],
python_version = "PY3",
shard_count = 3,
srcs_version = "PY3ONLY",
deps = [
":jaxite",
# copybara: xprof_analysis_client # buildcleaner: keep
Expand Down Expand Up @@ -200,9 +198,7 @@ tpu_test(
"jaxite_ec/test_case/t8/zprize_msm_curve_377_res_dim_8_seed_0.csv",
"jaxite_ec/test_case/t8/zprize_msm_curve_377_scalars_dim_8_seed_0.csv",
],
python_version = "PY3",
shard_count = 3,
srcs_version = "PY3ONLY",
deps = [
":jaxite",
# copybara: xprof_analysis_client # buildcleaner: keep
Expand All @@ -221,9 +217,7 @@ tpu_test(
size = "large",
timeout = "long",
srcs = ["jaxite_ec/elliptic_curve_test.py"],
python_version = "PY3",
shard_count = 16,
srcs_version = "PY3ONLY",
deps = [
":jaxite",
# copybara: xprof_analysis_client # buildcleaner: keep
Expand Down Expand Up @@ -485,6 +479,20 @@ cpu_gpu_tpu_test(
],
)

cpu_gpu_tpu_test(
name = "bat_utils_test",
size = "small",
srcs = ["jaxite/jaxite_ckks/bat_utils_test.py"],
deps = [
":jaxite_ckks",
"@abseil-py//absl/testing:absltest",
"@abseil-py//absl/testing:parameterized",
"@jaxite_deps//jax",
"@jaxite_deps//jaxlib",
"@jaxite_deps//numpy",
],
)

cpu_gpu_tpu_test(
name = "barrett_test",
size = "small",
Expand Down Expand Up @@ -617,6 +625,20 @@ py_test(
],
)

py_test(
name = "blind_rotate_utils_test",
size = "small",
srcs = ["jaxite/jaxite_ckks/blind_rotate_utils_test.py"],
deps = [
":jaxite_ckks",
"@abseil-py//absl/testing:absltest",
"@abseil-py//absl/testing:parameterized",
"@jaxite_deps//jax",
"@jaxite_deps//jaxlib",
"@jaxite_deps//numpy",
],
)

cpu_gpu_tpu_test(
name = "blind_rotate_ckks_test",
size = "small",
Expand All @@ -628,6 +650,7 @@ cpu_gpu_tpu_test(
":jaxite_ckks",
"@abseil-py//absl/testing:absltest",
"@abseil-py//absl/testing:parameterized",
"@jaxite_deps//absl/logging",
"@jaxite_deps//jax",
"@jaxite_deps//jaxlib",
"@jaxite_deps//numpy",
Expand Down
81 changes: 28 additions & 53 deletions jaxite/jaxite_ckks/basis_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import jax
import jax.numpy as jnp
from jaxite.jaxite_ckks import barrett
from jaxite.jaxite_ckks import bat_utils
from jaxite.jaxite_ckks import rns_utils

# Enable 64-bit precision for large integer arithmetic
Expand Down Expand Up @@ -85,48 +86,6 @@ def __hash__(self):
return id(self)


def matmul_bat_einsum(
lhs: jax.Array, rhs: jax.Array, subscripts: str
) -> jax.Array:
"""Basis Aligned Transformation (BAT) based matrix multiplication.

Args:
lhs: input
rhs: twiddle factor matrix
subscripts: einsum subscripts

Returns:
The matrix multiplication result.
"""
lhs_u8 = lhs.view(jnp.uint8)
shift_factors = jnp.array([0, 8, 16, 24], dtype=jnp.uint32)
i8_products = jnp.einsum(
subscripts, lhs_u8, rhs, preferred_element_type=jnp.uint32
)
return jnp.sum(i8_products.astype(jnp.uint64) << shift_factors, axis=(-1,))


def _basis_aligned_transformation(
matrix: jnp.ndarray, moduli: list[int]
) -> jnp.ndarray:
"""Prepares a matrix for Basis Aligned Transformation (BAT)."""
matrix_u64 = matrix.astype(jnp.uint64)
num_bytes = 4
matrix_u64_byteshifted = jnp.array(
[matrix_u64 << (8 * byte_idx) for byte_idx in range(num_bytes)],
dtype=jnp.uint64,
)
moduli_arr = jnp.array(moduli, dtype=jnp.uint64)
matrix_u64_byteshifted_mod_modulus = (
matrix_u64_byteshifted % moduli_arr
).astype(jnp.uint32)
# Output shape: (4, ..., moduli, 4)
matrix_u8 = jax.lax.bitcast_convert_type(
matrix_u64_byteshifted_mod_modulus, jnp.uint8
)
return matrix_u8


@jax.tree_util.register_pytree_node_class
class BasisConversionBarrett(BasisConversion):
"""Kernel for Basis Conversion with Barrett reduction."""
Expand Down Expand Up @@ -167,13 +126,11 @@ def precompute_constants(
dtype=jnp.uint64,
)

# BAT Preprocessing
q_hat_mod_p_bat_raw = _basis_aligned_transformation(
q_hat_mod_p_bat_raw = bat_utils.basis_aligned_transformation(
q_hat_mod_p, target_moduli
)
q_hat_mod_p_bat = q_hat_mod_p_bat_raw.transpose(1, 0, 2, 3).reshape(
-1, q_hat_mod_p_bat_raw.shape[2], 4
)
# Shape: (4, num_Q, num_P, 4) -> (num_Q, 4, num_P, 4)
q_hat_mod_p_bat = q_hat_mod_p_bat_raw.transpose(1, 0, 2, 3)

constants = BarrettBasisConversionConstants(
q_hat_inv_mod_q=q_hat_inv_mod_q,
Expand All @@ -191,21 +148,39 @@ def basis_change(
constants = self.precomputed_constants[control_index]
in_tower = jnp.asarray(in_tower, dtype=jnp.uint64)

degree = in_tower.shape[-2]
num_Q = in_tower.shape[-1]

if degree >= 128:
block_size = 128
num_blocks = degree // 128
else:
block_size = degree
num_blocks = 1

# Reshape degree dimension to (num_blocks, block_size) to optimize TPU vectorization
in_tower_reshaped = in_tower.reshape(
*in_tower.shape[:-2], num_blocks, block_size, num_Q
)

# Step 1: Compute c_unreduced = in_tower * QHatInvModq
# Ensure constants.q_hat_inv_mod_q broadcasts over leading dimensions.
# q_hat_inv_mod_q has shape (sizeQ,). in_tower has shape (..., sizeQ).
c_unreduced = in_tower * constants.q_hat_inv_mod_q
c_unreduced = in_tower_reshaped * constants.q_hat_inv_mod_q

# Step 2: Modular Reduction
c = barrett.modular_reduction(c_unreduced, constants.origin_barrett)

# Step 3: BAT-based matrix multiplication
summed_terms = matmul_bat_einsum(
c, constants.q_hat_mod_p_bat, "...q,qpb->...pb"
summed_terms = bat_utils.matmul_bat_einsum(
c, constants.q_hat_mod_p_bat, "...mq,mqpb->...pb"
)

# Flatten the degree dimension back
summed_terms_flat = summed_terms.reshape(
*summed_terms.shape[:-3], degree, summed_terms.shape[-1]
)

# Step 4: Final Modular Reduction
out_tower = barrett.modular_reduction(
summed_terms, constants.target_barrett
summed_terms_flat, constants.target_barrett
)
return out_tower
161 changes: 161 additions & 0 deletions jaxite/jaxite_ckks/bat_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
"""Basis Aligned Transformation (BAT) utilities for CKKS on TPU."""

import jax
import jax.numpy as jnp

# Enable 64-bit precision for large integer arithmetic
jax.config.update("jax_enable_x64", True)


def matmul_bat_einsum(
lhs: jax.Array, rhs: jax.Array, subscripts: str
) -> jax.Array:
"""Basis Aligned Transformation (BAT) based matrix multiplication.

Args:
lhs: input
rhs: twiddle factor matrix
subscripts: einsum subscripts

Returns:
The matrix multiplication result.
"""
lhs_u8 = jax.lax.bitcast_convert_type(lhs, jnp.uint8)
shift_factors = jnp.array([0, 8, 16, 24], dtype=jnp.uint32)
i8_products = jnp.einsum(
subscripts, lhs_u8, rhs, preferred_element_type=jnp.uint32
)
return jnp.sum(i8_products.astype(jnp.uint64) << shift_factors, axis=(-1,))


def basis_aligned_transformation(
matrix: jnp.ndarray, moduli: list[int]
) -> jnp.ndarray:
"""Prepares a matrix for Basis Aligned Transformation (BAT)."""
matrix_u64 = matrix.astype(jnp.uint64)
num_bytes = 4
matrix_u64_byteshifted = jnp.array(
[matrix_u64 << (8 * byte_idx) for byte_idx in range(num_bytes)],
dtype=jnp.uint64,
)
moduli_arr = jnp.array(moduli, dtype=jnp.uint64)
matrix_u64_byteshifted_mod_modulus = (
matrix_u64_byteshifted % moduli_arr
).astype(jnp.uint32)
# Output shape: (4, ..., moduli, 4)
matrix_u8 = jax.lax.bitcast_convert_type(
matrix_u64_byteshifted_mod_modulus, jnp.uint8
)
return matrix_u8


def basis_aligned_transform_key(
key_matrix: jax.Array, moduli: jax.Array | list[int]
) -> jax.Array:
"""Prepares the 4D key matrix of shape (degree, num_moduli, 2, dnum) for BAT.

Args:
key_matrix: The key matrix of shape (degree, num_moduli, 2, dnum).
moduli: The moduli of the key matrix.

Returns:
The transformed key matrix of shape (num_blocks, block_size, num_moduli,
2, dnum, 4, 4) in uint8.
"""
matrix_u64 = key_matrix.astype(jnp.uint64)
num_bytes = 4
matrix_u64_byteshifted = jnp.array(
[matrix_u64 << (8 * byte_idx) for byte_idx in range(num_bytes)],
dtype=jnp.uint64,
) # Shape: (4, degree, num_moduli, 2, dnum)

moduli_expanded = jnp.array(moduli, dtype=jnp.uint64).reshape(1, 1, -1, 1, 1)

matrix_u64_byteshifted_mod_modulus = (
matrix_u64_byteshifted % moduli_expanded
).astype(jnp.uint32)

# Bitcast to uint8: shape becomes (4, degree, num_moduli, 2, dnum, 4)
matrix_u8 = jax.lax.bitcast_convert_type(
matrix_u64_byteshifted_mod_modulus, jnp.uint8
)

# Transpose to (degree, num_moduli, u, v, q, p)
# Axes mapping:
# 0: byte_idx (q, size 4)
# 1: degree
# 2: num_moduli
# 3: u (size 2)
# 4: v (size dnum)
# 5: b (p, size 4)
matrix_u8_transposed = jnp.transpose(matrix_u8, (1, 2, 3, 4, 0, 5))

# Reshape degree dimension to (num_blocks, block_size) to optimize TPU vectorization
degree = key_matrix.shape[0]
if degree >= 128:
block_size = 128
num_blocks = degree // 128
else:
block_size = degree
num_blocks = 1
return matrix_u8_transposed.reshape(
num_blocks, block_size, *matrix_u8_transposed.shape[1:]
)


def matmul_bat_key_vector(
vector_v: jax.Array, key_matrix_bat: jax.Array
) -> jax.Array:
"""Computes BAT-based matrix-vector product for 2x2 or 2xdnum key multiplication.

Args:
vector_v: The input vector of shape (..., degree, num_moduli, dnum).
key_matrix_bat: The pre-transformed key matrix of shape (num_blocks,
block_size, num_moduli, 2, dnum, 4, 4).

Returns:
The matrix-vector product of shape (2, ..., degree, num_moduli) in uint64.
"""
degree = vector_v.shape[-3]
num_moduli = vector_v.shape[-2]
dnum = vector_v.shape[-1]

if degree >= 128:
block_size = 128
num_blocks = degree // 128
else:
block_size = degree
num_blocks = 1

# Reshape degree dimension to (num_blocks, block_size) to optimize TPU vectorization
v_reshaped = vector_v.reshape(
*vector_v.shape[:-3], num_blocks, block_size, num_moduli, dnum
)

# View-cast vector_v to uint8 -> (..., num_blocks, block_size, num_moduli, dnum, 4)
v_u8 = jax.lax.bitcast_convert_type(v_reshaped, jnp.uint8)

# einsum subscripts to compute the 2x2 (or 2xdnum) matrix-vector multiplication
# v_u8: ...ikjvq (where i is num_blocks, k is block_size, j is num_moduli, v is dnum, q is 4)
# key_matrix_bat: ikjuvqp (where i is num_blocks, k is block_size, j is num_moduli, u is 2, v is dnum, q is 4, p is 4)
# output: ...ikjup (where u is 2, p is 4)
i8_products = jnp.einsum(
"...ikjvq,ikjuvqp->...ikjup",
v_u8,
key_matrix_bat,
preferred_element_type=jnp.uint32,
)

shift_factors = jnp.array([0, 8, 16, 24], dtype=jnp.uint32)
# Shift and sum over the last dimension (p, size 4)
# to reconstruct uint64 values
# Shape after sum: (..., num_blocks, block_size, num_moduli, 2)
summed = jnp.sum(i8_products.astype(jnp.uint64) << shift_factors, axis=-1)

# Reshape to flatten num_blocks and block_size back to degree
# Shape becomes: (..., degree, num_moduli, 2)
summed_flat = summed.reshape(*summed.shape[:-4], degree, num_moduli, 2)

# Transpose to (2, ..., degree, num_moduli)
# where the components (size 2) is the first dimension
return jnp.moveaxis(summed_flat, -1, 0)
Loading
Loading