diff --git a/BUILD b/BUILD index e0bb789..236a80d 100644 --- a/BUILD +++ b/BUILD @@ -41,12 +41,14 @@ py_library( deps = [ # copybara: xprof_analysis_client # buildcleaner: keep # copybara: xprof_session # buildcleaner: keep + "@jaxite_deps_absl//:pkg:app", "@jaxite_deps_gmpy2//:pkg", "@jaxite_deps_jax//:pkg", "@jaxite_deps_jaxlib//:pkg", # copybara: jax/experimental:pallas_lib # copybara: jax/experimental:pallas_tpu "@jaxite_deps_numpy//:pkg", + "@jaxite_deps_pandas//:pkg", ], ) @@ -180,60 +182,6 @@ tpu_test( ], ) -tpu_test( - name = "jaxite_word_ntt_test", - size = "large", - timeout = "eternal", - srcs = ["jaxite/jaxite_word/ntt_test.py"], - shard_count = 3, - deps = [ - ":jaxite", - # copybara: xprof_analysis_client # buildcleaner: keep - # copybara: xprof_session # buildcleaner: keep - "@com_google_absl_py//absl/testing:absltest", - "@com_google_absl_py//absl/testing:parameterized", - "@jaxite_deps_gmpy2//:pkg", - "@jaxite_deps_jax//:pkg", - "@jaxite_deps_jaxlib//:pkg", - "@jaxite_deps_numpy//:pkg", - ], -) - -tpu_test( - name = "jaxite_word_sub_test", - size = "large", - timeout = "eternal", - srcs = ["jaxite/jaxite_word/sub_test.py"], - shard_count = 3, - deps = [ - ":jaxite", - # copybara: xprof_analysis_client # buildcleaner: keep - # copybara: xprof_session # buildcleaner: keep - "@com_google_absl_py//absl/testing:absltest", - "@com_google_absl_py//absl/testing:parameterized", - "@jaxite_deps_gmpy2//:pkg", - "@jaxite_deps_jax//:pkg", - "@jaxite_deps_jaxlib//:pkg", - "@jaxite_deps_numpy//:pkg", - ], -) - -tpu_test( - name = "add_test", - size = "large", - timeout = "eternal", - srcs = ["jaxite/jaxite_word/add_test.py"], - shard_count = 3, - deps = [ - ":jaxite", - "@com_google_absl_py//absl/testing:absltest", - "@com_google_absl_py//absl/testing:parameterized", - "@jaxite_deps_jax//:pkg", - "@jaxite_deps_jaxlib//:pkg", - "@jaxite_deps_numpy//:pkg", - ], -) - cpu_gpu_tpu_test( name = "decomposition_test", size = "small", @@ -440,7 +388,7 @@ py_test( name = "rns_test", size = "small", timeout = "moderate", - srcs = ["jaxite/jaxite_ckks/rns_test.py"], + srcs = ["jaxite/jaxite_word/rns_test.py"], deps = [ ":jaxite", ":test_utils", @@ -454,20 +402,97 @@ py_test( ], ) -gpu_tpu_test( - name = "ckks_test", +py_test( + name = "ntt_sm_test", size = "small", timeout = "moderate", - srcs = ["jaxite/jaxite_ckks/ckks_test.py"], + srcs = ["jaxite/jaxite_word/ntt_sm_test.py"], + deps = [ + ":jaxite", + ":test_utils", + "@com_google_absl_py//absl/testing:absltest", + "@com_google_absl_py//absl/testing:parameterized", + "@jaxite_deps_jax//:pkg", + "@jaxite_deps_jaxlib//:pkg", + "@jaxite_deps_numpy//:pkg", + ], +) + +py_test( + name = "ntt_mm_test", + size = "small", + timeout = "moderate", + srcs = ["jaxite/jaxite_word/ntt_mm_test.py"], + deps = [ + ":jaxite", + ":test_utils", + "@com_google_absl_py//absl/testing:absltest", + "@com_google_absl_py//absl/testing:parameterized", + "@jaxite_deps_jax//:pkg", + "@jaxite_deps_jaxlib//:pkg", + "@jaxite_deps_numpy//:pkg", + ], +) + +py_test( + name = "finite_field_test", + size = "small", + timeout = "moderate", + srcs = ["jaxite/jaxite_word/finite_field_test.py"], + deps = [ + ":jaxite", + "@com_google_absl_py//absl/testing:absltest", + "@com_google_absl_py//absl/testing:parameterized", + "@jaxite_deps_jax//:pkg", + "@jaxite_deps_jaxlib//:pkg", + "@jaxite_deps_numpy//:pkg", + ], +) + +py_test( + name = "ckks_ctx_test", + size = "small", + timeout = "moderate", + srcs = ["jaxite/jaxite_word/ckks_ctx_test.py"], + deps = [ + ":jaxite", + ":test_utils", + "@com_google_absl_py//absl/testing:absltest", + "@com_google_absl_py//absl/testing:parameterized", + "@jaxite_deps_jax//:pkg", + "@jaxite_deps_jaxlib//:pkg", + "@jaxite_deps_numpy//:pkg", + ], +) + +py_test( + name = "ciphertext_test", + size = "small", + timeout = "moderate", + srcs = ["jaxite/jaxite_word/ciphertext_test.py"], + deps = [ + ":jaxite", + ":test_utils", + "@com_google_absl_py//absl/testing:absltest", + "@com_google_absl_py//absl/testing:parameterized", + "@jaxite_deps_jax//:pkg", + "@jaxite_deps_jaxlib//:pkg", + "@jaxite_deps_numpy//:pkg", + ], +) + +py_test( + name = "bconv_test", + size = "small", + timeout = "moderate", + srcs = ["jaxite/jaxite_word/bconv_test.py"], deps = [ ":jaxite", ":test_utils", "@com_google_absl_py//absl/testing:absltest", "@com_google_absl_py//absl/testing:parameterized", - "@jaxite_deps_hypothesis//:pkg", "@jaxite_deps_jax//:pkg", "@jaxite_deps_jaxlib//:pkg", "@jaxite_deps_numpy//:pkg", - "@jaxite_deps_parameterized//:pkg", ], ) diff --git a/jaxite/jaxite_ckks/ckks.py b/jaxite/jaxite_ckks/ckks.py deleted file mode 100644 index ed0892e..0000000 --- a/jaxite/jaxite_ckks/ckks.py +++ /dev/null @@ -1,419 +0,0 @@ -"""CKKS Homomorphic Encryption for Approximate Numbers.""" - -import dataclasses -import functools -import math -import secrets -from typing import Tuple - -import jax -import jax.numpy as jnp -from jaxite.jaxite_ckks import rns -from jaxite.jaxite_ckks import rns_utils - -RnsPolynomial = rns.RnsPolynomial -RnsParams = rns.RnsParams -gen_rns_polynomial = rns.gen_rns_polynomial - - -def _rep(x: int, q: int) -> int: - """Balanced representative of x mod q.""" - x = x % q - return x if x <= q // 2 else -(q - x) - - -@dataclasses.dataclass -class CkksEncoder: - """A CKKS encoder. - - In CKKS a complex vector z in CC^{N/2} is identified as the image of a real - polynomial f(X) under mapping rho * tau: - - tau is the canonical embedding map which takes f(X) to its evaluation - f(omega_j) under all primitive 2N'th root of unity omega_j; - - rho maps pairs of conjugate complex numbers (a + bI, a - bI) to a + bI. - - To gain enough precision and represent f(X) as a discrete object, we further - scale up f(X) and round its coefficients to integers. Specifically: - - Encode(z) = round(scaling_factor * tau^-1(rho^-1(z))); - - Decode(f) = 1/scaling_factor * rho(tau(f)). - The rounded coefficients are represented as a polynomial in R_Q, which is the - plaintext in CKKS. We call the input complex vector z as the slot values, and - the CKKS encoding admits slot-wise additive and multiplicative homomorphism. - - Note that Decode is not exactly the inverse of Encode, and precision loss - will happen due to rounding error and float point arithmetic. - """ - - degree: int - moduli: list[int] - scaling_factor: int - - psis_bitrev: list[complex] = dataclasses.field(init=False) - psis_bitrev_inv: list[complex] = dataclasses.field(init=False) - - def __post_init__(self): - if not rns_utils.is_power_of_two(self.degree): - raise ValueError('`degree` must be a power of two.') - - n = self.degree - - # Generate the powers of primitive 2N'th root exp(pi * I / N) and rearrange - # them in the bit reversed order to run Cooley-Tukey and Gentleman-Sande. - theta = math.pi / n - self.psis_bitrev = [ - complex(math.cos(theta * i), math.sin(theta * i)) for i in range(n) - ] - rns_utils.bit_reversal_array(self.psis_bitrev) - self.psis_bitrev_inv = [complex(math.cos(0), -math.sin(0))] + [ - complex(math.cos(theta * i), -math.sin(theta * i)) - for i in range(n - 1, 0, -1) - ] - rns_utils.bit_reversal_array(self.psis_bitrev_inv) - self.psis_bitrev_inv = self.psis_bitrev_inv[::-1] - - def encode(self, values: list[complex]) -> RnsPolynomial: - """Encode an array of complex values to a CKKS plaintext polynomial. - - Args: - values: A list of complex numbers to be encoded. The length of `values` - can be at most N/2 where N is the ring degree of the CKKS scheme. - - Returns: - An RNS polynomial encoding the given values in its slots. - """ - num_slots = self.degree >> 1 - if len(values) > num_slots: - raise ValueError(f'`values` can have at most {num_slots} elements.') - - # Move values to their slot positions. - coeff_values = [complex(0, 0)] * num_slots - power = 1 - for j in range(num_slots): - coeff_values[(power - 1) // 4] = values[j] - power = (power * 5) % (2 * self.degree) - - # The encoded polynomial is round(scaling_factor * DFT^-1(vs)), where - # DFT^-1(vs) computes normalized half Gentleman-Sande inverse FFT. - rns_utils.bit_reversal_array(coeff_values) - self._half_gentleman_sande(coeff_values) - # Normalize the transformation and then round to integers. - factor = self.scaling_factor / num_slots - coeffs = [0] * self.degree - for i in range(num_slots): - coeffs[i] = round(coeff_values[i].real * factor) - coeffs[i + num_slots] = round(coeff_values[i].imag * factor) - - # Convert to RNS representation. - coeffs_qi = [[0] * self.degree for _ in range(len(self.moduli))] - for i, qi in enumerate(self.moduli): - for j in range(self.degree): - coeffs_qi[i][j] = coeffs[j] % qi - - return RnsPolynomial(self.degree, self.moduli, coeffs_qi, is_ntt=False) - - def decode(self, plaintext: RnsPolynomial) -> list[complex]: - """Decode the given plaintext polynomial to its slot values. - - Args: - plaintext: The plaintext polynomial to be decoded. - - Returns: - A list of complex numbers encoded in the plaintext polynomial. - """ - if plaintext.degree != self.degree: - raise ValueError(f'`plaintext` must have degree = {self.degree}.') - if plaintext.is_ntt: - raise ValueError('`plaintext` must be in the coefficient form.') - - # Convert RNS to integer representation. - num_slots = self.degree >> 1 - coeffs = self._crt(plaintext.coeffs, plaintext.moduli) - coeff_values = [ - complex(coeffs[i], coeffs[i + num_slots]) / self.scaling_factor - for i in range(num_slots) - ] - self._half_cooley_tukey(coeff_values) - rns_utils.bit_reversal_array(coeff_values) - - # Move the slot values to their original positions. - slots = [complex(0, 0)] * num_slots - power = 1 - for j in range(num_slots): - slots[j] = coeff_values[(power - 1) // 4] - power = (power * 5) % (2 * self.degree) - return slots - - def _half_cooley_tukey(self, coeffs: list[complex]) -> None: - """Cooley-Tukey FFT but assume the coeffs are half of the input vector. - - This is used to decode a CKKS plaintext polynomial to its slot values. - """ - num_coeffs = len(coeffs) - if num_coeffs * 2 != self.degree: - raise ValueError('`coeffs` must have length degree / 2.') - log_len = rns_utils.num_bits(num_coeffs) - for i in range(log_len - 1, -1, -1): - half_m = 1 << i - m = half_m << 1 - index_psi = 1 << (log_len - i) - for k in range(0, num_coeffs, m): - for j in range(half_m): - t = coeffs[k + j + half_m] * self.psis_bitrev[index_psi] - u = coeffs[k + j] - coeffs[k + j] += t - coeffs[k + j + half_m] = u - t - index_psi += 1 - - def _half_gentleman_sande(self, coeffs: list[complex]) -> None: - """Sandie-Gentleman FFT but assume the coeffs are half of the input vector. - - This is used to encode a list of complex numbers to a CKKS plaintext - polynomial. - """ - num_coeffs = len(coeffs) - if num_coeffs * 2 != self.degree: - raise ValueError('`coeffs` must have length degree / 2.') - log_len = rns_utils.num_bits(num_coeffs) - index_psi_base = 0 - for i in range(log_len): - half_m = 1 << i - m = half_m << 1 - index_psi_inv = index_psi_base - for k in range(0, num_coeffs, m): - for j in range(half_m): - t = coeffs[k + j + half_m] - u = coeffs[k + j] - coeffs[k + j] += t - coeffs[k + j + half_m] = (u - t) * self.psis_bitrev_inv[index_psi_inv] - index_psi_inv += 1 - index_psi_base += 1 << (log_len - i) - - def _crt(self, coeffs_qs: list[list[int]], qs: list[int]) -> list[int]: - """CRT interpolation of coeffs_qs mod (qs[i] for all i). - - Args: - coeffs_qs: The coefficients of a polynomial a(X) modulo q_i - , for all i. - qs: A list of moduli q_i's whose product is Q. - - Returns: - The coefficients of the polynomial a(X) modulo Q. - """ - num_moduli = len(qs) - if num_moduli == 1: - return [_rep(coeffs_qs[0][i], qs[0]) for i in range(self.degree)] - - q = functools.reduce(lambda x, y: x * y, qs) - q_hats = [q // q_i for q_i in qs] - q_hat_invs = [ - rns_utils.inverse_mod(q_hats[i], qs[i]) for i in range(num_moduli) - ] - - coeffs = [0] * self.degree - for i in range(num_moduli): - for j in range(self.degree): - coeffs[j] += coeffs_qs[i][j] * q_hat_invs[i] * q_hats[i] - for j in range(self.degree): - coeffs[j] = _rep(coeffs[j] % q, q) - return coeffs - - -@dataclasses.dataclass -class CkksCiphertext: - """A CKKS ciphertext. - - A CKKS ciphertext of degree k is a list of k+1 polynomials [c0, ..., ck] in - R_Q such that c0 + c1 * s + ... + ck * s^k = plaintext + error. - In addition, the level of a ciphertext is L such that Q is a product of L+1 - distinct prime moduli. - """ - - # the RNS moduli whose product is Q - moduli: list[int] - - # the polynomials. - components: list[RnsPolynomial] - - @property - def level(self) -> int: - return len(self.moduli) - 1 - - @property - def degree(self) -> int: - return len(self.components) - 1 - - def to_ntt_form(self, ntt_params) -> None: - for c in self.components: - c.to_ntt_form(ntt_params) - - def to_coeffs_form(self, ntt_params) -> None: - for c in self.components: - c.to_coeffs_form(ntt_params) - - def to_jnp_array(self) -> Tuple[jnp.ndarray, jnp.ndarray]: - return ( - jnp.array(self.moduli, dtype=jnp.uint64), - jnp.array([c.to_jnp_array() for c in self.components]), - ) - - -@dataclasses.dataclass -class CkksSecretKey: - """A CKKS secret key.""" - - # the RNS moduli whose product is Q - moduli: list[int] - - # the secret key. - key: RnsPolynomial - - @property - def level(self) -> int: - return len(self.moduli) - 1 - - @property - def degree(self) -> int: - return len(self.key) - 1 - - def to_ntt_form(self, ntt_params) -> None: - self.key.to_ntt_form(ntt_params) - - def to_coeffs_form(self, ntt_params) -> None: - self.key.to_coeffs_form(ntt_params) - - -def gen_uniform_polynomial(degree: int, moduli: list[int]) -> RnsPolynomial: - """Generate a uniformly random RNS polynomial in R_Q = Z[X] / (Q, X^N+1).""" - coeffs_q = [] - for q in moduli: - coeffs_q.append([secrets.randbelow(q) for _ in range(degree)]) - return RnsPolynomial(degree, moduli, coeffs_q, is_ntt=False) - - -def gen_gaussian_polynomial( - degree: int, moduli: list[int], sigma: float -) -> RnsPolynomial: - """Generate a random Gaussian polynomial in R_Q = Z[X] / (Q, X^N+1). - - Note: Each coefficient is independently sampled from a rounded Gaussian - distribution with parameter sigma. - - Args: - degree: The degree N of the ring R_Q. - moduli: The list of prime moduli q_i's whose product is Q. - sigma: The standard deviation of the Gaussian distribution. - - Returns: - An RNS polynomial with coefficients sampled from a Gaussian distribution. - """ - prng = secrets.SystemRandom() - coeffs = [round(prng.normalvariate(0, sigma)) for _ in range(degree)] - return gen_rns_polynomial(degree, coeffs, moduli) - - -def gen_secret_key(degree: int, moduli: list[int]) -> 'CkksSecretKey': - """Generate a CKKS secret key.""" - return CkksSecretKey( - moduli, gen_gaussian_polynomial(degree, moduli, sigma=3.2) - ) - - -def gen_ciphertext_from_jnp_array( - degree: int, - moduli: list[int], - components: jax.Array, - is_ntt: bool = True, -) -> CkksCiphertext: - """Generate a CKKS ciphertext from its JAX array representation.""" - return CkksCiphertext( - moduli, - [ - RnsPolynomial( - degree, - moduli, - coeffs, - is_ntt=is_ntt, - ) - for coeffs in components.tolist() - ], - ) - - -def encrypt( - secret_key: CkksSecretKey, - values: list[complex], - encoder: CkksEncoder, - rns_params: RnsParams, -) -> CkksCiphertext: - """Encrypts a vector of complex values and returns a CKKS ciphertext. - - In CKKS, a complex vector of dimension N/2 is encrypted as (c0, c1) ∈ R_Q^2 - where: - - c0 = a * s + e + Encode(values) - - c1 = -a - - Encode(values) performs a scaled FFT on values, rounded to integers, and then - represented as a polynomial in R_Q. For details see CkksEncoder. - - The polynomials (-a, a * s + e) form a RLWE sample wrt secret s. - - Args: - secret_key: The CKKS secret key. - values: A list of complex numbers to be encrypted in the slots. - encoder: The CKKS encoder. - rns_params: The RNS parameters. - - Returns: - A CKKS ciphertext. - """ - a = gen_uniform_polynomial(rns_params.degree, rns_params.moduli) - a.to_ntt_form(rns_params.ntt_params) - secret_key.to_ntt_form(rns_params.ntt_params) - e = gen_gaussian_polynomial(rns_params.degree, rns_params.moduli, sigma=3.2) - e.to_ntt_form(rns_params.ntt_params) - plaintext = encoder.encode(values) - plaintext.to_ntt_form(rns_params.ntt_params) - - c0 = a * secret_key.key + e + plaintext - c1 = -a - return CkksCiphertext(rns_params.moduli, [c0, c1]) - - -def decrypt( - secret_key: CkksSecretKey, - ciphertext: CkksCiphertext, - encoder: CkksEncoder, - rns_params: RnsParams, -) -> list[complex]: - """Decrypts a CKKS ciphertext and returns the decrypted complex vector. - - Note: This version of decryption function does not satisfy the IND-CPA-D - security which is typically required for CKKS. For more details, see - https://eprint.iacr.org/2020/1533 and https://eprint.iacr.org/2022/816. - - Args: - secret_key: The CKKS secret key. - ciphertext: The CKKS ciphertext. - encoder: The CKKS encoder. - rns_params: The RNS parameters. - - Returns: - A list of complex numbers. - """ - if ciphertext.level != secret_key.level: - raise ValueError(f'`ciphertext` and `secret_key` must have the same level.') - if ciphertext.degree < 1: - raise ValueError('`ciphertext` must have degree >= 1.') - - plaintext = ciphertext.components[0] - plaintext.to_ntt_form(rns_params.ntt_params) - secret = secret_key.key - secret.to_ntt_form(rns_params.ntt_params) - for i in range(1, ciphertext.degree + 1): - c = ciphertext.components[i] - c.to_ntt_form(rns_params.ntt_params) - plaintext += c * secret - secret *= secret_key.key - - plaintext.to_coeffs_form(rns_params.ntt_params) - return encoder.decode(plaintext) diff --git a/jaxite/jaxite_ckks/ckks_test.py b/jaxite/jaxite_ckks/ckks_test.py deleted file mode 100644 index 04c9c04..0000000 --- a/jaxite/jaxite_ckks/ckks_test.py +++ /dev/null @@ -1,299 +0,0 @@ -"""Basic tests for CKKS. - -CPU-based tests: -- Encode & decode, with additive and multiplicative homomorphisms; -- Encrypt & decrypt; - -TPU-based tests: -- Homomorphic add and subtraction. -""" - -from concurrent import futures -import random - -import jax -from jaxite.jaxite_ckks import ckks -from jaxite.jaxite_ckks import rns -from jaxite.jaxite_word import add -from jaxite.jaxite_word import sub -import parameterized - -from absl.testing import absltest -from absl.testing import parameterized as parameterized_test - - -ProcessPoolExecutor = futures.ProcessPoolExecutor - -jax.config.update('jax_enable_x64', True) -jax.config.update('jax_traceback_filtering', 'off') - - -@parameterized.parameterized_class([ - # The followings are toy parameters that should only be used for testing. - # In general, we instantiate the CKKS scheme with parameters defining - # the ring R_Q = Z[X] / (Q, X^N+1), a scaling factor used in the canonical - # embedding encoding, and bit precisions that should be achieved in the - # encoding and encryption. - { - 'degree': 8, - 'moduli': [335552513], - 'scaling_factor': 2**9, - 'encoding_precision': 5, - 'encryption_precision': 3, - }, - { - 'degree': 16, - 'moduli': [335552513, 65537], - 'scaling_factor': 2**16, - 'encoding_precision': 10, - 'encryption_precision': 8, - }, - { - 'degree': 1024, - 'moduli': [335552513, 65537], - 'scaling_factor': 2**16, - 'encoding_precision': 7, - 'encryption_precision': 5, - }, -]) -class CkksTest(parameterized_test.TestCase): - - def setUp(self): - super().setUp() - self.rns_params = rns.RnsParams(self.degree, self.moduli) - self.ntt_params = self.rns_params.ntt_params - - def _random_slots(self, degree: int) -> list[complex]: - """Generate a list of random complex numbers having norm <= 1.""" - rand_complex = lambda: complex(random.uniform(-1, 1), random.uniform(-1, 1)) - return [rand_complex() for _ in range(degree >> 1)] - - def _compare_with_precision( - self, a: list[complex], b: list[complex], bit_precision: int = 10 - ): - """Check if two lists of numbers are close componentwise.""" - assert len(a) == len(b) - percision_bound = pow(2, -bit_precision) - for i in range(len(a)): - assert abs(a[i] - b[i]) < percision_bound, ( - f'bad precision: a[{i}] = {a[i]}, b[{i}] = {b[i]}, expected precision' - f' = {bit_precision} bits' - ) - - def test_encoder(self): - """Encode and then decode should not lose too much precision.""" - encoder = ckks.CkksEncoder( - degree=self.degree, - moduli=self.moduli, - scaling_factor=self.scaling_factor, - ) - slots = self._random_slots(self.degree) - plaintext = encoder.encode(slots) - self.assertEqual(plaintext.degree, self.degree) - self.assertEqual(plaintext.moduli, self.moduli) - - decoded = encoder.decode(plaintext) - self._compare_with_precision( - slots, decoded, bit_precision=self.encoding_precision - ) - - def test_encoding_additive_homomorphism(self): - """The encoding scheme should be approximately additive.""" - encoder = ckks.CkksEncoder( - degree=self.degree, - moduli=self.moduli, - scaling_factor=self.scaling_factor, - ) - slots0 = self._random_slots(self.degree) - slots1 = self._random_slots(self.degree) - poly0 = encoder.encode(slots0) - poly1 = encoder.encode(slots1) - poly_sum = poly0 + poly1 - decoded = encoder.decode(poly_sum) - expected = [slots0[i] + slots1[i] for i in range(self.degree >> 1)] - self._compare_with_precision( - expected, decoded, bit_precision=self.encoding_precision - ) - - def test_encoding_multiplicative_homomorphism(self): - """The encoding scheme should be approximately multiplicative (slotwise).""" - encoder = ckks.CkksEncoder( - degree=self.degree, - moduli=self.moduli, - scaling_factor=self.scaling_factor, - ) - slots0 = self._random_slots(self.degree) - slots1 = self._random_slots(self.degree) - poly0 = encoder.encode(slots0) - poly1 = encoder.encode(slots1) - # Make sure the polynomials are in the NTT form before multiplication. - poly0.to_ntt_form(self.ntt_params) - poly1.to_ntt_form(self.ntt_params) - - poly_prod = poly0 * poly1 - # Convert the product to the coefficient form before decoding. - poly_prod.to_coeffs_form(self.ntt_params) - decoded = encoder.decode(poly_prod) - assert len(decoded) == self.degree >> 1 - - # poly_prof encodes the component-wise product under square of the - # scaling factor. So normalize it before decoding. - for i in range(self.degree >> 1): - decoded[i] /= self.scaling_factor - expected = [slots0[i] * slots1[i] for i in range(self.degree >> 1)] - self._compare_with_precision( - expected, decoded, bit_precision=self.encoding_precision - ) - - def test_encrypt_decrypt(self): - encoder = ckks.CkksEncoder( - degree=self.degree, - moduli=self.moduli, - scaling_factor=self.scaling_factor, - ) - secret_key = ckks.gen_secret_key(self.degree, self.moduli) - slots = self._random_slots(self.degree) - ciphertext = ckks.encrypt(secret_key, slots, encoder, self.rns_params) - decoded = ckks.decrypt(secret_key, ciphertext, encoder, self.rns_params) - self._compare_with_precision( - slots, decoded, bit_precision=self.encryption_precision - ) - - @parameterized_test.named_parameters( - { - 'testcase_name': 'jax_add', - 'test_target': add.jax_add, - }, - { - 'testcase_name': 'vmap_add', - 'test_target': add.vmap_add, - }, - ) - def test_homomorphic_add_with(self, test_target): - encoder = ckks.CkksEncoder( - degree=self.degree, - moduli=self.moduli, - scaling_factor=self.scaling_factor, - ) - secret_key = ckks.gen_secret_key(self.degree, self.moduli) - slots0 = self._random_slots(self.degree) - slots1 = self._random_slots(self.degree) - ciphertext0 = ckks.encrypt(secret_key, slots0, encoder, self.rns_params) - ciphertext1 = ckks.encrypt(secret_key, slots1, encoder, self.rns_params) - - modulus_list, ciphertext_data0 = ciphertext0.to_jnp_array() - _, ciphertext_data1 = ciphertext1.to_jnp_array() - jax_results = test_target(ciphertext_data0, ciphertext_data1, modulus_list) - ciphertext_sum = ckks.gen_ciphertext_from_jnp_array( - self.degree, self.moduli, jax_results - ) - decoded = ckks.decrypt(secret_key, ciphertext_sum, encoder, self.rns_params) - expected = [slots0[i] + slots1[i] for i in range(self.degree >> 1)] - self._compare_with_precision( - expected, decoded, bit_precision=self.encryption_precision - ) - - @parameterized_test.named_parameters( - { - 'testcase_name': 'jax_sub', - 'test_target': sub.jax_sub, - }, - { - 'testcase_name': 'vmap_sub', - 'test_target': sub.vmap_sub, - }, - ) - def test_homomorphic_sub_with(self, test_target): - encoder = ckks.CkksEncoder( - degree=self.degree, - moduli=self.moduli, - scaling_factor=self.scaling_factor, - ) - secret_key = ckks.gen_secret_key(self.degree, self.moduli) - slots0 = self._random_slots(self.degree) - slots1 = self._random_slots(self.degree) - ciphertext0 = ckks.encrypt(secret_key, slots0, encoder, self.rns_params) - ciphertext1 = ckks.encrypt(secret_key, slots1, encoder, self.rns_params) - - modulus_list, ciphertext_data0 = ciphertext0.to_jnp_array() - _, ciphertext_data1 = ciphertext1.to_jnp_array() - jax_results = test_target(ciphertext_data0, ciphertext_data1, modulus_list) - ciphertext_diff = ckks.gen_ciphertext_from_jnp_array( - self.degree, self.moduli, jax_results - ) - decoded = ckks.decrypt( - secret_key, ciphertext_diff, encoder, self.rns_params - ) - expected = [slots0[i] - slots1[i] for i in range(self.degree >> 1)] - self._compare_with_precision( - expected, decoded, bit_precision=self.encryption_precision - ) - - -class CkksNegativeTest(parameterized_test.TestCase): - """Testing negative cases for CKKS implementation.""" - - def setUp(self): - super().setUp() - self.degree = 8 - self.moduli = [12289] - self.scaling_factor = 2**5 - - @parameterized_test.named_parameters( - { - 'testcase_name': 'zero_degree', - 'invalid_degree': 0, - }, - { - 'testcase_name': 'odd_degree', - 'invalid_degree': 7, - }, - ) - def test_create_encoder_with_invalid_degree(self, invalid_degree): - with self.assertRaises(ValueError): - ckks.CkksEncoder( - degree=invalid_degree, - moduli=self.moduli, - scaling_factor=self.scaling_factor, - ) - - def test_encode_with_too_many_slots(self): - encoder = ckks.CkksEncoder( - degree=self.degree, - moduli=self.moduli, - scaling_factor=self.scaling_factor, - ) - slots = [complex(0, 0)] * (self.degree // 2 + 1) - with self.assertRaises(ValueError): - encoder.encode(slots) - - def test_decode_with_invalid_degree(self): - encoder = ckks.CkksEncoder( - degree=self.degree, - moduli=self.moduli, - scaling_factor=self.scaling_factor, - ) - # Create a polynomial with an invalid degree. - invalid_degree = self.degree - 1 - coeffs = [[0] * invalid_degree for _ in self.moduli] - plaintext = rns.RnsPolynomial( - invalid_degree, self.moduli, coeffs, is_ntt=False - ) - with self.assertRaises(ValueError): - encoder.decode(plaintext) - - def test_decode_with_ntt_polynomial(self): - encoder = ckks.CkksEncoder( - degree=self.degree, - moduli=self.moduli, - scaling_factor=self.scaling_factor, - ) - # Create a polynomial in the NTT form. - coeffs = [[0] * self.degree for _ in self.moduli] - plaintext = rns.RnsPolynomial(self.degree, self.moduli, coeffs, is_ntt=True) - with self.assertRaises(ValueError): - encoder.decode(plaintext) - - -if __name__ == '__main__': - absltest.main() diff --git a/jaxite/jaxite_ckks/rns_utils.py b/jaxite/jaxite_ckks/rns_utils.py deleted file mode 100644 index 2aa2ab5..0000000 --- a/jaxite/jaxite_ckks/rns_utils.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Utility functions for the RNS based RLWE schemes.""" - -from typing import Any - - -def inverse_mod(x: int, q: int) -> int: - """Returns the inverse of x mod q.""" - return int(pow(x, -1, q)) - - -def is_power_of_two(x: int) -> bool: - """Returns True if x is a power of two.""" - return x > 0 and (x & (x - 1)) == 0 - - -def num_bits(x: int) -> int: - """Returns the number of bits in x.""" - return x.bit_length() - 1 - - -def bit_reversal(x: int, num_bits: int) -> int: - """Returns the bit-reversal of x with num_bits representation.""" - result = 0 - for _ in range(num_bits): - result <<= 1 - result |= x & 1 - x >>= 1 - return result - - -def bit_reversal_array(xs: list[Any]) -> None: - """Rearrange the given array in bit-reversal order in place.""" - n = num_bits(len(xs)) - for i in range(len(xs)): - j = bit_reversal(i, n) - if i < j: - xs[i], xs[j] = xs[j], xs[i] diff --git a/jaxite/jaxite_word/add.py b/jaxite/jaxite_word/add.py index a07bcd0..c967f65 100644 --- a/jaxite/jaxite_word/add.py +++ b/jaxite/jaxite_word/add.py @@ -30,13 +30,7 @@ def jax_add(value_a: jax.Array, value_b: jax.Array, modulus_list: jax.Array): def vmap_add(value_a: jax.Array, value_b: jax.Array, modulus_list: jax.Array): """This function processes all degree of the two input polynomials in SIMD using jax.vmap. - Assumes the input data type is a 3-dimensional jax Array as (num_elements, num_towers, degree) - where: - * num_elements: number of polynomials - * num_towers: number of RNS limbs - * degree: degree of the polynomials - - This vmap_add can later be extended to batch ciphertexts. + Assuming the input data type is jax array. Args: value_a: the first operand of the addition. diff --git a/jaxite/jaxite_word/add_test.py b/jaxite/jaxite_word/add_test.py deleted file mode 100644 index e49b831..0000000 --- a/jaxite/jaxite_word/add_test.py +++ /dev/null @@ -1,108 +0,0 @@ -"""A module for operations on test CKKS evaluation kernels including. - -- ModAdd -- HEAdd -- HESub -- HEMul -- HERotate -""" - -from concurrent import futures -from typing import Any, Callable - -import jax -import jax.numpy as jnp -from jaxite.jaxite_word import add - -from absl.testing import absltest -from absl.testing import parameterized - - -ProcessPoolExecutor = futures.ProcessPoolExecutor - -jax.config.update("jax_enable_x64", True) -jax.config.update("jax_traceback_filtering", "off") - - -class CKKSEvalKernelsTest(parameterized.TestCase): - """A base class for running bootstrap tests.""" - - def __init__(self, *args, **kwargs): - super(CKKSEvalKernelsTest, self).__init__(*args, **kwargs) - self.debug = False # dsiable it from printing the test input values - self.modulus_element_0_tower_0 = 1152921504606748673 - self.modulus_element_0_tower_1 = 268664833 - self.modulus_element_0_tower_2 = 557057 - self.random_key = jax.random.key(0) - - def random(self, shape, modulus_list, dtype=jnp.int32): - assert len(modulus_list) == shape[1] - - return jnp.concatenate( - [ - jax.random.randint( - self.random_key, - shape=(shape[0], 1, shape[2]), - minval=0, - maxval=bound, - dtype=dtype, - ) - for bound in modulus_list - ], - axis=1, - ) - - @parameterized.named_parameters( - dict( - testcase_name="jax_add", - test_target=add.jax_add, - modulus_list=[1152921504606748673, 268664833, 557057], - shape=(2, 3, 16384), # number of elements, number of towers, degree - ), - dict( - testcase_name="vmap_add", - test_target=add.vmap_add, - modulus_list=[1152921504606748673, 268664833, 557057], - shape=(2, 3, 16384), # number of elements, number of towers, degree - ), - ) - def test_add( - self, - test_target: Callable[[Any, Any, Any], Any], - modulus_list=jax.Array, - shape=tuple[int, int, int], - ): - """This function tests the add function using Python native integer data type with arbitrary precision. - - This test finishes in 1.05 second. - - Args: - test_target: The function to test. - modulus_list: A jax.Array of integers. - shape: A tuple of integers representing the shape of the input arrays. - """ - # Only test a single element to save comparison time, - # Correctness-wise, it's sufficient for add. - value_a = self.random(shape, modulus_list, dtype=jnp.uint64) - value_b = self.random(shape, modulus_list, dtype=jnp.uint64) - assert value_a.shape == shape - assert value_b.shape == shape - result_a_plus_b = [] - for element_id in range(value_a.shape[0]): - result_a_plus_b_one_element = [] - for tower_id in range(value_a.shape[1]): - add_res = int(value_b[element_id, tower_id, 0]) + int( - value_a[element_id, tower_id, 0] - ) - if add_res > modulus_list[tower_id]: - add_res = add_res - modulus_list[tower_id] - result_a_plus_b_one_element.append(add_res) - result_a_plus_b.append(result_a_plus_b_one_element) - result_a_plus_b = jnp.array(result_a_plus_b, dtype=jnp.uint64) - modulus_list = jnp.array(modulus_list, dtype=jnp.uint64) - result = test_target(value_a, value_b, modulus_list) - self.assertEqual(result[:, :, 0].all(), result_a_plus_b.all()) - - -if __name__ == "__main__": - absltest.main() diff --git a/jaxite/jaxite_word/bconv.py b/jaxite/jaxite_word/bconv.py new file mode 100644 index 0000000..d374de2 --- /dev/null +++ b/jaxite/jaxite_word/bconv.py @@ -0,0 +1,469 @@ +"""BConv: Basis Conversion class for JAX-based homomorphic encryption. + +This module provides the BConv class and subclasses which handle basis extension +and modulus switching +using efficient modular reduction. It is designed to work with +vectorized operations on JAX arrays. +""" + +import jax +import jax.numpy as jnp +from jaxite.jaxite_word import finite_field as ff_context +import jaxite.jaxite_word.util as util + +# Maintain 64-bit precision for large integer arithmetic +jax.config.update("jax_enable_x64", True) + + +def _is_nvidia(): + return "NVIDIA" in jax.devices()[0].device_kind + + +def matmul_bat_einsum(lhs: jax.Array, rhs: jax.Array, subscripts: str): + """Basis Aligned Transformation (BAT) based matrix multiplication + + Args: + lhs (jax.Array): input + rhs (jax.Array): twiddle factor matrix + subscripts (str): einsum subscripts + + Returns: + jax.Array: result + """ + # preprocess + lhs = lhs.view(jnp.uint8) + + shift_factors = jnp.array([0, 8, 16, 24], dtype=jnp.uint32) + + # computation + i8_products = jnp.einsum( + subscripts, lhs, rhs, preferred_element_type=jnp.uint32 + ) + return jnp.sum(i8_products.astype(jnp.uint64) << shift_factors, axis=(-1,)) + + +class BConv: + + def __init__(self, overall_moduli): + """Initialize the BConv object. + + Args: + overall_moduli: A list or tuple of integers representing the all + available moduli. + """ + self.overall_moduli = overall_moduli + # Lists to store configurations for each control index + self.original_moduli = [] + self.target_moduli = [] + self.ff_ctx_origin = [] + self.ff_ctx_target = [] + + # Lists to store precomputed constants for each control index + self.QHatInvModq = [] + self.QHatModp = [] + self.QHatModpBAT = [] + + def _create_contexts(self, original_moduli, target_moduli): + """Initialize the finite field contexts. Must be implemented by subclasses. + + Returns: (ff_ctx_origin, ff_ctx_target) + """ + raise NotImplementedError + + def _generate_constants_single( + self, original_moduli, target_moduli, ff_ctx_origin, ff_ctx_target + ): + """Generates constants for a single configuration.""" + # compute_QHatInvModq_QHatModp returns lists, we convert them to JAX arrays + # with appropriate shapes for broadcasting. + QHatInvModq_list, QHatModp_list = util.compute_QHatInvModq_QHatModp( + original_moduli, target_moduli + ) + + # QHatInvModq: Inverse of (Q/q_i) mod q_i + # Shape: (sizeQ,) -> JAX array + QHatInvModq = jnp.array(QHatInvModq_list, dtype=jnp.uint64) + QHatInvModq = ff_ctx_origin.to_computation_format(QHatInvModq) + + # QHatModp: (Q/q_i) mod p_j + # Shape: (sizeQ, sizeP) -> JAX array + QHatModp = jnp.array(QHatModp_list, dtype=jnp.uint64) + QHatModp = ff_ctx_target.to_computation_format(QHatModp) + + # BAT Preprocessing + # QHatModpBAT + # Input QHatModp: (sizeQ, sizeP) + # _basis_aligned_transformation -> (4, sizeQ, sizeP, 4) (dims: a, q, p, b) + # We want to match input (..., d, q, a) -> output (..., d, p, b) + # Transpose to (q, a, p, b) -> (q*a, p, b) for einsum "...dq, qpb -> ...dpb" + QHatModpBAT_raw = self._basis_aligned_transformation( + QHatModp, target_moduli + ) + QHatModpBAT = QHatModpBAT_raw.transpose(1, 0, 2, 3).reshape( + -1, QHatModpBAT_raw.shape[2], 4 + ) + + return QHatInvModq, QHatModp, QHatModpBAT + + def control_gen(self, control_indices_list, perf_test=False): + """Generates and stores precomputed constants QHatInvModq and QHatModp necessary for + + the basis change operation. + + Args: + control_indices_list: A sequence of (original_index, target_index) + tuples/lists. + original_index: Indices of original_moduli in + overall_moduli. + target_index: Indices of target_moduli in + overall_moduli. + """ + # Clear existing lists + self.original_moduli = [] + self.target_moduli = [] + self.ff_ctx_origin = [] + self.ff_ctx_target = [] + self.QHatInvModq = [] + self.QHatModp = [] + self.QHatModpBAT = [] + + for original_index, target_index in control_indices_list: + omParams = [self.overall_moduli[i] for i in original_index] + tmParams = [self.overall_moduli[i] for i in target_index] + + self.original_moduli.append(omParams) + self.target_moduli.append(tmParams) + + ctx_origin, ctx_target = self._create_contexts(omParams, tmParams) + self.ff_ctx_origin.append(ctx_origin) + self.ff_ctx_target.append(ctx_target) + + if perf_test: + sizeQ = len(omParams) + sizeP = len(tmParams) + + # Mocking QHatInvModq: (sizeQ,) + QHatInvModq = util.random_parameters( + (sizeQ,), omParams, dtype=jnp.uint64 + ) + + # Mocking QHatModp: (sizeQ, sizeP) + QHatModp = util.random_parameters( + (sizeQ, sizeP), tmParams, dtype=jnp.uint64 + ) + + # Mocking QHatModpBAT: (sizeQ * 4, sizeP, 4) + # It's uint8, so just random bytes + QHatModpBAT = jnp.zeros((sizeQ * 4, sizeP, 4), dtype=jnp.uint8) + + self.QHatInvModq.append(QHatInvModq) + self.QHatModp.append(QHatModp) + self.QHatModpBAT.append(QHatModpBAT) + else: + QHatInvModq, QHatModp, QHatModpBAT = self._generate_constants_single( + omParams, tmParams, ctx_origin, ctx_target + ) + + self.QHatInvModq.append(QHatInvModq) + self.QHatModp.append(QHatModp) + self.QHatModpBAT.append(QHatModpBAT) + + def _basis_aligned_transformation(self, matrix: jnp.ndarray, moduli): + """Prepares a matrix for Basis Aligned Transformation (BAT). + + Adapted from ntt_mm.py. Assumes matrix last dimension corresponds to + 'moduli'. + """ + matrix_u64 = matrix.astype(jnp.uint64) + num_bytes = 4 + matrix_u64_byteshifted = jnp.array( + [matrix_u64 << (8 * byte_idx) for byte_idx in range(num_bytes)], + dtype=jnp.uint64, + ) + moduli_arr = jnp.array(moduli, dtype=jnp.uint64) + matrix_u64_byteshifted_mod_modulus = ( + matrix_u64_byteshifted % moduli_arr + ).astype(jnp.uint32) + # Output shape: (4, ..., moduli, 4) + matrix_u8 = jax.lax.bitcast_convert_type( + matrix_u64_byteshifted_mod_modulus, jnp.uint8 + ) + return matrix_u8 + + # @functools.partial(jax.jit, static_argnames=("self",)) + def basis_change( + self, in_tower: jnp.ndarray, control_index: int = 0 + ) -> jnp.ndarray: + """Performs the approximate basis change from original_moduli to target_moduli. + + Input: + in_tower: Coefficients in original basis. + Shape: (..., ring_dim, sizeQ) + control_index: Index of the control set to use. + + Output: + out_tower: Coefficients in new basis. + Shape: (..., ring_dim, sizeP) + """ + # Ensure inputs are correctly typed + in_tower = jnp.asarray(in_tower, dtype=jnp.uint64) + + # Retrieve constants and contexts for this control index + QHatInvModq = self.QHatInvModq[control_index] + QHatModp = self.QHatModp[control_index] + ff_ctx_origin = self.ff_ctx_origin[control_index] + ff_ctx_target = self.ff_ctx_target[control_index] + + # Step 1: Compute c_unreduced = in_tower * QHatInvModq + c_unreduced = in_tower * QHatInvModq + + # Step 2: Modular Reduction on c_unreduced using original moduli context + c = ff_ctx_origin.modular_reduction(c_unreduced) + + # Base Term: c * QHatModp + # Shape: (..., d, p) + if _is_nvidia(): + summed_terms = jnp.einsum( + "...dq,qp->...dp", + c.astype(jnp.uint32), + QHatModp.astype(jnp.uint32), + preferred_element_type=jnp.uint64, + ) + else: + products = ( + c[..., None].astype(jnp.uint64) * QHatModp[None, ...] + ) # Need to convert it into BAT based implementation + summed_terms = jnp.sum(products, axis=-2) + + # Step 4: Final Modular Reduction using target moduli context + out_tower = ff_ctx_target.modular_reduction(summed_terms) + + return out_tower + + # @functools.partial(jax.jit, static_argnames=("self",)) + def basis_change_bat( + self, in_tower: jnp.ndarray, control_index: int = 0 + ) -> jnp.ndarray: + """Performs the approximate basis change using BAT optimization. + + Currently does not support modulus switching. + + Input: + in_tower: Coefficients in original basis. + Shape: (..., ring_dim, sizeQ) + control_index: Index of the control set to use. + + Output: + out_tower: Coefficients in new basis. + Shape: (..., ring_dim, sizeP) + """ + + # Ensure inputs are u64 for BAT + # Note: We assume inputs fit in u64 (< 2^64) + in_tower_u64 = jnp.asarray(in_tower, dtype=jnp.uint64) + + # Retrieve constants and contexts + QHatInvModq = self.QHatInvModq[control_index] + QHatModpBAT = self.QHatModpBAT[control_index] + ff_ctx_origin = self.ff_ctx_origin[control_index] + ff_ctx_target = self.ff_ctx_target[control_index] + + # Step 1: Compute c_unreduced = in_tower * QHatInvModqBAT + c_unreduced = in_tower_u64 * QHatInvModq + + # Step 2: Modular Reduction + c = ff_ctx_origin.modular_reduction(c_unreduced).astype(jnp.uint32) + + # QHatModpBAT: (q*a, p, b) + summed_terms = matmul_bat_einsum(c, QHatModpBAT, "...q,qpb->...pb") + + # Step 4: Final Modular Reduction + out_tower = ff_ctx_target.modular_reduction(summed_terms) + + return out_tower + + +class BConvBarrett(BConv): + + def _create_contexts(self, original_moduli, target_moduli): + return ( + ff_context.BarrettContext(moduli=original_moduli), + ff_context.BarrettContext(moduli=target_moduli), + ) + + +class BConvMontgomery(BConv): + + def _create_contexts(self, original_moduli, target_moduli): + return ( + ff_context.MontgomeryContext(moduli=original_moduli), + ff_context.MontgomeryContext(moduli=target_moduli), + ) + + def basis_change( + self, in_tower: jnp.ndarray, control_index: int = 0 + ) -> jnp.ndarray: + in_tower = jnp.asarray(in_tower, dtype=jnp.uint64) + + QHatInvModq = self.QHatInvModq[control_index] + QHatModp = self.QHatModp[control_index] + ff_ctx_origin = self.ff_ctx_origin[control_index] + ff_ctx_target = self.ff_ctx_target[control_index] + + c_unreduced = in_tower * QHatInvModq + c = ff_ctx_origin.modular_reduction(c_unreduced) + + # Domain conversion for Montgomery + c_reduced_back = ff_ctx_origin.to_original_format(c) + c = ff_ctx_target.to_computation_format(c_reduced_back) + + if _is_nvidia(): + summed_terms = jnp.einsum( + "...dq,qp->...dp", + c.astype(jnp.uint32), + QHatModp.astype(jnp.uint32), + preferred_element_type=jnp.uint64, + ) + else: + products = ( + c[..., None].astype(jnp.uint64) * QHatModp[None, ...] + ) # Need to convert it into BAT based implementation + summed_terms = jnp.sum(products, axis=-2) + + out_tower = ff_ctx_target.modular_reduction(summed_terms) + return out_tower + + +class BConvShoup(BConv): + + def __init__(self, overall_moduli): + super().__init__(overall_moduli) + self.QHatInvModq_shoup = [] + self.QHatModp_shoup = [] + + def _create_contexts(self, original_moduli, target_moduli): + return ( + ff_context.ShoupContext(moduli=original_moduli), + ff_context.ShoupContext(moduli=target_moduli), + ) + + def control_gen(self, control_indices_list, perf_test=False): + super().control_gen(control_indices_list) + # Clear shoup lists + self.QHatInvModq_shoup = [] + self.QHatModp_shoup = [] + + for i in range(len(control_indices_list)): + QHatInvModq = self.QHatInvModq[i] + QHatModp = self.QHatModp[i] + ff_ctx_origin = self.ff_ctx_origin[i] + ff_ctx_target = self.ff_ctx_target[i] + + QHatInvModq_shoup = ff_ctx_origin.precompute_constant_operand(QHatInvModq) + QHatModp_shoup = ff_ctx_target.precompute_constant_operand(QHatModp) + + self.QHatInvModq_shoup.append(QHatInvModq_shoup) + self.QHatModp_shoup.append(QHatModp_shoup) + + def basis_change( + self, in_tower: jnp.ndarray, control_index: int = 0 + ) -> jnp.ndarray: + in_tower = jnp.asarray(in_tower, dtype=jnp.uint64) + + QHatInvModq = self.QHatInvModq[control_index] + QHatInvModq_shoup = self.QHatInvModq_shoup[control_index] + QHatModp = self.QHatModp[control_index] + QHatModp_shoup = self.QHatModp_shoup[control_index] + ff_ctx_origin = self.ff_ctx_origin[control_index] + ff_ctx_target = self.ff_ctx_target[control_index] + + c_unreduced = in_tower * QHatInvModq + + # Use dual-operand reduction + c_unreduced_shoup = in_tower * QHatInvModq_shoup + c = ff_ctx_origin.modular_reduction(c_unreduced, c_unreduced_shoup) + + if _is_nvidia(): + summed_terms = jnp.einsum( + "...dq,qp->...dp", + c.astype(jnp.uint32), + QHatModp.astype(jnp.uint32), + preferred_element_type=jnp.uint64, + ) + else: + products = ( + c[..., None].astype(jnp.uint64) * QHatModp[None, ...] + ) # Need to convert it into BAT based implementation + summed_terms = jnp.sum(products, axis=-2) + + summed_terms_shoup = jnp.einsum( + "...dq,qp->...dp", + c.astype(jnp.uint32), + QHatModp_shoup.astype(jnp.uint32), + preferred_element_type=jnp.uint64, + ) + out_tower = ff_ctx_target.modular_reduction( + summed_terms, summed_terms_shoup + ) + return out_tower + + +class BConvBATLazy(BConv): + + def _create_contexts(self, original_moduli, target_moduli): + return ( + ff_context.BATLazyContext(moduli=original_moduli), + ff_context.BATLazyContext(moduli=target_moduli), + ) + + def basis_change( + self, in_tower: jnp.ndarray, control_index: int = 0 + ) -> jnp.ndarray: + in_tower = jnp.asarray(in_tower, dtype=jnp.uint64) + + QHatInvModq = self.QHatInvModq[control_index] + QHatModp = self.QHatModp[control_index] + ff_ctx_origin = self.ff_ctx_origin[control_index] + ff_ctx_target = self.ff_ctx_target[control_index] + + c_unreduced = in_tower * QHatInvModq + c = ff_ctx_origin.modular_reduction(c_unreduced) + + # Force strict reduction for BATLazy correctness + c = ff_ctx_origin.to_original_format(c) + + if _is_nvidia(): + summed_terms = jnp.einsum( + "...dq,qp->...dp", + c.astype(jnp.uint32), + QHatModp.astype(jnp.uint32), + preferred_element_type=jnp.uint64, + ) + else: + products = ( + c[..., None].astype(jnp.uint64) * QHatModp[None, ...] + ) # Need to convert it into BAT based implementation + summed_terms = jnp.sum(products, axis=-2) + + out_tower = ff_ctx_target.modular_reduction(summed_terms) + return out_tower + + def basis_change_bat( + self, in_tower: jnp.ndarray, control_index: int = 0 + ) -> jnp.ndarray: + in_tower = jnp.asarray(in_tower, dtype=jnp.uint64) + + QHatInvModq = self.QHatInvModq[control_index] + QHatModpBAT = self.QHatModpBAT[control_index] + ff_ctx_origin = self.ff_ctx_origin[control_index] + ff_ctx_target = self.ff_ctx_target[control_index] + + c_unreduced = in_tower * QHatInvModq + c = ff_ctx_origin.modular_reduction(c_unreduced).astype(jnp.uint32) + + # Enforce strict reduction + c = ff_ctx_origin.to_original_format(c) + + summed_terms = matmul_bat_einsum(c, QHatModpBAT, "...q,qpb->...pb") + out_tower = ff_ctx_target.modular_reduction(summed_terms) + return out_tower diff --git a/jaxite/jaxite_word/bconv_test.py b/jaxite/jaxite_word/bconv_test.py new file mode 100644 index 0000000..5cfbd74 --- /dev/null +++ b/jaxite/jaxite_word/bconv_test.py @@ -0,0 +1,341 @@ +from absl.testing import absltest +from absl.testing import parameterized +import jaxite.jaxite_word.bconv as bconv +import jax +import jax.numpy as jnp +import numpy as np +from jaxite.jaxite_word.util import random_batched_ciphertext + +# Use 64-bit precision as in bconv.py +jax.config.update("jax_enable_x64", True) + +TEST_PARAMS = [ + ( + "L2_to_L5", + [ + [ + 180089039, + 904401266, + 277587483, + 381410246, + 867235356, + 971323117, + 934942938, + 338146069, + 129667711, + 97559399, + 337422188, + 364870460, + 916966745, + 312366062, + 762079964, + 605485434, + ], + [ + 540094309, + 1034680811, + 1057648335, + 677992674, + 650354195, + 558219774, + 502221165, + 503532224, + 1049911792, + 146837876, + 560962740, + 820076664, + 58915608, + 1034452760, + 724437159, + 68291682, + ], + ], + [1073741441, 1073740609], + [268437409, 268436801, 268435361, 268435649, 524353], + [249077041, 824663761], + [ + [268428382, 268430206, 268434526, 268433662, 390018], + [268429214, 268431038, 268435358, 268434494, 390850], + ], + [ + [ + 127196115, + 177098281, + 103398386, + 262465714, + 225857559, + 213539642, + 56845406, + 173328911, + 21637023, + 13036123, + 259867486, + 247888119, + 190104469, + 18415021, + 107052173, + 152967426, + ], + [ + 168457304, + 81199027, + 93565169, + 186078678, + 45587255, + 266135885, + 57716353, + 256503901, + 42940759, + 230451532, + 167299604, + 7499360, + 30178241, + 217184571, + 253380763, + 263628678, + ], + [ + 36832759, + 236716332, + 248966405, + 220314486, + 69166987, + 6190101, + 204761211, + 194300152, + 254116210, + 106417675, + 161103783, + 244499620, + 193481707, + 80047389, + 247286681, + 101528753, + ], + [ + 57170724, + 202300871, + 265775502, + 100961596, + 165644129, + 105852026, + 62793519, + 256209638, + 261792224, + 19441022, + 102172131, + 193804848, + 207565270, + 260633034, + 136301593, + 126040258, + ], + [ + 7075121, + 3849323, + 6595938, + 5951409, + 7277876, + 7164094, + 6139709, + 4952812, + 4506557, + 5030408, + 7544214, + 3494637, + 8199168, + 9397516, + 5558747, + 9127873, + ], + ], # , [258532, 178852, 303702, 183526, 461287, 347505, 371826, 233635, 311733, 311231, 203272, 348519, 333873, 483515, 315217, 213872]], + ), +] + + +class BConvContextTest(parameterized.TestCase): + # @absltest.skip("Skip a single test") + @parameterized.named_parameters(*TEST_PARAMS) + def test_barrett_context( + self, + partCtCloneCoef, + original_moduli, + target_moduli, + QHatInvModq, + QHatModp, + reference_result, + ): + """Verifies that basis_change works with BarrettContext""" + key = jax.random.PRNGKey(0) + in_tower = jax.numpy.array(partCtCloneCoef, dtype=jnp.uint64).T + reference_result = jax.numpy.array(reference_result, dtype=jnp.uint64).T + + # New API setup + overall_moduli = original_moduli + target_moduli + original_index = list(range(len(original_moduli))) + target_index = list(range(len(original_moduli), len(overall_moduli))) + + _bconv = bconv.BConvBarrett(overall_moduli) + _bconv.control_gen([(original_index, target_index)]) + + in_formatted = _bconv.ff_ctx_origin[0].to_computation_format(in_tower) + out_formatted = _bconv.basis_change(in_formatted) + out = _bconv.ff_ctx_target[0].to_original_format(out_formatted) + + np.testing.assert_array_equal(reference_result, out) + + # @absltest.skip("Skip a single test") + @parameterized.named_parameters(*TEST_PARAMS) + def test_bat_lazy_context( + self, + partCtCloneCoef, + original_moduli, + target_moduli, + QHatInvModq, + QHatModp, + reference_result, + ): + """Verifies that basis_change works with BATLazyContext""" + key = jax.random.PRNGKey(0) + in_tower = jax.numpy.array(partCtCloneCoef, dtype=jnp.uint64).T + reference_result = jax.numpy.array(reference_result, dtype=jnp.uint64).T + + overall_moduli = original_moduli + target_moduli + original_index = list(range(len(original_moduli))) + target_index = list(range(len(original_moduli), len(overall_moduli))) + + _bconv = bconv.BConvBATLazy(overall_moduli) + _bconv.control_gen([(original_index, target_index)]) + + in_formatted = _bconv.ff_ctx_origin[0].to_computation_format(in_tower) + out_formatted = _bconv.basis_change(in_formatted) + out = _bconv.ff_ctx_target[0].to_original_format(out_formatted) + + # BATLazy produces result congruent mod p, but not necessarily fully reduced. + # Hence we check the post modular reduction results. + target_moduli_arr = jnp.array(target_moduli, dtype=jnp.uint64) + diff = ( + out.astype(jnp.int64) - reference_result.astype(jnp.int64) + ) % target_moduli_arr.astype(jnp.int64) + np.testing.assert_array_equal(diff, jnp.zeros_like(diff)) + + def test_multiple_control_gen(self): + """Verifies that BConv supports multiple control generations.""" + # Define a simple setup + # overall_moduli = [q0, q1, p0, p1] + # Config 0: [q0] -> [p0] + # Config 1: [q1] -> [p1] + + # Using small primes for easy verification + q0, q1 = 17, 19 + p0, p1 = 23, 29 + overall_moduli = [q0, q1, p0, p1] + + # Config 0 + original_index_0 = [0] + target_index_0 = [2] + + # Config 1 + original_index_1 = [1] + target_index_1 = [3] + + _bconv = bconv.BConvBarrett(overall_moduli) + _bconv.control_gen( + [(original_index_0, target_index_0), (original_index_1, target_index_1)] + ) + + # Test Config 0 + val = 15 + in_tower_0 = jnp.array( + [[val]], dtype=jnp.uint64 + ) # shape (1, 1) to match (d, q) ? sizeQ=1 + in_tower = jnp.array([[[15]]], dtype=jnp.uint64) # (1, 1, 1) + out_0 = _bconv.basis_change(in_tower, control_index=0) + self.assertEqual(out_0[0, 0, 0], 15) + + in_tower_1 = jnp.array([[[20]]], dtype=jnp.uint64) + out_1 = _bconv.basis_change(in_tower_1, control_index=1) + self.assertEqual(out_1[0, 0, 0], 1) + + # Quick check for non-interference + in_tower_0_b = jnp.array([[[20]]], dtype=jnp.uint64) + out_0_b = _bconv.basis_change( + in_tower_0_b, control_index=0 + ) # 20 mod 17 -> 3 -> 3 + self.assertEqual(out_0_b[0, 0, 0], 3) + + +class BConvBATTest(absltest.TestCase): + + def setUp(self): + super().setUp() + # Define some example moduli. Both fit in 32 bits (required for BAT assumption). + # These are from basis_change_test.py (approx 2^27) + self.original_moduli = [ + 134219681, + 134218433, + 134219009, + 1073741857, + 1073740609, + ] + self.target_moduli = [268435361, 268435009, 6710893, 1067031829] + + self.overall_moduli = self.original_moduli + self.target_moduli + self.original_index = list(range(len(self.original_moduli))) + self.target_index = list( + range(len(self.original_moduli), len(self.overall_moduli)) + ) + + self.bconv = bconv.BConvBarrett(self.overall_moduli) + self.bconv.control_gen([(self.original_index, self.target_index)]) + + # @absltest.skip("Skip a single test") + def test_basis_change_bat_vs_standard(self): + """Verifies that basis_change_bat produces the same result as basis_change""" + key = jax.random.PRNGKey(0) + + # Dimensions + batch = 1 + elements = 2 + d = 128 # small ring dim + sizeQ = len(self.original_moduli) + in_tower = random_batched_ciphertext( + (batch, elements, d, sizeQ), self.original_moduli, jnp.uint32 + ) + + # Expected result (Standard) + expected = self.bconv.basis_change(in_tower) + + # Actual result (BAT) + actual = self.bconv.basis_change_bat(in_tower) + target_moduli_arr = jnp.array(self.target_moduli, dtype=jnp.uint64) + diff = ( + actual.astype(jnp.int64) - expected.astype(jnp.int64) + ) % target_moduli_arr.astype(jnp.int64) + np.testing.assert_array_equal(diff, jnp.zeros_like(diff)) + + # @absltest.skip("Skip a single test") + def test_basis_change_bat_random_big(self): + """Test with larger shapes to ensure robustness.""" + key = jax.random.PRNGKey(1) + batch = 4 + elements = 4 + d = 8 + sizeQ = len(self.original_moduli) + shape = (batch, elements, d, sizeQ) + + in_tower = random_batched_ciphertext( + shape, self.original_moduli, jnp.uint32 + ) + + expected = self.bconv.basis_change(in_tower) + actual = self.bconv.basis_change_bat(in_tower) + target_moduli_arr = jnp.array(self.target_moduli, dtype=jnp.uint64) + diff = ( + actual.astype(jnp.int64) - expected.astype(jnp.int64) + ) % target_moduli_arr.astype(jnp.int64) + np.testing.assert_array_equal(diff, jnp.zeros_like(diff)) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxite/jaxite_word/ciphertext.py b/jaxite/jaxite_word/ciphertext.py new file mode 100644 index 0000000..0e433ea --- /dev/null +++ b/jaxite/jaxite_word/ciphertext.py @@ -0,0 +1,679 @@ +import math +from typing import List, Optional, Tuple, Union +from jaxite.jaxite_word.bconv import BConvBarrett +import jaxite.jaxite_word.finite_field as ff_context +import jax.numpy as jnp +import jaxite.jaxite_word.ntt_mm as ntt +import jaxite.jaxite_word.util as util + + +######################## +# Helper Functions +######################## +def gen_power_of_inv_psi_arr(moduli, ring_dim): + q_list = [moduli] if not isinstance(moduli, list) else moduli + psi_list = [util.root_of_unity(2 * ring_dim, q) for q in q_list] + inv_psi = [pow(psi, -1, q) for (q, psi) in zip(q_list, psi_list)] + power_of_inv_psi_arr = [ + [pow(inv_psi[idx], i, q_list[idx]) for i in range(ring_dim)] + for idx in range(len(psi_list)) + ] + return jnp.array(power_of_inv_psi_arr, jnp.uint64).T.reshape( + 1, 1, ring_dim, -1 + ) + + +class Ciphertext: + + def __init__(self, shapes: dict, parameters: Optional[dict] = None): + """Initialize the Ciphertext object. + + Each ciphertext is a 4D tensor with shape (batch, num_elements, num_moduli, + degree). + + Args: + shapes (dict): A dictionary containing the shapes of the ciphertext. - + batch: The batch size of the ciphertext. - num_elements: The number of + elements in the ciphertext. - degree: The degree of the ciphertext. + - + num_moduli: The number of moduli in the ciphertext. - precision: The + precision of the ciphertext. + parameters (Optional[dict], optional): A dictionary containing the + parameters of the ciphertext. - moduli: The moduli of the ciphertext. + - If the moduli is a single integer, the ciphertext will be a single + modulus. - If the moduli is a list of integers, the ciphertext will be + a multi-modulus. + finite_field_context (Optional[object], optional): The finite field + context to use. - If not provided, a default BarrettContext will be + created. + r (Optional[int], optional): The r parameter for the NTT context. + c (Optional[int], optional): The c parameter for the NTT context. + """ + self.batch = shapes['batch'] + self.num_elements = shapes['num_elements'] + self.num_moduli = shapes['num_moduli'] + self.degree = shapes['degree'] + log_degree = int(math.log2(self.degree)) + self.precision = shapes['precision'] + if 'degree_layout' in shapes: + self.degree_layout = shapes['degree_layout'] + else: + self.degree_layout = (self.degree,) + + if len(self.degree_layout) == 2: + self.r = self.degree_layout[0] + self.c = self.degree_layout[1] + else: + self.r = 1 << (log_degree // 2) + self.c = self.degree // self.r + + if self.precision <= 32: + self.modulus_dtype = jnp.uint32 + else: + self.modulus_dtype = jnp.uint64 + + if parameters is not None and 'moduli' in parameters: + self.moduli = parameters['moduli'] + else: + self.moduli = util.find_moduli_ntt( + self.num_moduli, self.precision, 2 * self.degree + ) + + # NTT Parameters + if parameters is not None and 'finite_field_context' in parameters: + finite_field_context = parameters['finite_field_context']( + moduli=self.moduli + ) + else: + finite_field_context = ff_context.BarrettContext(moduli=self.moduli) + + ntt_params = { + 'r': self.r, + 'c': self.c, + 'finite_field_context': finite_field_context, + } + self.bit_reverse_indices = jnp.array( + util.bit_reverse_indices(self.degree), jnp.uint32 + ) + + self.shape_in_ntt_all_limbs = (-1, self.r, self.c, self.num_moduli) + if ( + parameters is not None + and 'BAT_lazy' in parameters + and parameters['BAT_lazy'] + ): + self.ntt_ctx = ntt.NTTCiphertextBATLazyContext( + moduli=self.moduli, parameters=ntt_params + ) + else: + if isinstance(finite_field_context, ff_context.BarrettContext): + self.ntt_ctx = ntt.NTTCiphertextBarrettContext( + moduli=self.moduli, parameters=ntt_params + ) + elif isinstance(finite_field_context, ff_context.MontgomeryContext): + self.ntt_ctx = ntt.NTTCiphertextMontgomeryContext( + moduli=self.moduli, parameters=ntt_params + ) + elif isinstance(finite_field_context, ff_context.ShoupContext): + self.ntt_ctx = ntt.NTTCiphertextShoupContext( + moduli=self.moduli, parameters=ntt_params + ) + else: + raise ValueError( + 'Unsupported finite field context type:' + f' {type(finite_field_context)}' + ) + + self.moduli_array = jnp.array(self.moduli, dtype=self.modulus_dtype) + self.ciphertext = jnp.zeros( + (self.batch, self.num_elements, self.degree, self.num_moduli), + dtype=self.modulus_dtype, + ) + self.extend_ciphertext = jnp.zeros( + (self.batch, self.num_elements, self.degree, 1), + dtype=self.modulus_dtype, + ) + self.bconv = None + self.bconv_indices_list = [] + + def _create_bconv(self, moduli): + if self.bconv is None: + self.bconv = BConvBarrett(moduli) + return self.bconv + + def random_init(self): + self.ciphertext = util.random_batched_ciphertext( + (self.batch, self.num_elements, *self.degree_layout, self.num_moduli), + self.moduli, + dtype=self.modulus_dtype, + ) + + ##################### + # Getter Functions + ##################### + @property + def shape(self): + return self.ciphertext.shape + + def get_batch_ciphertext(self) -> jnp.ndarray: + return self.ciphertext + + def get_ciphertext(self, batch_index) -> jnp.ndarray: + return self.ciphertext[batch_index] + + def get_element(self, element_index) -> jnp.ndarray: + return self.ciphertext[:, element_index] + + def get_limb(self, limb_index) -> jnp.ndarray: + return self.ciphertext[..., limb_index] + + ##################### + # Setter Functions + # Note set_ciphertext, set_element, set_limb are in place operations, not recommended in JAX + ##################### + def set_batch_ciphertext(self, batch_ciphertext: jnp.ndarray) -> None: + self.ciphertext = batch_ciphertext + + def set_ciphertext( + self, batch_index: int, ciphertext: jnp.ndarray + ) -> None: + self.ciphertext = self.ciphertext.at[batch_index].set(ciphertext) + + def set_element( + self, element_index: int, element: jnp.ndarray + ) -> None: + self.ciphertext = self.ciphertext.at[:, element_index].set(element) + + def set_limb(self, limb_index: int, limb: jnp.ndarray) -> None: + self.ciphertext = self.ciphertext.at[..., limb_index].set(limb) + + def get_moduli_array(self) -> jnp.ndarray: + return self.moduli_array + + def get_moduli(self) -> Union[List[int], int]: + return self.moduli + + def get_modulus(self, index: int) -> int: + return self.moduli[index] + + ##################### + # Domain Conversion Functions + ##################### + def to_ntt_form(self): + current_shape = self.ciphertext.shape + reshaped_in = self.ciphertext.reshape(self.shape_in_ntt_all_limbs) + ntt_result = self.ntt_ctx.ntt(reshaped_in) + self.ciphertext = ntt_result.reshape(current_shape) + + def to_coeffs_form(self): + current_shape = self.ciphertext.shape + reshaped_in = self.ciphertext.reshape(self.shape_in_ntt_all_limbs) + intt_result = self.ntt_ctx.intt(reshaped_in) + self.ciphertext = intt_result.reshape(current_shape) + + def to_compute_format(self): + self.ciphertext = self.ntt_ctx.to_computation_format(self.ciphertext) + + def to_original_format(self): + self.ciphertext = self.ntt_ctx.to_original_format(self.ciphertext) + + ##################### + # Arithmetic Functions Entire Ciphertext + ##################### + def add(self, other: Union['Ciphertext', jnp.ndarray]): + other_array = other.ciphertext if isinstance(other, Ciphertext) else other + self.ciphertext = self.ciphertext + other_array + + def sub(self, other: Union['Ciphertext', jnp.ndarray]): + other_array = other.ciphertext if isinstance(other, Ciphertext) else other + self.ciphertext = self.ciphertext - other_array + + def mul(self, other: Union['Ciphertext', jnp.ndarray]): + other_array = other.ciphertext if isinstance(other, Ciphertext) else other + self.ciphertext = self.ciphertext.astype(jnp.uint64) * other_array.astype( + jnp.uint64 + ) + + def modmul(self, other: Union['Ciphertext', jnp.ndarray]): + other_array = other.ciphertext if isinstance(other, Ciphertext) else other + temp = self.ciphertext.astype(jnp.uint64) * other_array.astype(jnp.uint64) + reduced = self.ntt_ctx.ff_ctx.modular_reduction(temp) + self.ciphertext = reduced.astype(self.modulus_dtype) + + def mod_reduce(self): + reduced = self.ntt_ctx.ff_ctx.modular_reduction( + self.ciphertext.astype(jnp.uint64) + ) + self.ciphertext = reduced.astype(self.modulus_dtype) + + ##################### + # Modulus Dropping Functions + ##################### + def drop_last_modulus(self) -> jnp.ndarray: + if self.num_moduli <= 1: + raise ValueError('Cannot drop modulus from a single-limb ciphertext.') + + # Drop ciphertext limb and track the new modulus set. + self.ciphertext = self.ciphertext[..., :-1] + self.moduli = self.moduli[:-1] + self.moduli_array = self.moduli_array[:-1] + self.num_moduli -= 1 + + # Update finite field context and rebuild NTT context for the reduced limb set. + self.shape_in_ntt_all_limbs = (-1, self.r, self.c, self.num_moduli) + self.shape_in_ntt_last_limb = (-1, self.r, self.c) + self.ntt_ctx.drop_last_modulus() + return self.ciphertext + + ##################### + # Modulus Switching Helpers + ##################### + def modulus_switch_control_gen( + self, degree_layout: Optional[tuple] = None, perf_test: bool = False + ): + if degree_layout is None: + degree_layout = (self.degree,) + if self.num_moduli <= 1: + raise ValueError( + 'Cannot perform modulus switch with fewer than two moduli.' + ) + + ring_dim = self.degree + overall_psi = [util.root_of_unity(2 * ring_dim, q) for q in self.moduli] + overall_power_of_psi = jnp.array( + [ + [ + pow(overall_psi[idx], i, self.moduli[idx]) + for i in range(ring_dim) + ] + for idx in range(len(self.moduli)) + ], + jnp.uint64, + ) + + gammas, betas = util.gamma_beta_calculation( + self.moduli, perf_test=perf_test + ) + gammas_power_of_psi_no_last = ( + gammas[: self.num_moduli - 1, None].astype(jnp.uint64) + * overall_power_of_psi[: self.num_moduli - 1].astype(jnp.uint64) + ) % jnp.array(self.moduli[: self.num_moduli - 1], jnp.uint64)[:, None] + inv_psi_last = pow(overall_psi[-1], -1, self.moduli[-1]) + power_of_inv_psi_arr_last_tower = jnp.array( + [pow(inv_psi_last, i, self.moduli[-1]) for i in range(ring_dim)], + jnp.uint64, + ) + + ct_last_limb_shapes = { + 'batch': self.batch, + 'num_elements': 4, + 'degree': ring_dim, + 'precision': 32, + 'num_moduli': 1, + 'degree_layout': degree_layout, + } + self.ct_last_limb = Ciphertext( + ct_last_limb_shapes, + parameters={ + 'moduli': [self.moduli[-1]], + 'finite_field_context': ff_context.BarrettContext, + 'r': self.r, + 'c': self.c, + }, + ) + + self.gammas_power_of_psi_no_last = jnp.array( + gammas_power_of_psi_no_last, jnp.uint64 + ).T.reshape(1, *degree_layout, self.num_moduli - 1) + self.betas = jnp.array(betas, jnp.uint64)[: self.num_moduli] + + # Reshape power_of_inv_psi_arr_last_tower to (1, 1, degree, 1) for broadcasting + self.power_of_last_tower_inv_psi = jnp.array( + power_of_inv_psi_arr_last_tower, jnp.uint64 + ).reshape(1, *degree_layout, 1) + + self.moduli_no_last = jnp.array(self.moduli[:-1], dtype=self.modulus_dtype) + self.moduli_threshold = jnp.array( + (self.moduli[-1] + 1) // 2, dtype=jnp.uint32 + ) + self.drop_last_moduli_arr = jnp.array(self.moduli[:-1], jnp.uint32) + self.last_tower_moduli = self.moduli[-1] + + self.mod_switch_params = True # Flag to indicate initialization + self.drop_last_modulus() + + def rescale(self): + """Rescale implementation using Ciphertext class methods.""" + # Ensure control parameters are available on the ciphertext + if not hasattr(self, 'mod_switch_params'): + self.modulus_switch_control_gen(degree_layout=self.degree_layout) + + # Capture current state before dropping modulus + in_ciphertexts = self.get_batch_ciphertext() + + # 1. Extract Last Tower and Process logic similar to hemul.py + last_towers = in_ciphertexts[..., -1:] + self.ct_last_limb.set_batch_ciphertext(last_towers.astype(jnp.uint32)) + self.ct_last_limb.to_coeffs_form() + self.ct_last_limb.modmul(self.power_of_last_tower_inv_psi) + ct_last_limb_modred = self.ct_last_limb.get_batch_ciphertext().astype( + jnp.uint32 + ) + + condition = ct_last_limb_modred < self.moduli_threshold + result_if_lt_threshold = ct_last_limb_modred + result_if_ge_threshold = ( + jnp.array(self.drop_last_moduli_arr, jnp.uint64) + - self.last_tower_moduli + + ct_last_limb_modred + ) + last_poly_switch_modulus_coef = jnp.where( + condition, result_if_lt_threshold, result_if_ge_threshold + ) + last_poly_switch_modulus_coef_twisted = ( + last_poly_switch_modulus_coef.astype(jnp.uint64) + * self.gammas_power_of_psi_no_last + ) + self.set_batch_ciphertext(last_poly_switch_modulus_coef_twisted) + self.mod_reduce() + self.to_ntt_form() + mod_reduce_last_res_unreduced = self.get_batch_ciphertext() + + self.set_batch_ciphertext(in_ciphertexts[..., :-1]) + self.mul(self.betas) + self.set_batch_ciphertext( + self.get_batch_ciphertext() + + mod_reduce_last_res_unreduced.astype(jnp.uint64) + ) + self.mod_reduce() + return self.get_batch_ciphertext() + + ##################### + # FHE Kernel Functions + ##################### + def ciphertext_mult(self) -> Tuple[jnp.ndarray, jnp.ndarray]: + a0 = self.ciphertext[:, 0].astype(jnp.uint64) + a1 = self.ciphertext[:, 1].astype(jnp.uint64) + b0 = self.ciphertext[:, 2].astype(jnp.uint64) + b1 = self.ciphertext[:, 3].astype(jnp.uint64) + + mul0_t = a0 * b0 + mul0 = self.ntt_ctx.ff_ctx.modular_reduction(mul0_t).astype( + self.modulus_dtype + ) + + mul2_t = a1 * b1 + mul2 = self.ntt_ctx.ff_ctx.modular_reduction(mul2_t).astype( + self.modulus_dtype + ) + + t1_t = a0 * b1 + a1 * b0 + mul1 = self.ntt_ctx.ff_ctx.modular_reduction(t1_t) + + return ( + jnp.concatenate([mul0[:, None], mul1[:, None]], axis=1), + mul2[:, None], + ) + + ##################### + # Key Switching Functions + ##################### + def key_switch_control_gen( + self, + extend_moduli: List[int], + dnum: int, + evalkey_a, + evalkey_b, + perf_test: bool = False, + selected_moduli: List[int] | None = None, + degree_layout: tuple | int | None = None, + ): + """Generates all parameters needed for key switching.""" + if degree_layout is None: + degree_layout = (self.degree,) + # 1. Configuration + self.ks_params = {} + current_moduli = self.moduli if selected_moduli is None else selected_moduli + drop_last_extend_moduli = current_moduli + extend_moduli + self.ks_params['drop_last_extend_moduli'] = drop_last_extend_moduli + ring_dim = self.degree + + # Generate Drop Last Power of Psi + drop_last_psi = [ + util.root_of_unity(2 * ring_dim, q) for q in drop_last_extend_moduli + ] + if perf_test: + self.ks_params['drop_last_power_of_psi'] = util.random_parameters( + (*self.degree_layout, len(drop_last_extend_moduli)), + drop_last_extend_moduli, + dtype=jnp.uint64, + ) + else: + self.ks_params['drop_last_power_of_psi'] = jnp.array( + [ + [ + pow(drop_last_psi[idx], i, drop_last_extend_moduli[idx]) + for i in range(ring_dim) + ] + for idx in range(len(drop_last_extend_moduli)) + ], + jnp.uint64, + ).T.reshape(*self.degree_layout, len(drop_last_extend_moduli)) + + # Tower Indices + sizeQ_drop_last = len(current_moduli) + alpha = (sizeQ_drop_last + dnum - 1) // dnum + self.ks_params['alpha'] = alpha + self.ks_params['numPartQl'] = (sizeQ_drop_last + alpha - 1) // alpha + + original_moduli_extract_index = [] + for i in range(sizeQ_drop_last): + if i % alpha == 0: + original_moduli_extract_index.append([i]) + else: + original_moduli_extract_index[-1].append(i) + + ( + select_tower_index_overall, + non_select_tower_index_overall, + restore_indices, + ) = ([], [], []) + for part in range(self.ks_params['numPartQl']): + select_tower_overall_index = original_moduli_extract_index[part] + non_select_tower_overall_index = [ + i + for i in range(len(drop_last_extend_moduli)) + if i not in select_tower_overall_index + ] + concat_order = select_tower_overall_index + non_select_tower_overall_index + restore_index = [0] * len(concat_order) + for pos, val in enumerate(concat_order): + restore_index[val] = pos + select_tower_index_overall.append( + jnp.array(select_tower_overall_index, jnp.uint16) + ) + non_select_tower_index_overall.append( + jnp.array(non_select_tower_overall_index, jnp.uint16) + ) + restore_indices.append(jnp.array(restore_index, jnp.uint16)) + + self.ks_params['select_tower_index_overall'] = select_tower_index_overall + self.ks_params['non_select_tower_index_overall'] = ( + non_select_tower_index_overall + ) + self.ks_params['restore_indices'] = restore_indices + + # Helper function for inv psi + def gen_power_of_inv_psi_arr(moduli): + q_list = [moduli] if not isinstance(moduli, list) else moduli + psi_list = [util.root_of_unity(2 * ring_dim, q) for q in q_list] + inv_psi = [pow(psi, -1, q) for (q, psi) in zip(q_list, psi_list)] + power_of_inv_psi_arr = [ + [pow(inv_psi[idx], i, q_list[idx]) for i in range(ring_dim)] + for idx in range(len(psi_list)) + ] + return jnp.array(power_of_inv_psi_arr, jnp.uint64).T.reshape( + 1, 1, *degree_layout, -1 + ) + + if perf_test: + self.ks_params['power_of_inv_psi_arr_drop_last'] = util.random_parameters( + (len(current_moduli), ring_dim), current_moduli, dtype=jnp.uint64 + ).T.reshape(1, 1, *degree_layout, -1) + else: + self.ks_params['power_of_inv_psi_arr_drop_last'] = ( + gen_power_of_inv_psi_arr(current_moduli) + ) + + # Basis Change + bconv_indices_list = [] + for part in range(self.ks_params['numPartQl']): + bconv_indices_list.append(( + select_tower_index_overall[part].tolist(), + non_select_tower_index_overall[part].tolist(), + )) + + # Manage Global BConv + self.ks_params['ks_control_start_idx'] = len(self.bconv_indices_list) + self.bconv_indices_list.extend(bconv_indices_list) + self._create_bconv(drop_last_extend_moduli) + self.bconv.control_gen(self.bconv_indices_list, perf_test=perf_test) + + # KeySwitch Parts CTs + self.ks_params['ct_ks_parts'] = [] + + # Handle batch size for key switch + ks_batch_size = self.batch + + ct_shapes_common = { + 'batch': ks_batch_size, + 'num_elements': 2, + 'degree': ring_dim, + 'precision': self.precision, + 'degree_layout': self.degree_layout, + } + ct_params_common = { + 'r': self.r, + 'c': self.c, + 'finite_field_context': ff_context.BarrettContext, + } + + for part in range(self.ks_params['numPartQl']): + target_indices = non_select_tower_index_overall[part].tolist() + target_moduli = [drop_last_extend_moduli[i] for i in target_indices] + shapes_part = ct_shapes_common.copy() + shapes_part['num_moduli'] = len(target_moduli) + params_part = ct_params_common.copy() + params_part['moduli'] = target_moduli + self.ks_params['ct_ks_parts'].append(Ciphertext(shapes_part, params_part)) + + # CT for Drop Last + Extend (for ks_core result) + shapes_dle = ct_shapes_common.copy() + shapes_dle['num_moduli'] = len(drop_last_extend_moduli) + params_dle = ct_params_common.copy() + params_dle['moduli'] = drop_last_extend_moduli + self.ks_params['ct_drop_last_extend'] = Ciphertext(shapes_dle, params_dle) + + # Keys Setup + # Reshape keys + idx_cur_last_tower = len(current_moduli) + overall_sizeP = len(extend_moduli) + + self.ks_params['evk_a_precomp'] = jnp.concatenate( + [evalkey_a[..., :idx_cur_last_tower], evalkey_a[..., -overall_sizeP:]], + axis=-1, + ).reshape(-1, *self.degree_layout, len(drop_last_extend_moduli)) + self.ks_params['evk_b_precomp'] = jnp.concatenate( + [evalkey_b[..., :idx_cur_last_tower], evalkey_b[..., -overall_sizeP:]], + axis=-1, + ).reshape(-1, *self.degree_layout, len(drop_last_extend_moduli)) + self.ks_params['idx_cur_last_tower'] = idx_cur_last_tower + self.ks_params['overall_sizeP'] = overall_sizeP + + def key_switch(self): + """Performs key switching on the current ciphertext. + + Mutates self.ciphertext to the extended moduli basis. + """ + if not hasattr(self, 'ks_params'): + raise RuntimeError( + 'Key switch parameters not initialized. Call key_switch_control_gen' + ' first.' + ) + + # Aliases + params = self.ks_params + power_of_inv_psi_arr_drop_last = params['power_of_inv_psi_arr_drop_last'] + select_tower_index_overall = params['select_tower_index_overall'] + non_select_tower_index_overall = params['non_select_tower_index_overall'] + ks_control_start_idx = params['ks_control_start_idx'] + ct_ks_parts = params['ct_ks_parts'] + numPartQl = params['numPartQl'] + ct_drop_last_extend = params['ct_drop_last_extend'] + evk_a_precomp = params['evk_a_precomp'] + evk_b_precomp = params['evk_b_precomp'] + restore_indices = params['restore_indices'] + + # Save NTT input for reconstruction + input_ntt = self.get_batch_ciphertext() + + # Step 3: Precompute + self.to_coeffs_form() + self.modmul(power_of_inv_psi_arr_drop_last) + partCtCloneCoef = self.get_batch_ciphertext() + + res0 = None + res1 = None + for part in range(numPartQl): + select_idxs = select_tower_index_overall[part] + non_select_idxs = non_select_tower_index_overall[part] + _power_of_psi_arr = jnp.take( + params['drop_last_power_of_psi'], non_select_idxs, axis=-1 + ) + partCtCloneEval = self.bconv.basis_change_bat( + jnp.take(partCtCloneCoef, select_idxs, axis=-1), + control_index=ks_control_start_idx + part, + ).astype(jnp.uint64) + ct_part = ct_ks_parts[part] + ct_part.ciphertext = partCtCloneEval + ct_part.modmul(_power_of_psi_arr) + partCtCloneEval_scaled_multi_moduli = ct_part.ciphertext + ct_part.ciphertext = partCtCloneEval_scaled_multi_moduli + ct_part.to_ntt_form() + partsCtCompl_multi_moduli = ct_part.ciphertext + + partsCtExt_cur_part = jnp.concatenate( + [ + jnp.take(input_ntt, select_idxs, axis=-1), + partsCtCompl_multi_moduli, + ], + axis=-1, + ) + partsCtExt_cur_part = jnp.take( + partsCtExt_cur_part, restore_indices[part], axis=-1 + ) + + if res0 is None: + res0 = ( + partsCtExt_cur_part + * evk_b_precomp.astype(jnp.uint64)[part][None, None, :, :] + ) + res1 = ( + partsCtExt_cur_part + * evk_a_precomp.astype(jnp.uint64)[part][None, None, :, :] + ) + else: + res0 += ( + partsCtExt_cur_part + * evk_b_precomp.astype(jnp.uint64)[part][None, None, :, :] + ) + res1 += ( + partsCtExt_cur_part + * evk_a_precomp.astype(jnp.uint64)[part][None, None, :, :] + ) + + result = jnp.concatenate([res0, res1], axis=1) + ct_drop_last_extend.set_batch_ciphertext(result) + ct_drop_last_extend.mod_reduce() + ks_result = ct_drop_last_extend.get_batch_ciphertext() + self.set_batch_ciphertext(ks_result) diff --git a/jaxite/jaxite_word/ciphertext_test.py b/jaxite/jaxite_word/ciphertext_test.py new file mode 100644 index 0000000..2c034b4 --- /dev/null +++ b/jaxite/jaxite_word/ciphertext_test.py @@ -0,0 +1,271 @@ +"""Finite Field Test Suite + +Test cases: +- Montgomery Single Modulus Context +- Barrett Single Modulus Context +- Shoup Single Modulus Context +- Montgomery Multi Modulus Context +- Barrett Multi Modulus Context +- Shoup Multi Modulus Context + +Terminology: +- Modulus -- Moduli: Single form or plural form of modulus. + + +Usage: +- Specify the overall moduli for the context, and corresponding parameter +required for the modular reduction. +- Then feed "moduli" and "parameters" to the context constructor. +- Then context->modular_reduction(input) to get the reduced result for certain +inputs. +""" + +from absl.testing import absltest +from absl.testing import parameterized +import jaxite.jaxite_word.ciphertext as ct +import jaxite.jaxite_word.finite_field as ff_context +import jax +import jax.numpy as jnp +import numpy as np + +testing_params = [{"testcase_name": "0"}] + + +@parameterized.named_parameters(testing_params) +class FiniteFieldTest(parameterized.TestCase): + + def setUp(self): + # Setup random input data and their modmul reference results. + self.random_key = jax.random.key(0) + # batch, (w/ element), moduli (limbs/towers), degree + shapes = { + "batch": 3, + "num_elements": 2, + "num_moduli": 4, + "degree": 16, + "precision": 29, + } + + ct_a = ct.Ciphertext(shapes) + ct_b = ct.Ciphertext(shapes, {"moduli": ct_a.get_moduli()}) + ct_ab = ct.Ciphertext(shapes, {"moduli": ct_a.get_moduli()}) + ct_ab_modq = ct.Ciphertext(shapes, {"moduli": ct_a.get_moduli()}) + ct_a.random_init() + ct_b.random_init() + self.a = ct_a.get_batch_ciphertext().astype(jnp.uint64) + self.b = ct_b.get_batch_ciphertext().astype(jnp.uint64) + self.ab = self.a * self.b + self.ab_modq = (self.ab % ct_a.get_moduli_array()).astype(jnp.uint32) + ct_ab.set_batch_ciphertext(self.ab) + ct_ab_modq.set_batch_ciphertext(self.ab_modq) + + self.single_a = ct_a.get_limb(0).astype(jnp.uint64) + self.single_b = ct_b.get_limb(0).astype(jnp.uint64) + self.single_moduli = ct_ab.get_modulus(0) + self.single_ab = ct_ab.get_limb(0) + self.single_ab_modq = ct_ab_modq.get_limb(0) + + self.moduli = ct_ab.get_moduli() + + # @absltest.skip("test single implementation") + def test_montgomery_single_moduli_context(self): + context = ff_context.MontgomeryContext(self.single_moduli) + single_a_mont = context.to_computation_format( + self.single_a.astype(jnp.uint64) + ) + single_b_mont = context.to_computation_format( + self.single_b.astype(jnp.uint64) + ) + single_ab_mont = single_a_mont.astype(jnp.uint64) * single_b_mont.astype( + jnp.uint64 + ) + result_mont = context.modular_reduction(single_ab_mont) + result = context.to_original_format(result_mont.astype(jnp.uint64)) + np.testing.assert_array_equal(result, self.single_ab_modq) + + # @absltest.skip("test single implementation") + def test_barrett_single_moduli_context(self): + context = ff_context.BarrettContext(self.single_moduli) + result = context.modular_reduction(self.single_ab) + np.testing.assert_array_equal(result, self.single_ab_modq) + + # @absltest.skip("test single implementation") + def test_shoup_single_moduli_context(self): + context = ff_context.ShoupContext(self.single_moduli) + single_a_precomputed = context.precompute_constant_operand( + self.single_a.astype(jnp.uint64) + ) + single_ab = self.single_a.astype(jnp.uint64) * self.single_b.astype( + jnp.uint64 + ) + single_ab_shoup = single_a_precomputed * self.single_b.astype(jnp.uint64) + result_shoup = context.modular_reduction(single_ab, single_ab_shoup) + result = context.to_original_format(result_shoup.astype(jnp.uint64)) + np.testing.assert_array_equal(result, self.single_ab_modq) + + # @absltest.skip("test single implementation") + def test_montgomery_multi_moduli_context(self): + context = ff_context.MontgomeryContext(self.moduli) + a_mont = context.to_computation_format(self.a.astype(jnp.uint64)) + b_mont = context.to_computation_format(self.b.astype(jnp.uint64)) + ab_mont = a_mont.astype(jnp.uint64) * b_mont.astype(jnp.uint64) + result_mont = context.modular_reduction(ab_mont) + result = context.to_original_format(result_mont.astype(jnp.uint64)) + np.testing.assert_array_equal(result, self.ab_modq) + + # @absltest.skip("test single implementation") + def test_barrett_multi_moduli_context(self): + context = ff_context.BarrettContext(self.moduli) + result = context.modular_reduction(self.ab) + np.testing.assert_array_equal(result, self.ab_modq) + + # @absltest.skip("test single implementation") + def test_shoup_multi_moduli_context(self): + context = ff_context.ShoupContext(self.moduli) + a_precomputed = context.precompute_constant_operand( + self.a.astype(jnp.uint64) + ) + ab = self.a.astype(jnp.uint64) * self.b.astype(jnp.uint64) + ab_shoup = a_precomputed * self.b.astype(jnp.uint64) + result_shoup = context.modular_reduction(ab, ab_shoup) + result = context.to_original_format(result_shoup.astype(jnp.uint64)) + np.testing.assert_array_equal(result, self.ab_modq) + + # @absltest.skip("test single implementation") + def test_ntt_conversion(self): + shapes = { + "batch": 3, + "num_elements": 2, + "num_moduli": 4, + "degree": 16, + "precision": 29, + } + parameters = { + "r": 4, + "c": 4, + "finite_field_context": ( + ff_context.BarrettContext + ), # ff_context.BarrettContext, ff_context.MontgomeryContext, ff_context.ShoupContext + } + ct_temp = ct.Ciphertext(shapes, parameters) + ct_temp.random_init() + + original = ct_temp.get_batch_ciphertext() + + # Check NTT round trip + ct_temp.to_compute_format() + ct_temp.to_ntt_form() + ct_temp.to_coeffs_form() + ct_temp.to_original_format() + + np.testing.assert_array_equal(original, ct_temp.get_batch_ciphertext()) + + +class CiphertextLimbDomainConversionTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.limb_shapes = { + "batch": 2, + "num_elements": 2, + "num_moduli": 1, + "degree": 16, + "precision": 29, + } + self.drop_shapes = { + "batch": 2, + "num_elements": 2, + "num_moduli": 3, + "degree": 8, + "precision": 29, + } + self.limb_index = 0 + base_ct = ct.Ciphertext(self.limb_shapes) + base_ct.random_init() + self.original = base_ct.get_batch_ciphertext() + self.moduli = base_ct.get_moduli() + + def test_drop_last_modulus_preserves_remaining_limbs_and_context(self): + ct_temp = ct.Ciphertext(self.drop_shapes) + ct_temp.random_init() + before_drop = ct_temp.get_batch_ciphertext() + expected_remaining = before_drop[:, :, :, :-1] + + ct_temp.drop_last_modulus() + + np.testing.assert_array_equal( + ct_temp.get_batch_ciphertext(), expected_remaining + ) + self.assertEqual(ct_temp.num_moduli, self.drop_shapes["num_moduli"] - 1) + self.assertEqual( + ct_temp.ntt_ctx.ff_ctx.moduli_reduction.shape[0], + self.drop_shapes["num_moduli"] - 1, + ) + + +class CiphertextArithmeticTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.shapes = { + "batch": 1, + "num_elements": 1, + "num_moduli": 2, + "degree": 8, + "precision": 29, + } + self.ct1 = ct.Ciphertext(self.shapes) + self.ct1.random_init() + self.ct2 = ct.Ciphertext(self.shapes, {"moduli": self.ct1.get_moduli()}) + self.ct2.random_init() + + self.arr1 = self.ct1.get_batch_ciphertext() + self.arr2 = self.ct2.get_batch_ciphertext() + + # @absltest.skip("test single implementation") + def test_mul_ciphertext(self): + # mul modifies in-place + expected = self.arr1.astype(jnp.uint64) * self.arr2.astype(jnp.uint64) + self.ct1.mul(self.ct2) + np.testing.assert_array_equal(self.ct1.get_batch_ciphertext(), expected) + + # @absltest.skip("test single implementation") + def test_mul_array(self): + expected = self.arr1.astype(jnp.uint64) * self.arr2.astype(jnp.uint64) + self.ct1.mul(self.arr2) + np.testing.assert_array_equal(self.ct1.get_batch_ciphertext(), expected) + + # @absltest.skip("test single implementation") + def test_modmul_ciphertext(self): + expected_temp = self.arr1.astype(jnp.uint64) * self.arr2.astype(jnp.uint64) + expected = self.ct1.ntt_ctx.ff_ctx.modular_reduction(expected_temp).astype( + self.ct1.modulus_dtype + ) + + self.ct1.modmul(self.ct2) + np.testing.assert_array_equal(self.ct1.get_batch_ciphertext(), expected) + + # @absltest.skip("test single implementation") + def test_modmul_array(self): + expected_temp = self.arr1.astype(jnp.uint64) * self.arr2.astype(jnp.uint64) + expected = self.ct1.ntt_ctx.ff_ctx.modular_reduction(expected_temp).astype( + self.ct1.modulus_dtype + ) + + self.ct1.modmul(self.arr2) + np.testing.assert_array_equal(self.ct1.get_batch_ciphertext(), expected) + + # @absltest.skip("test single implementation") + def test_mod_reduce(self): + # Create a value that needs reduction + self.ct1.ciphertext = self.ct1.ciphertext.astype(jnp.uint64) * 100 + expected = self.ct1.ntt_ctx.ff_ctx.modular_reduction( + self.ct1.ciphertext + ).astype(self.ct1.modulus_dtype) + + self.ct1.mod_reduce() + np.testing.assert_array_equal(self.ct1.get_batch_ciphertext(), expected) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxite/jaxite_word/ckks_ctx.py b/jaxite/jaxite_word/ckks_ctx.py new file mode 100644 index 0000000..5551254 --- /dev/null +++ b/jaxite/jaxite_word/ckks_ctx.py @@ -0,0 +1,809 @@ +import cmath +import math +import random +from typing import List +import jaxite.jaxite_word.ciphertext as ct +import jax.numpy as jnp +import jaxite.jaxite_word.key_gen as kg +import numpy as np +import jaxite.jaxite_word.util as util + + +MAX_INT_64 = 9223372036854775295 +sigma = 3.190000057220458984375 + + +######################## +# Common Functions +######################## +def _roots(m: int) -> List[complex]: + return [cmath.exp(1j * 2 * math.pi * k / m) for k in range(m)] + + +def _rot_group(m: int, nh: int, g: int = 5) -> List[int]: + # CKKS rotation subgroup (powers of 5 mod m) + r = [1] + for _ in range(1, nh): + r.append((r[-1] * g) % m) + return r + + +def _bitrev_inplace(a: List[complex]) -> None: + n = len(a) + j = 0 + for i in range(1, n): + bit = n >> 1 + while j & bit: + j ^= bit + bit >>= 1 + j ^= bit + if i < j: + a[i], a[j] = a[j], a[i] + + +def FFTSpecialInv(vals: List[complex], cycl_order: int) -> None: + """CKKS 'special' inverse FFT (DIF-style), in-place: + + - twiddles via rotation group, + - per-stage idx = ((lenq - (rg[j] % lenq)) % lenq) * (m/lenq), + - scale by 1/Nh, + - *post* bit-reversal (to match OpenFHE ordering in your test). + """ + m = cycl_order + nh = len(vals) + roots = _roots(m) + rg = _rot_group(m, nh, g=5) + + length = nh + while length >= 1: + half = length >> 1 + lenq = length << 2 + step = m // lenq + for i in range(0, nh, length): + for j in range(half): + mod = rg[j] % lenq + idx = ((lenq - mod) % lenq) * step + u = vals[i + j] + t = vals[i + j + half] + vals[i + j] = u + t + vals[i + j + half] = (u - t) * roots[idx] + length >>= 1 + + inv = 1.0 / nh + for k in range(nh): + vals[k] *= inv + + _bitrev_inplace(vals) + + +def FFTSpecial(vals: List[complex], cycl_order: int) -> None: + """CKKS 'special' forward FFT (DIT-style), in-place: + + - *pre* bit-reversal, + - per-stage idx = (rg[j] % lenq) * (m/lenq), + - no final scale. + """ + m = cycl_order + nh = len(vals) + roots = _roots(m) + rg = _rot_group(m, nh, g=5) + + _bitrev_inplace(vals) + + length = 2 + while length <= nh: + half = length >> 1 + lenq = length << 2 + step = m // lenq + for i in range(0, nh, length): + for j in range(half): + mod = rg[j] % lenq + idx = mod * step + u = vals[i + j] + v = vals[i + j + half] * roots[idx] + vals[i + j] = u + v + vals[i + j + half] = u - v + length <<= 1 + + +def nearest_int(x: float) -> int: + # matches OpenFHE's usual "nearest" behavior used for CKKS encode + return int(math.floor(x + 0.5)) if x >= 0 else -int(math.floor(-x + 0.5)) + + +def slot_to_coeffs(y: List[complex]) -> List[float]: + """Slot -> real polynomial coefficients (length N=2*Nh). + + Adjust here if your build packs coefficients differently. Current: + [Re(y_0..y_{Nh-1}), Im(y_0..y_{Nh-1})] + """ + nh = len(y) + re = [y[i].real for i in range(nh)] + im = [y[i].imag for i in range(nh)] + return re + im + + +def fit_to_native_vector( + vec: List[int], big_bound: int, modulus: int, ring_dim: int +) -> List[int]: + """Python equivalent of CKKSPackedEncoding::FitToNativeVector. + + - Places dslots values into a length-ring_dim vector at indices i*gap + + where gap = ring_dim // dslots. + - Maps each input value n using bigBound and modulus: + if n > bigBound/2: (n - (bigBound - modulus)) mod modulus + else: n mod modulus + """ + dslots = len(vec) + if dslots == 0: + return [0] * ring_dim + + big_value_half = big_bound >> 1 + diff = big_bound - modulus + gap = ring_dim // dslots + + native = [0] * ring_dim + for i, val in enumerate(vec): + n = int(val) + if n > big_value_half: + mapped = (n - diff) % modulus + else: + mapped = n % modulus + native[gap * i] = mapped + return native + + +def ckks_encrypt( + plaintext: List[List[int]], + public_key: List[List[List[int]]], + q_towers: List[int], + noise_scale_degree: int = 1, + sigma=3.190000057220458984375, + v=None, + e=None, + ba=None, + e_ref=None, + v_ref=None, +): + # plaintext is now (degree, moduli) + degree = len(plaintext) + num_towers = len(q_towers) + psi = None + if v is None: + psi = [ + util.root_of_unity(int(2 * degree), q_towers[t_id]) + for t_id in range(num_towers) + ] + v = kg.gen_ternary_uniform_polynomial(degree, q_towers).coeffs + v = [ + util.bit_reverse_array( + util.ntt_negacyclic_bit_reverse(v[t_id], q_towers[t_id], psi[t_id]) + ) + for t_id in range(num_towers) + ] + if v_ref is not None: + np.testing.assert_array_equal(v, v_ref) + if e is None: + # This noise is too large! I need to copy the OpenFHE's implementation to fix it + psi = [ + util.root_of_unity(int(2 * degree), q_towers[t_id]) + for t_id in range(num_towers) + ] + e0 = kg.gen_gaussian_polynomial(degree, q_towers, sigma=sigma).coeffs + e0 = [ + util.bit_reverse_array( + util.ntt_negacyclic_bit_reverse( + (e0[t_id]), q_towers[t_id], psi[t_id] + ) + ) + for t_id in range(num_towers) + ] + e1 = kg.gen_gaussian_polynomial(degree, q_towers, sigma=sigma).coeffs + e1 = [ + util.bit_reverse_array( + util.ntt_negacyclic_bit_reverse( + (e1[t_id]), q_towers[t_id], psi[t_id] + ) + ) + for t_id in range(num_towers) + ] + if e_ref is not None: + np.testing.assert_array_equal(e0, e_ref[0]) + np.testing.assert_array_equal(e1, e_ref[1]) + else: + e0, e1 = e[0], e[1] + ns = noise_scale_degree + if len(plaintext[0]) < len(public_key[0]): + diff_length = len(public_key[0]) - len(plaintext[0]) + public_key[0] = public_key[0][:-diff_length] + public_key[1] = public_key[1][:-diff_length] + + # Prepare c0, c1 accumulators with shape (degree, moduli) + c0 = [[0] * num_towers for _ in range(degree)] + c1 = [[0] * num_towers for _ in range(degree)] + + # We need to iterate carefully. + # v is (moduli, degree) + # public_key is (2, moduli, degree) -> public_key[0], public_key[1] are (moduli, degree) + # e0, e1 are (moduli, degree) + # But we want output (degree, moduli) + + for i in range(degree): + # plaintext[i] is [m0, m1, ..., mk] corresponding to degree i + p_i_moduli = plaintext[i] + + for j in range(num_towers): + q_j = q_towers[j] + v_ji = v[j][i] + pk0_ji = public_key[0][j][i] + e0_ji = e0[j][i] + + val0 = (v_ji * pk0_ji + ns * e0_ji) % q_j + + # Add plaintext + val0 = (val0 + p_i_moduli[j]) % q_j + c0[i][j] = val0 + + pk1_ji = public_key[1][j][i] + e1_ji = e1[j][i] + + val1 = (v_ji * pk1_ji + ns * e1_ji) % q_j + c1[i][j] = val1 + + if ba is not None: + # ba corresponds to the value BEFORE adding plaintext (c0_pre, c1) + # c0 currently holds val0 + plaintext. We need to subtract plaintext for verification. + # c1 holds val1 (no plaintext added), so it's fine. + c0_T = [ + [(c0[d][m] - plaintext[d][m]) % q_towers[m] for d in range(degree)] + for m in range(num_towers) + ] + c1_T = [[c1[d][m] for d in range(degree)] for m in range(num_towers)] + np.testing.assert_array_equal(c0_T, ba[0]) + np.testing.assert_array_equal(c1_T, ba[1]) + + return [c0, c1] + + +def ckks_decrypt( + ciphertext: List[List[List[int]]], + private_key: List[List[int]], + q_towers: List[int], +): + # ciphertext is list of elements. element 0 is c0, etc. + # each element is (degree, moduli) + num_elements = len(ciphertext) + degree = len(ciphertext[0]) + num_towers = len(ciphertext[0][0]) + + s = private_key # (moduli, degree) + + if num_towers < len(s): + diff_length = len(s) - num_towers + s = s[:-diff_length] + + # Pre-transpose s for easier access or just index carefully + # s is (moduli, degree) + + # We want to compute: M(X) = c0 + c1*s + ... + # Result should be (degree, moduli) initially before NTT/CRT? + # Actually decrypt returns coefficients. + + # Let's accumulate in (moduli, degree) for the final NTT part which expects that layout usually, + # OR we adapt the rest of the function. + # The original returned `first_element_coef` which was (moduli, degree). + # But we want "ciphertext/plaintext in the layout of (degree, moduli)". + # So we should probably return (degree, moduli). + + # Let's accumulate in (degree, moduli). + + res_poly = [[0] * num_towers for _ in range(degree)] + + # s_power starts as s^1. s is (moduli, degree). + # We need s^k in (moduli, degree). + + s_powers = [s] # s^1 + # Generate powers if needed (for num_elements > 2) + # original code: s_power updated iteratively. + + cur_s_power = [list(row) for row in s] # Copy s + + # c0 + for d in range(degree): + for m in range(num_towers): + res_poly[d][m] = ciphertext[0][d][m] + + for i in range(1, num_elements): + ci = ciphertext[i] # (degree, moduli) + + for d in range(degree): + for m in range(num_towers): + # + ci * s^i + term = (ci[d][m] * cur_s_power[m][d]) % q_towers[m] + res_poly[d][m] = (res_poly[d][m] + term) % q_towers[m] + + if i < num_elements - 1: + # Update s_power to s^(i+1) + # s^(i+1) = s^i * s + new_s_power = [[0] * degree for _ in range(num_towers)] + for m in range(num_towers): + qi = q_towers[m] + for d in range(degree): + new_s_power[m][d] = (cur_s_power[m][d] * s[m][d]) % qi + cur_s_power = new_s_power + + # Now res_poly is (degree, moduli) + # We need to do inverse NTT. + # Existing utils utilize (moduli, degree) usually? + # util.bit_reverse_array takes 1D list. + # util.intt_negacyclic_bit_reverse takes 1D list. + + # So we can process row by row if we transpose or col by col. + # The original returned `first_element_coef` as list of lists (moduli, degree). + # We want to return (degree, moduli). + + final_res = [[0] * num_towers for _ in range(degree)] + + for m in range(num_towers): + # Extract column m + col = [res_poly[d][m] for d in range(degree)] + + # bit reverse + col_rev = util.bit_reverse_array(col) + + # intt + coef = util.intt_negacyclic_bit_reverse( + col_rev, q_towers[m], util.root_of_unity(2 * degree, q_towers[m]) + ) + + for d in range(degree): + final_res[d][m] = coef[d] + + return final_res + + +def ckks_encode( + slots: List[complex], + cycl_order: int, + q_towers: List[int], + p_towers: List[int], + scale: float, + noise_scale_degree: int = 1, + max_bits_in_word: int = 61, +): + """Encode slots to DCRTPoly EVAL form with given (Q,P) towers, NATIVE_INT=64. + + Returns dict with residues for Q and P towers and the scaled integer coeffs. + """ + nh = len(slots) + N = 2 * nh + m = cycl_order + assert m == 4 * nh, "cycl_order must be 4*Nh for CKKS special FFT size" + + # 1) inverse special FFT + y = list(slots) + FFTSpecialInv(y, m) + + # 2) slot->coeff packing + coeffs = slot_to_coeffs(y) # length N + + # 3) scale and determine bit length like OpenFHE (NATIVEINT==64) + # Find logc = ceil(log2(max(|scaled_real|, |scaled_imag|))) across slots + scaled_vals = [scale * v for v in coeffs] + logc = -(10**9) + for v in scaled_vals: + absv = abs(v) + if absv != 0.0: + logci = int(math.ceil(math.log2(absv))) + if logc < logci: + logc = logci + if logc == -(10**9): + logc = 0 + if logc < 0: + raise ValueError("Scaling factor too small") + # 4) approxFactor to keep values within 60-bit word, then quantize + log_valid = logc if logc <= max_bits_in_word else max_bits_in_word + log_approx = logc - log_valid + approx_factor = 2.0**log_approx + # Quantize with round-to-nearest after dividing by approx_factor + ints_base = [nearest_int(v / approx_factor) for v in scaled_vals] + ints_base = [x + MAX_INT_64 if x < 0 else x for x in ints_base] + + elements = [ + fit_to_native_vector(ints_base, MAX_INT_64, q_tower, N) + for q_tower in q_towers + ] + + # 5) Scale back up by approx_factor (power of two) in the ring + if log_approx > 0: + step = 1 << log_approx + ints = [ + [x * step % q_towers[mod_id] for x in elements[mod_id]] + for mod_id in range(len(q_towers)) + ] + else: + ints = elements + + # 6) If noise scale degree > 1, multiply by round(scale)^(d-1) + if noise_scale_degree > 1: + int_pow_p = int(round(scale)) + if int_pow_p != 1: + power = pow(int_pow_p, noise_scale_degree - 1) + for mod_id in range(len(q_towers)): + ints[mod_id] = [x * power % q_towers[mod_id] for x in ints[mod_id]] + + # 4) residues per tower + Q_res = [ + util.ntt_negacyclic_bit_reverse( + ints[mod_id], + q_towers[mod_id], + util.root_of_unity(m, q_towers[mod_id]), + ) + for mod_id in range(len(q_towers)) + ] + # Current Q_res is (moduli, degree) + # We want to return (degree, moduli) + + Q_res_T = [[0] * len(q_towers) for _ in range(N)] + for mod_id in range(len(q_towers)): + rev = util.bit_reverse_array(Q_res[mod_id]) + for deg in range(N): + Q_res_T[deg][mod_id] = rev[deg] + + return Q_res_T + + +def ckks_decode( + plaintext: List[int], + scalingFactor: float, + slots: int, + q: int, + p: int, + CKKS_M_FACTOR: int = 1, + ADD_NOISE: bool = False, +): + # Ported from notebook implementation + degree = len(plaintext) + q_half = q >> 1 + Nh = degree // 2 + gap = Nh // slots + powP_positive = pow(2, p) + powP = pow(2, -p) + + # Step 1: scale back to intermediate complex vector m(X) + sf_pre = (1.0 / scalingFactor) * powP_positive + + real_part_list = [] + imag_part_list = [] + for i in range(slots): + # real part from first half + r_val = plaintext[i] + if r_val > q_half: + real_part = -((q - r_val) * sf_pre) + else: + real_part = r_val * sf_pre + real_part_list.append(int(real_part)) + + # imag part from second half + im_val = plaintext[i + Nh] + if im_val > q_half: + imag_part = -((q - im_val) * sf_pre) + else: + imag_part = im_val * sf_pre + imag_part_list.append(int(imag_part)) + + curValues = [ + complex(real_part_list[i], imag_part_list[i]) for i in range(slots) + ] + + # Step 2: compute conjugate vector and estimated stddev (per OpenFHE logic) + def _conjugate(vec: List[complex]) -> List[complex]: + n = len(vec) + result = [0j] * n + for idx in range(1, n): + z = vec[n - idx] + result[idx] = complex(-z.imag, -z.real) + z0 = vec[0] + result[0] = complex(z0.real, -z0.imag) + return result + + def _stddev(vec: List[complex], conjugate: List[complex]) -> float: + import math as math + + s = len(vec) + if s == 1: + return vec[0].imag + dslots = s * 2 + complex_values = [vec[i] - conjugate[i] for i in range(s // 2 + 1)] + mean = 2 * sum((cv.real + cv.imag) for cv in complex_values[1 : (s // 2)]) + mean += complex_values[0].imag + mean += 2 * complex_values[s // 2].real + mean /= dslots - 1.0 + variance = 2 * sum( + ((cv.real - mean) ** 2 + (cv.imag - mean) ** 2) + for cv in complex_values[1 : (s // 2)] + ) + variance += (complex_values[0].imag - mean) ** 2 + variance += 2 * (complex_values[s // 2].real - mean) ** 2 + variance /= dslots - 2.0 + return 0.5 * math.sqrt(variance) + + conjugate = _conjugate(curValues) + + stddev_dbl = _stddev(curValues, conjugate) + logstd = math.log2(stddev_dbl) if stddev_dbl > 0 else float("-inf") + if stddev_dbl < 0.125 * math.sqrt(degree): + stddev_dbl = 0.125 * math.sqrt(degree) + if logstd > p - 5.0: + raise Exception( + "The decryption failed because the approximation error is too high." + " Check the parameters. " + ) + + stddev = math.sqrt(CKKS_M_FACTOR + 1) * stddev_dbl + scale = 0.5 * powP + + # For security, add tiny Gaussian noise scaled by 2^{-p}; it doesn't affect ~1e-3 accuracy + rng = random.Random() + + def _gauss(): + return rng.gauss(0.0, stddev) + + if ADD_NOISE: + curValues = [ + complex( + real_part_list[i] * scale + + conjugate[i].real * scale + + powP * _gauss(), + imag_part_list[i] * scale + + conjugate[i].imag * scale + + powP * _gauss(), + ) + for i in range(slots) + ] + else: + curValues = [ + complex( + real_part_list[i] * scale + conjugate[i].real * scale, + imag_part_list[i] * scale + conjugate[i].imag * scale, + ) + for i in range(slots) + ] + + # Step 3: Special forward FFT to slot values + FFTSpecial(curValues, degree * 2) + curValues = [complex(curValues[i].real, 0.0) for i in range(slots)] + # Return real parts only + return curValues + + +def _crt_combine_rns_plaintext( + rns_plaintext: List[List[int]], moduli: List[int] +) -> List[int]: + """Combine residues modulo pairwise-coprime moduli using the standard CRT formula. + + rns_plaintext is (degree, moduli). + """ + M = 1 + for q in moduli: + M *= q + Mi_list = [M // qi for qi in moduli] + inv_list = [pow(Mi, -1, qi) for Mi, qi in zip(Mi_list, moduli)] + + degree = len(rns_plaintext) + num_moduli = len(moduli) + + result = [] + for d in range(degree): + X = 0 + # rns_plaintext[d] is [r0, r1, ...] for degree d + residues = rns_plaintext[d] + + for i in range(num_moduli): + ri = residues[i] + qi = moduli[i] + Mi = Mi_list[i] + inv = inv_list[i] + X += (int(ri) % int(qi)) * int(Mi) * int(inv) + + result.append(X % M) + return result + + +######################## +# CKKS Context Class +######################## +class CKKSContext: + + def __init__(self, parameters: dict): + self.parameters = parameters + self.degree = parameters["degree"] + self.num_slots = parameters.get("num_slots", self.degree // 2) + self.scalingFactor = parameters.get("scalingFactor", 0.0) + self.output_scale = parameters.get("output_scale", 0.0) + self.q_towers = parameters["q_towers"] + self.p_towers = parameters.get("p_towers", []) + self.p = parameters.get("p", 0) + self.CKKS_M_FACTOR = parameters.get("CKKS_M_FACTOR", 1) + self.moduli = self.q_towers + + self.public_key = parameters.get("public_key", None) + self.secret_key = parameters.get("secret_key", None) + self.rotation_key = parameters.get("rotation_key", None) + self.evaluation_key = parameters.get("evaluation_key", None) + + def encrypt( + self, + plaintext: ct.Ciphertext, + v=None, + e=None, + ba=None, + e_ref=None, + v_ref=None, + ) -> ct.Ciphertext: + if self.public_key is None: + raise ValueError("Public key is not set in the context.") + + element = plaintext.get_element(0)[0] # Shape: (degree, num_moduli) + # element is already (degree, moduli), no transpose needed + encoded_values = element.tolist() + + c_poly = ckks_encrypt( + plaintext=encoded_values, + public_key=self.public_key, + q_towers=self.q_towers, + noise_scale_degree=self.parameters.get("noise_scale_degree", 1), + sigma=self.parameters.get("sigma", 3.190000057220458984375), + v=v, + e=e, + ba=ba, + e_ref=e_ref, + v_ref=v_ref, + ) + + shapes = { + "batch": 1, + "num_elements": 2, + "num_moduli": len(self.q_towers), + "degree": self.degree, + "precision": 32, + } + + res_ct = ct.Ciphertext(shapes, parameters={"moduli": self.q_towers}) + + # c0, c1 are (degree, moduli) naturally now + c0 = jnp.expand_dims( + jnp.array(c_poly[0], dtype=jnp.uint64), axis=0 + ) # (1, degree, moduli) + c1 = jnp.expand_dims(jnp.array(c_poly[1], dtype=jnp.uint64), axis=0) + + res_ct.set_element(0, c0) + res_ct.set_element(1, c1) + + return res_ct + + def decrypt(self, ciphertext: ct.Ciphertext) -> ct.Ciphertext: + if self.secret_key is None: + raise ValueError("Secret key is not set in the context.") + c_list = [ + ciphertext.ciphertext[0, 0].tolist(), # c0 + ciphertext.ciphertext[1, 0].tolist(), # c1 + ] + num_elems = ciphertext.num_elements + c_list = [] + for i in range(num_elems): + c_list.append(ciphertext.ciphertext[0, i].tolist()) # (degree, moduli) + num_moduli_ct = ciphertext.num_moduli + current_q_towers = self.q_towers[:num_moduli_ct] + decrypted_poly_rns = ckks_decrypt( + ciphertext=c_list, + private_key=self.secret_key, + q_towers=current_q_towers, + ) + shapes = { + "batch": 1, + "num_elements": 1, + "num_moduli": len(current_q_towers), + "degree": self.degree, + "precision": 32, + } + + res_ct = ct.Ciphertext(shapes, parameters={"moduli": current_q_towers}) + # decrypted_poly_rns is (degree, moduli) + elem = jnp.expand_dims( + jnp.array(decrypted_poly_rns, dtype=jnp.uint32), axis=0 + ) + res_ct.set_element(0, elem) + + return res_ct + + def encode(self, slots: List[complex], shift: int = 0) -> ct.Ciphertext: + m = self.degree * 2 + encoded_rns = ckks_encode( + slots=slots, + cycl_order=m, + q_towers=self.q_towers, + p_towers=self.p_towers, + scale=self.scalingFactor, + max_bits_in_word=self.parameters.get("max_bits_in_word", 61), + ) + + shapes = { + "batch": 1, + "num_elements": 1, + "num_moduli": len(self.q_towers), + "degree": self.degree, + "precision": 32, + } + + res_ct = ct.Ciphertext(shapes, parameters={"moduli": self.q_towers}) + # encoded_rns is (degree, moduli) + elem = jnp.expand_dims(jnp.array(encoded_rns, dtype=jnp.uint64), axis=0) + res_ct.set_element(0, elem) + + return res_ct + + def decode( + self, encoded_plaintext: ct.Ciphertext, is_ntt: bool = False + ) -> jnp.ndarray: + rns_poly = encoded_plaintext.ciphertext[0, 0].tolist() # (degree, moduli) + num_towers = len(encoded_plaintext.moduli) + + if is_ntt: + # rns_poly is (degree, moduli). + new_poly = [[0] * num_towers for _ in range(self.degree)] + for t_id in range(num_towers): + # get column + col = [rns_poly[d][t_id] for d in range(self.degree)] + + rev = util.bit_reverse_array(col) + intt_vals = util.intt_negacyclic_bit_reverse( + rev, + self.q_towers[t_id], + util.root_of_unity(2 * self.degree, self.q_towers[t_id]), + ) + + for d in range(self.degree): + new_poly[d][t_id] = intt_vals[d] + rns_poly = new_poly + + plain_combined = _crt_combine_rns_plaintext( + rns_poly, self.q_towers[:num_towers] + ) + + big_q = 1 + for qi in self.q_towers[:num_towers]: + big_q *= qi + + res = ckks_decode( + plaintext=plain_combined, + scalingFactor=self.output_scale, + slots=self.num_slots, + q=big_q, + p=self.p, + CKKS_M_FACTOR=self.CKKS_M_FACTOR, + ) + + return jnp.array(res) + + def multiply( + self, ciphertext1: ct.Ciphertext, ciphertext2: ct.Ciphertext + ) -> ct.Ciphertext: + raise NotImplementedError("Mul function not implemented") + + def rotate(self, ciphertext: ct.Ciphertext, shift: int) -> ct.Ciphertext: + raise NotImplementedError("Rotate function not implemented") + + def rescale(self, ciphertext: ct.Ciphertext) -> ct.Ciphertext: + raise NotImplementedError("Rescale function not implemented") + + def add( + self, ciphertext1: ct.Ciphertext, ciphertext2: ct.Ciphertext + ) -> ct.Ciphertext: + raise NotImplementedError("Add function not implemented") + + def sub( + self, ciphertext1: ct.Ciphertext, ciphertext2: ct.Ciphertext + ) -> ct.Ciphertext: + raise NotImplementedError("Sub function not implemented") diff --git a/jaxite/jaxite_word/ckks_ctx_test.py b/jaxite/jaxite_word/ckks_ctx_test.py new file mode 100644 index 0000000..8e50db2 --- /dev/null +++ b/jaxite/jaxite_word/ckks_ctx_test.py @@ -0,0 +1,192 @@ +from absl.testing import absltest +from absl.testing import parameterized +import jax +import jax.numpy as jnp +import jaxite.jaxite_word.key_gen as kg +import jaxite.jaxite_word.ckks_ctx as ckks_ctx +import numpy as np +import jaxite.jaxite_word.util as util +from jaxite.jaxite_word.hemul import HEMul +from jaxite.jaxite_word.herot import HERot +from jaxite.jaxite_word.ciphertext import Ciphertext + +jax.config.update('jax_enable_x64', True) +jax.config.update('jax_traceback_filtering', 'off') + +testing_params = [ + { + 'testcase_name': '0', + } +] + +@parameterized.named_parameters(testing_params) +class CKKSContextTest(parameterized.TestCase): + def setUp(self): + self.degree = 16 + self.num_slots = 8 + self.dnum = 3 + self.r, self.c = 4, 4 + + self.scalingFactor = 563019763943521 + self.q_towers = [1073742881, 1073742721, 1073741441, 1073741857, 524353] + self.p_towers = [1073740609, 1073739937, 1073739649] + self.q = 696985728458547852910430465530901300664961 + self.qt = 1329229981441028949792278227703286337 + self.p = 30 + self.CKKS_M_FACTOR = 1 + self.noise_scale_degree = 1 + self.max_bits_in_word = 61 + self.sigma = 3.190000057220458984375 + self.degree_layout = (self.r, self.c) + + key_pair = kg.gen_pke_pair(self.q_towers, self.p_towers, self.degree) + self.params = { + "degree": self.degree, + "num_slots": self.num_slots, + "scalingFactor": self.scalingFactor, + "output_scale": self.scalingFactor, + "q_towers": self.q_towers, + "p_towers": self.p_towers, + "p": self.p, + "CKKS_M_FACTOR": self.CKKS_M_FACTOR, + "max_bits_in_word": self.max_bits_in_word, + "noise_scale_degree": self.noise_scale_degree, + "public_key": key_pair["public_key"], + "secret_key": key_pair["secret_key"] + } + self.real_values_input_in1 = [ + complex(0.25, 0), complex(0.5, 0), complex(0.75, 0), complex(1, 0), + complex(2, 0), complex(3, 0), complex(4, 0), complex(5, 0), + ] + self.real_values_input_in2 = [ + complex(5, 0), complex(4, 0), complex(3, 0), complex(2, 0), + complex(1, 0), complex(0.75, 0), complex(0.5, 0), complex(0.25, 0), + ] + self.real_values_multiply_result = [ + complex(1.25, 0), complex(2, 0), complex(2.25, 0), complex(2, 0), + complex(2, 0), complex(2.25, 0), complex(2, 0), complex(1.25, 0), + ] + self.real_values_rotate_result = [ + complex(0.5, 0), complex(0.75, 0), complex(1, 0), complex(2, 0), + complex(3, 0), complex(4, 0), complex(5, 0), complex(0.25, 0), + ] + + # @absltest.skip("test a single experiment") + def test_ckks_context_encode_decode(self): + # Paramters Setup + ctx = ckks_ctx.CKKSContext(self.params) + # Step 1: Encoding + encoded_ct = ctx.encode(self.real_values_input_in1) + # Step 2: Decoding + decoded_values = ctx.decode(encoded_ct, is_ntt=True) + np.testing.assert_array_almost_equal(decoded_values, self.real_values_input_in1, decimal=3) + + # @absltest.skip("test a single experiment") + def test_ckks_context_encrypt_decrypt(self): + # Paramters Setup + ctx = ckks_ctx.CKKSContext(self.params) + # Step 1: Encoding + encoded_ct = ctx.encode(self.real_values_input_in1) + # Step 2: Encryption + encrypted_ct = ctx.encrypt(encoded_ct) + # Step 3: Decryption + decrypted_ct = ctx.decrypt(encrypted_ct) + # Step 4: Decoding + decoded_values = ctx.decode(decrypted_ct) + np.testing.assert_array_almost_equal(decoded_values, self.real_values_input_in1, decimal=3) + + # @absltest.skip("test a single experiment") + def test_ckks_context_encrypt_rotate_decrypt(self): + # Paramters Setup + rotate_idx = 1 + coef_map = util.precompute_auto_map(self.degree, kg.find_automorphism_index_2n_complex(rotate_idx, self.degree)) + # initialization + herot_obj = HERot(self.r, self.c, self.dnum, self.q_towers, self.p_towers) + ek = kg.gen_rotation_key(self.params["secret_key"], self.q_towers, self.p_towers, rot_index=rotate_idx, dnum=self.dnum, noise_std=self.sigma, noise_scale=self.noise_scale_degree) + herot_obj.setup_rotate(jnp.array(ek["a"], jnp.uint64).transpose(0,2,1).reshape(self.dnum,*self.degree_layout,-1), jnp.array(ek["b"], jnp.uint64).transpose(0,2,1).reshape(self.dnum,*self.degree_layout,-1), coef_map) + herot_obj.control_gen(batch=1, degree_layout=self.degree_layout) + ctx = ckks_ctx.CKKSContext(self.params) + # Step 1: Encoding + encoded_ct = ctx.encode(self.real_values_input_in1) + # Step 2: Encryption + encrypted_ct = ctx.encrypt(encoded_ct) + # Step 3: Rotate + result = herot_obj.rotate(encrypted_ct.ciphertext.reshape(1, 2, *self.degree_layout, len(self.q_towers))) + encrypted_ct.set_batch_ciphertext(result.reshape(1,2,self.degree,len(self.q_towers))) + # Step 4: Decryption + decrypted_ct = ctx.decrypt(encrypted_ct) + # Step 5: Decoding + decoded_values = ctx.decode(decrypted_ct) + np.testing.assert_array_almost_equal(decoded_values, self.real_values_rotate_result, decimal=3) + + # @absltest.skip("test a single experiment") + def test_ckks_context_encrypt_rescale_decrypt(self): + # Paramters Setup + batch, num_elements, degree, num_moduli = 1, 2, 16, 5 + ct_shapes = {'batch': 1, 'num_elements': 2, 'degree': 16, 'num_moduli': 5, 'precision': 32, 'degree_layout': self.degree_layout} + ct_params = {'moduli': self.q_towers, 'r': self.r, 'c': self.c} + params = self.params.copy() + params.update({ + "output_scale": (self.scalingFactor/self.q_towers[-1]), + }) + # Initialization + ctx = ckks_ctx.CKKSContext(params) + ct = Ciphertext(ct_shapes, ct_params) + ct.modulus_switch_control_gen(degree_layout=self.degree_layout) + + # Step 1: Encoding + encoded_ct = ctx.encode(self.real_values_input_in1) + # Step 2: Encryption + encrypted_ct = ctx.encrypt(encoded_ct) + # Step 3: Rescale + ct.set_batch_ciphertext(encrypted_ct.ciphertext.reshape(batch, num_elements, *self.degree_layout, num_moduli)) + ct.rescale() + # Step 4: Decryption + ct.ciphertext = ct.ciphertext.reshape(batch, num_elements, degree, num_moduli-1) + decrypted_ct = ctx.decrypt(ct) + # Step 5: Decoding + decoded_values = ctx.decode(decrypted_ct) + np.testing.assert_array_almost_equal(decoded_values, self.real_values_input_in1, decimal=3) + + # @absltest.skip("test a single experiment") + def test_ckks_context_encrypt_multiply_decrypt(self): + """ + Test the encryption, multiplication, and decryption of the CKKS context. + See hemul_test.py for the debugging version + """ + # Paramters Setup + r, c = 4, 4 + assert (r*c==self.degree) + batch, num_elements, dnum, num_eval_mult = 1, 2, self.dnum, 1 + self.ek = kg.gen_evaluation_key(self.params["secret_key"], q=self.q_towers, P=self.p_towers, noise_std=self.sigma, noise_scale=1, dnum=3) + eval_key_a, eval_key_b = jnp.array(self.ek["a"], dtype=jnp.uint32).transpose(0,2,1), jnp.array(self.ek["b"], dtype=jnp.uint32).transpose(0,2,1) + params = self.params.copy() + params.update({ + "evaluation_key": [eval_key_a, eval_key_b], + "output_scale": (self.scalingFactor/self.q_towers[-1])**2, + }) + # Initialization + ctx = ckks_ctx.CKKSContext(params) + he_mul = HEMul(batch, r, c, dnum, num_eval_mult, self.q_towers, self.p_towers) + he_mul.control_gen(degree_layout=self.degree_layout) + he_mul.setup_relinearization(eval_key_a, eval_key_b) + # Step 1: Encoding + encoded_ct1 = ctx.encode(self.real_values_input_in1) + encoded_ct2 = ctx.encode(self.real_values_input_in2) + # Step 2: Encryption + encrypted_ct1 = ctx.encrypt(encoded_ct1) + encrypted_ct2 = ctx.encrypt(encoded_ct2) + # Step 3: Homomorphic Multiplication + in_cts = jnp.concatenate([encrypted_ct1.ciphertext, encrypted_ct2.ciphertext], axis=1).reshape(batch, 2*num_elements, r, c, len(self.q_towers)).astype(jnp.uint32) + encrypted_result = he_mul.mul(in_cts) + # Step 4: Decryption + encrypted_ct1.drop_last_modulus() + encrypted_ct1.set_batch_ciphertext(encrypted_result.reshape(batch, 2, self.degree, len(self.q_towers)-1)) + decrypted_result = ctx.decrypt(encrypted_ct1) + # Step 5: Decoding + decoded_values = ctx.decode(decrypted_result, is_ntt=False) + np.testing.assert_array_almost_equal(decoded_values, self.real_values_multiply_result, decimal=3) + + +if __name__ == "__main__": + absltest.main() \ No newline at end of file diff --git a/jaxite/jaxite_word/finite_field.py b/jaxite/jaxite_word/finite_field.py new file mode 100644 index 0000000..f9de585 --- /dev/null +++ b/jaxite/jaxite_word/finite_field.py @@ -0,0 +1,458 @@ +"""Name: JAX Finite Field Context Integration + +Name Template: Context + - : (JAX accelerator backend). + - : [Optional] + - Empty: Standard scalar. + - RNS: Residue Number System. + - DRNS: Digitized RNS. + - RD: Radix Decomposition (Big Integer simulation). + - : [Optional] + - Montgomery: Montgomery reduction. + - Barrett: Barrett reduction. + - Shoup: Shoup reduction. + - : [Optional] + - MultipleModuli: vectorized over moduli. + - Lazy: Lazy reduction. + - Opt/Opt2: Optimization levels or specific variants. + - Context: Class suffix. + - : [Optional] Abstract base class. + +Explanation: This module adapts the generic finite field contexts for use with +JAX. It inherits from the base contexts in `finite_field_context.py` and adds +functionality to precompute and format parameters (such as modular inverses, RNS +matrices, and bit-shifted constants) into JAX-compatible arrays. It serves as +the configuration bridge between the mathematical specifications and the JAX +kernels. +""" + +import math +from typing import List, Union +import jax +import jax.numpy as jnp +import jaxite.jaxite_word.util as util + +jax.config.update("jax_enable_x64", True) + + +######################## +# Base Context Class +######################## +class FiniteFieldContextBase: + + def __init__(self, moduli: int): + self.moduli = moduli + + def to_computation_format(self, a: jnp.ndarray): + return a + + def to_original_format(self, a: jnp.ndarray): + return a + + def precompute_constant_operand(self, a: int): + # return [(a * (1 << self.w)) // m for m in self.moduli] # The algorithm being performed + return a + + def get_jax_parameters(self): + return {} + + def modular_reduction(self, a: jnp.ndarray) -> jnp.ndarray: + raise NotImplementedError("Subclasses must implement this method") + + def drop_last_modulus(self): + raise NotImplementedError("Subclasses must implement this method") + + +######################## +# Montgomery Modulus Reduction Context +######################## +class MontgomeryContext(FiniteFieldContextBase): + + def __init__(self, moduli: Union[List[int], int]): + super().__init__(moduli) + self.moduli = moduli + if type(self.moduli) is int: + self.moduli = [self.moduli] + self.w = 32 + self.w_inv = [util.modinv(1 << self.w, m) for m in self.moduli] + self.w_inv_reduction = jnp.array(self.w_inv, jnp.uint64) + + self.moduli_reduction = jnp.array(self.moduli, jnp.uint64) + + self.moduli_inv_32 = [util.modinv(m, 2**32) for m in self.moduli] + self.moduli_low16 = [m & 0xFFFF for m in self.moduli] + self.moduli_high16 = [m >> 16 for m in self.moduli] + + self.q = jnp.array(self.moduli, dtype=jnp.uint32) + self.q_low = jnp.array(self.moduli_low16, dtype=jnp.uint32) + self.q_high = jnp.array(self.moduli_high16, dtype=jnp.uint32) + self.q_inv_32 = jnp.array(self.moduli_inv_32, dtype=jnp.uint32) + + def to_computation_format(self, a: jnp.ndarray): + # return [(a * (1 << self.w)) % m for m in self.moduli] # The algorithm being performed + return ((a << self.w) % self.moduli_reduction).astype(jnp.uint32) + + def to_original_format(self, a: jnp.ndarray): + return (a * self.w_inv_reduction) % self.moduli_reduction + + def get_jax_parameters(self): + return { + "moduli": util.to_tuple(self.moduli), + "moduli_inv_32": util.to_tuple(self.moduli_inv_32), + "moduli_low": util.to_tuple(self.moduli_low16), + "moduli_high": util.to_tuple(self.moduli_high16), + } + + def modular_reduction(self, z: jnp.ndarray) -> jnp.ndarray: + """Montgomery reduction from u64 to u32 optimized version using only 32-bit operations + + Args: + z: - is u64 array of shape (B, M) - input + + parameters: + moduli: + - Tuple parameters constants + - is u32 array of shape (M) + - modular or moduli + moduli_low: + - Tuple parameters constants + - is u32 array of shape (M) + - low 16 bits of modular or moduli + moduli_high: + - Tuple parameters constants + - is u32 array of shape (M) + - high 16 bits of modular or moduli + moduli_inv_32: + - Tuple parameters constants + - is u32 array of shape (M) + - modular inverse of q mod 2^32 + Returns: + - is u32 array of shape (B, M) + - output + - reduced value + """ + + # Local constants + MASK32 = 0xFFFFFFFF + MASK16 = 0xFFFF + SHIFT16 = 16 + SHIFT32 = 32 + # Ensure dimensions for broadcasting + q = self.q + q_low = self.q_low + q_high = self.q_high + q_inv_32 = self.q_inv_32 + + # Computation + z_low = z.astype(jnp.uint32) + z_high = (z >> SHIFT32).astype(jnp.uint32) + t = (z_low * q_inv_32) & MASK32 + t_low = t & MASK16 + t_high = (t >> SHIFT16) & MASK16 + + prod_high = t_high * q_high # This contributes directly to upper 32 bits + prod_mid_high = t_high * q_low # Upper 16 bits go to upper 32 bits + prod_mid_low = t_low * q_high # Upper 16 bits go to upper 32 bits + prod_low = t_low * q_low # Upper 16 bits contribute to middle part + mid_low = ( + (prod_mid_high & MASK16) + + (prod_mid_low & MASK16) + + (prod_low >> SHIFT16) + ) + mid_high = ( + (prod_mid_high >> SHIFT16) + + (prod_mid_low >> SHIFT16) + + (mid_low >> SHIFT16) + ) + + # Final upper 32 bits + t_final = prod_high + mid_high + b = z_high + q - t_final + # Ensure strict reduction + # b = jnp.where(b >= q, b - q, b).astype(jnp.uint32) + return b.astype(jnp.uint32) + + def drop_last_modulus(self): + # self.moduli_reduction, self.moduli_inv_32, self.moduli_low16, self.moduli_high16 are not updated here. + # Because they are not used in the reduction. + # self.moduli = self.moduli[:-1] + self.moduli_reduction = self.moduli_reduction[:-1] + self.q = self.q[:-1] + self.q_low = self.q_low[:-1] + self.q_high = self.q_high[:-1] + self.q_inv_32 = self.q_inv_32[:-1] + + +######################## +# Barrett Modulus Reduction Context +######################## +class BarrettContext(FiniteFieldContextBase): + + def __init__(self, moduli: Union[List[int], int]): + super().__init__(moduli) + self.moduli = moduli + if type(self.moduli) is int: + self.moduli = [self.moduli] + + self.barrett_s = [2 * math.ceil(math.log2(m)) for m in self.moduli] + self.barrett_w = [min(s, 32) for s in self.barrett_s] + self.barrett_s_w = [s - w for s, w in zip(self.barrett_s, self.barrett_w)] + self.barrett_m = [ + math.floor(2**s / m) for s, m in zip(self.barrett_s, self.moduli) + ] + # used for run-time reduction + self.m = jnp.array(self.barrett_m, dtype=jnp.uint64) + self.moduli_reduction = jnp.array(self.moduli, dtype=jnp.uint64) + self.w = jnp.array(self.barrett_w, dtype=jnp.uint16) + self.s_w = jnp.array(self.barrett_s_w, dtype=jnp.uint16) + + def to_computation_format(self, a): + return a + + def to_original_format(self, a): + return a + + def get_jax_parameters(self): + return { + "barrett_m": util.to_tuple(self.barrett_m), + "moduli": util.to_tuple(self.moduli), + "barrett_w": util.to_tuple(self.barrett_w), + "barrett_s_w": util.to_tuple(self.barrett_s_w), + } + + def modular_reduction(self, z: jnp.ndarray) -> jnp.ndarray: + """Vectorized implementation of the Barrett reduction. + + Works for modulus `q` less than 31 bits. + + This implementation sets the internal shift width `w` to `min(s, 32)` so it + works with small modulus `moduli < 2^16`. + + Args: + z: The input value. + moduli: The RNS moduli. + s_w: The bit width of moduli. + w: The internal shift width. + m: The precomputed value for Barrett reduction. + + Returns: + The result of the Barrett reduction. + """ + m = self.m + moduli = self.moduli_reduction + w = self.w + s_w = self.s_w + + z1 = z & 0xFFFFFFFF + z2 = z >> w + t = ((z1 * m) >> w) + (z2 * m) + t = t >> s_w + z = z - t * moduli + pred = z >= moduli + return jnp.where(pred, z - moduli, z).astype(jnp.uint32) + # return (z - moduli * pred).astype(jnp.uint32) + + def modular_reduction_single_modulus( + self, z: jnp.ndarray, modulus_index: int + ) -> jnp.ndarray: + """Vectorized implementation of the Barrett reduction. + + Works for modulus `q` less than 31 bits. + + This implementation sets the internal shift width `w` to `min(s, 32)` so it + works with small modulus `moduli < 2^16`. + + Args: + z: The input value. + moduli: The RNS moduli. + s_w: The bit width of moduli. + w: The internal shift width. + m: The precomputed value for Barrett reduction. + + Returns: + The result of the Barrett reduction. + """ + m = self.m[modulus_index] + moduli = self.moduli_reduction[modulus_index] + w = self.w[modulus_index] + s_w = self.s_w[modulus_index] + + z1 = z.astype(jnp.uint32) + z2 = (z >> w).astype(jnp.uint32) + t = ((z1 * m) >> w) + (z2 * m) + t = t >> s_w + z = z - t * moduli + pred = z >= moduli + return jnp.where(pred, z - moduli, z).astype(jnp.uint32) + # return (z - moduli * pred).astype(jnp.uint32) + + def drop_last_modulus(self): + # self.barrett_s, self.barrett_w, self.barrett_s_w, self.barrett_m are not updated here. + # Because they are not used in the reduction. + # self.moduli = self.moduli[:-1] + self.m = self.m[:-1] + self.moduli_reduction = self.moduli_reduction[:-1] + self.w = self.w[:-1] + self.s_w = self.s_w[:-1] + + +######################## +# Shoup Modulus Reduction Context +######################## +class ShoupContext(FiniteFieldContextBase): + + def __init__(self, moduli: Union[List[int], int]): + super().__init__(moduli) + self.moduli = moduli + if type(self.moduli) is int: + self.moduli = [self.moduli] + self.moduli_reduction = jnp.array(self.moduli, jnp.uint64) + self.q = jnp.array(self.moduli, dtype=jnp.uint64) + self.w = 32 + + def to_computation_format(self, a: jnp.ndarray): + # return [(a % m) for m in self.moduli] # The algorithm being performed + return (a % self.moduli_reduction).astype(jnp.uint32) + + def to_original_format(self, a: jnp.ndarray): + return (a % self.moduli_reduction).astype(jnp.uint32) + + def precompute_constant_operand(self, a: int): + # return [(a * (1 << self.w)) // m for m in self.moduli] # The algorithm being performed + return (a << self.w) // self.moduli_reduction + + def get_jax_parameters(self): + return { + "moduli": util.to_tuple(self.moduli), + } + + def modular_reduction(self, z: jnp.ndarray, z_s: jnp.ndarray) -> jnp.ndarray: + """Shoup's reduction from u64 to u32 + + Args: + z: - is u64 array of shape (B, M) - input - z = a * b + z_s: - is u64 array of shape (B, M) - input - z_s = a * b_s - b_s is b + in Shoup's precomputation format + + parameters: + moduli: + - Tuple parameters constants + - is u32 array of shape (M) + - modular or moduli + Returns: + - is u32 array of shape (B, M) + - output + - reduced value + """ + t = z_s >> 32 + u = z - t * self.q + # Ensure strict reduction + # u = jnp.where(u >= self.q, u - self.q, u).astype(jnp.uint32) + return u.astype(jnp.uint32) + + def drop_last_modulus(self): + # self.moduli is not updated here. + # Because it is used in the precomputation. + # self.moduli = self.moduli[:-1] + self.moduli_reduction = self.moduli_reduction[:-1] + self.q = self.q[:-1] + + +######################## +# BAT Lazy Reduction Context +######################## +class BATLazyContext(FiniteFieldContextBase): + + def __init__(self, moduli: Union[List[int], int]): + super().__init__(moduli) + self.moduli = moduli + if type(self.moduli) is int: + self.moduli = [self.moduli] + + # L=4 bytes (for 32-bit modulus) + self.L = 4 + + # Precompute R matrix for each modulus + # R_i,j corresponds to the j-th byte of (256^(i+L) mod q) + # Dimensions: (M, 4, 4) because we have 4 high-bytes (B) and 4 result-bytes (L) + moduli_arr = jnp.array(self.moduli, dtype=jnp.uint64) + + # 1. Vectorize 'i' loop (bytes 4, 5, 6, 7): Compute r_val = 256^(i+4) % m + shifts_i = jnp.arange(4, 8, dtype=jnp.uint64) * 8 + # Broadcast shape: (1, 4) vs (M, 1) -> (M, 4) + r_vals = (jnp.array(1, dtype=jnp.uint64) << shifts_i[None, :]) % moduli_arr[ + :, None + ] + + # 2. Vectorize 'j' loop: Split r_vals into 4 bytes (little endian) + shifts_j = jnp.arange(4, dtype=jnp.uint64) * 8 + # Result: (M, 4, 4) + self.R = ((r_vals[:, :, None] >> shifts_j[None, None, :]) & 0xFF).astype( + jnp.uint8 + ) + self.moduli_reduction = jnp.array(self.moduli, jnp.uint64) + + def to_computation_format(self, a: jnp.ndarray): + return a + + def to_original_format(self, a: jnp.ndarray): + return (a % self.moduli_reduction).astype(jnp.uint32) + + def get_jax_parameters(self): + return {"moduli": util.to_tuple(self.moduli), "R": self.R} + + def modular_reduction(self, z: jnp.ndarray) -> jnp.ndarray: + """BAT Lazy Reduction from u64 to u32 + + Implements: result = B @ R + A + where z is split into Lower Part A (bytes 0-3) and Higher Part B (bytes + 4-7). + + Args: + z: u64 array of shape (..., M) if RNS, or arbitrary shape if single + modulus. + + Returns: + u32 array (Partially reduced) + """ + # 1. Extract bytes from z using bitcast + # This treats the 64-bit integers as vectors of 8 bytes (Little Endian) + z_bytes = jax.lax.bitcast_convert_type( + z.astype(jnp.uint64), new_dtype=jnp.uint8 + ) + + # 2. Split into Lower Part A (bytes 0-3) and Higher Part B (bytes 4-7) + # A_bytes, B_bytes each have shape (..., 4) where ... matches z's shape. + A_bytes, B_bytes = jnp.split(z_bytes, 2, axis=-1) + + # 3. Perform Matrix Multiplication: LazyReductionResult = B @ R + A + # Logic: + # - If we have a single modulus (M=1), we assume ALL input elements should be + # reduced by this same modulus, regardless of input shape dimensions. + # - If we have multiple moduli (M>1), we assume the LAST dimension of input + # corresponds to the moduli dimension M. + + # Unified implementation for both Single Modulus and RNS + # Use einsum for automatic broadcasting and hardware-efficient 8-bit matmul + # - Single Modulus: B (..., 4) @ R_squeezed (4, 4) -> (..., 4) + # - RNS: B (..., M, 4) @ R (M, 4, 4) -> (..., M, 4) + # Note: jnp.squeeze ensures R is (4, 4) when M=1, matching the "Single Modulus" lack of M-dim. + # We perform the input in 8-bit and accumulate in 32-bit for TPU efficiency. + matmul_res = jnp.einsum( + "...i,...ij->...j", B_bytes, self.R, preferred_element_type=jnp.uint32 + ) + + # 4. Add Lower Part A + result_bytes = matmul_res + A_bytes + + # 5. Reconstruct integer + shift_factors = jnp.array([0, 8, 16, 24], dtype=jnp.uint32) + result = jnp.sum( + result_bytes.astype(jnp.uint64) << shift_factors, axis=(-1,) + ) + + return result + + def drop_last_modulus(self): + self.moduli = self.moduli[:-1] + self.R = self.R[:-1] diff --git a/jaxite/jaxite_word/finite_field_test.py b/jaxite/jaxite_word/finite_field_test.py new file mode 100644 index 0000000..6d19778 --- /dev/null +++ b/jaxite/jaxite_word/finite_field_test.py @@ -0,0 +1,94 @@ +"""Finite Field Test Suite + +Test cases: +- Montgomery Single Modulus Context +- Barrett Single Modulus Context +- Shoup Single Modulus Context + +Terminology: +- Modulus: Single form of modulus. + +Usage: +- Specify the overall modulus for the context, and corresponding parameter +required for the modular reduction. +- Then feed "modulus" and "parameters" to the context constructor. +- Then context->modular_reduction(input) to get the reduced result for certain +inputs. +""" + +import warnings +from absl.testing import absltest +from absl.testing import parameterized +import jaxite.jaxite_word.finite_field as ff_context +import jax +import jax.numpy as jnp +import numpy as np +import jaxite.jaxite_word.util as util + +testing_params = [{"testcase_name": "0"}] + + +@parameterized.named_parameters(testing_params) +class FiniteFieldTest(parameterized.TestCase): + + def setUp(self): + # Setup random input data and their modmul reference results. + self.modulus = util.find_moduli_ntt(1, 31, 16)[0] + self.random_key = jax.random.key(0) + self.a = jax.random.randint( + self.random_key, (1,), 0, self.modulus - 1, dtype=jnp.int32 + ) + self.b = jax.random.randint( + self.random_key, (1,), 0, self.modulus - 1, dtype=jnp.int32 + ) + self.ab = self.a.astype(jnp.uint64) * self.b.astype(jnp.uint64) + self.ab_modq = (self.ab % self.modulus).astype(jnp.uint32) + + # @absltest.skip("test single implementation") + def test_montgomery_single_moduli_context(self): + context = ff_context.MontgomeryContext(self.modulus) + a_mont = context.to_computation_format(self.a[0].astype(jnp.uint64)) + b_mont = context.to_computation_format(self.b[0].astype(jnp.uint64)) + ab_mont = a_mont.astype(jnp.uint64) * b_mont.astype(jnp.uint64) + result_mont = context.modular_reduction(ab_mont) + result = context.to_original_format(result_mont.astype(jnp.uint64)) + np.testing.assert_array_equal(result[0], self.ab_modq) + + # @absltest.skip("test single implementation") + def test_barrett_single_moduli_context(self): + context = ff_context.BarrettContext(self.modulus) + ab = self.a.astype(jnp.uint64) * self.b.astype(jnp.uint64) + result = context.modular_reduction(ab) + np.testing.assert_array_equal(result[0], self.ab_modq) + + # @absltest.skip("test single implementation") + def test_shoup_single_moduli_context(self): + context = ff_context.ShoupContext(self.modulus) + warnings.warn( + "Shoup's reduction requires one operand to be known ahead of time." + ) + a_precomputed = context.precompute_constant_operand( + self.a.astype(jnp.uint64) + ) + ab = self.a.astype(jnp.uint64) * self.b.astype(jnp.uint64) + ab_shoup = a_precomputed * self.b.astype(jnp.uint64) + result_shoup = context.modular_reduction(ab, ab_shoup) + result = context.to_original_format(result_shoup.astype(jnp.uint64)) + np.testing.assert_array_equal(result[0], self.ab_modq) + + # @absltest.skip("test single implementation") + def test_bat_lazy_single_moduli_context(self): + context = ff_context.BATLazyContext(self.modulus) + warnings.warn( + "BATLazy's reduction requires one operand to be known ahead of time." + ) + result = context.modular_reduction(self.ab) + # Check mathematical correctness: result % modulus == expected % modulus + # Note: Lazy reduction guarantees result is congruent to ab mod q, but not necessarily strictly < q. + # We verify the congruence property. + res_mod = context.to_original_format(result.astype(jnp.uint64)) + np.testing.assert_array_equal(res_mod[0], self.ab_modq) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxite/jaxite_word/hemul.py b/jaxite/jaxite_word/hemul.py new file mode 100644 index 0000000..146f6ca --- /dev/null +++ b/jaxite/jaxite_word/hemul.py @@ -0,0 +1,360 @@ +from jaxite.jaxite_word.bconv import BConvBarrett +from jaxite.jaxite_word.ciphertext import Ciphertext +from jaxite.jaxite_word.finite_field import BarrettContext +import jax +import jax.numpy as jnp +import jaxite.jaxite_word.util as util + +jax.config.update('jax_enable_x64', True) + + +class HEMul: + + def __init__( + self, batch, r, c, dnum, num_eval_mult, original_moduli, extend_moduli + ): + self.batch = batch + self.r = r + self.c = c + self.dnum = dnum + self.num_eval_mult = num_eval_mult + self.original_moduli = original_moduli + self.drop_last_extend_moduli = original_moduli[:-1] + extend_moduli + self.drop_last_moduli = original_moduli[:-1] + self.extend_moduli = extend_moduli + self.last_tower_moduli = original_moduli[-1] + + def control_gen(self, degree_layout=None, perf_test=False): + if degree_layout is not None: + self.degree_layout = degree_layout + else: + self.degree_layout = (self.r, self.c) + + self.perf_test = perf_test + # ========================================================================== + # 0. Configuration Derivation + # ========================================================================== + ring_dim = self.r * self.c + original_moduli = self.original_moduli + extend_moduli = self.extend_moduli + + sizeQ_in = len(original_moduli) + overall_sizeQ_in = sizeQ_in + overall_sizeP_in = len(extend_moduli) + self.overall_sizeQ_in, self.overall_sizeP_in = ( + overall_sizeQ_in, + overall_sizeP_in, + ) + overall_sizeQ_in_no_last = sizeQ_in - 1 + + # ========================================================================== + # 1. Drop Last & Extend Moduli Arrays + # ========================================================================== + last_tower_psi = util.root_of_unity(2 * ring_dim, self.last_tower_moduli) + power_of_last_tower_psi = jnp.array( + [ + pow(last_tower_psi, i, self.last_tower_moduli) + for i in range(ring_dim) + ], + jnp.uint64, + ) + last_tower_inv_psi = pow(last_tower_psi, -1, self.last_tower_moduli) + self.power_of_last_tower_inv_psi = jnp.array( + [ + pow(last_tower_inv_psi, i, self.last_tower_moduli) + for i in range(ring_dim) + ], + jnp.uint64, + ).reshape(1, 1, -1) + self.moduli_threshold = jnp.array( + (self.last_tower_moduli + 1) // 2, dtype=jnp.uint32 + ) + + # ========================================================================== + # 2. Instantiate Ciphertext Object + # Note: We must use precision=32 as per user request, matching the data type typically used here (uint32 for main storage) + # ========================================================================== + ct_refer_shapes = { + 'batch': self.batch, + 'num_elements': 4, + 'degree': ring_dim, + 'precision': 32, + 'num_moduli': overall_sizeQ_in, + 'degree_layout': self.degree_layout, + } + self.ct_refer = Ciphertext( + ct_refer_shapes, + parameters={ + 'moduli': self.original_moduli, + 'finite_field_context': BarrettContext, + 'r': self.r, + 'c': self.c, + }, + ) + self.ct_refer.modulus_switch_control_gen(degree_layout=self.degree_layout) + + ct_shapes = { + 'batch': self.batch, + 'num_elements': 4, + 'degree': self.r * self.c, + 'num_moduli': overall_sizeQ_in - 1, + 'precision': 32, + 'degree_layout': self.degree_layout, + } + self.ct_obj = Ciphertext( + ct_shapes, + parameters={ + 'moduli': self.drop_last_moduli, + 'finite_field_context': BarrettContext, + 'r': self.r, + 'c': self.c, + }, + ) + + ct_last_limb_shapes = { + 'batch': self.batch, + 'num_elements': 4, + 'degree': ring_dim, + 'precision': 32, + 'num_moduli': 1, + 'degree_layout': self.degree_layout, + } + self.ct_last_limb = Ciphertext( + ct_last_limb_shapes, + parameters={ + 'moduli': [self.last_tower_moduli], + 'finite_field_context': BarrettContext, + 'r': self.r, + 'c': self.c, + }, + ) + + idx_cur_last_tower = ( + overall_sizeQ_in - self.num_eval_mult + ) # Calculate here for control_gen + + ct_extend_shapes = { + 'batch': self.batch, + 'num_elements': 2, + 'degree': ring_dim, + 'precision': 32, + 'num_moduli': overall_sizeP_in, + 'degree_layout': self.degree_layout, + } + self.ct_extend = Ciphertext( + ct_extend_shapes, parameters={'moduli': self.extend_moduli} + ) + + # ========================================================================== + # 3. Parameter Generation + # ========================================================================== + if perf_test: + power_of_psi = util.random_parameters( + (len(self.drop_last_extend_moduli), ring_dim), + self.drop_last_extend_moduli, + dtype=jnp.uint64, + ).T + power_of_inv_psi_approx_down = util.random_parameters( + (len(self.drop_last_extend_moduli), ring_dim), + self.drop_last_extend_moduli, + dtype=jnp.uint64, + ).T + else: + extend_psi = [ + util.root_of_unity(2 * ring_dim, q) + for q in self.drop_last_extend_moduli + ] + power_of_psi = jnp.array( + [ + [ + pow(extend_psi[idx], i, self.drop_last_extend_moduli[idx]) + for i in range(ring_dim) + ] + for idx in range(len(self.drop_last_extend_moduli)) + ], + jnp.uint64, + ).T + extend_inv_psi = [ + pow(psi, -1, q) + for (q, psi) in zip(self.drop_last_extend_moduli, extend_psi) + ] + power_of_inv_psi_approx_down = jnp.array( + [ + [ + pow(extend_inv_psi[idx], i, self.drop_last_extend_moduli[idx]) + for i in range(ring_dim) + ] + for idx in range(len(self.drop_last_extend_moduli)) + ], + jnp.uint64, + ).T + self.power_of_psi = power_of_psi[:, :idx_cur_last_tower] + self.power_of_inv_psi_approx_down = power_of_inv_psi_approx_down[ + :, -overall_sizeP_in: + ] + self.bconv = BConvBarrett(self.drop_last_extend_moduli) + control_indices_list = [] + rotate_indices = list(range(idx_cur_last_tower)) + extend_indices = list( + range(idx_cur_last_tower, idx_cur_last_tower + overall_sizeP_in) + ) + control_indices_list.append((extend_indices, rotate_indices)) + self.bconv.control_gen(control_indices_list, perf_test=perf_test) + + current_moduli = self.extend_moduli + target_moduli = [ + item for item in self.drop_last_moduli if item not in current_moduli + ] + P = 1 + for moduli in current_moduli: + P *= moduli + PInvModq_approx_down = [util.modinv(P, q) for q in target_moduli] + self.PInvModq = jnp.asarray(PInvModq_approx_down, dtype=jnp.uint32).reshape( + idx_cur_last_tower + ) + + gammas, betas = util.gamma_beta_calculation( + self.original_moduli, perf_test=perf_test + ) + gammas_gen_power_of_psi = self.power_of_psi.T + self.gammas_power_of_psi_no_last = ( + gammas[:, None].astype(jnp.uint64) + * gammas_gen_power_of_psi.astype(jnp.uint64) + ) % jnp.array(self.drop_last_moduli, jnp.uint64)[:, None] + + # ========================================================================== + # 4. Parameter Reshape + # ========================================================================== + self.power_of_psi = self.power_of_psi.reshape( + *self.degree_layout, idx_cur_last_tower + ) + self.power_of_inv_psi_approx_down = ( + self.power_of_inv_psi_approx_down.reshape( + *self.degree_layout, overall_sizeP_in + ) + ) + self.gammas_power_of_psi_no_last = self.gammas_power_of_psi_no_last.T + self.betas = jnp.array(betas, jnp.uint64).reshape(1, 1, 1, -1) + self.drop_last_moduli_arr = jnp.array( + self.drop_last_moduli, jnp.uint32 + ).reshape(1, 1, 1, -1) + + self.drop_last_extend_moduli_arr = jnp.array( + self.drop_last_extend_moduli, jnp.uint32 + ) + self.q_correction = self.drop_last_extend_moduli_arr[ + :idx_cur_last_tower + ].reshape(1, 1, 1, -1) + self.post_rescale_shape = ( + self.batch, + 4, + *self.degree_layout, + len(self.drop_last_moduli), + ) + # self.post_rescale_shape = (self.batch, 4, ring_dim, len(self.drop_last_moduli)) + + def setup_relinearization(self, evalkey_a_vector, evalkey_b_vector): + # Reshape keys to align with 4D structure for broadcasting + # evalkey_a_vector: (K, D, M) -> (K, 1, D, M) + self.evalkey_a_vector = evalkey_a_vector.astype(jnp.uint64)[ + :, None, *self.degree_layout + ] + self.evalkey_b_vector = evalkey_b_vector.astype(jnp.uint64)[ + :, None, *self.degree_layout + ] + + # Delegate KS prep to Ciphertext + # We use None for perf_test if not set in control_gen used self.perf_test (set in control_gen) + # self.perf_test should be set in control_gen + self.ct_obj.key_switch_control_gen( + self.extend_moduli, + self.dnum, + evalkey_a_vector, + evalkey_b_vector, + perf_test=self.perf_test, + selected_moduli=self.drop_last_moduli, + degree_layout=self.degree_layout, + ) + + def mul(self, in_ciphertexts): + # ---------- Step 1: Modulus Reduction (drop last tower) per-ciphertext ---------- + # Unpack static params (single tuple) in the same fixed order used in generation + num_eval_mult = self.num_eval_mult + overall_sizeQ_in, overall_sizeP_in = ( + self.overall_sizeQ_in, + self.overall_sizeP_in, + ) + idx_cur_last_tower = overall_sizeQ_in - num_eval_mult + + # ---------- Step 1: Rescale ---------- + self.ct_refer.set_batch_ciphertext(in_ciphertexts) + temp_res = self.ct_refer.rescale() + self.ct_obj.set_batch_ciphertext(temp_res.reshape(self.post_rescale_shape)) + + # ---------- Step 2: Homomorphic multiplication core (inline) ---------- + ct_post_mult, last_ele_post_mult = self.ct_obj.ciphertext_mult() + + # ---------- Step 3 & 4: Key switch using Ciphertext methods ---------- + self.ct_obj.set_batch_ciphertext(last_ele_post_mult) + self.ct_obj.key_switch() + keyswitch_core_res = self.ct_obj.get_batch_ciphertext() + + # ---------- Step 5: Approximate modulus down (via Ciphertext) ---------- + result_ciphertext_list = [] + overall_moduli_jax = jnp.asarray(self.drop_last_moduli, dtype=jnp.uint32) + approx_down_in_jax = jnp.asarray(keyswitch_core_res, dtype=jnp.uint32) + + self.ct_extend.set_batch_ciphertext( + approx_down_in_jax[ + ..., idx_cur_last_tower : (idx_cur_last_tower + overall_sizeP_in) + ] + ) + self.ct_extend.to_coeffs_form() + self.ct_extend.modmul( + jnp.array(self.power_of_inv_psi_approx_down, jnp.uint64) + ) + reduced_approx_down = self.ct_extend.get_batch_ciphertext() + ct_new_basis_coef = self.bconv.basis_change_bat( + reduced_approx_down, control_index=0 + ).astype(jnp.uint64) + + for element_index in range(ct_new_basis_coef.shape[1]): + tower_new_basis_coef = ct_new_basis_coef[ + :, element_index : element_index + 1, ... + ] + self.ct_obj.set_batch_ciphertext(tower_new_basis_coef) + self.ct_obj.modmul(self.power_of_psi) + tower_new_basis_coef_scaled_muli_moduli_modq = ( + self.ct_obj.get_batch_ciphertext() + ) + + self.ct_obj.set_batch_ciphertext( + tower_new_basis_coef_scaled_muli_moduli_modq + ) + self.ct_obj.to_ntt_form() + tower_new_basis_jax = self.ct_obj.get_batch_ciphertext() + + current_approx_down_in = approx_down_in_jax[ + :, element_index : element_index + 1, ..., :idx_cur_last_tower + ] + sub_result = jnp.where( + current_approx_down_in < tower_new_basis_jax, + current_approx_down_in + overall_moduli_jax - tower_new_basis_jax, + current_approx_down_in - tower_new_basis_jax, + ) + + self.ct_obj.set_batch_ciphertext(sub_result) + self.ct_obj.modmul(self.PInvModq) + reduced_elem_modq = self.ct_obj.get_batch_ciphertext() + + result_ciphertext_list.append(reduced_elem_modq) + + approx_mod_down_custom = jnp.concatenate(result_ciphertext_list, axis=1) + + # ---------- Step 6: Add and return ---------- + result = ct_post_mult + approx_mod_down_custom + val = jnp.where( + result >= self.q_correction, result - self.q_correction, result + ) + + return val diff --git a/jaxite/jaxite_word/herot.py b/jaxite/jaxite_word/herot.py new file mode 100644 index 0000000..651bf0c --- /dev/null +++ b/jaxite/jaxite_word/herot.py @@ -0,0 +1,416 @@ +import jaxite.jaxite_word.bconv as bconv +from jaxite.jaxite_word.ciphertext import Ciphertext +import jax +import jax.numpy as jnp +import numpy as np +import jaxite.jaxite_word.util as util + +# enable 64-bit computation in jax +jax.config.update('jax_enable_x64', True) + + +def is_power_of_two(x: int) -> bool: + """Returns True if x is a power of two.""" + return x > 0 and (x & (x - 1)) == 0 + + +def mat_1d_shuffle_to_2d(coefMap, r, c): + """Memory Aligned Transformation + + Perform 1D data shuffing of O(N) in matrix fashion with O(sqrt(N)) memory + cost. + Precomputes the 2D indices. + coefMap is the 1D permuted indices. + + Factor coefMap (length r*c) into row_perm (len r) and col_perm (len c) such + that: + coefMap.reshape(r,c)[i,j] == row_perm[i]*c + col_perm[j] + + Returns: + row_perm: int32[r] + col_perm: int32[c] + """ + if coefMap.ndim != 1 or coefMap.shape[0] != r * c: + raise ValueError( + f'coefMap must be 1D of length r*c. Got shape {coefMap.shape},' + f' r*c={r*c}.' + ) + if r <= 0 or c <= 0: + raise ValueError('r and c must be positive.') + # (Recommended, since your degree is power-of-2) + if not (is_power_of_two(r) and is_power_of_two(c)): + raise ValueError('For your setting, r and c should be powers of two.') + coef2d = coefMap.reshape(r, c) + + # If coef2d[i,j] = row_perm[i]*c + col_perm[j], then: + row_perm = (coef2d[:, 0] // c).astype(jnp.int32) + col_perm = (coef2d[0, :] % c).astype(jnp.int32) + + coef2d_h = np.asarray(jax.device_get(coef2d)) + row_h = np.asarray(jax.device_get(row_perm)) + col_h = np.asarray(jax.device_get(col_perm)) + expected_h = row_h[:, None] * c + col_h[None, :] + if not np.array_equal(coef2d_h, expected_h): + raise ValueError( + 'coefMap is NOT decomposable into a single global row permutation +' + f' column permutation for r={r}, c={c}. (i.e., not P_row ⊗ P_col). Pick' + ' a different (r,c) factorization, or keep using jnp.take(a, coefMap,' + ' axis=2).' + ) + + return row_perm, col_perm + + +class HERot: + + def __init__(self, r, c, dnum, rotate_in_ciphertext_moduli, extend_moduli): + self.r = r + self.c = c + self.dnum = dnum + self.rotate_in_ciphertext_moduli = rotate_in_ciphertext_moduli + self.extend_moduli = extend_moduli + self.overall_moduli_init = ( + self.rotate_in_ciphertext_moduli + self.extend_moduli + ) + # Instantiating a single class of BConvBarrett + self.bconv = bconv.BConvBarrett(self.overall_moduli_init) + self.evalkey_a_vector = None + self.evalkey_b_vector = None + + def control_gen(self, batch=1, degree_layout=None, perf_test=False): + if degree_layout is None: + degree_layout = (self.r * self.c,) + # degree_layout = (self.r, self.c) + self.degree_layout = degree_layout + sizeQl_in = len(self.rotate_in_ciphertext_moduli) + sizeQlP_in = len(self.extend_moduli) + sizeQl_in + alpha = (sizeQl_in + self.dnum - 1) // self.dnum + ring_dim = self.r * self.c + overall_moduli = self.rotate_in_ciphertext_moduli + self.extend_moduli + self.perf_test = perf_test + + # External Input + if perf_test: + power_of_psi = util.random_parameters( + (*degree_layout, len(overall_moduli)), + overall_moduli, + dtype=jnp.uint64, + ) + power_of_inv_psi = util.random_parameters( + (*degree_layout, len(overall_moduli)), + overall_moduli, + dtype=jnp.uint64, + ) + else: + original_psi = [ + util.root_of_unity(2 * ring_dim, q) for q in overall_moduli + ] + power_of_psi = jnp.array( + [ + [ + pow(original_psi[idx], i, overall_moduli[idx]) + for i in range(ring_dim) + ] + for idx in range(len(overall_moduli)) + ], + jnp.uint64, + ).T.reshape(*degree_layout, len(overall_moduli)) + inv_psi = [ + pow(psi, -1, q) for (q, psi) in zip(overall_moduli, original_psi) + ] + power_of_inv_psi = jnp.array( + [ + [ + pow(inv_psi[idx], i, overall_moduli[idx]) + for i in range(ring_dim) + ] + for idx in range(len(overall_moduli)) + ], + jnp.uint64, + ).T.reshape(*degree_layout, len(overall_moduli)) + + ## parameters generation for approximation mod down + current_moduli = self.extend_moduli + target_moduli = [ + item for item in overall_moduli if item not in current_moduli + ] + + P = 1 + for moduli in current_moduli: + P *= moduli + PInvModq_approx_down = [util.modinv(P, q) for q in target_moduli] + + self.overall_moduli = overall_moduli + self.PInvModq = jnp.asarray(PInvModq_approx_down, dtype=jnp.uint32).reshape( + sizeQl_in + ) + self.power_of_psi = jnp.array(power_of_psi, jnp.uint64) + self.power_of_inv_psi = power_of_inv_psi[..., :sizeQl_in] + self.power_of_inv_psi_approx_down = power_of_inv_psi[ + ..., sizeQl_in:sizeQlP_in + ] + self.sizeQlP, self.sizeQl = sizeQlP_in, sizeQl_in + self.batch = batch + + # BConv control generation + original_moduli_extract_index = [] + for i in range(sizeQl_in): + if i % alpha == 0: + original_moduli_extract_index.append([i]) + else: + original_moduli_extract_index[-1].append(i) + numPartQl = (sizeQl_in + alpha - 1) // alpha + + control_indices_list = [] + + self.select_tower_index = [] + self.non_select_tower_index = [] + + # 1. Basis change for key switch decomposition + for part in range(numPartQl): + sel_index = original_moduli_extract_index[part] + non_sel_index = [ + i for i in range(len(overall_moduli)) if i not in sel_index + ] + self.select_tower_index.append(sel_index) + self.non_select_tower_index.append(non_sel_index) + control_indices_list.append((sel_index, non_sel_index)) + + # Precompute restore indices for scatter optimization + self.restore_indices = [] + for part in range(numPartQl): + sel_index = self.select_tower_index[part] + non_sel_index = self.non_select_tower_index[part] + + # The resulting array after concatenation will have elements in this order: + # [elements corresponding to sel_index, elements corresponding to non_sel_index] + concat_order = sel_index + non_sel_index + + # We want to map this back to the natural order [0, 1, 2, ..., len(overall_moduli)-1] + # restore_index[i] should be the position of 'i' in concat_order + restore_index = [0] * len(concat_order) + for pos, val in enumerate(concat_order): + restore_index[val] = pos + + self.restore_indices.append(jnp.array(restore_index, dtype=jnp.uint16)) + + # 2. Basis change for approximation modulus switch + rotate_indices = list(range(sizeQl_in)) + extend_indices = list(range(sizeQl_in, sizeQlP_in)) + control_indices_list.append((extend_indices, rotate_indices)) + + self.bconv.control_gen(control_indices_list, perf_test=perf_test) + + # Pre-allocate Ciphertext objects to amortize NTT context creation + ct_in_shapes = { + 'batch': batch, + 'num_elements': 1, + 'degree': ring_dim, + 'precision': 32, + 'num_moduli': sizeQl_in, + 'degree_layout': degree_layout, + } + self.ct_in = Ciphertext( + ct_in_shapes, + parameters={'moduli': overall_moduli[:sizeQl_in], 'BAT_lazy': False}, + ) + + self.ct_parts = [] + for part in range(numPartQl): + _target_indices_list = self.non_select_tower_index[part] + _target_moduli_list = [overall_moduli[i] for i in _target_indices_list] + _num_moduli_part = len(_target_moduli_list) + ct_part_shapes = { + 'batch': batch, + 'num_elements': 1, + 'degree': ring_dim, + 'precision': 32, + 'num_moduli': _num_moduli_part, + 'degree_layout': degree_layout, + } + self.ct_parts.append( + Ciphertext( + ct_part_shapes, + parameters={'moduli': _target_moduli_list, 'BAT_lazy': False}, + ) + ) + + ct_full_shapes = { + 'batch': batch, + 'num_elements': 1, + 'degree': ring_dim, + 'precision': 32, + 'num_moduli': sizeQlP_in, + 'degree_layout': degree_layout, + } + self.ct_full = Ciphertext( + ct_full_shapes, + parameters={'moduli': self.overall_moduli, 'BAT_lazy': False}, + ) + + ct_approx_shapes = { + 'batch': batch, + 'num_elements': 1, + 'degree': ring_dim, + 'precision': 32, + 'num_moduli': sizeQlP_in - sizeQl_in, + 'degree_layout': degree_layout, + } + self.ct_approx = Ciphertext( + ct_approx_shapes, + parameters={'moduli': self.extend_moduli, 'BAT_lazy': False}, + ) + + def setup_rotate(self, evalkey_a_vector, evalkey_b_vector, coefMap): + self.evalkey_a_vector = evalkey_a_vector.astype(jnp.uint64) + self.evalkey_b_vector = evalkey_b_vector.astype(jnp.uint64) + self.coefMap = jnp.asarray(coefMap, dtype=jnp.int32) + + def rotate( + self, + in_ciphertexts, + ): + assert self.evalkey_a_vector is not None + assert self.evalkey_b_vector is not None + # ------------------------------- + # Inline of key_switch_precompute_core_28bit for in_ciphertexts[-1] + # ------------------------------- + batch, r, c, dnum = self.batch, self.r, self.c, self.dnum + sizeQlP, sizeQl = self.sizeQlP, self.sizeQl + select_tower_index, non_select_tower_index = ( + self.select_tower_index, + self.non_select_tower_index, + ) + power_of_inv_psi = self.power_of_inv_psi + power_of_psi = self.power_of_psi + power_of_inv_psi_approx_down = self.power_of_inv_psi_approx_down + sizeP = sizeQlP - sizeQl + ring_dim = r * c + + overall_moduli_jax = jnp.asarray(self.overall_moduli, dtype=jnp.uint32) + original_moduli = jnp.take(overall_moduli_jax, jnp.arange(sizeQl), axis=0) + in_tower = in_ciphertexts[:, -1:, ..., :sizeQl] + + # ---------- Step 1: Keyswitch ---------- + self.ct_in.ciphertext = in_tower + # self.ct_in.key_switch() # Rotate implements inline keyswitch for better performance. + self.ct_in.to_coeffs_form() + self.ct_in.modmul(jnp.array(power_of_inv_psi, jnp.uint64)) + partCtCloneCoef = self.ct_in.ciphertext + partsCtExt = [] + res0 = None + res1 = None + for part in range(dnum): + select_tower_index_arr = jnp.array(select_tower_index[part], jnp.uint16) + non_select_tower_index_arr = jnp.array( + non_select_tower_index[part], jnp.uint16 + ) + power_of_psi_arr_part = jnp.take( + power_of_psi, non_select_tower_index_arr, axis=-1 + ) + + input_for_bconv = jnp.take( + partCtCloneCoef, select_tower_index_arr, axis=-1 + ) + partCtCloneEval = self.bconv.basis_change_bat( + input_for_bconv, control_index=part + ).astype(jnp.uint64) + + ct_part = self.ct_parts[part] + ct_part.ciphertext = partCtCloneEval + ct_part.modmul(power_of_psi_arr_part) + partCtCloneEval_scaled_multi_moduli = ct_part.ciphertext + + ct_part.ciphertext = partCtCloneEval_scaled_multi_moduli + ct_part.to_ntt_form() + partsCtCompl_multi_moduli = ct_part.ciphertext + + partsCtExt_cur_part = jnp.concatenate( + [ + jnp.take(in_tower, select_tower_index_arr, axis=-1), + partsCtCompl_multi_moduli, + ], + axis=-1, + ) + partsCtExt_cur_part = jnp.take( + partsCtExt_cur_part, self.restore_indices[part], axis=-1 + ) + if res0 is None: + res0 = ( + partsCtExt_cur_part * self.evalkey_b_vector[part][None, None, ...] + ) + res1 = ( + partsCtExt_cur_part * self.evalkey_a_vector[part][None, None, ...] + ) + else: + res0 += ( + partsCtExt_cur_part * self.evalkey_b_vector[part][None, None, ...] + ) + res1 += ( + partsCtExt_cur_part * self.evalkey_a_vector[part][None, None, ...] + ) + + keyswitch_core_res = jnp.concatenate([res0, res1], axis=1) + self.ct_full.ciphertext = keyswitch_core_res + self.ct_full.mod_reduce() + keyswitch_core_res = self.ct_full.get_batch_ciphertext() + # keyswitch_core_res = self.ct_in.get_batch_ciphertext() # If use function + + # ---------- Step 3: Approximation modulus switch (inline) ---------- + result_ciphertext_list = [] + overall_moduli_jax = jnp.asarray(original_moduli, dtype=jnp.uint32) + approx_down_in_jax = jnp.asarray(keyswitch_core_res, dtype=jnp.uint32) + + for element_index in range(keyswitch_core_res.shape[1]): + _current_slice = approx_down_in_jax[ + :, element_index : element_index + 1, ..., sizeQl : (sizeQl + sizeP) + ] + + self.ct_approx.ciphertext = _current_slice + self.ct_approx.to_coeffs_form() + self.ct_approx.modmul(jnp.array(power_of_inv_psi_approx_down, jnp.uint64)) + reduced_approx_down = self.ct_approx.ciphertext + + tower_new_basis_coef = self.bconv.basis_change_bat( + reduced_approx_down, control_index=dnum + ).astype(jnp.uint64) + + self.ct_in.ciphertext = tower_new_basis_coef + self.ct_in.modmul(power_of_psi[..., :sizeQl]) + tower_new_basis_coef_scaled_muli_moduli_modq = self.ct_in.ciphertext + + self.ct_in.ciphertext = tower_new_basis_coef_scaled_muli_moduli_modq + self.ct_in.to_ntt_form() + tower_new_basis_jax = self.ct_in.ciphertext + + current_approx_down_in = approx_down_in_jax[ + :, element_index : element_index + 1, ..., :sizeQl + ] + sub_result = jnp.where( + current_approx_down_in < tower_new_basis_jax, + current_approx_down_in + overall_moduli_jax - tower_new_basis_jax, + current_approx_down_in - tower_new_basis_jax, + ) + + self.ct_in.ciphertext = sub_result + self.ct_in.modmul(self.PInvModq) + reduced_elem_modq = self.ct_in.ciphertext + + result_ciphertext_list.append(reduced_elem_modq) + + mod_down_res = jnp.concatenate(result_ciphertext_list, axis=1) + + # ---------- Step 4: Add and return ---------- + base0 = jnp.asarray(in_ciphertexts[:, 0:1], dtype=jnp.uint32) + q_b = overall_moduli_jax + base0 = base0 + mod_down_res[:, 0:1] + base0_modq = jnp.where(base0 >= q_b, base0 - q_b, base0) + ks_results = mod_down_res.at[:, 0:1].set(base0_modq) + + # ---------- Step 5: Coefficient map ---------- + # coef_idx = self.coefMap + ks_results = jnp.take( + ks_results.reshape(-1, 2, ring_dim, sizeQl), self.coefMap, axis=2 + ) + + return ks_results diff --git a/jaxite/jaxite_word/key_gen.py b/jaxite/jaxite_word/key_gen.py new file mode 100644 index 0000000..89641d8 --- /dev/null +++ b/jaxite/jaxite_word/key_gen.py @@ -0,0 +1,426 @@ +import math +import secrets +from typing import Any, Dict, List, Tuple +import jax.numpy as jnp +import jaxite.jaxite_word.rns as rns +import jaxite.jaxite_word.util as util + +RnsPolynomial = rns.RnsPolynomial +RnsParams = rns.RnsParams +gen_rns_polynomial = rns.gen_rns_polynomial + +MAX_INT_64 = 9223372036854775295 +sigma = 3.190000057220458984375 + + +######################## +# Key Generation +######################## +def gen_ternary_uniform_polynomial( + degree: int, moduli: list[int] +) -> RnsPolynomial: + """Generate a uniformly random RNS polynomial in R_Q = Z[X] / (Q, X^N+1).""" + coeffs_q = [secrets.randbelow(2) for _ in range(degree)] + return gen_rns_polynomial(degree, coeffs_q, moduli) + + +def gen_uniform_polynomial(degree: int, moduli: list[int]) -> RnsPolynomial: + """Generate a uniformly random RNS polynomial in R_Q = Z[X] / (Q, X^N+1).""" + coeffs_q = [] + for q in moduli: + coeffs_q.append([secrets.randbelow(q) for _ in range(degree)]) + return RnsPolynomial(degree, moduli, coeffs_q, is_ntt=False) + + +def gen_gaussian_polynomial( + degree: int, moduli: list[int], sigma: float +) -> RnsPolynomial: + """Generate a random Gaussian polynomial in R_Q = Z[X] / (Q, X^N+1). + + Note: Each coefficient is independently sampled from a rounded Gaussian + distribution with parameter sigma. + + Args: + degree: The degree N of the ring R_Q. + moduli: The list of prime moduli q_i's whose product is Q. + sigma: The standard deviation of the Gaussian distribution. + + Returns: + An RNS polynomial with coefficients sampled from a Gaussian distribution. + """ + prng = secrets.SystemRandom() + coeffs = [round(prng.normalvariate(0, sigma)) for _ in range(degree)] + return gen_rns_polynomial(degree, coeffs, moduli) + + +def _validate_private_key(private_key: List[List[int]]) -> Tuple[int, int]: + + if not isinstance(private_key, list) or not private_key: + raise ValueError("private_key must be a non-empty 2D list") + num_elements = len(private_key) + degree = None + for row in private_key: + if not isinstance(row, list) or not row: + raise ValueError( + "private_key must be a non-empty 2D list with non-empty rows" + ) + if degree is None: + degree = len(row) + elif len(row) != degree: + raise ValueError( + "All rows in private_key must have the same length (degree)" + ) + for coeff in row: + if not isinstance(coeff, int): + raise TypeError("All coefficients in private_key must be integers") + return num_elements, degree # type: ignore[return-value] + + +def _mod_q(x: int, q: int) -> int: + r = x % q + return r if r >= 0 else r + q + + +def modulus_switch( + coefficient: int | List[int], + cur_moduli: int = 524353, + target_moduli: int = 1152921504606845473, +) -> int | List[int]: + """Switch coefficients from cur_moduli to target_moduli using centered representation. + + If c < (cur_moduli + 1) // 2, it is unchanged. Otherwise, it is treated as + a negative representative (c - cur_moduli) and lifted to Z_{target_moduli} + by computing target_moduli + (c - cur_moduli). + + Args: + coefficient: An integer or a list of integers in Z_{cur_moduli}. + cur_moduli: Current modulus (default: 524353). + target_moduli: Target modulus (default: 1152921504606845473). + + Returns: + The switched coefficient with the same container type as input + (int for int input, List[int] for list input). + """ + threshold = (cur_moduli + 1) // 2 + + def _switch_one(value: int) -> int: + v = int(value) + return v if v < threshold else target_moduli + v - cur_moduli + + if isinstance(coefficient, list): + return [_switch_one(c) for c in coefficient] + else: + return _switch_one(int(coefficient)) + + +def gen_evaluation_key( + private_key: List[List[int]], + q: int | List[int], + P: int | List[int] = 1, + noise_std: float = 3.190000057220458984375, + noise_scale: int = 1, + a: List[List[List[int]]] | None = None, + e: List[List[List[int]]] | None = None, + dnum: int = 3, +): + + num_elements, degree = _validate_private_key(private_key) + + q_list: List[int] = list(q) + p_list: List[int] = list(P) if isinstance(P, list) else [int(P)] + if len(private_key) != len(q_list): + raise ValueError("private_key must have one row per q modulus (len(q))") + size_q = len(q_list) + # size_p = len(p_list) + # size_qp = size_q + size_p + + sk_q = private_key + # sOld is s^2 mod q for Q part + sOld = [ + [_mod_q(sk_q[i][j] * sk_q[i][j], q_list[i]) for j in range(degree)] + for i in range(size_q) + ] + + return key_switch_gen( + sOld=sOld, + sNew=sk_q, + q_list=q_list, + p_list=p_list, + noise_std=noise_std, + noise_scale=noise_scale, + a=a, + e=e, + dnum=dnum, + ) + + +def key_switch_gen( + sOld: List[List[int]], + sNew: List[List[int]], + q_list: List[int], + p_list: List[int], + noise_std: float = 3.190000057220458984375, + noise_scale: int = 1, + a: List[List[List[int]]] | None = None, + e: List[List[List[int]]] | None = None, + dnum: int = 3, +) -> Dict[str, Any]: + """Construct evaluation key parts given secret forms sOld (Q) and sNew (time domain). + + Args: + sOld: Secret squared residues modulo each q in Q, shape [|Q|][N]. + sNew: Secret in time domain for Q basis, shape [|Q|][N]. + q_list: List of moduli forming Q. + p_list: List of moduli forming P. + noise_std: Standard deviation for error sampling. + noise_scale: Integer multiplier applied to error samples. + a: Optional pre-specified a samples per part and limb. + e: Optional pre-specified e samples per part and limb. + dnum: Number of partitions over Q for HYBRID scheme. + + Returns: + Dict with keys: "a", "b", "modulus", "P", and "shape". + """ + + sOut = [] + degree = len(sNew[0]) + for limb, q in zip(sNew, q_list): + psi = util.root_of_unity(2 * degree, q) + test_in = util.bit_reverse_array(limb) + temp = util.intt_negacyclic_bit_reverse(test_in, q, psi) + sOut.append(temp) + + size_q = len(q_list) + size_p = len(p_list) + size_qp = size_q + size_p + + degree = len(sOut[0]) + s_qp = [] + for q_p in p_list: + temp = [sOut[0][j] for j in range(degree)] + temp = modulus_switch(temp, q_list[0], q_p) + s_qp.append(temp) + + s_p_eva = [] + for limb, p in zip(s_qp, p_list): + psi = util.root_of_unity(2 * degree, p) + temp = util.ntt_negacyclic_bit_reverse(limb, p, psi) + s_p_eva.append(util.bit_reverse_array(temp)) + + s_qp_eva = sNew + s_p_eva + P_prod = 1 + for p in p_list: + P_prod *= p + P_mod_q = [P_prod % qi for qi in q_list] + + num_per_part_q = (size_q + dnum - 1) // dnum + num_part_q = math.ceil(size_q / num_per_part_q) + + a_parts: List[List[List[int]]] = [] + b_parts: List[List[List[int]]] = [] + moduli_list = q_list + p_list + for part in range(num_part_q): + start_idx = num_per_part_q * part + end_idx = min(size_q, start_idx + num_per_part_q) + a_sample = gen_ternary_uniform_polynomial(degree, moduli_list[:size_qp]) + a_rows = [ + util.bit_reverse_array( + util.ntt_negacyclic_bit_reverse( + a_sample.coeffs[i], + modulus_i, + util.root_of_unity(int(degree << 1), modulus_i), + ) + ) + for i, modulus_i in enumerate(moduli_list[:size_qp]) + ] + e_sample = gen_gaussian_polynomial( + degree, moduli_list[:size_qp], sigma=noise_std + ) + e_rows = [ + util.bit_reverse_array( + util.ntt_negacyclic_bit_reverse( + e_sample.coeffs[i], + modulus_i, + util.root_of_unity(int(degree << 1), modulus_i), + ) + ) + for i, modulus_i in enumerate(moduli_list[:size_qp]) + ] + b_rows: List[List[int]] = [] + for i in range(size_qp): + modulus_i = q_list[i] if i < size_q else p_list[i - size_q] + a_row = a[part][i] if a is not None else a_rows[i] + e_row = e[part][i] if e is not None else e_rows[i] + s_row = s_qp_eva[i] + if i < start_idx or i >= end_idx: + b_row = [ + _mod_q( + _mod_q(-a_row[j] * s_row[j], modulus_i) + + _mod_q(noise_scale * e_row[j], modulus_i), + modulus_i, + ) + for j in range(degree) + ] + else: + b_row = [ + _mod_q( + ( + _mod_q(-a_row[j] * s_row[j], modulus_i) + + _mod_q(P_mod_q[i] * sOld[i][j], modulus_i) + + _mod_q(noise_scale * e_row[j], modulus_i) + ), + modulus_i, + ) + for j in range(degree) + ] + b_rows.append(b_row) + a_parts.append(a_rows) + b_parts.append(b_rows) + + return { + "a": a_parts if a is None else a, + "b": b_parts, + "modulus": {"Q": q_list, "P": p_list}, + "P": p_list, + "shape": (num_part_q, size_qp, degree), + } + + +def find_automorphism_index_2n_complex(i: int, m: int): + if i == 0: + return 1 + elif i == m - 1: + return i + if not util.is_power_of_two(m): + raise ValueError("m should be a power of two.") + + g0 = pow(5, -1, m) if i < 0 else 5 # modular inverse of 5 mod m when i < 0 + g = g0 + i_unsigned = abs(i) + for _ in range(1, i_unsigned): + g = (g * g0) & (m - 1) # modulo m since m is a power of two + return g + + +def precompute_rotation_key_map(n: int, k: int) -> list[int]: + m = n << 1 + logm = int(round(math.log2(m))) + logn = int(round(math.log2(n))) + precomp = [0] * n + for j in range(n): + j_tmp = (j << 1) + 1 + mul = j_tmp * k + idx = (mul - ((mul >> logm) << logm)) >> 1 + jrev = util.bit_reverse(j, logn) + idxrev = util.bit_reverse(idx, logn) + precomp[jrev] = idxrev + return precomp + + +def gen_rotation_key( + sk, + original_moduli, + extend_moduli, + rot_index, + dnum=3, + noise_std=3.190000057220458984375, + noise_scale=1, + a=None, + e=None, +): + n = len(sk[0]) + result = find_automorphism_index_2n_complex(rot_index, 2 * n) + key_map_idx = util.modinv(result, 2 * n) + target_order = precompute_rotation_key_map(n, key_map_idx) + # transform sk based on the order. + sk_rot = jnp.array(sk)[:, jnp.array(target_order)].tolist() + ek = key_switch_gen( + sk, + sNew=sk_rot, + q_list=original_moduli, + p_list=extend_moduli, + noise_std=noise_std, + noise_scale=noise_scale, + a=a, + e=e, + dnum=dnum, + ) + return ek + + +def gen_pke_pair( + q_towers: List[int], + p_towers: List[int], + degree: int, + noise_std: float = 3.190000057220458984375, + noise_scale: int = 1, + a_ref=None, + s_ref=None, + e_ref=None, +) -> Dict[str, Any]: + """Generate a PKE pair. + + Args: + q_towers: List of moduli forming Q. + p_towers: List of moduli forming P. + degree: The degree N of the ring R_Q. + noise_std: Standard deviation for error sampling. + noise_scale: Integer multiplier applied to error samples. + + Returns: + Dict with keys: "public_key", "secret_key". + """ + moduli_list = q_towers + p_towers + s = gen_ternary_uniform_polynomial(degree, moduli_list) + s = [ + util.bit_reverse_array( + util.ntt_negacyclic_bit_reverse( + s.coeffs[i], + modulus_i, + util.root_of_unity(int(degree << 1), modulus_i), + ) + ) + for i, modulus_i in enumerate(moduli_list) + ] + s = s_ref if s_ref is not None else s + a = gen_uniform_polynomial(degree, moduli_list) + a = [ + util.bit_reverse_array( + util.ntt_negacyclic_bit_reverse( + a.coeffs[i], + modulus_i, + util.root_of_unity(int(degree << 1), modulus_i), + ) + ) + for i, modulus_i in enumerate(moduli_list) + ] + a = a_ref if a_ref is not None else a + e = gen_gaussian_polynomial(degree, moduli_list, sigma=noise_std) + e = [ + util.bit_reverse_array( + util.ntt_negacyclic_bit_reverse( + e.coeffs[i], + modulus_i, + util.root_of_unity(int(degree << 1), modulus_i), + ) + ) + for i, modulus_i in enumerate(moduli_list) + ] + e = e_ref if e_ref is not None else e + + b = [ + [ + _mod_q( + _mod_q(e[i][j] * noise_scale, moduli_list[i]) + - _mod_q(a[i][j] * s[i][j], moduli_list[i]), + moduli_list[i], + ) + for j in range(degree) + ] + for i in range(len(moduli_list)) + ] + s = s[: len(q_towers)] + return { + "public_key": [b, a], + "secret_key": s, + } diff --git a/jaxite/jaxite_word/ntt.py b/jaxite/jaxite_word/ntt.py deleted file mode 100644 index 987fc84..0000000 --- a/jaxite/jaxite_word/ntt.py +++ /dev/null @@ -1,1243 +0,0 @@ -"""JAX implementation of Gentalman Sande NTT.""" - -import concurrent.futures -import functools - -import jax -import jax.numpy as jnp -import numpy as np - - -######################## -# Offline Functions -######################## -def chunk_decomposition(x, chunkwidth=8): - """Precision-level data conversion. - - Args: - x: The input data. - chunkwidth: The chunkwidth. - - Returns: - The decomposed data. - """ - dtype = jnp.uint8 - if chunkwidth == 16: - dtype = jnp.uint16 - elif chunkwidth == 32: - dtype = jnp.uint32 - - elements = [] - mask = (1 << chunkwidth) - 1 - # Mask to extract the lower bits (e.g., 32 bits -> 0xFFFFFFFF) - - # Extract each element from the integer - while x > 0: - elements.append(x & mask) # Extract the lower bits - x >>= chunkwidth # Shift to remove the extracted bits - - # Convert the list to a JAX array - return jnp.array(elements, dtype=dtype) - - -def rechunkify(arr_a, chunkwidth): - """Rechunkify the input array back to the desired precision. - - Args: - arr_a: The input array. - chunkwidth: The chunkwidth. - - Returns: - The rechunkified array. - """ - dtype_double_length = jnp.uint16 - if chunkwidth == 16: - dtype_double_length = jnp.uint32 - elif chunkwidth == 32: - dtype_double_length = jnp.uint64 - - # assume the precision of partial sum is <= 2 * precision of input value. - bitmask = (1 << chunkwidth) - 1 - - # # Data Type Illustration - # We need to accumulate these data - # - Could directly perform bitwidth concatenation to generate the final - # result if there is no overlap across each partial sum - # LSB MSB - # |-----------------> bit - # | a0 - # | ==-- - # | a1 - # | ==-- - # | a2 - # | ==-- - # | a3 - # v ==-- - - # whole a0 a1 a2 a3 - # precision ==-- ==-- ==-- ==-- - - # lower a0 a1 a2 a3 - # half == == == == - - # upper a0 a1 a2 a3 - # half -- -- -- -- - - # # Chunk Splitting -> upper and lower half - # padding to align - # lower a0 a1 a2 a3 0 - # half == == == == == - - # upper 0 a0 a1 a2 a3 - # half -- -- -- -- -- - - # # Vectorized Accumulation - # lower a0 a1 a2 a3 0 - # half == == == == == - # + + + + + - # upper 0 a0 a1 a2 a3 - # half -- -- -- -- -- - - # -> result b0 b1 b2 b3 b4 - # -- 1/0-- 1/0-- 1/0-- -- - # (b1 and b4 does not have carry for sure.) - - # Each result chunk might have one more bit for carry. - # Perform one more chunk decomposition and accumulation. - - # # One more Chunk Splitting for partial sum "b" to take care of carry bit. - # carry b0 b1 b2 b3 b4 - # 0 1/0 1/0 1/0 0 - - # carry b4 b0 b1 b2 b3 - # right 0 0 1/0 1/0 1/0 - # shift - # (wrap around rotation, b4 is always zero so will be correct) - # + + + + + - # lower b0 b1 b2 b3 b4 - # half -- -- -- -- -- - # = = = = = - # c0 c1 c2 c3 c4 - # -> -- -- -- -- 1/0-- - # (! c4 might overflow, need one more chunk decomposition) - - # c0 c1 c2 c3 c4 c5 - # -> -- -- -- -- -- 1/0 - - # Chunk Splitting -> upper and lower half - arr_a_lower_half = jnp.bitwise_and(arr_a, bitmask) - arr_a_upper_half = jnp.right_shift(arr_a, chunkwidth) - - # Padding to align - arr_a_lower_half_pad = jnp.pad(arr_a_lower_half, (0, 1)) - arr_a_upper_half_pad = jnp.pad(arr_a_upper_half, (1, 0)) - - # Vectorized Accumulation - arr_b = jnp.add( - arr_a_lower_half_pad.astype(dtype_double_length), - arr_a_upper_half_pad.astype(dtype_double_length), - ) - - while not jnp.all(arr_b <= bitmask): - arr_b_lower_half = jnp.bitwise_and(arr_b, bitmask) - arr_b_carry = jnp.right_shift(arr_b, chunkwidth) - arr_b = jnp.roll(arr_b_carry, 1, axis=-1) - arr_b = jnp.add(arr_b_lower_half, arr_b) - - # Vectorized Accumulation - arr_c = arr_b - - # break top chunk into upper and lower to avoid overflow. - arr_c = jnp.pad(arr_c, (0, 1)) - arr_c = arr_c.at[-1].set(jnp.right_shift(arr_c[-2], chunkwidth)) - arr_c = arr_c.at[-2].set(jnp.bitwise_and(arr_c[-2], bitmask)) - - return arr_c - - -def smul_as_dense_gemv_bat( - x, total_in_precision=32, chunkwidth=8, q=4294967291 -): - """This is the implementation of Basis Align Transformation (BAT). - - Major improvement to achieve dense matrix. - - Args: - x: The input matrix. - total_in_precision: The total precision of the input matrix. - chunkwidth: The chunkwidth. - q: The modulus. - - Returns: - The dense matrix. - - Steps: - 1. break x into [x0, x1, x2, x3] - 2. reform [x0, x1, x2, x3] into the output - [ - x0 r00 r00 r00 # 2^0 - x1 x0+r01 r01 r01 # 2^8 - x2 x1+r02 x0+r02 r02 # 2^16 - x3 x2+r03 x1+r03 x0+r03 # 2^24 - ] - - Note: prefilled value are just examples. - We pick largest 2^32-1 to make sure that intermediate results might - exceed 32-bit precision range, and expose potential precision overflow. - """ - dtype_double_length = jnp.uint16 - chunk_upper_bound = (1 << 8) - 1 - if chunkwidth == 16: - dtype_double_length = jnp.uint32 - chunk_upper_bound = (1 << 16) - 1 - elif chunkwidth == 32: - dtype_double_length = jnp.uint64 - chunk_upper_bound = (1 << 32) - 1 - - total_chunk_num = int(jnp.ceil(total_in_precision / chunkwidth)) - - # the number of row in left matrix - height = total_chunk_num + total_chunk_num - 1 - x_dtype = chunk_decomposition(x, chunkwidth) - x_dense = jnp.zeros( - (total_chunk_num + total_chunk_num - 1, total_chunk_num), - dtype=dtype_double_length, - ) - for j in range(total_chunk_num): - upper_idx = min(total_chunk_num, x_dtype.shape[0] + j) - x_dense = x_dense.at[j:upper_idx, j].set(x_dtype[: upper_idx - j]) - - # [ - # x0 # 2^0 - # x1 x0 # 2^8 - # x2 x1 x0 # 2^16 - # x3 x2 x1 x0 # 2^24 - # ----------- - # x3 x2 x1 # 2^32 iterate all elements in the bottom block - # x3 x2 # 2^40 - # x3 # 2^48 - # ] - - # Perform BAT to the following block of the matrix - # j 2 1 0 - # x3 x2 x1 # 2^32 i=0 - # x3 x2 # 2^40 i=1 - # x3 # 2^48 i=2 - - for i in range(x_dtype.shape[0] - 1): - for j in range(x_dtype.shape[0] - 1 - i): - basis = (total_chunk_num + i) * chunkwidth - projected_data = (int(x_dtype[i + j + 1]) << basis) % q - r = chunk_decomposition(projected_data, chunkwidth).astype( - dtype_double_length - ) - - x_dense = x_dense.at[: len(r), total_chunk_num - 1 - j].set( - jnp.add(r, x_dense[: len(r), total_chunk_num - 1 - j]) - ) - - while not jnp.all(x_dense <= chunk_upper_bound) or not jnp.all( - x_dense[total_chunk_num:, :] == 0 - ): - for j in range(total_chunk_num - 1): - # Iterate over different columns - if not jnp.all(x_dense[:, total_chunk_num - 1 - j] <= chunk_upper_bound): - arr_new_chunkified = rechunkify( - x_dense[:, total_chunk_num - 1 - j], chunkwidth - ) - x_dense = x_dense.at[:, total_chunk_num - 1 - j].set( - arr_new_chunkified[:height] - ) - - # j 2 1 0 - # x3 x2 x1 # 2^32 i=0 - # x3 x2 # 2^40 i=1 - # x3 # 2^48 i=2 - for i in range(x_dtype.shape[0] - 1): - for j in range(x_dtype.shape[0] - 1 - i): - data = x_dense[total_chunk_num + i, total_chunk_num - 1 - j] - if data > 0: - basis = (total_chunk_num + i) * chunkwidth - projected_data = (int(data) << basis) % q - r = chunk_decomposition(projected_data, chunkwidth).astype( - dtype_double_length - ) - - x_dense = x_dense.at[: len(r), total_chunk_num - 1 - j].set( - jnp.add(r, x_dense[: len(r), total_chunk_num - 1 - j]) - ) - x_dense = x_dense.at[ - total_chunk_num + i, total_chunk_num - 1 - j - ].set(0) - return x_dense[:total_chunk_num, :].astype(jnp.uint8) - - -def smul_as_dense_gemv_bat_jax(x, q=4294967291): - """This is the implementation of bat; Major improvement to achieve dense matrix. - - Args: - x: The input matrix. - q: The modulus. - - Returns: - The dense matrix. - - Steps: - 1. break x into [x0, x1, x2, x3] - 2. reform [x0, x1, x2, x3] into the output - [ - x0 r00 r00 r00 # 2^0 - x1 x0+r01 r01 r01 # 2^8 - x2 x1+r02 x0+r02 r02 # 2^16 - x3 x2+r03 x1+r03 x0+r03 # 2^24 - ] - """ - assert x.dtype == jnp.uint32 - chunkwidth = 8 - chunk_upper_bound = (1 << 8) - 1 - total_chunk_num = 4 - - # the number of row in left matrix - height = 7 - x_dtype = jax.lax.bitcast_convert_type(x, new_dtype=jnp.uint8) - x_dense = jnp.array( - [ - [x_dtype[0], 0, 0, 0], - [x_dtype[1], x_dtype[0], 0, 0], - [x_dtype[2], x_dtype[1], x_dtype[0], 0], - [x_dtype[3], x_dtype[2], x_dtype[1], x_dtype[0]], - [0, x_dtype[3], x_dtype[2], x_dtype[1]], - [0, 0, x_dtype[3], x_dtype[2]], - [0, 0, 0, x_dtype[3]], - ], - dtype=jnp.uint16, - ) - - # [ - # x0 # 2^0 - # x1 x0 # 2^8 - # x2 x1 x0 # 2^16 - # x3 x2 x1 x0 # 2^24 - # ----------- - # x3 x2 x1 # 2^32 iterate all elements in the bottom block - # x3 x2 # 2^40 - # x3 # 2^48 - # ] - - # Perform BAT to the following block of the matrix - # j 2 1 0 - # x3 x2 x1 # 2^32 i=0 - # x3 x2 # 2^40 i=1 - # x3 # 2^48 i=2 - - for i in range(x_dtype.shape[0] - 1): - for j in range(x_dtype.shape[0] - 1 - i): - basis = (total_chunk_num + i) * chunkwidth - projected_data = ((x_dtype[i + j + 1].astype(jnp.uint64)) << basis) % q - r = jax.lax.bitcast_convert_type( - projected_data, new_dtype=jnp.uint8 - ).astype(jnp.uint16) - - x_dense = x_dense.at[:, total_chunk_num - 1 - j].set( - jnp.add(r[:height], x_dense[:, total_chunk_num - 1 - j]) - ) - - while not jnp.all(x_dense <= chunk_upper_bound) or not jnp.all( - x_dense[total_chunk_num:, :] == 0 - ): - # for _ in range(2): # rechunkify won't exceed 3 times. - for j in range(total_chunk_num - 1): - # Iterate over different columns - if not jnp.all(x_dense[:, total_chunk_num - 1 - j] <= chunk_upper_bound): - arr_new_chunkified = rechunkify( - x_dense[:, total_chunk_num - 1 - j], chunkwidth - ) - x_dense = x_dense.at[:, total_chunk_num - 1 - j].set( - arr_new_chunkified[:height] - ) - - # j 2 1 0 - # x3 x2 x1 # 2^32 i=0 - # x3 x2 # 2^40 i=1 - # x3 # 2^48 i=2 - for i in range(x_dtype.shape[0] - 1): - for j in range(x_dtype.shape[0] - 1 - i): - data = x_dense[total_chunk_num + i, total_chunk_num - 1 - j] - if data > 0: - basis = (total_chunk_num + i) * chunkwidth - projected_data = (data.astype(jnp.uint64) << basis) % q - r = jax.lax.bitcast_convert_type( - projected_data, new_dtype=jnp.uint8 - ).astype(jnp.uint16) - x_dense = x_dense.at[:, total_chunk_num - 1 - j].set( - jnp.add(r[:height], x_dense[:, total_chunk_num - 1 - j]) - ) - x_dense = x_dense.at[ - total_chunk_num + i, total_chunk_num - 1 - j - ].set(0) - return x_dense[:total_chunk_num, :].astype(jnp.uint8) - - -def hpmatmul_offline_compile_bat(mat_a, q): - """Convert the input (m,n) matrix into (m,n,p,q), i.e. - - replace each element in the original matrix by a p*q matrix (p==q). - - Args: - mat_a: The input matrix. - q: The modulus. - - Returns: - The converted matrix. - """ - assert mat_a.dtype == jnp.uint32 # This version is defined for 32-bit input. - m, n = mat_a.shape[0], mat_a.shape[1] - total_in_precision = 32 - chunkwidth = 8 - # Convert left-side matrix - total_chunk_num = int(jnp.ceil(total_in_precision / chunkwidth)) - - left_mat = jnp.zeros( - (m, n, total_chunk_num, total_chunk_num), dtype=jnp.uint16 - ) - - with concurrent.futures.ThreadPoolExecutor() as executor: - futures = [] - for i in range(m): - futures.extend( - executor.submit( - smul_as_dense_gemv_bat, - mat_a[i, k], - total_in_precision, - chunkwidth, - q, - ) - for k in range(n) - ) - index_pairs = [] - for i in range(m): - for k in range(n): - index_pairs.append((i, k)) - for future, (i, k) in zip(futures, index_pairs): - left_mat = left_mat.at[i, k, :, :].set(future.result()) - - return left_mat - - -def find_all_root_of_unity(n, q): - """Find the n-th root of unity in the field of integers modulo q.""" - all_root_of_unity = [] - for v in range(2, q): - if (v**n) % q == 1: - find_root = all((v**i) % q != 1 for i in range(1, n)) - if find_root: - all_root_of_unity.append(v) - print(v) - return all_root_of_unity - - -def prime_factors(n): - """Return the set of prime factors of n.""" - factors = set() - # Divide out factors of 2 - while n % 2 == 0: - factors.add(2) - n //= 2 - # Check odd factors from 3 to sqrt(n) - p = 3 - while p**2 <= n: - while n % p == 0: - factors.add(p) - n //= p - p += 2 - if n > 1: - factors.add(n) - return factors - - -def find_generator(q): - """Find a primitive root modulo q. - - Args: - q (int): The prime modulus. - - Returns: - A generator of GF(q)^*. - - Raises: - ValueError: If no generator is found, indicating q is not prime. - """ - phi = q - 1 - factors = prime_factors(phi) - - # Test candidates from 2 to q-1. - for g in range(2, q): - is_generator = all(pow(g, phi // p, q) != 1 for p in factors) - if is_generator: - return g - raise ValueError("No generator found, check that q is prime.") - - -def nth_primitive_root(n, q): - """Returns a primitive n-th root of unity in GF(q). - - Args: - n (int): The desired order of the root of unity. - q (int): The prime modulus of the finite field GF(q). - - Returns: - int: An element omega in GF(q) such that omega^n = 1 and its order is - exactly n. - - Precondition: n divides (q-1). - """ - if (q - 1) % n != 0: - raise ValueError( - "n must divide q-1 for a primitive n-th root of unity to exist." - ) - - # Find a generator g of GF(q)^* (a primitive element). - g = find_generator(q) - # Compute omega = g^((q-1)/n) mod q. - exponent = (q - 1) // n - omega = pow(g, exponent, q) - - # Optional: Verify that omega is indeed of order n. - if pow(omega, n, q) != 1: - raise ValueError("Something went wrong: omega^n != 1") - # Check that no smaller positive exponent gives 1. - for d in range(1, n): - if n % d == 0 and pow(omega, d, q) == 1: - raise ValueError( - "Found an exponent d < n with omega^d == 1, so omega is not" - " primitive." - ) - - return omega - - -def bit_reverse(x, bits): - """Compute the bit-reversal of integer x with the given number of bits.""" - result = 0 - for i in range(bits): - if (x >> i) & 1: # if i-th bit of x is 1 - result |= 1 << (bits - 1 - i) # set the corresponding reversed bit - return result - - -def tonelli_shanks(q, omega): - """Solve for x in x^2 ≡ omega (mod q) using the Tonelli-Shanks algorithm. - - Args: - q: The prime modulus. - omega: The element to find the square root of. - - Returns: - One square root of omega modulo q, or None if no square root exists. - - Raises: - ValueError: If the algorithm fails to find the exponent i. - """ - # Check if a is a quadratic residue mod p - if pow(omega, (q - 1) // 2, q) != 1: - return None # no solution exists - - # Special case: p ≡ 3 (mod 4) - if q % 4 == 3: - return pow(omega, (q + 1) // 4, q) - - # Write p - 1 as q_minus_1 * 2^s with q_minus_1 odd. - s = 0 - q_minus_1 = q - 1 - while q_minus_1 % 2 == 0: - s += 1 - q_minus_1 //= 2 - - # Find a quadratic non-residue z - z = 2 - while pow(z, (q - 1) // 2, q) != q - 1: - z += 1 - - m = s - c = pow(z, q_minus_1, q) - t = pow(omega, q_minus_1, q) - r = pow(omega, (q_minus_1 + 1) // 2, q) - - while t != 1: - # Find the least i, 0 < i < m, such that t^(2^i) ≡ 1 (mod p) - i = 1 - temp = pow(t, 2, q) - while temp != 1: - temp = pow(temp, 2, q) - i += 1 - if i == m: - raise ValueError("Algorithm failed to find the exponent i") - # Update values - b = pow(c, 2 ** (m - i - 1), q) - m = i - c = pow(b, 2, q) - t = (t * c) % q - r = (r * b) % q - return r - - -def compute_psi(omega, n, q): - """Compute psi such that psi^2 = omega and psi^n = -1 mod q. - - Given an n-th primitive root of unity omega in GF(q), compute psi. - - Args: - omega: The n-th primitive root of unity. - n: The order of the root of unity. - q: The prime modulus. - - Returns: - An element psi in GF(q) such that psi^2 = omega and psi^n = -1 mod q. - """ - psi = tonelli_shanks(q, omega) - if psi is None: - raise ValueError("No square root exists for omega modulo q.") - - # Check the negacyclic condition: psi^n should equal -1 modulo q. - if pow(psi, n, q) != q - 1: - # Try the other square root (q - psi) - psi = q - psi - if pow(psi, n, q) != q - 1: - raise ValueError("Neither square root of omega satisfies psi^n = -1 mod q.") - return psi - - -def gen_twiddle_matrix(rows, cols, q, omega): - """Precompute the twiddle matrix T of shape (rows, cols), where T[r, c] = omega^(r*c) mod q. - - Args: - rows: The number of rows in the matrix. - cols: The number of columns in the matrix. - q: The modulus. - omega: The primitive root of unity. - - Returns: - The twiddle matrix. - """ - twiddle_matrix = np.zeros((rows, cols), dtype=int) - for r in range(rows): - for c in range(cols): - twiddle_matrix[r, c] = pow(omega, r * c, q) - return twiddle_matrix - - -def gen_twiddle_matrix_inv(rows, cols, q, omega): - """Precompute the inverse twiddle matrix T_inv of shape (rows, cols). - - T_inv[r, c] = omega^{- (r*c)} mod q. - - Args: - rows: The number of rows in the matrix. - cols: The number of columns in the matrix. - q: The modulus. - omega: The primitive root of unity. - - Returns: - The inverse twiddle matrix. - """ - twiddle_matrix_inv = np.zeros((rows, cols), dtype=int) - for r in range(rows): - for c in range(cols): - twiddle_matrix_inv[r, c] = pow(omega, -r * c, q) - return twiddle_matrix_inv - - -def ntt_original_form(v, q, omega): - length = len(v) - coef_mat = gen_twiddle_matrix(length, length, q, omega) - result = [0] * length - for k in range(length): - acc = 0 - for j in range(length): - acc = (acc + v[j] * coef_mat[j, k]) % q - result[k] = acc - return result - - -def intt_original_form(v, q, omega): - """Compute the Inverse NTT (naive O(length^2) algorithm) of vector v of length length over GF(q). - - omega_inv is a primitive L-th root of unity for the inverse transform, i.e. - - if the forward NTT uses omega, then we use omega_inv = omega^{-1} mod q. - The result is normalized by multiplying by the modular inverse of L. - - Args: - v: The input vector. - q: The prime modulus. - omega: The primitive L-th root of unity. - - Returns: - The inverse NTT of v. - """ - - length = len(v) - omega_inv = pow(omega, -1, q) # modular inverse of root - coef_mat = gen_twiddle_matrix(length, length, q, omega_inv) - result = [0] * length - # Compute the modular inverse of L modulo q - length_inv = pow(length, -1, q) - for j in range(length): - acc = 0 - for k in range(length): - # Using omega_inv^(j*k) - acc = (acc + v[k] * coef_mat[j, k]) % q - result[j] = (acc * length_inv) % q - - return result - - -def ntt_bit_reverse(a, q, omega): - """Compute the Number Theoretic Transform of array a modulo p using a given primitive omega of unity.""" - n = len(a) - # Ensure that omega^n ≡ 1 (mod p) and n divides p-1 for validity. - # (This should be true if omega is a correct n-th omega of unity.) - # Bit-reverse the input array indices - bits = n.bit_length() - 1 # number of bits needed for indexes 0..n-1 - for i in range(n): - j = bit_reverse(i, bits) - if i < j: - a[i], a[j] = a[j], a[i] # swap to achieve bit-reversed order - # Cooley-Tukey iterative FFT (NTT) - length = 2 - while length <= n: - # Compute twiddle factor step: use omega^(n/length) as the increment - w_m = pow(omega, n // length, q) - half = length // 2 - for i in range(0, n, length): # loop over sub-FFT blocks - w = 1 - for j in range(i, i + half): # loop within each block - u = a[j] - v = a[j + half] * w % q # multiply by current twiddle factor - a[j] = (u + v) % q # butterfly: combine top part - a[j + half] = (u - v) % q # butterfly: combine bottom part - w = w * w_m % q # advance twiddle factor for next element - length *= 2 - return a - - -def intt_bit_reverse(a, q, omega): - """Compute the Inverse Number Theoretic Transform of array a modulo p using the given primitive root.""" - n = len(a) - inv_root = pow(omega, -1, q) # modular inverse of root - # Decimation-in-frequency (Gentleman-Sande) butterfly operations - length = n - while length >= 2: - w_m = pow(inv_root, n // length, q) - half = length // 2 - for i in range(0, n, length): - w = 1 - for j in range(i, i + half): - u = a[j] - v = a[j + half] - a[j] = (u + v) % q # combine pairs (top value) - a[j + half] = ( - ((u - v) % q) * w % q - ) # combine pairs (bottom), then multiply by twiddle - w = w * w_m % q # advance twiddle factor - length //= 2 - # Bit-reverse the result (to invert the initial bit-reversal - # permutation in NTT) - bits = n.bit_length() - 1 - for i in range(n): - j = bit_reverse(i, bits) - if i < j: - a[i], a[j] = a[j], a[i] - # Divide by n (multiply by n^{-1} mod p) to finish the inverse transform - inv_n = pow(n, -1, q) - for i in range(n): - a[i] = a[i] * inv_n % q - return a - - -def ntt_four_step(x, q, omega, rows, cols): - """Compute the 4-step NTT of the input vector x (length N = rows * cols) over GF(q). - - Args: - x: list or 1D numpy array (length N). - q: prime modulus. - omega: the primitive N-th root of unity. - rows: factors of N, so that N = rows * cols. - cols: factors of N, so that N = rows * cols. - - Returns: - A list representing the NTT result. - - Process: - 1. Columns: NTT on each column (length rows) using omega_col = omega^cols. - 2. Twiddle: Multiply by T[r,c] = omega^(r*c). - 3. Rows: NTT on each row (length cols) using omega_row = omega^rows. - 4. Reordering: Final output is flatten(transpose(Z)). - """ - num_elements = rows * cols - if len(x) != num_elements: - raise ValueError("Length of x must equal rows * cols") - omega_row = pow(omega, rows, q) - omega_col = pow(omega, cols, q) - matrix_a = np.array(x, dtype=int).reshape((rows, cols)) - y = np.zeros((rows, cols), dtype=int) - for c in range(cols): - col = matrix_a[:, c].tolist() - y[:, c] = ntt_original_form(col, q, omega_col) - print(f"after Step 1={y}") - - twiddle = gen_twiddle_matrix(rows, cols, q, omega) - y = (y * twiddle) % q - print(f"after Step 2={y}") - - matrix_z = np.zeros((rows, cols), dtype=int) - for r in range(rows): - row = y[r, :].tolist() - matrix_z[r, :] = ntt_original_form(row, q, omega_row) - print(f"after Step 3={matrix_z}") - matrix_x = np.array( - matrix_z.T - ).flatten() # forward transform reorders via transpose flattening - print(f"after Step 3 after transpose={matrix_x}") - return matrix_x.tolist() - - -def intt_four_step(x, q, omega, rows, cols): - """Compute the 4-step Inverse NTT of the input vector X (length N = rows * cols) over GF(q). - - Forward transform recap: - - Columns: NTT on each column (length rows) using omega_col = omega^cols. - - Twiddle: Multiply by T[r,c] = omega^(r*c). - - Rows: NTT on each row (length cols) using omega_row = omega^rows. - - Reordering: Final output is flatten(transpose(Z)). - - To invert, we perform: - 0. Compute the appropriate inverse roots. - 1. Undo the reordering. - 2. Inverse row transform (length cols) on each row. - 3. Multiply by the inverse twiddle matrix T_inv[r,c] = omega^(-r*c). - 4. Inverse column transform (length rows) on each column. - 5. Reassemble the final result. - - Note: The naive inverse NTT (intt_original_form) already divides by the - transform length. - Hence, the two stages provide an overall normalization of 1/(rows·cols) = 1/N. - - Args: - x: list or 1D numpy array (length N) that is the forward NTT result. - q: prime modulus. - omega: the primitive N-th root of unity used in the forward transform. - (Forward transform used: rowNTT with omega_row = omega^R and columnNTT - with omega_col = omega^C, plus twiddle multiplication T[r,c] = - omega^(r*c).) - rows: factors of N, so that N = rows * cols. - cols: factors of N, so that N = rows * cols. - - Returns: - A list representing the inverse NTT result (the original vector). - """ - num_elements = rows * cols - if len(x) != num_elements: - raise ValueError("Length of X must equal rows * cols") - - # Step 0: Compute necessary inverse roots and normalization factors. - # For the inverse column transform (of length rows): - omega_col = pow(omega, cols, q) - # For the inverse row transform (of length cols): - omega_row = pow(omega, rows, q) - - # Step 1: Undo the final reordering of the forward transform. - # The forward transform did: X = flatten(transpose(Z)) with Z of shape - # (rows, cols). - # To recover Z, first reshape X into shape (cols, rows) then transpose. - matrix_z = np.array(x, dtype=int).reshape((cols, rows)).T - # Now Z is an rows x cols matrix. - - # Step 2: Inverse row transform. - # For each row of Y (length cols), compute the inverse NTT using omega_row. - y = np.zeros((rows, cols), dtype=int) - for r in range(rows): - row = matrix_z[r, :].tolist() - # intt on each row of length cols using inv_omega_row - # (inverse happens inside intt_original_form) - y[r, :] = intt_original_form(row, q, omega_row) - - # Step 3: Multiply by the inverse twiddle factor matrix. - # The forward twiddle matrix was T[r,c] = omega^(r*c). Its inverse is: - # T_inv[r,c] = omega^{-r*c} mod q. - twiddle_inv = gen_twiddle_matrix_inv(rows, cols, q, omega) - y = (y * twiddle_inv) % q - - # Step 4: Inverse column transform. - # For each column of Z (length rows), compute the inverse NTT using - # inv_omega_col. - matrix_a = np.zeros((rows, cols), dtype=int) - for c in range(cols): - col = y[:, c].tolist() - # intt on each column of length rows using inv_omega_col - # (inverse happens inside intt_original_form). - matrix_a[:, c] = intt_original_form(col, q, omega_col) - - # Step 5: Reassemble the final result. - # The forward transform mapped the original vector x to X using a reordering. - # Here, we flatten A (row-major order) to obtain the original x. - x_recovered = np.array(matrix_a).flatten() - return x_recovered.tolist() - - -def ntt_negacyclic(a, q, psi, rows, cols): - """Compute the negacyclic NTT of array a (length n) modulo q. - - Args: - a: list (or 1D array) of integers (length n). - q: prime modulus. - psi: an element in GF(q) such that psi^(2*n) = 1 and psi^n = -1 mod q. (That - is, psi is a primitive 2n-th root of unity; note that then ω = psi^2 is a - primitive n-th root of unity.) - rows: Number of rows in the matrix. - cols: Number of columns in the matrix. - - Returns: - The negacyclic NTT of a. - - Process: - 1. Pre-twist: multiply each coefficient a[i] by psi^i. - 2. Compute the vanilla NTT (for example, using ntt_bit_reverse) with ω = - psi^2. - """ - n = len(a) - # Check that psi^n = -1 mod q. - if pow(psi, n, q) != q - 1: - raise ValueError( - "psi is not a valid 2n-th root of unity for negacyclic NTT (psi^n must" - " equal -1 mod q)." - ) - - # Pre-twisting: multiply a[i] by psi^i. - a_twisted = [(a[i] * pow(psi, i, q)) % q for i in range(n)] - - # Compute vanilla NTT using ω = psi². - omega = pow(psi, 2, q) - - # a_transformed = ntt_bit_reverse(a_twisted.copy(), q, omega) - return ntt_four_step(a_twisted.copy(), q, omega, rows, cols) - - -def intt_negacyclic(a, q, psi, rows, cols): - """Compute the inverse negacyclic NTT of array a (length n) modulo q. - - Args: a : list (or 1D array) of integers (length n) in the negacyclic - evaluation domain. q : prime modulus. psi : an element in GF(q) such that - psi^(2*n) = 1 and psi^n = -1 mod q. (That is, psi is a primitive 2n-th root of - unity; note that then ω = psi^2 is a primitive n-th root of unity.) - - Returns: - The original input vector (i.e. the inverse transform). - - Process: - 1. Compute the inverse vanilla NTT using ω = psi². - 2. Post-twist: multiply the result by psi^(–i) for coefficient index i. - """ - n = len(a) - omega = pow(psi, 2, q) - - # Compute the inverse vanilla NTT. - # a_inv = intt_bit_reverse(a.copy(), q, omega) - a_inv = intt_four_step(a.copy(), q, omega, rows, cols) - - # Post-twisting: multiply a_inv[i] by psi^(–i). - psi_inv = pow(psi, -1, q) - return [(a_inv[i] * pow(psi_inv, i, q)) % q for i in range(n)] - - -def ntt_negacyclic_tpu_algorithm( - a, q, psi, rows, cols, tf_step1, coef_step2, tf_step3 -): - """Compute the negacyclic NTT of array a (length n) modulo q. - - Args: - a: list (or 1D array) of integers (length n). - q: prime modulus. - psi: an element in GF(q) such that psi^(2*n) = 1 and psi^n = -1 mod q. (That - is, psi is a primitive 2n-th root of unity; note that then ω = psi^2 is a - primitive n-th root of unity.) - rows: Number of rows in the matrix. - cols: Number of columns in the matrix. - tf_step1: The twiddle factor matrix for step 1. - coef_step2: The twiddle factor matrix for step 2 (element-wise - multiplication). - tf_step3: The twiddle factor matrix for step 3. - - Returns: - The negacyclic NTT of a. - - Process: - 1. Pre-twist: multiply each coefficient a[i] by psi^i. - 2. Compute the vanilla NTT (for example, using ntt_bit_reverse) with ω = - psi^2. - """ - n = len(a) - # Check that psi^n = -1 mod q. - if pow(psi, n, q) != q - 1: - raise ValueError( - "psi is not a valid 2n-th root of unity for negacyclic NTT (psi^n must" - " equal -1 mod q)." - ) - - # Pre-twisting: multiply a[i] by psi^i. - a_twisted = [(a[i] * pow(psi, i, q)) % q for i in range(n)] - - num_elements = rows * cols - if len(a_twisted) != num_elements: - raise ValueError("Length of a_twisted must equal rows * cols") - matrix_a = np.array(a_twisted, dtype=int).reshape((rows, cols)) - y = np.matmul(tf_step1, matrix_a) - y = y % q - - y = y * coef_step2 - y = y % q - - z = np.matmul(y, tf_step3) - z = z % q - x = np.array( - z.T - ).flatten() # forward transform reorders via transpose flattening - return x.tolist() - - -def intt_negacyclic_tpu_algorithm( - a, q, psi, rows, cols, inv_tf_step1, inv_coef_step2, inv_tf_step3 -): - """Compute the inverse negacyclic NTT of array a (length n) modulo q using TPU-friendly operations. - - Args: a : list (or 1D array) of integers (length n) in the negacyclic - evaluation domain. q : prime modulus. psi : an element in GF(q) such that - psi^(2*n) = 1 and psi^n = -1 mod q. (That is, psi is a primitive 2n-th root of - unity; note that then ω = psi^2 is a primitive n-th root of unity.) - rows: Number of rows in the matrix. - cols: Number of columns in the matrix. - inv_tf_step1: The inverse of the first transform matrix. - inv_coef_step2: The inverse of the second coefficient matrix. - inv_tf_step3: The inverse of the third transform matrix. - - Returns: - The original input vector (i.e. the inverse transform). - - Process: - 1. Compute the inverse vanilla NTT using ω = psi². - 2. Post-twist: multiply the result by psi^(–i) for coefficient index i. - """ - n = len(a) - - num_elements = rows * cols - if len(a) != num_elements: - raise ValueError("Length of a must equal rows * cols") - - # Step 1: Undo the final reordering of the forward transform. - # The forward transform did: X = flatten(transpose(Z)) with Z of shape - # (rows, cols). - # To recover Z, first reshape X into shape (cols, rows) then transpose. - z = np.array(a, dtype=int).reshape((cols, rows)).T - # Now z is an rows x cols matrix. - - # Step 2: Inverse row transform. - # For each row of Y (length cols), compute the inverse NTT using omega_row. - y = np.matmul(z, inv_tf_step1) % q - cols_inv = pow(cols, -1, q) - y = y * cols_inv % q - # Step 3: Multiply by the inverse twiddle factor matrix. - # The forward twiddle matrix was T[r,c] = omega^(r*c). Its inverse is: - # T_inv[r,c] = omega^{-r*c} mod q. - y = (y * inv_coef_step2) % q - - # Step 4: Inverse column transform. - # For each column of Z (length rows), compute the inverse NTT using - # inv_omega_col. - a = np.matmul(inv_tf_step3, y) % q - rows_inv = pow(rows, -1, q) - a = a * rows_inv % q - - # Step 5: Reassemble the final result. - # The forward transform mapped the original vector x to X using a reordering. - # Here, we flatten A (row-major order) to obtain the original x. - x_recovered = np.array(a).flatten() - - # Post-twisting: multiply x_recovered[i] by psi^(–i). - psi_inv = pow(psi, -1, q) - return [(x_recovered[i] * pow(psi_inv, i, q)) % q for i in range(n)] - - -######################## -# Online Functions -######################## -@functools.partial( - jax.jit, - static_argnames=("q", "s", "m"), -) -def barret_reduction(z, q, s, m): - """Vectorized implementation of the Barrett reduction. - - This implementation sets the internal shift width `w` to `min(s, 32)` so it - works with small modulus `q < 2^16`. - - Args: - z: The input value (at most 64 bits). - q: The modulus. - s: The bit width of q. - m: The precomputed value for Barrett reduction. - - Returns: - The result of the Barrett reduction. - """ - w = min(s, 32) - z1 = z & (2**w - 1) - z2 = z >> w - t = ((z1 * m) >> w) + (z2 * m) - t = t >> (s - w) - z = (z - t * q).astype(jnp.uint32) - pred = z >= q - return jnp.where(pred, z - q, z) - - -@functools.partial( - jax.jit, - static_argnames=("q", "s"), -) -def barret_reduction_static_q(z, q, s): - """Vectorized implementation of the Barrett reduction. - - This implementation specializes on the value of `q`, which allows XLA to - apply aggressive compile-time optimizations. - TODO: remove `m` in the function arguments - - Args: - z: The input value. - q: The modulus. - s: The bit width of q. - - Returns: - The result of the Barrett reduction. - """ - # if this implementation fails to pass any test, move this line out - # and add 'm' into static_argnames - m = jnp.floor(2**s / q).astype(jnp.uint32) - w = min(s, 32) - z1 = z & (2**w - 1) - z2 = z >> w - t = ((z1 * m) >> w) + (z2 * m) - t = t >> (s - w) - z = (z - t * q).astype(jnp.uint32) - pred = z >= q - return jnp.where(pred, z - q, z) - - -@jax.jit -def hpmatmul_bat_coef_lhs_batch(lhs: jax.Array, y: jax.Array): - """Input (m, k) Left Matrix -> (m, k, p, q) Left Matrix, where each element in the original (m, k) matrix is replaced by a (p, q) matrix. - - Expect the dtype of `lhs` and `rhs` to be `jnp.uint32`. - - Args: - lhs: The input left matrix. - y: The input right matrix. - - Returns: - The result of the bat coefficient multiplication, with the same batch size - as the input matrices. - """ - rhs: jax.Array = jax.lax.bitcast_convert_type(y, new_dtype=jnp.uint8) - i8_products = jnp.einsum( - "mkpq,bknq->bmnp", - lhs, - rhs, - preferred_element_type=jnp.int32, - ) - shift_factors = jnp.array([0, 8, 16, 24], dtype=jnp.uint32) - return jnp.sum(i8_products.astype(jnp.uint64) << shift_factors, axis=(-1,)) - - -@jax.jit -def hpmatmul_bat_coef_rhs_batch(y: jax.Array, rhs: jax.Array): - """Input (k, n) right Matrix -> (k, n, p, q) right Matrix, where each element in the original (k, n) matrix is replaced by a (p, q) matrix. - - Expect the dtype of `lhs` and `rhs` to be `jnp.uint32`. - - Args: - y: The input left matrix. - rhs: The input right matrix. - - Returns: - The result of the bat coefficient multiplication, with the same batch size - as the input matrices. - """ - - lhs: jax.Array = jax.lax.bitcast_convert_type(y, new_dtype=jnp.uint8) - i8_products = jnp.einsum( - "bmkq,knpq->bmnp", - lhs, - rhs, - preferred_element_type=jnp.int32, - ) - shift_factors = jnp.array([0, 8, 16, 24], dtype=jnp.uint32) - return jnp.sum(i8_products.astype(jnp.uint64) << shift_factors, axis=(-1,)) - - -@jax.named_call -@functools.partial( - jax.jit, - static_argnames=("q", "s", "m"), -) -def ntt_layout_invariant_batch( - poly_coef_2d, - tf_step1, - coef_step2, - tf_step3, - q, - s, - m, -): - """Jax implementation of Gentalman Sande NTT, vectorized implementation on VPU.""" - assert poly_coef_2d.dtype == jnp.uint32 - assert tf_step1.dtype == jnp.uint8 - assert coef_step2.dtype == jnp.uint32 - assert tf_step3.dtype == jnp.uint8 - - result_step1 = hpmatmul_bat_coef_lhs_batch(tf_step1, poly_coef_2d) - - result_step1_mod_q = barret_reduction(result_step1, q, s, m) - - result_step2 = jax.numpy.multiply(result_step1_mod_q, coef_step2) - result_step2_mod_q = barret_reduction(result_step2, q, s, m) - result_step3 = hpmatmul_bat_coef_rhs_batch(result_step2_mod_q, tf_step3) - return barret_reduction(result_step3, q, s, m) - - -@jax.named_call -@functools.partial( - jax.jit, - static_argnames=("rows_inv", "q", "s", "m"), -) -def intt_layout_invariant_batch( - poly_coef_2d, - tf_step1, - coef_step2, - tf_step3, - rows_inv, - q, - s, - m, -): - """Jax implementation of Gentalman Sande NTT, vectorized implementation on VPU.""" - assert poly_coef_2d.dtype == jnp.uint32 - assert tf_step1.dtype == jnp.uint8 - assert coef_step2.dtype == jnp.uint32 - assert tf_step3.dtype == jnp.uint8 - - result_step1 = hpmatmul_bat_coef_rhs_batch(poly_coef_2d, tf_step1) - result_step1_mod_q = barret_reduction(result_step1, q, s, m) - result_step2 = jax.numpy.multiply(result_step1_mod_q, coef_step2) - result_step2_mod_q = barret_reduction(result_step2, q, s, m) - result_step3 = hpmatmul_bat_coef_lhs_batch(tf_step3, result_step2_mod_q) - result_step3_mod_q = barret_reduction(result_step3, q, s, m) - result_scaled = jax.numpy.multiply(result_step3_mod_q, rows_inv) - return barret_reduction(result_scaled, q, s, m) diff --git a/jaxite/jaxite_word/ntt_mm.py b/jaxite/jaxite_word/ntt_mm.py new file mode 100644 index 0000000..b8994f2 --- /dev/null +++ b/jaxite/jaxite_word/ntt_mm.py @@ -0,0 +1,887 @@ +"""This script is specifically designed for NTT/INTT used for ciphertext. + +Main difference to ntt_o.py is this script supports (1) multiple moduli (2) +multiple batch (3) distributed sharding +""" + +import concurrent.futures +from typing import List, Union +import jaxite.jaxite_word.finite_field as ff_context +import jax +import jax.numpy as jnp +import numpy as np +import jaxite.jaxite_word.util as util + +def _is_nvidia(): + return "NVIDIA" in jax.devices()[0].device_kind + + +######################## +# Common Functions +######################## +def matmul_bat_einsum(lhs: jax.Array, rhs: jax.Array, subscripts: str): + """Basis Aligned Transformation (BAT) based matrix multiplication + + Args: + lhs (jax.Array): input + rhs (jax.Array): twiddle factor matrix + subscripts (str): einsum subscripts + + Returns: + jax.Array: result + """ + # preprocess + lhs = jax.lax.bitcast_convert_type(lhs, new_dtype=jnp.uint8) + shift_factors = jnp.array([0, 8, 16, 24], dtype=jnp.uint32) + + # computation + i8_products = jnp.einsum( + subscripts, lhs, rhs, preferred_element_type=jnp.uint32 + ) + return jnp.sum(i8_products.astype(jnp.uint64) << shift_factors, axis=(-1,)) + + +def matmul_conv_flexible_kernel( + x: jnp.ndarray, y: jnp.ndarray, subscripts: tuple[str, str, str] +) -> jnp.ndarray: + assert x.dtype == jnp.uint32 + assert y.dtype == jnp.uint32 + + lhs: jax.Array = jax.lax.bitcast_convert_type(x, new_dtype=jnp.uint8) # bnmp + rhs: jax.Array = jax.lax.bitcast_convert_type(y, new_dtype=jnp.uint8) # nk1q + # https://github.com/google/jax/issues/11483 + rhs = jax.lax.rev(rhs, [2]) + + if _is_nvidia(): + u8_products = jax.lax.conv_general_dilated( + lhs.astype( + jnp.int16 + ), # NVIDIA GPU does not support uint8 as input type + rhs.astype( + jnp.int16 + ), # NVIDIA GPU does not support uint8 as input type + window_strides=(1,), + padding=((3, 3),), + dimension_numbers=subscripts, + preferred_element_type=jnp.float32, # NVIDIA GPU does not support uint32 as output type + ) + else: + u8_products = jax.lax.conv_general_dilated( + lhs, + rhs, + window_strides=(1,), + padding=((3, 3),), + dimension_numbers=subscripts, + preferred_element_type=jnp.uint32, + ) + + shift_factors = jnp.array([0, 8, 16, 24, 32, 40, 48], dtype=jnp.uint32) + return jnp.sum(u8_products.astype(jnp.uint64) << shift_factors, axis=(2,)) + + +######################## +# Parameter Generation Functions +######################## +def gen_twiddle_matrix(rows, cols, q, omega): + """Precompute the twiddle matrix T of shape (rows, cols), where T[r, c] = omega^(r*c) mod q. + + Args: + rows: The number of rows in the matrix. + cols: The number of columns in the matrix. + q: The modulus. + omega: The primitive root of unity. + + Returns: + The twiddle matrix. + """ + # Vectorized modular exponentiation via exponent bit-decomposition + r_idx = np.arange(rows, dtype=np.int64)[:, None] + c_idx = np.arange(cols, dtype=np.int64)[None, :] + exponents = r_idx * c_idx # shape (rows, cols) + twiddle_matrix = np.zeros((rows, cols), dtype=int) + + def compute_row(r): + for c in range(cols): + twiddle_matrix[r, c] = pow(int(omega), int(exponents[r, c]), int(q)) + + with concurrent.futures.ThreadPoolExecutor() as executor: + list(executor.map(compute_row, range(rows))) + return twiddle_matrix + + +def gen_twiddle_matrix_inv(rows, cols, q, omega): + """Precompute the inverse twiddle matrix T_inv of shape (rows, cols). + + T_inv[r, c] = omega^{- (r*c)} mod q. + + Args: + rows: The number of rows in the matrix. + cols: The number of columns in the matrix. + q: The modulus. + omega: The primitive root of unity. + + Returns: + The inverse twiddle matrix. + """ + twiddle_matrix_inv = np.zeros((rows, cols), dtype=int) + for r in range(rows): + for c in range(cols): + twiddle_matrix_inv[r, c] = pow(int(omega), int(-r * c), int(q)) + return twiddle_matrix_inv + + +######################## +# NTT Context with different modular reduction methods +######################## +class NTTCiphertextContextBase: + """Base class for NTT Context with different modular reduction methods + + This class implements the numpy version of three-step NTT algorithm. + Args: + moduli: The modulus. + transform_length: The transform length. + parameters: The parameters. + + Returns: + The NTT Context. + """ + + def __init__( + self, moduli: Union[int, List[int]], parameters: dict, perf_test=False + ): + self.ff_ctx = parameters.get("finite_field_context", None) + self.num_bytes = 4 + + if type(moduli) == int: + moduli = [moduli] + self.moduli = moduli + self.parameters = parameters + assert all(q < 2**31 for q in moduli), "moduli must be less than 2**32" + self.r = parameters.get("r", 0) + self.c = parameters.get("c", 0) + assert self.r != 0, "r must be non-zero" + assert self.c != 0, "c must be non-zero" + self.transform_length = self.r * self.c + self.psi_list = [ + util.root_of_unity(2 * self.transform_length, q) for q in self.moduli + ] + self.omega_list = [ + (psi**2) % q for psi, q in zip(self.psi_list, self.moduli) + ] + if perf_test: + # Use random data for performance testing to avoid expensive precomputation + key = jax.random.PRNGKey(0) + self.ntt_bat_tf_step1 = jax.random.bits( + key, (self.r, 4, self.r, 4, len(moduli)), dtype=jnp.uint8 + ) + self.ntt_tf_step2 = jax.random.bits( + key, (self.r, self.c, len(moduli)), dtype=jnp.uint64 + ) + self.ntt_bat_tf_step3 = jax.random.bits( + key, (self.c, 4, self.c, 4, len(moduli)), dtype=jnp.uint8 + ) + self.intt_bat_tf_step1 = jax.random.bits( + key, (self.c, 4, self.c, 4, len(moduli)), dtype=jnp.uint8 + ) + self.intt_tf_step2 = jax.random.bits( + key, (self.r, self.c, len(moduli)), dtype=jnp.uint64 + ) + self.intt_bat_tf_step3 = jax.random.bits( + key, (self.r, 4, self.r, 4, len(moduli)), dtype=jnp.uint8 + ) + + if isinstance(self, NTTCiphertextShoupContext): + self.ntt_tf_step1 = jax.random.bits( + key, (self.r, self.r, len(moduli)), dtype=jnp.uint32 + ) + self.ntt_tf_step3 = jax.random.bits( + key, (self.c, self.c, len(moduli)), dtype=jnp.uint32 + ) + self.intt_tf_step1 = jax.random.bits( + key, (self.c, self.c, len(moduli)), dtype=jnp.uint32 + ) + self.intt_tf_step3 = jax.random.bits( + key, (self.r, self.r, len(moduli)), dtype=jnp.uint32 + ) + + self.ntt_tf_step1_shoup = jax.random.bits( + key, (self.r, self.r, len(moduli)), dtype=jnp.uint32 + ) + self.ntt_tf_step2_shoup = jax.random.bits( + key, (self.r, self.c, len(moduli)), dtype=jnp.uint64 + ) + self.ntt_tf_step3_shoup = jax.random.bits( + key, (self.c, self.c, len(moduli)), dtype=jnp.uint32 + ) + self.intt_tf_step1_shoup = jax.random.bits( + key, (self.c, self.c, len(moduli)), dtype=jnp.uint32 + ) + self.intt_tf_step2_shoup = jax.random.bits( + key, (self.r, self.c, len(moduli)), dtype=jnp.uint64 + ) + self.intt_tf_step3_shoup = jax.random.bits( + key, (self.r, self.r, len(moduli)), dtype=jnp.uint32 + ) + else: + self.memory_aligned_transformation() + self.ntt_tf_step1, self.ntt_tf_step2, self.ntt_tf_step3 = ( + self.ntt_coefficients_precompute() + ) + self.intt_tf_step1, self.intt_tf_step2, self.intt_tf_step3 = ( + self.intt_coefficients_precompute() + ) + self.ntt_bat_tf_step1 = self.basis_aligned_transformation( + self.to_computation_format(self.ntt_tf_step1) + ) + self.ntt_tf_step2 = self.to_computation_format(self.ntt_tf_step2).astype( + jnp.uint64 + ) + self.ntt_bat_tf_step3 = self.basis_aligned_transformation( + self.to_computation_format(self.ntt_tf_step3) + ) + self.intt_bat_tf_step1 = self.basis_aligned_transformation( + self.to_computation_format(self.intt_tf_step1) + ) + self.intt_tf_step2 = self.to_computation_format( + self.intt_tf_step2 + ).astype(jnp.uint64) + self.intt_bat_tf_step3 = self.basis_aligned_transformation( + self.to_computation_format(self.intt_tf_step3) + ) + + ######################## + # Offline Functions + ######################## + def ntt_coefficients_precompute(self): + """R = self.r, C = self.c, M = len(self.moduli) + + - ntt_tf_step1: shape (R, R, M), u32 + - ntt_tf_step2: shape (R, C, M), u32 + - ntt_tf_step3: shape (C, C, M), u32 + """ + tf_step1_list, tf_step2_list, tf_step3_list = [], [], [] + for i, modulus in enumerate(self.moduli): + omega_col = pow(self.omega_list[i], self.c, modulus) + omega_row = pow(self.omega_list[i], self.r, modulus) + tf_step1_one_modulus = gen_twiddle_matrix( + self.r, self.r, modulus, omega_col + ) + tf_step2_one_modulus = gen_twiddle_matrix( + self.r, self.c, modulus, self.omega_list[i] + ) + tf_step3_one_modulus = gen_twiddle_matrix( + self.c, self.c, modulus, omega_row + ) + tf_step1_one_modulus = tf_step1_one_modulus[ + self.perm_r, : + ] # Memory Aligned Transformation + tf_step2_one_modulus = tf_step2_one_modulus[ + self.perm_r, : + ] # Memory Aligned Transformation + tf_step3_one_modulus = tf_step3_one_modulus[ + :, self.perm_c + ] # Memory Aligned Transformation + tf_step1_list.append(tf_step1_one_modulus) + tf_step2_list.append(tf_step2_one_modulus) + tf_step3_list.append(tf_step3_one_modulus) + tf_step1 = jnp.array(tf_step1_list, dtype=jnp.uint32).transpose( + 1, 2, 0 + ) # Make moduli the last dimension + tf_step2 = jnp.array(tf_step2_list, dtype=jnp.uint32).transpose( + 1, 2, 0 + ) # Make moduli the last dimension + tf_step3 = jnp.array(tf_step3_list, dtype=jnp.uint32).transpose( + 1, 2, 0 + ) # Make moduli the last dimension + return tf_step1, tf_step2, tf_step3 + + def intt_coefficients_precompute(self): + """R = self.r, C = self.c, M = len(self.moduli) + + - intt_tf_step1: shape (C, C, M), u32 + - intt_tf_step2: shape (R, C, M), u32 + - intt_tf_step3: shape (R, R, M), u32 + """ + intt_tf_step1_list, intt_tf_step2_list, intt_tf_step3_list = [], [], [] + for i, modulus in enumerate(self.moduli): + omega_col = pow(self.omega_list[i], self.c, modulus) + omega_row = pow(self.omega_list[i], self.r, modulus) + inv_omega_col = pow(omega_col, -1, modulus) + inv_omega_row = pow(omega_row, -1, modulus) + intt_tf_step1_one_modulus = gen_twiddle_matrix( + self.c, self.c, modulus, inv_omega_row + ) + intt_tf_step2_one_modulus = gen_twiddle_matrix_inv( + self.r, self.c, modulus, self.omega_list[i] + ) + intt_tf_step3_one_modulus = gen_twiddle_matrix( + self.r, self.r, modulus, inv_omega_col + ) + intt_tf_step1_one_modulus = intt_tf_step1_one_modulus[ + self.perm_c, : + ] # Memory Aligned Transformation + intt_tf_step2_one_modulus = intt_tf_step2_one_modulus[ + self.perm_r, : + ] # Memory Aligned Transformation + intt_tf_step3_one_modulus = intt_tf_step3_one_modulus[ + :, self.perm_r + ] # Memory Aligned Transformation + col_inv = pow(self.c, -1, modulus) + row_inv = pow(self.r, -1, modulus) + intt_tf_step2_one_modulus = ( + intt_tf_step2_one_modulus * col_inv + ) % modulus + intt_tf_step3_one_modulus = ( + intt_tf_step3_one_modulus * row_inv + ) % modulus + intt_tf_step1_list.append(intt_tf_step1_one_modulus) + intt_tf_step2_list.append(intt_tf_step2_one_modulus) + intt_tf_step3_list.append(intt_tf_step3_one_modulus) + intt_tf_step1 = jnp.array(intt_tf_step1_list, dtype=jnp.uint32).transpose( + 1, 2, 0 + ) # Make moduli the last dimension + intt_tf_step2 = jnp.array(intt_tf_step2_list, dtype=jnp.uint32).transpose( + 1, 2, 0 + ) # Make moduli the last dimension + intt_tf_step3 = jnp.array(intt_tf_step3_list, dtype=jnp.uint32).transpose( + 1, 2, 0 + ) # Make moduli the last dimension + return intt_tf_step1, intt_tf_step2, intt_tf_step3 + + def to_computation_format(self, a: Union[np.ndarray, jax.Array]): + assert self.ff_ctx is not None + return self.ff_ctx.to_computation_format(a.astype(jnp.uint64)).astype( + jnp.uint32 + ) + + def to_original_format(self, a: Union[np.ndarray, jax.Array]): + assert self.ff_ctx is not None + return self.ff_ctx.to_original_format(a.astype(jnp.uint64)).astype( + jnp.uint32 + ) + + def basis_aligned_transformation(self, matrix: np.ndarray): + matrix_u64 = matrix.astype(np.uint64) + matrix_u64_byteshifted = np.array( + [matrix_u64 << (8 * byte_idx) for byte_idx in range(self.num_bytes)], + dtype=np.uint64, + ) + # shape is (4, rows, cols, moduli) + matrix_u64_byteshifted_mod_modulus = ( + matrix_u64_byteshifted % jnp.array(self.moduli, dtype=np.uint64) + ).astype(np.uint32) + # shape is (4, rows, cols, moduli, bytes=4) + matrix_u8 = jax.lax.bitcast_convert_type( + matrix_u64_byteshifted_mod_modulus, jnp.uint8 + ).transpose(1, 0, 2, 4, 3) + return matrix_u8 + + def memory_aligned_transformation(self): + """Memory Aligned Transformation (MAT) + + Must run after gen_twiddle_matrix() + """ + + def get_bit_reverse_perm(n): + """Generates a list of indices for bit-reversal permutation of size n.""" + if n <= 0: + return [] + bits = n.bit_length() - 1 + perm = [0] * n + for i in range(n): + # Reverse bits of i + r = 0 + temp = i + for _ in range(bits): + r = (r << 1) | (temp & 1) + temp >>= 1 + perm[i] = r + return perm + + self.perm_r = get_bit_reverse_perm(self.r) + self.perm_c = get_bit_reverse_perm(self.c) + + def get_jax_parameters(self): + assert self.ff_ctx is not None + return { + "ntt_bat_tf_step1": util.to_tuple(self.ntt_bat_tf_step1), + "ntt_tf_step2": util.to_tuple(self.ntt_tf_step2), + "ntt_bat_tf_step3": util.to_tuple(self.ntt_bat_tf_step3), + "intt_bat_tf_step1": util.to_tuple(self.intt_bat_tf_step1), + "intt_tf_step2": util.to_tuple(self.intt_tf_step2), + "intt_bat_tf_step3": util.to_tuple(self.intt_bat_tf_step3), + "finite_field_parameters": self.ff_ctx.get_jax_parameters(), + "rows": self.r, + "cols": self.c, + } + + ######################## + # Online Functions + ######################## + def ntt_limb(self, v: jax.Array, limb_index: int): + """NTT with modular u32 + + B = Batch size, R = self.r, C = self.c + Q = 4 (number of bytes per element) + + Args: + v: - is u32 array of shape (B, R, C) - will be casted into u8 array of + shape (B, R, C, Q) + ntt_bat_tf_step1: - is u8 array of shape (R, 4, R, 4) + ntt_tf_step2: - is u32 array of shape (R, C) + ntt_bat_tf_step3: - is u8 array of shape (C, 4, C, 4) + + Returns: + - is u32 array of shape (B, R, C) + - output + """ + result_step1 = matmul_bat_einsum( + v, self.ntt_bat_tf_step1[..., limb_index], "brcq,zqrp->bzcp" + ) + assert self.ff_ctx is not None + result_step1_reduced = self.ff_ctx.modular_reduction_single_modulus( + result_step1, limb_index + ) + result_step2 = jnp.multiply( + result_step1_reduced.astype(jnp.uint64), + self.ntt_tf_step2[..., limb_index], + ) + result_step2_reduced = self.ff_ctx.modular_reduction_single_modulus( + result_step2, limb_index + ) + result_step3 = matmul_bat_einsum( + result_step2_reduced, + self.ntt_bat_tf_step3[..., limb_index], + "brcq,cqnp->brnp", + ) + result_step3_reduced = self.ff_ctx.modular_reduction_single_modulus( + result_step3, limb_index + ) + return result_step3_reduced + + def intt_limb(self, v: jax.Array, limb_index: int): + """INTT with modular u32 + + B = Batch size, R = self.r, C = self.c + Q = 4 (number of bytes per element) + + Args: + v: - is u32 array of shape (B, R, C) - will be casted into u8 array of + shape (B, R, C, Q) + intt_bat_tf_step1: - is u8 array of shape (C, 4, C, 4) + intt_tf_step2: - is u32 array of shape (R, C) + intt_bat_tf_step3: - is u8 array of shape (R, 4, R, 4) + + Returns: + - is u32 array of shape (B, R, C) + - output + """ + result_step1 = matmul_bat_einsum( + v, self.intt_bat_tf_step1[..., limb_index], "brcq,cqlp->brlp" + ) + assert self.ff_ctx is not None + result_step1_reduced = self.ff_ctx.modular_reduction_single_modulus( + result_step1, limb_index + ) + result_step2 = jnp.multiply( + result_step1_reduced.astype(jnp.uint64), + self.intt_tf_step2[..., limb_index], + ) + result_step2_reduced = self.ff_ctx.modular_reduction_single_modulus( + result_step2, limb_index + ) + result_step3 = matmul_bat_einsum( + result_step2_reduced, + self.intt_bat_tf_step3[..., limb_index], + "brcq,lqrp->blcp", + ) + result_step3_reduced = self.ff_ctx.modular_reduction_single_modulus( + result_step3, limb_index + ) + return result_step3_reduced + + def ntt(self, v: jax.Array): + """NTT with modular u32 + + B = Batch size, R = self.r, C = self.c + Q = 4 (number of bytes per element) + M = len(self.moduli) + + Args: + v: - is u32 array of shape (B, R, C, M) - will be casted into u8 array + of shape (B, R, C, M, Q) + ntt_bat_tf_step1: - is u8 array of shape (R, 4, R, 4, M) + ntt_tf_step2: - is u32 array of shape (R, C, M) + ntt_bat_tf_step3: - is u8 array of shape (C, 4, C, 4, M) + + Returns: + - is u32 array of shape (B, R, C, M) + - output + """ + result_step1 = matmul_bat_einsum( + v, self.ntt_bat_tf_step1, "brcmq,zqrpm->bzcmp" + ) # "mqkp,bknq->bmnp"; "bkncq,mqkpc->bmncp" + assert self.ff_ctx is not None + result_step1_reduced = self.ff_ctx.modular_reduction(result_step1) + result_step2 = jnp.multiply( + result_step1_reduced.astype(jnp.uint64), self.ntt_tf_step2 + ) + result_step2_reduced = self.ff_ctx.modular_reduction(result_step2) + result_step3 = matmul_bat_einsum( + result_step2_reduced, self.ntt_bat_tf_step3, "brcmq,cqnpm->brnmp" + ) # "bmkq,kqnp->bnmp" "bmkcq,kqnpc->bnmcp" + result_step3_reduced = self.ff_ctx.modular_reduction(result_step3) + return result_step3_reduced + + def intt(self, v: jax.Array): + """INTT with modular u32 + + B = Batch size, R = self.r, C = self.c + Q = 4 (number of bytes per element) + M = len(self.moduli) + + Args: + v: - is u32 array of shape (B, R, C, M) - will be casted into u8 array + of shape (B, R, C, M, Q) + intt_bat_tf_step1: - is u8 array of shape (C, 4, C, 4, M) + intt_tf_step2: - is u32 array of shape (R, C, M) + intt_bat_tf_step3: - is u8 array of shape (R, 4, R, 4, M) + + Returns: + - is u32 array of shape (B, R, C, M) + - output + """ + result_step1 = matmul_bat_einsum( + v, self.intt_bat_tf_step1, "brcmq,cqlpm->brlmp" + ) + assert self.ff_ctx is not None + result_step1_reduced = self.ff_ctx.modular_reduction(result_step1) + result_step2 = jnp.multiply( + result_step1_reduced.astype(jnp.uint64), self.intt_tf_step2 + ) + result_step2_reduced = self.ff_ctx.modular_reduction(result_step2) + result_step3 = matmul_bat_einsum( + result_step2_reduced, self.intt_bat_tf_step3, "brcmq,lqrpm->blcmp" + ) + result_step3_reduced = self.ff_ctx.modular_reduction(result_step3) + return result_step3_reduced + + ######################## + # Modulus Dropping Functions + ######################## + def drop_last_modulus(self): + self.ntt_bat_tf_step1 = self.ntt_bat_tf_step1[..., :-1] + self.ntt_tf_step2 = self.ntt_tf_step2[..., :-1] + self.ntt_bat_tf_step3 = self.ntt_bat_tf_step3[..., :-1] + self.intt_bat_tf_step1 = self.intt_bat_tf_step1[..., :-1] + self.intt_tf_step2 = self.intt_tf_step2[..., :-1] + self.intt_bat_tf_step3 = self.intt_bat_tf_step3[..., :-1] + assert self.ff_ctx is not None + self.ff_ctx.drop_last_modulus() + + +class NTTCiphertextBarrettContext(NTTCiphertextContextBase): + + def __init__(self, moduli: int, parameters: dict, perf_test=False): + super().__init__(moduli, parameters, perf_test=perf_test) + if type(self.moduli) is int: + self.moduli = [self.moduli] + if self.ff_ctx is None: + self.ff_ctx = ff_context.BarrettContext(moduli) + assert self.ff_ctx is not None, "finite_field_context must be provided" + assert ( + self.moduli == self.ff_ctx.moduli + ), "moduli must be the same as the moduli of the finite_field_context" + + +class NTTCiphertextMontgomeryContext(NTTCiphertextContextBase): + + def __init__(self, moduli: int, parameters: dict, perf_test=False): + super().__init__(moduli, parameters, perf_test=perf_test) + if type(self.moduli) is int: + self.moduli = [self.moduli] + if self.ff_ctx is None: + self.ff_ctx = ff_context.MontgomeryContext(moduli) + assert self.ff_ctx is not None, "finite_field_context must be provided" + assert ( + self.moduli == self.ff_ctx.moduli + ), "moduli must be the same as the moduli of the finite_field_context" + + +class NTTCiphertextShoupContext(NTTCiphertextContextBase): + """NTT with Shoup's Modular Reduction + + Note that Shoup's Reduction is NOT compatible with Basis Aligned + Transformation (BAT). + We use 1-d convolution to perform matrix multiplication for Shoup. + """ + + def __init__(self, moduli: int, parameters: dict, perf_test=False): + super().__init__(moduli, parameters, perf_test=perf_test) + if type(self.moduli) is int: + self.moduli = [self.moduli] + if self.ff_ctx is None: + self.ff_ctx = ff_context.ShoupContext(moduli) + assert self.ff_ctx is not None, "finite_field_context must be provided" + assert ( + self.moduli == self.ff_ctx.moduli + ), "moduli must be the same as the moduli of the finite_field_context" + + if not perf_test: + self.ntt_bat_tf_step1 = self.to_computation_format( + self.ntt_tf_step1 + ).astype(jnp.uint32) + self.ntt_tf_step2 = self.to_computation_format(self.ntt_tf_step2).astype( + jnp.uint64 + ) + self.ntt_bat_tf_step3 = self.to_computation_format( + self.ntt_tf_step3 + ).astype(jnp.uint32) + self.intt_bat_tf_step1 = self.to_computation_format( + self.intt_tf_step1 + ).astype(jnp.uint32) + self.intt_tf_step2 = self.to_computation_format( + self.intt_tf_step2 + ).astype(jnp.uint64) + self.intt_bat_tf_step3 = self.to_computation_format( + self.intt_tf_step3 + ).astype(jnp.uint32) + + self.ntt_tf_step1_shoup = self.to_shoup_computation_format( + self.ntt_tf_step1 + ).astype(jnp.uint32) + self.ntt_tf_step2_shoup = self.to_shoup_computation_format( + self.ntt_tf_step2 + ).astype(jnp.uint64) + self.ntt_tf_step3_shoup = self.to_shoup_computation_format( + self.ntt_tf_step3 + ).astype(jnp.uint32) + self.intt_tf_step1_shoup = self.to_shoup_computation_format( + self.intt_tf_step1 + ).astype(jnp.uint32) + self.intt_tf_step2_shoup = self.to_shoup_computation_format( + self.intt_tf_step2 + ).astype(jnp.uint64) + self.intt_tf_step3_shoup = self.to_shoup_computation_format( + self.intt_tf_step3 + ).astype(jnp.uint32) + + def to_computation_format(self, a: Union[np.ndarray, jax.Array]): + return self.ff_ctx.to_computation_format(a.astype(jnp.uint64)) + + def to_shoup_computation_format(self, a: Union[np.ndarray, jax.Array]): + return self.ff_ctx.precompute_constant_operand(a.astype(jnp.uint64)) + + def to_original_format(self, a: Union[np.ndarray, jax.Array]): + return self.ff_ctx.to_original_format(a.astype(jnp.uint64)) + + def get_jax_parameters(self): + return { + "ntt_tf_step1": util.to_tuple(self.ntt_tf_step1), + "ntt_tf_step2": util.to_tuple(self.ntt_tf_step2), + "ntt_tf_step3": util.to_tuple(self.ntt_tf_step3), + "intt_tf_step1": util.to_tuple(self.intt_tf_step1), + "intt_tf_step2": util.to_tuple(self.intt_tf_step2), + "intt_tf_step3": util.to_tuple(self.intt_tf_step3), + "finite_field_parameters": self.ff_ctx.get_jax_parameters(), + "rows": self.r, + "cols": self.c, + "ntt_tf_step1_shoup": util.to_tuple(self.ntt_tf_step1_shoup), + "ntt_tf_step2_shoup": util.to_tuple(self.ntt_tf_step2_shoup), + "ntt_tf_step3_shoup": util.to_tuple(self.ntt_tf_step3_shoup), + "intt_tf_step1_shoup": util.to_tuple(self.intt_tf_step1_shoup), + "intt_tf_step2_shoup": util.to_tuple(self.intt_tf_step2_shoup), + "intt_tf_step3_shoup": util.to_tuple(self.intt_tf_step3_shoup), + } + + def ntt(self, v: jax.Array): + """NTT with modular u32 + + Args: + v: - is u32 array of shape (B, R, C) - input + + Returns: + - is u32 array of shape (B, R, C) + - output + """ + conv_over_rns = jax.vmap( + lambda x, y: matmul_conv_flexible_kernel(x, y, ("NCW", "IOW", "NCW")), + in_axes=(-1, -1), + out_axes=-1, + ) + batched_conv_step1 = jax.vmap( + lambda x, y_b: conv_over_rns(x, y_b), + in_axes=(None, 0), # x is shared across B, y_b iterates over axis 0 + out_axes=0, + ) + + conv_over_rns_step3 = jax.vmap( + lambda x, y: matmul_conv_flexible_kernel(x, y, ("NCW", "IOW", "CNW")), + in_axes=(-1, -1), + out_axes=-1, + ) + batched_conv_step3 = jax.vmap( + lambda x_b, y: conv_over_rns_step3(x_b, y), + in_axes=(0, None), + out_axes=0, + ) + result_step1 = batched_conv_step1(self.ntt_tf_step1, v) + result_step1_shoup = batched_conv_step1(self.ntt_tf_step1_shoup, v) + result_step1_reduced = self.ff_ctx.modular_reduction( + result_step1, result_step1_shoup + ) + result_step2 = jnp.multiply( + result_step1_reduced.astype(jnp.uint64), self.ntt_tf_step2 + ) + result_step2_shoup = jnp.multiply( + result_step1_reduced.astype(jnp.uint64), self.ntt_tf_step2_shoup + ) + result_step2_reduced = self.ff_ctx.modular_reduction( + result_step2, result_step2_shoup + ) + result_step3 = batched_conv_step3(result_step2_reduced, self.ntt_tf_step3) + result_step3_shoup = batched_conv_step3( + result_step2_reduced, self.ntt_tf_step3_shoup + ) + result_step3_reduced = self.ff_ctx.modular_reduction( + result_step3, result_step3_shoup + ) + result_step3_reduced = result_step3_reduced.transpose(0, 2, 1, 3) + return result_step3_reduced + + def intt(self, v: jax.Array): + """INTT with modular u32 + + Args: + v: - is u32 array of shape (B, R, C) - input + + Returns: + - is u32 array of shape (B, R, C) + - output + """ + # computation + conv_over_rns = jax.vmap( + lambda x, y: matmul_conv_flexible_kernel(x, y, ("CNW", "IOW", "NCW")), + in_axes=(-1, -1), + out_axes=-1, + ) + batched_conv_step1 = jax.vmap( + lambda x_b, y: conv_over_rns(x_b, y), + in_axes=(0, None), # x is shared across B, y_b iterates over axis 0 + out_axes=0, + ) + + conv_over_rns_step3 = jax.vmap( + lambda x, y: matmul_conv_flexible_kernel(x, y, ("NCW", "IOW", "NCW")), + in_axes=(-1, -1), + out_axes=-1, + ) + batched_conv_step3 = jax.vmap( + lambda x, y_b: conv_over_rns_step3(x, y_b), + in_axes=(None, 0), + out_axes=0, + ) + v = v.transpose(0, 2, 1, 3) + result_step1 = batched_conv_step1(v, self.intt_tf_step1) + result_step1_shoup = batched_conv_step1(v, self.intt_tf_step1_shoup) + result_step1_reduced = self.ff_ctx.modular_reduction( + result_step1, result_step1_shoup + ) + result_step2 = jnp.multiply( + result_step1_reduced.astype(jnp.uint64), self.intt_tf_step2 + ) + result_step2_shoup = jnp.multiply( + result_step1_reduced.astype(jnp.uint64), self.intt_tf_step2_shoup + ) + result_step2_reduced = self.ff_ctx.modular_reduction( + result_step2, result_step2_shoup + ) + result_step3 = batched_conv_step3(self.intt_tf_step3, result_step2_reduced) + result_step3_shoup = batched_conv_step3( + self.intt_tf_step3_shoup, result_step2_reduced + ) + result_step3_reduced = self.ff_ctx.modular_reduction( + result_step3, result_step3_shoup + ) + return result_step3_reduced + + +class NTTCiphertextBATLazyContext(NTTCiphertextContextBase): + + def __init__(self, moduli: int, parameters: dict, perf_test=False): + super().__init__(moduli, parameters, perf_test=perf_test) + if type(self.moduli) is int: + self.moduli = [self.moduli] + if self.ff_ctx is None: + self.ff_ctx = ff_context.BarrettContext(moduli) + assert self.ff_ctx is not None, "finite_field_context must be provided" + assert ( + self.moduli == self.ff_ctx.moduli + ), "moduli must be the same as the moduli of the finite_field_context" + self.ff_ctx_bat_lazy = ff_context.BATLazyContext(moduli) + + def ntt(self, v: jax.Array): + """NTT with modular u32 + + B = Batch size, R = self.r, C = self.c + Q = 4 (number of bytes per element) + M = len(self.moduli) + + Args: + v: - is u32 array of shape (B, R, C, M) - will be casted into u8 array + of shape (B, R, C, M, Q) + ntt_bat_tf_step1: - is u8 array of shape (R, 4, R, 4, M) + ntt_tf_step2: - is u32 array of shape (R, C, M) + ntt_bat_tf_step3: - is u8 array of shape (C, 4, C, 4, M) + + Returns: + - is u32 array of shape (B, R, C, M) + - output + """ + result_step1 = matmul_bat_einsum( + v, self.ntt_bat_tf_step1, "brcmq,zqrpm->bzcmp" + ) # "mqkp,bknq->bmnp"; "bkncq,mqkpc->bmncp" + result_step1_reduced = self.ff_ctx_bat_lazy.modular_reduction(result_step1) + result_step2 = jnp.multiply( + result_step1_reduced.astype(jnp.uint64), self.ntt_tf_step2 + ) + result_step2_reduced = self.ff_ctx.modular_reduction(result_step2) + result_step3 = matmul_bat_einsum( + result_step2_reduced.astype(jnp.uint32), + self.ntt_bat_tf_step3, + "brcmq,cqnpm->brnmp", + ) # "bmkq,kqnp->bnmp" "bmkcq,kqnpc->bnmcp" + result_step3_reduced = self.ff_ctx.modular_reduction(result_step3) + return result_step3_reduced + + def intt(self, v: jax.Array): + """INTT with modular u32 + + B = Batch size, R = self.r, C = self.c + Q = 4 (number of bytes per element) + M = len(self.moduli) + + Args: + v: - is u32 array of shape (B, R, C, M) - will be casted into u8 array + of shape (B, R, C, M, Q) + intt_bat_tf_step1: - is u8 array of shape (C, 4, C, 4, M) + intt_tf_step2: - is u32 array of shape (R, C, M) + intt_bat_tf_step3: - is u8 array of shape (R, 4, R, 4, M) + + Returns: + - is u32 array of shape (B, R, C, M) + - output + """ + result_step1 = matmul_bat_einsum( + v, self.intt_bat_tf_step1, "brcmq,cqlpm->brlmp" + ) + result_step1_reduced = self.ff_ctx_bat_lazy.modular_reduction(result_step1) + result_step2 = jnp.multiply( + result_step1_reduced.astype(jnp.uint64), self.intt_tf_step2 + ) + result_step2_reduced = self.ff_ctx.modular_reduction(result_step2) + result_step3 = matmul_bat_einsum( + result_step2_reduced, self.intt_bat_tf_step3, "brcmq,lqrpm->blcmp" + ) + result_step3_reduced = self.ff_ctx.modular_reduction(result_step3) + return result_step3_reduced diff --git a/jaxite/jaxite_word/ntt_mm_test.py b/jaxite/jaxite_word/ntt_mm_test.py new file mode 100644 index 0000000..077221c --- /dev/null +++ b/jaxite/jaxite_word/ntt_mm_test.py @@ -0,0 +1,315 @@ +from absl.testing import absltest +from absl.testing import parameterized +import jaxite.jaxite_word.finite_field as ff_context +import jax +import jaxite.jaxite_word.ntt_mm as ntt + +jax.config.update("jax_enable_x64", True) +import numpy as np +import jax.numpy as jnp +import jaxite.jaxite_word.util as util +import os + + +NTT = [ + ( + "0", + [134219681, 134219681, 134219681], + None, + 3, + 4, + 4, + [ + [ + 105825732, + 68433452, + 36629220, + 126901109, + 89469849, + 106633716, + 15102657, + 108374459, + 68789927, + 23451922, + 93538050, + 20585372, + 30604976, + 37517995, + 65644325, + 102451383, + ], + [ + 105825732, + 68433452, + 36629220, + 126901109, + 89469849, + 106633716, + 15102657, + 108374459, + 68789927, + 23451922, + 93538050, + 20585372, + 30604976, + 37517995, + 65644325, + 102451383, + ], + [ + 105825732, + 68433452, + 36629220, + 126901109, + 89469849, + 106633716, + 15102657, + 108374459, + 68789927, + 23451922, + 93538050, + 20585372, + 30604976, + 37517995, + 65644325, + 102451383, + ], + ], + [ + [ + 26196696, + 45475009, + 10055359, + 23277424, + 69041040, + 71916973, + 73894069, + 3311254, + 44646798, + 49882443, + 28097016, + 70484730, + 10811958, + 11946041, + 61318182, + 19099272, + ], + [ + 26196696, + 45475009, + 10055359, + 23277424, + 69041040, + 71916973, + 73894069, + 3311254, + 44646798, + 49882443, + 28097016, + 70484730, + 10811958, + 11946041, + 61318182, + 19099272, + ], + [ + 26196696, + 45475009, + 10055359, + 23277424, + 69041040, + 71916973, + 73894069, + 3311254, + 44646798, + 49882443, + 28097016, + 70484730, + 10811958, + 11946041, + 61318182, + 19099272, + ], + ], + ), +] + + +class NTTTest(parameterized.TestCase): + + def __init__(self, *args, **kwargs): + super(NTTTest, self).__init__(*args, **kwargs) + self.random_key = jax.random.key(0) + + # @absltest.skip("test single implementation") + @parameterized.named_parameters(*NTT) + def test_NTT_Barrett(self, q, psi, batch, r, c, coef_in, eval_in): + b = 2 # batch size + coef_in = jnp.concatenate( + [ + jnp.array(coef_in, dtype=jnp.uint64) + .transpose(1, 0) + .reshape(1, r * c, -1) + for _ in range(b) + ], + axis=0, + ).astype(jnp.uint32) + eval_in = jnp.concatenate( + [ + jnp.array(eval_in, dtype=jnp.uint32) + .transpose(1, 0) + .reshape(1, r * c, -1) + for _ in range(b) + ], + axis=0, + ).astype(jnp.uint32) + parameters = { + "r": r, + "c": c, + "finite_field_context": ff_context.BarrettContext(moduli=q), + } + ntt_ctx = ntt.NTTCiphertextBarrettContext(moduli=q, parameters=parameters) + # bit_reverse_indices = jnp.array(util.bit_reverse_indices(r*c), jnp.uint32) + ntt_result_cf = ntt_ctx.ntt(coef_in.reshape(b, r, c, -1)) + # coef_in_br = jnp.take(ntt_result_cf.reshape(b, r*c, -1), bit_reverse_indices, axis=-2) + np.testing.assert_array_equal(eval_in, ntt_result_cf.reshape(b, r * c, -1)) + intt_result = ntt_ctx.intt(ntt_result_cf) + np.testing.assert_array_equal( + coef_in, intt_result.reshape(b, r * c, -1).tolist() + ) + + # @absltest.skip("test single implementation") + @parameterized.named_parameters(*NTT) + def test_NTT_Montgomery(self, q, psi, batch, r, c, coef_in, eval_in): + parameters = { + "r": r, + "c": c, + "finite_field_context": ff_context.MontgomeryContext(moduli=q), + } + b = 2 # batch size + coef_in = jnp.concatenate( + [ + jnp.array(coef_in, dtype=jnp.uint64) + .transpose(1, 0) + .reshape(1, r * c, -1) + for _ in range(b) + ], + axis=0, + ) + eval_in = jnp.concatenate( + [ + jnp.array(eval_in, dtype=jnp.uint32) + .transpose(1, 0) + .reshape(1, r * c, -1) + for _ in range(b) + ], + axis=0, + ) + + ntt_ctx = ntt.NTTCiphertextMontgomeryContext( + moduli=q, parameters=parameters + ) + # bit_reverse_indices = jnp.array(util.bit_reverse_indices(r*c), jnp.uint32) + test_in_cf = ( + ntt_ctx.to_computation_format(coef_in) + .astype(jnp.uint32) + .reshape(b, r, c, -1) + ) + ntt_result_cf = ntt_ctx.ntt(test_in_cf) + eval_recovered = ntt_ctx.to_original_format( + ntt_result_cf.reshape(b, r * c, -1).astype(jnp.uint64) + ) + # coef_in_br = jnp.take(eval_recovered, bit_reverse_indices, axis=-2) + np.testing.assert_array_equal(eval_in, eval_recovered) + intt_result = ntt_ctx.intt(ntt_result_cf) + x_recovered = ntt_ctx.to_original_format(intt_result.reshape(b, r * c, -1)) + np.testing.assert_array_equal( + coef_in, x_recovered.reshape(b, r * c, -1).tolist() + ) + jit_ntt = jax.jit(ntt_ctx.ntt) + jit_ntt(test_in_cf) + profile_name = f"NTT_Montgomery_Performance" + file_path = os.path.join( + os.environ.get("TEST_TMPDIR", "/tmp"), profile_name + ) + with jax.profiler.trace(file_path): + jit_ntt(test_in_cf) + + # @absltest.skip("test single implementation") + @parameterized.named_parameters(*NTT) + def test_NTT_Shoup(self, q, psi, batch, r, c, coef_in, eval_in): + b = 2 + parameters = { + "r": r, + "c": c, + "finite_field_context": ff_context.ShoupContext(moduli=q), + } + coef_in = jnp.concatenate( + [ + jnp.array(coef_in, dtype=jnp.uint64) + .transpose(1, 0) + .reshape(1, r * c, -1) + for _ in range(b) + ], + axis=0, + ) + eval_in = jnp.concatenate( + [ + jnp.array(eval_in, dtype=jnp.uint32) + .transpose(1, 0) + .reshape(1, r * c, -1) + for _ in range(b) + ], + axis=0, + ) + + ntt_ctx = ntt.NTTCiphertextShoupContext(moduli=q, parameters=parameters) + ntt_result_cf = ntt_ctx.ntt( + jnp.array(coef_in, dtype=jnp.uint32).reshape(b, r, c, -1) + ) + eval_recovered = ntt_ctx.to_original_format(ntt_result_cf) + np.testing.assert_array_equal(eval_in, eval_recovered.reshape(b, r * c, -1)) + intt_result = ntt_ctx.intt(ntt_result_cf) + x_recovered = ntt_ctx.to_original_format(intt_result) + np.testing.assert_array_equal( + coef_in, x_recovered.reshape(b, r * c, -1).tolist() + ) + + # @absltest.skip("test single implementation") + @parameterized.named_parameters(*NTT) + def test_NTT_BATLazy(self, q, psi, batch, r, c, coef_in, eval_in): + b = 2 # batch size + coef_in = jnp.concatenate( + [ + jnp.array(coef_in, dtype=jnp.uint64) + .transpose(1, 0) + .reshape(1, r * c, -1) + for _ in range(b) + ], + axis=0, + ).astype(jnp.uint32) + eval_in = jnp.concatenate( + [ + jnp.array(eval_in, dtype=jnp.uint32) + .transpose(1, 0) + .reshape(1, r * c, -1) + for _ in range(b) + ], + axis=0, + ).astype(jnp.uint32) + parameters = { + "r": r, + "c": c, + "finite_field_context": ff_context.BarrettContext(moduli=q), + } + ntt_ctx = ntt.NTTCiphertextBATLazyContext(moduli=q, parameters=parameters) + ntt_result_cf = ntt_ctx.ntt(coef_in.reshape(b, r, c, -1)) + np.testing.assert_array_equal(eval_in, ntt_result_cf.reshape(b, r * c, -1)) + intt_result = ntt_ctx.intt(ntt_result_cf) + np.testing.assert_array_equal( + coef_in, intt_result.reshape(b, r * c, -1).tolist() + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxite/jaxite_word/ntt_sm.py b/jaxite/jaxite_word/ntt_sm.py new file mode 100644 index 0000000..20fb615 --- /dev/null +++ b/jaxite/jaxite_word/ntt_sm.py @@ -0,0 +1,634 @@ +import concurrent.futures +import jax +import jax.numpy as jnp +import jaxite.jaxite_word.finite_field as ff_context +import jaxite.jaxite_word.util as util +import numpy as np + + +######################## +# Common Functions +######################## +def matmul_bat_einsum(lhs: jax.Array, rhs: jax.Array, subscripts: str): + """Basis Aligned Transformation (BAT) based matrix multiplication + + Args: + lhs (jax.Array): input + rhs (jax.Array): twiddle factor matrix + subscripts (str): einsum subscripts + + Returns: + jax.Array: result + """ + # preprocess + lhs = jax.lax.bitcast_convert_type(lhs, new_dtype=jnp.uint8) + shift_factors = jnp.array([0, 8, 16, 24], dtype=jnp.uint32) + + # computation + i8_products = jnp.einsum( + subscripts, lhs, rhs, preferred_element_type=jnp.uint32 + ) + return jnp.sum(i8_products.astype(jnp.uint64) << shift_factors, axis=(-1,)) + + +def matmul_conv_flexible_kernel( + x: jnp.ndarray, y: jnp.ndarray, subscripts: tuple[str, str, str] +) -> jnp.ndarray: + assert x.dtype == jnp.uint32 + assert y.dtype == jnp.uint32 + + lhs: jax.Array = jax.lax.bitcast_convert_type(x, new_dtype=jnp.uint8) # bnmp + rhs: jax.Array = jax.lax.bitcast_convert_type(y, new_dtype=jnp.uint8) # nk1q + # https://github.com/google/jax/issues/11483 + rhs = jax.lax.rev(rhs, [2]) + + if "NVIDIA" in jax.devices()[0].device_kind: + u8_products = jax.lax.conv_general_dilated( + lhs.astype( + jnp.int16 + ), # NVIDIA GPU does not support uint8 as input type + rhs.astype( + jnp.int16 + ), # NVIDIA GPU does not support uint8 as input type + window_strides=(1,), + padding=((3, 3),), + dimension_numbers=subscripts, + preferred_element_type=jnp.float32, # NVIDIA GPU does not support uint32 as output type + ) + else: + u8_products = jax.lax.conv_general_dilated( + lhs, + rhs, + window_strides=(1,), + padding=((3, 3),), + dimension_numbers=subscripts, + preferred_element_type=jnp.uint32, + ) + + shift_factors = jnp.array([0, 8, 16, 24, 32, 40, 48], dtype=jnp.uint32) + return jnp.sum(u8_products.astype(jnp.uint64) << shift_factors, axis=(2,)) + + +######################## +# Parameter Generation Functions +######################## +def gen_twiddle_matrix(rows, cols, q, omega): + """Precompute the twiddle matrix T of shape (rows, cols), where T[r, c] = omega^(r*c) mod q. + + Args: + rows: The number of rows in the matrix. + cols: The number of columns in the matrix. + q: The modulus. + omega: The primitive root of unity. + + Returns: + The twiddle matrix. + """ + # Vectorized modular exponentiation via exponent bit-decomposition + r_idx = np.arange(rows, dtype=np.int64)[:, None] + c_idx = np.arange(cols, dtype=np.int64)[None, :] + exponents = r_idx * c_idx # shape (rows, cols) + twiddle_matrix = np.zeros((rows, cols), dtype=int) + + def compute_row(r): + for c in range(cols): + twiddle_matrix[r, c] = pow(int(omega), int(exponents[r, c]), int(q)) + + with concurrent.futures.ThreadPoolExecutor() as executor: + list(executor.map(compute_row, range(rows))) + return twiddle_matrix + + +def gen_twiddle_matrix_inv(rows, cols, q, omega): + """Precompute the inverse twiddle matrix T_inv of shape (rows, cols). + + T_inv[r, c] = omega^{- (r*c)} mod q. + + Args: + rows: The number of rows in the matrix. + cols: The number of columns in the matrix. + q: The modulus. + omega: The primitive root of unity. + + Returns: + The inverse twiddle matrix. + """ + twiddle_matrix_inv = np.zeros((rows, cols), dtype=int) + for r in range(rows): + for c in range(cols): + twiddle_matrix_inv[r, c] = pow(int(omega), int(-r * c), int(q)) + return twiddle_matrix_inv + + +######################## +# NTT Context with different modular reduction methods +######################## +class NTTContextBase: + """Base class for NTT Context with different modular reduction methods + + This class implements the numpy version of three-step NTT algorithm. + Args: + moduli: The modulus. + transform_length: The transform length. + parameters: The parameters. + + Returns: + The NTT Context. + """ + + def __init__(self, moduli: int, parameters: dict): + self.ff_ctx = parameters.get("finite_field_context", ff_context.BarrettContext(moduli)) + self.num_bytes = 4 + + self.moduli = moduli + self.parameters = parameters + assert self.moduli < 2**31, "moduli must be less than 2**32" + self.r = parameters.get("r", 0) + self.c = parameters.get("c", 0) + assert self.r != 0, "r must be non-zero" + assert self.c != 0, "c must be non-zero" + self.transform_length = self.r * self.c + self.psi = util.root_of_unity(2 * self.transform_length, self.moduli) + self.omega = (self.psi**2) % self.moduli + + self.ntt_tf_step1, self.ntt_tf_step2, self.ntt_tf_step3 = ( + self.ntt_coefficients_precompute() + ) + self.intt_tf_step1, self.intt_tf_step2, self.intt_tf_step3 = ( + self.intt_coefficients_precompute() + ) + + self.memory_aligned_transformation() + self.ntt_tf_bat_mat_comp_step1 = self.basis_aligned_transformation( + self.to_computation_format(self.ntt_tf_mat_step1) + ) + self.ntt_tf_mat_comp_step2 = self.to_computation_format( + self.ntt_tf_mat_step2 + ).astype(jnp.uint64) + self.ntt_tf_bat_mat_comp_step3 = self.basis_aligned_transformation( + self.to_computation_format(self.ntt_tf_mat_step3) + ) + self.intt_tf_bat_mat_comp_step1 = self.basis_aligned_transformation( + self.to_computation_format(self.intt_tf_mat_step1) + ) + self.intt_tf_mat_comp_step2 = self.to_computation_format( + self.intt_tf_mat_step2 + ).astype(jnp.uint64) + self.intt_tf_bat_mat_comp_step3 = self.basis_aligned_transformation( + self.to_computation_format(self.intt_tf_mat_step3) + ) + + ######################## + # Offline Functions + ######################## + def ntt_coefficients_precompute(self): + omega_col = pow(self.omega, self.c, self.moduli) + omega_row = pow(self.omega, self.r, self.moduli) + tf_step1 = gen_twiddle_matrix(self.r, self.r, self.moduli, omega_col) + tf_step2 = gen_twiddle_matrix(self.r, self.c, self.moduli, self.omega) + tf_step3 = gen_twiddle_matrix(self.c, self.c, self.moduli, omega_row) + return tf_step1, tf_step2, tf_step3 + + def intt_coefficients_precompute(self): + omega_col = pow(self.omega, self.c, self.moduli) + omega_row = pow(self.omega, self.r, self.moduli) + inv_omega_col = pow(omega_col, -1, self.moduli) + inv_omega_row = pow(omega_row, -1, self.moduli) + intt_tf_step1 = gen_twiddle_matrix( + self.c, self.c, self.moduli, inv_omega_row + ) + intt_tf_step2 = gen_twiddle_matrix_inv( + self.r, self.c, self.moduli, self.omega + ) + # Precompute col_inv * step2 to merge the two multiplication steps in intt + col_inv = pow(self.c, -1, self.moduli) + row_inv = pow(self.r, -1, self.moduli) + intt_tf_step2 = (intt_tf_step2 * col_inv) % self.moduli + intt_tf_step3 = gen_twiddle_matrix( + self.r, self.r, self.moduli, inv_omega_col + ) + intt_tf_step3 = (intt_tf_step3 * row_inv) % self.moduli + return intt_tf_step1, intt_tf_step2, intt_tf_step3 + + def to_computation_format(self, a: np.ndarray): + return self.ff_ctx.to_computation_format(a) + + def to_original_format(self, a: np.ndarray): + return self.ff_ctx.to_original_format(a) + + def basis_aligned_transformation(self, matrix: np.ndarray): + n_row, n_col = matrix.shape # might not be the same as self.r and self.c + matrix_u64 = matrix.astype(np.uint64) + matrix_u64_byteshifted = np.array( + [matrix_u64 << (8 * byte_idx) for byte_idx in range(self.num_bytes)], + dtype=np.uint64, + ) + # shape is (4, rows, cols) + matrix_u64_byteshifted_mod_modulus = ( + matrix_u64_byteshifted % self.moduli + ).astype(np.uint32) + # shape is (4, rows, cols, bytes=4) + matrix_u8 = jax.lax.bitcast_convert_type( + matrix_u64_byteshifted_mod_modulus, jnp.uint8 + ).transpose(1, 0, 2, 3) + return matrix_u8 + + def memory_aligned_transformation(self): + """Memory Aligned Transformation (MAT) + + Must run after gen_twiddle_matrix() + """ + + def get_bit_reverse_perm(n): + """Generates a list of indices for bit-reversal permutation of size n.""" + if n <= 0: + return [] + bits = n.bit_length() - 1 + perm = [0] * n + for i in range(n): + # Reverse bits of i + r = 0 + temp = i + for _ in range(bits): + r = (r << 1) | (temp & 1) + temp >>= 1 + perm[i] = r + return perm + + perm_r = get_bit_reverse_perm(self.r) + perm_c = get_bit_reverse_perm(self.c) + self.ntt_tf_mat_step1 = self.ntt_tf_step1[perm_r, :] + self.ntt_tf_mat_step2 = self.ntt_tf_step2[perm_r, :] + self.ntt_tf_mat_step3 = self.ntt_tf_step3[:, perm_c] + self.intt_tf_mat_step1 = self.intt_tf_step1[perm_c, :] + self.intt_tf_mat_step2 = self.intt_tf_step2[perm_r, :] + self.intt_tf_mat_step3 = self.intt_tf_step3[:, perm_r] + + def ntt_three_step_reference(self, x): + """3-step NTT algorithm reference implementation + + Args: + x: The input vector. + + Returns: + The NTT result. + """ + assert ( + len(x) == self.transform_length + ), "x must have length transform_length" + twist_factor = self.twist_factor + tf_step1 = self.ntt_tf_step1.astype(np.uint64) + step2 = self.ntt_tf_step2.astype(np.uint64) + tf_step3 = self.ntt_tf_step3.astype(np.uint64) + x = np.array(x, dtype=np.uint64) + + x_twisted = np.mod(x * twist_factor, self.moduli) + x_matrix = x_twisted.reshape((self.r, self.c)) + y = np.mod(np.matmul(tf_step1, x_matrix), self.moduli) + y = np.mod(y * step2, self.moduli) + z = np.mod(np.matmul(y, tf_step3), self.moduli) + x = z.flatten() + return x.tolist() + + def intt_three_step_reference(self, x): + """3-step Inverse NTT algorithm reference implementation + + Args: + x: The input vector. + + Returns: + The Inverse NTT result. + """ + assert ( + len(x) == self.transform_length + ), "x must have length transform_length" + tf_step1 = self.intt_tf_step1.astype(np.uint64) + step2 = self.intt_tf_step2.astype(np.uint64) + tf_step3 = self.intt_tf_step3.astype(np.uint64) + x = np.array(x, dtype=np.uint64) + + z = x.reshape((self.c, self.r)) + y = np.mod(np.matmul(z, tf_step1), self.moduli) + y = np.mod(y * step2, self.moduli) # step2 includes col_inv + a = np.mod(np.matmul(tf_step3, y), self.moduli) # tf_step3 includes row_inv + x_recovered = np.array(a).flatten() + x = np.mod(x_recovered * self.untwist_factor, self.moduli) + + return x.tolist() + + ######################## + # Online Functions + ######################## + def ntt(self, v: jax.Array): + """NTT with modular u32 + + B = Batch size, R = self.r, C = self.c + Q = 4 (number of bytes per element) + + Args: + v: - is u32 array of shape (B, R, C) - will be casted into u8 array of + shape (B, R, C, Q) + ntt_bat_tf_step1: - is u8 array of shape (R, 4, R, 4) + ntt_tf_step2: - is u32 array of shape (R, C) + ntt_bat_tf_step3: - is u8 array of shape (C, 4, C, 4) + + Returns: + - is u32 array of shape (B, R, C) + - output + """ + result_step1 = matmul_bat_einsum( + v, self.ntt_tf_bat_mat_comp_step1, "brcq,zqrp->bzcp" + ) + result_step1_reduced = self.ff_ctx.modular_reduction(result_step1) + result_step2 = jnp.multiply( + result_step1_reduced.astype(jnp.uint64), self.ntt_tf_mat_comp_step2 + ) + result_step2_reduced = self.ff_ctx.modular_reduction(result_step2) + result_step3 = matmul_bat_einsum( + result_step2_reduced, self.ntt_tf_bat_mat_comp_step3, "brcq,cqnp->brnp" + ) + result_step3_reduced = self.ff_ctx.modular_reduction(result_step3) + return result_step3_reduced + + def intt(self, v: jax.Array): + """INTT with modular u32 + + B = Batch size, R = self.r, C = self.c + Q = 4 (number of bytes per element) + + Args: + v: - is u32 array of shape (B, R, C) - will be casted into u8 array of + shape (B, R, C, Q) + intt_bat_tf_step1: - is u8 array of shape (C, 4, C, 4) + intt_tf_step2: - is u32 array of shape (R, C) + intt_bat_tf_step3: - is u8 array of shape (R, 4, R, 4) + + Returns: + - is u32 array of shape (B, R, C) + - output + """ + result_step1 = matmul_bat_einsum( + v, self.intt_tf_bat_mat_comp_step1, "brcq,cqlp->brlp" + ) + result_step1_reduced = self.ff_ctx.modular_reduction(result_step1) + result_step2 = jnp.multiply( + result_step1_reduced.astype(jnp.uint64), self.intt_tf_mat_comp_step2 + ) + result_step2_reduced = self.ff_ctx.modular_reduction(result_step2) + result_step3 = matmul_bat_einsum( + result_step2_reduced, self.intt_tf_bat_mat_comp_step3, "brcq,lqrp->blcp" + ) + result_step3_reduced = self.ff_ctx.modular_reduction(result_step3) + return result_step3_reduced + + +class NTTBarrettContext(NTTContextBase): + + def __init__(self, moduli: int, parameters: dict): + super().__init__(moduli, parameters) + if type(self.moduli) is int: + self.moduli = [self.moduli] + assert self.ff_ctx is not None, "finite_field_context must be provided" + assert ( + self.moduli == self.ff_ctx.moduli + ), "moduli must be the same as the moduli of the finite_field_context" + + +class NTTMontgomeryContext(NTTContextBase): + + def __init__(self, moduli: int, parameters: dict): + super().__init__(moduli, parameters) + if type(self.moduli) is int: + self.moduli = [self.moduli] + assert self.ff_ctx is not None, "finite_field_context must be provided" + assert ( + self.moduli == self.ff_ctx.moduli + ), "moduli must be the same as the moduli of the finite_field_context" + + +class NTTBATLazyContext(NTTContextBase): + + def __init__(self, moduli: int, parameters: dict): + super().__init__(moduli, parameters) + if type(self.moduli) is int: + self.moduli = [self.moduli] + assert self.ff_ctx is not None, "finite_field_context must be provided" + assert ( + self.moduli == self.ff_ctx.moduli + ), "moduli must be the same as the moduli of the finite_field_context" + self.ff_ctx_full = ff_context.BarrettContext(moduli) + + ######################## + # Online Functions + ######################## + def ntt(self, v: jax.Array): + """NTT with modular u32 + + B = Batch size, R = self.r, C = self.c + Q = 4 (number of bytes per element) + + Args: + v: - is u32 array of shape (B, R, C) - will be casted into u8 array of + shape (B, R, C, Q) + ntt_bat_tf_mat_comp_step1: - is u8 array of shape (R, 4, R, 4) + ntt_tf_mat_comp_step2: - is u32 array of shape (R, C) + ntt_bat_tf_mat_comp_step3: - is u8 array of shape (C, 4, C, 4) + + Returns: + - is u32 array of shape (B, R, C) + - output + """ + result_step1 = matmul_bat_einsum( + v, self.ntt_tf_bat_mat_comp_step1, "brcq,zqrp->bzcp" + ) + result_step1_reduced = self.ff_ctx.modular_reduction(result_step1) + result_step2 = jnp.multiply( + result_step1_reduced.astype(jnp.uint64), self.ntt_tf_mat_comp_step2 + ) + result_step2_reduced = self.ff_ctx_full.modular_reduction(result_step2) + result_step3 = matmul_bat_einsum( + result_step2_reduced, self.ntt_tf_bat_mat_comp_step3, "brcq,cqnp->brnp" + ) + result_step3_reduced = self.ff_ctx_full.modular_reduction(result_step3) + return result_step3_reduced + + def intt(self, v: jax.Array): + """INTT with modular u32 + + B = Batch size, R = self.r, C = self.c + Q = 4 (number of bytes per element) + + Args: + v: - is u32 array of shape (B, C, R) - will be casted into u8 array of + shape (B, C, R, Q) + intt_bat_tf_mat_comp_step1: - is u8 array of shape (C, 4, C, 4) + intt_tf_mat_comp_step2: - is u32 array of shape (R, C) # Step 1 + multiplication changes its order + intt_bat_tf_mat_comp_step3: - is u8 array of shape (R, 4, R, 4) + + Returns: + - is u32 array of shape (B, R, C) + - output + """ + result_step1 = matmul_bat_einsum( + v, self.intt_tf_bat_mat_comp_step1, "brcq,cqlp->brlp" + ) + result_step1_reduced = self.ff_ctx.modular_reduction(result_step1) + result_step2 = jnp.multiply( + result_step1_reduced.astype(jnp.uint64), self.intt_tf_mat_comp_step2 + ) + result_step2_reduced = self.ff_ctx_full.modular_reduction(result_step2) + result_step3 = matmul_bat_einsum( + result_step2_reduced, self.intt_tf_bat_mat_comp_step3, "brcq,lqrp->blcp" + ) + result_step3_reduced = self.ff_ctx_full.modular_reduction(result_step3) + return result_step3_reduced + + +class NTTShoupContext(NTTContextBase): + """NTT with Shoup's Modular Reduction + + Note that Shoup's Reduction is NOT compatible with Basis Aligned + Transformation (BAT). + We use 1-d convolution to perform matrix multiplication for Shoup. + """ + + def __init__(self, moduli: int, parameters: dict): + super().__init__(moduli, parameters) + if type(self.moduli) is int: + self.moduli = [self.moduli] + assert self.ff_ctx is not None, "finite_field_context must be provided" + assert ( + self.moduli == self.ff_ctx.moduli + ), "moduli must be the same as the moduli of the finite_field_context" + self.ntt_tf_mat_step1 = self.to_computation_format( + self.ntt_tf_mat_step1 + ).astype(jnp.uint32) + self.ntt_tf_mat_step2 = self.to_computation_format( + self.ntt_tf_mat_step2 + ).astype(jnp.uint64) + self.ntt_tf_mat_step3 = self.to_computation_format( + self.ntt_tf_mat_step3 + ).astype(jnp.uint32) + self.intt_tf_mat_step1 = self.to_computation_format( + self.intt_tf_mat_step1 + ).astype(jnp.uint32) + self.intt_tf_mat_step2 = self.to_computation_format( + self.intt_tf_mat_step2 + ).astype(jnp.uint64) + self.intt_tf_mat_step3 = self.to_computation_format( + self.intt_tf_mat_step3 + ).astype(jnp.uint32) + + self.ntt_tf_mat_step1_shoup = self.to_shoup_computation_format( + self.ntt_tf_mat_step1 + ).astype(jnp.uint32) + self.ntt_tf_mat_step2_shoup = self.to_shoup_computation_format( + self.ntt_tf_mat_step2 + ).astype(jnp.uint64) + self.ntt_tf_mat_step3_shoup = self.to_shoup_computation_format( + self.ntt_tf_mat_step3 + ).astype(jnp.uint32) + self.intt_tf_mat_step1_shoup = self.to_shoup_computation_format( + self.intt_tf_mat_step1 + ).astype(jnp.uint32) + self.intt_tf_mat_step2_shoup = self.to_shoup_computation_format( + self.intt_tf_mat_step2 + ).astype(jnp.uint64) + self.intt_tf_mat_step3_shoup = self.to_shoup_computation_format( + self.intt_tf_mat_step3 + ).astype(jnp.uint32) + + def to_shoup_computation_format(self, a: np.ndarray): + shape = a.shape + a = a.flatten() + a_list = a.tolist() + a_computation_format = [ + self.ff_ctx.precompute_constant_operand(a_i) for a_i in a_list + ] + a_computation_format = np.array(a_computation_format, dtype=np.uint64) + a_computation_format = a_computation_format.reshape(*shape) + return a_computation_format + + def ntt(self, v: jax.Array): + """NTT with modular u32 + + Args: + v: - is u32 array of shape (B, R, C) - input + + Returns: + - is u32 array of shape (B, R, C) + - output + """ + result_step1 = matmul_conv_flexible_kernel( + self.ntt_tf_mat_step1, v, ("NCW", "IOW", "NCW") + ) + result_step1_shoup = matmul_conv_flexible_kernel( + self.ntt_tf_mat_step1_shoup, v, ("NCW", "IOW", "NCW") + ) + result_step1_reduced = self.ff_ctx.modular_reduction( + result_step1, result_step1_shoup + ) + result_step2 = jnp.multiply( + result_step1_reduced.astype(jnp.uint64), self.ntt_tf_mat_step2 + ) + result_step2_shoup = jnp.multiply( + result_step1_reduced.astype(jnp.uint64), self.ntt_tf_mat_step2_shoup + ) + result_step2_reduced = self.ff_ctx.modular_reduction( + result_step2, result_step2_shoup + ) + result_step3 = matmul_conv_flexible_kernel( + result_step2_reduced, self.ntt_tf_mat_step3, ("NCW", "IOW", "CNW") + ) + result_step3_shoup = matmul_conv_flexible_kernel( + result_step2_reduced, self.ntt_tf_mat_step3_shoup, ("NCW", "IOW", "CNW") + ) + result_step3_reduced = self.ff_ctx.modular_reduction( + result_step3, result_step3_shoup + ) + result_step3_reduced = result_step3_reduced.T + return result_step3_reduced.astype(jnp.uint32) + + def intt(self, v: jax.Array): + """INTT with modular u32 + + Args: + v: - is u32 array of shape (B, R, C) - input + + Returns: + - is u32 array of shape (B, R, C) + - output + """ + # computation + v = v.T + result_step1 = matmul_conv_flexible_kernel( + v, self.intt_tf_mat_step1, ("CNW", "IOW", "NCW") + ) + result_step1_shoup = matmul_conv_flexible_kernel( + v, self.intt_tf_mat_step1_shoup, ("CNW", "IOW", "NCW") + ) + result_step1_reduced = self.ff_ctx.modular_reduction( + result_step1, result_step1_shoup + ) + result_step2 = jnp.multiply( + result_step1_reduced.astype(jnp.uint64), self.intt_tf_mat_step2 + ) + result_step2_shoup = jnp.multiply( + result_step1_reduced.astype(jnp.uint64), self.intt_tf_mat_step2_shoup + ) + result_step2_reduced = self.ff_ctx.modular_reduction( + result_step2, result_step2_shoup + ) + result_step3 = matmul_conv_flexible_kernel( + self.intt_tf_mat_step3, result_step2_reduced, ("NCW", "IOW", "NCW") + ) + result_step3_shoup = matmul_conv_flexible_kernel( + self.intt_tf_mat_step3_shoup, + result_step2_reduced, + ("NCW", "IOW", "NCW"), + ) + result_step3_reduced = self.ff_ctx.modular_reduction( + result_step3, result_step3_shoup + ) + return result_step3_reduced diff --git a/jaxite/jaxite_word/ntt_sm_test.py b/jaxite/jaxite_word/ntt_sm_test.py new file mode 100644 index 0000000..bbe8abb --- /dev/null +++ b/jaxite/jaxite_word/ntt_sm_test.py @@ -0,0 +1,137 @@ +from absl.testing import absltest +from absl.testing import parameterized +import jaxite.jaxite_word.finite_field as ff_context +import jax +import jaxite.jaxite_word.ntt_sm as ntt + +jax.config.update("jax_enable_x64", True) +import numpy as np +import jax.numpy as jnp +import jaxite.jaxite_word.util as util + +NTT = [( + "0", + 134219681, + None, + 1, + 4, + 4, + [ + 105825732, + 68433452, + 36629220, + 126901109, + 89469849, + 106633716, + 15102657, + 108374459, + 68789927, + 23451922, + 93538050, + 20585372, + 30604976, + 37517995, + 65644325, + 102451383, + ], + [ + 26196696, + 45475009, + 10055359, + 23277424, + 69041040, + 71916973, + 73894069, + 3311254, + 44646798, + 49882443, + 28097016, + 70484730, + 10811958, + 11946041, + 61318182, + 19099272, + ], +)] + + +class NTTTest(parameterized.TestCase): + + def __init__(self, *args, **kwargs): + super(NTTTest, self).__init__(*args, **kwargs) + self.random_key = jax.random.key(0) + + @parameterized.named_parameters(*NTT) + def test_NTT_Barrett(self, q, psi, batch, r, c, coef_in, eval_in): + parameters = { + "r": r, + "c": c, + "finite_field_context": ff_context.BarrettContext(moduli=q), + } + ntt_ctx = ntt.NTTContextBase(moduli=q, parameters=parameters) + ntt_result_cf = ntt_ctx.ntt( + jnp.array(coef_in, dtype=jnp.uint32).reshape(-1, r, c) + ) + np.testing.assert_array_equal(eval_in, ntt_result_cf.flatten().tolist()) + intt_result = ntt_ctx.intt(ntt_result_cf) + np.testing.assert_array_equal(coef_in, intt_result[0].flatten().tolist()) + + # @absltest.skip("test single implementation") + @parameterized.named_parameters(*NTT) + def test_NTT_Montgomery(self, q, psi, batch, r, c, coef_in, eval_in): + parameters = { + "r": r, + "c": c, + "finite_field_context": ff_context.MontgomeryContext(moduli=q), + } + ntt_ctx = ntt.NTTContextBase(moduli=q, parameters=parameters) + test_in_cf = ntt_ctx.to_computation_format( + jnp.array(coef_in, dtype=jnp.uint64) + ).reshape(-1, r, c) + ntt_result_cf = ntt_ctx.ntt(test_in_cf) + eval_recovered = ntt_ctx.to_original_format( + ntt_result_cf.astype(jnp.uint64).flatten() + ) + np.testing.assert_array_equal(eval_in, eval_recovered.tolist()) + intt_result = ntt_ctx.intt(ntt_result_cf) + x_recovered = ntt_ctx.to_original_format(intt_result.astype(jnp.uint64)) + np.testing.assert_array_equal(coef_in, x_recovered[0].flatten().tolist()) + + # @absltest.skip("test single implementation") + @parameterized.named_parameters(*NTT) + def test_NTT_BATLazy(self, q, psi, batch, r, c, coef_in, eval_in): + parameters = { + "r": r, + "c": c, + "finite_field_context": ff_context.BATLazyContext(moduli=q), + } + ntt_ctx = ntt.NTTBATLazyContext(moduli=q, parameters=parameters) + ntt_result_cf = ntt_ctx.ntt( + jnp.array(coef_in, dtype=jnp.uint32).reshape(-1, r, c) + ) + np.testing.assert_array_equal(eval_in, ntt_result_cf.flatten() % q) + intt_result = ntt_ctx.intt(ntt_result_cf) + x_recovered = ntt_ctx.to_original_format(intt_result) + np.testing.assert_array_equal(coef_in, x_recovered[0].flatten() % q) + + # @absltest.skip("test single implementation") + @parameterized.named_parameters(*NTT) + def test_NTT_Shoup(self, q, psi, batch, r, c, coef_in, eval_in): + parameters = { + "r": r, + "c": c, + "finite_field_context": ff_context.ShoupContext(moduli=q), + } + ntt_ctx = ntt.NTTShoupContext(moduli=q, parameters=parameters) + ntt_result_cf = ntt_ctx.ntt( + jnp.array(coef_in, dtype=jnp.uint32).reshape(r, c) + ) + eval_recovered = ntt_ctx.to_original_format(ntt_result_cf.flatten()) + np.testing.assert_array_equal(eval_in, eval_recovered) + intt_result = ntt_ctx.intt(ntt_result_cf) + x_recovered = ntt_ctx.to_original_format(intt_result) + np.testing.assert_array_equal(coef_in, x_recovered.flatten().tolist()) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxite/jaxite_word/ntt_test.py b/jaxite/jaxite_word/ntt_test.py deleted file mode 100644 index 2bcb683..0000000 --- a/jaxite/jaxite_word/ntt_test.py +++ /dev/null @@ -1,584 +0,0 @@ -"""A module for operations on test CKKS evaluation kernels including. - -- NTT -""" - -import functools -import json -import math - -import jax -import jax.numpy as jnp -from jaxite.jaxite_word import ntt -import numpy as np - -# copybara: from google3.perftools.accelerators.xprof.api.python import xprof_analysis_client -# copybara: from google3.perftools.accelerators.xprof.api.python import xprof_session -from absl.testing import absltest -from absl.testing import parameterized - - -jax.config.update("jax_enable_x64", True) - - -TEST_PARAMS = [ - ( - "test_degree_4", - 113, - 1, - 2, - 2, - [1, 2, 4, 1], - ), - ( - "test_degree_8", - 113, - 1, - 2, - 4, - [1, 2, 4, 1, 3, 5, 6, 8], - ), -] - - -TEST_PARAMS_FULL = [ - ( - "test_degree_16_case1", - 1152921504606844513, - 7645792537133126, - 1, - 4, - 4, - [ - 83963236163890886, - 351982916215618729, - 19109541039909111, - 209897461014368538, - 310692812327896771, - 1059668300217304028, - 1050480625312143092, - 318185408722083429, - 691103099785523933, - 925312603743652858, - 943007261004624982, - 1151476381807175556, - 70951954169609321, - 259753300466854865, - 717282208767975565, - 714233289979413983, - ], - ), - ( - "test_degree_16_case2", - 1152921504606844417, - 97466480447807994, - 1, - 4, - 4, - [ - 83963265707149094, - 351982931975631049, - 19109585857500759, - 209897508168723930, - 310692857143987363, - 1059668342128390172, - 1050480648209416628, - 318185432882906469, - 691103122593273821, - 925312641738174490, - 943007287323534262, - 1151476401561067428, - 70951966266556073, - 259753315402657457, - 717282250621563469, - 714233320002290303, - ], - ), -] - - -class CKKSEvalNTTTest(parameterized.TestCase): - """A base class for running bootstrap tests. - - Example Test Case: - If use GF(17) and N = 8 (so q=17 and N=8). - In GF(17), the multiplicative group has order 16. - Suppose the forward transform used a primitive 8th root of unity. - For example, we can use omega = 2, since 2^8 mod 17 == 1 and its order is 8. - """ - - def __init__(self, *args, **kwargs): - super(CKKSEvalNTTTest, self).__init__(*args, **kwargs) - self.random_key = jax.random.key(0) - - @parameterized.named_parameters(*TEST_PARAMS) - def test_vanilla_ntt_original_form( - self, - q, - batch, - r, - c, - test_in, - ): - print("Test test_vanilla_ntt_original_form") - n = r * c - row_count, col_count = r, c # for example, n = row_count * col_count - assert row_count * col_count == n - omega = ntt.nth_primitive_root(n, q) - - ntt_result = ntt.ntt_original_form(test_in, q, omega) - print("Forward original form NTT of x:", ntt_result) - - x_recovered = ntt.intt_original_form(ntt_result, q, omega) - print("Recovered x from inverse original form NTT:", x_recovered) - - self.assertEqual(test_in, x_recovered) - - @parameterized.named_parameters(*TEST_PARAMS) - def test_vanilla_ntt_cooley_tukey( - self, - q, - batch, - r, - c, - test_in, - ): - print("Test test_vanilla_ntt_cooley_tukey") - - n = r * c - row_count, col_count = r, c # for example, n = row_count * col_count - assert row_count * col_count == n - omega = ntt.nth_primitive_root(n, q) - - ntt_result = ntt.ntt_bit_reverse(test_in, q, omega) - print("Forward bit-reverse NTT of x:", ntt_result) - - x_recovered = ntt.intt_bit_reverse(ntt_result, q, omega) - print("Recovered x from inverse bit-reverse NTT:", x_recovered) - - self.assertEqual(test_in, x_recovered) - - @parameterized.named_parameters(*TEST_PARAMS) - def test_vanilla_ntt_4_step( - self, - q, - batch, - r, - c, - test_in, - ): - print("Test test_vanilla_ntt_4_step") - - n = r * c - row_count, col_count = r, c # for example, n = row_count * col_count - assert row_count * col_count == n - omega = ntt.nth_primitive_root(n, q) - print("omega=", omega) - ntt_result = ntt.ntt_four_step(test_in, q, omega, row_count, col_count) - print("Forward 4-step NTT of x:", ntt_result) - - x_recovered = ntt.intt_four_step(ntt_result, q, omega, row_count, col_count) - # x_recovered = ntt.intt_bit_reverse(ntt_result, q, omega) - print("Recovered x from inverse 4-step NTT:", x_recovered) - - self.assertEqual(test_in, x_recovered) - - @parameterized.named_parameters(*TEST_PARAMS) - def test_negacyclic_ntt_4_step( - self, - q, - batch, - r, - c, - test_in, - ): - print("Test test_negacyclic_ntt_4_step_degree_4") - - n = r * c - row_count, col_count = r, c # for example, n = row_count * col_count - assert row_count * col_count == n - omega = ntt.nth_primitive_root(n, q) - psi = ntt.compute_psi(omega, n, q) - - ntt_result = ntt.ntt_negacyclic(test_in, q, psi, row_count, col_count) - print("Forward negacyclic NTT of x:", ntt_result) - - x_recovered = ntt.intt_negacyclic(ntt_result, q, psi, row_count, col_count) - print("Recovered x from inverse negacyclic NTT:", x_recovered) - - self.assertEqual(test_in, x_recovered) - - @parameterized.named_parameters(*TEST_PARAMS) - def test_barrett_reduction( - self, - q, - batch, - r, - c, - test_in, - ): - print("Test test_negacyclic_ntt_4_step_degree_4") - test_in = [[[69, 95, 147, 139], [7617, 6977, 8472, 7687]]] - result_ref = [[[69, 95, 34, 26], [46, 84, 110, 3]]] - s = 2 * math.ceil(math.log2(q)) - m = math.floor(2**s / q) - print(f"s={s}, m={m}") - ntt_result = ntt.barret_reduction( - jnp.array(test_in, dtype=jnp.uint64), q, s, m - ) - print("Forward negacyclic NTT of x:", ntt_result.tolist()) - - self.assertEqual(result_ref, ntt_result.tolist()) - - @parameterized.named_parameters(*TEST_PARAMS) - def test_negacyclic_ntt_tpu_algorithm( - self, - q, - batch, - r, - c, - test_in, - ): - """This is testing the 4-step implementation of negacyclic NTT O(N sqrt(N)) complexity. - - The coefficients generation parts have been moved to offline, while online - performs the - computation. - """ - print("Test test_negacyclic_ntt_tpu_algorithm") - # We use GF(q) and N = r*c (so q=q and N=r*c). - # In GF(q), the multiplicative group has order q-1. - - n = r * c - row_count, col_count = r, c # for example, n = row_count * col_count - assert row_count * col_count == n - - omega = ntt.nth_primitive_root(n, q) - psi = ntt.compute_psi(omega, n, q) - - omega_col = pow(omega, c, q) - omega_row = pow(omega, r, q) - tf_mat_step1 = jnp.array( - ntt.gen_twiddle_matrix(r, r, q, omega_col), dtype=int - ) - coef_step2 = jnp.array(ntt.gen_twiddle_matrix(r, c, q, omega), dtype=int) - tf_mat_step3 = jnp.array( - ntt.gen_twiddle_matrix(c, c, q, omega_row), dtype=int - ) - - inv_omega_col = pow( - omega_col, -1, q - ) # inverse primitive R-th root for columns - inv_omega_row = pow( - omega_row, -1, q - ) # inverse primitive C-th root for rows - inv_tf_mat_step3 = jnp.array( - ntt.gen_twiddle_matrix(r, r, q, inv_omega_col), dtype=int - ) - inv_coef_step2 = jnp.array( - ntt.gen_twiddle_matrix_inv(r, c, q, omega), dtype=int - ) - inv_tf_mat_step1 = jnp.array( - ntt.gen_twiddle_matrix(c, c, q, inv_omega_row), dtype=int - ) - - np.testing.assert_array_equal(tf_mat_step1.T, tf_mat_step1) - np.testing.assert_array_equal(tf_mat_step3.T, tf_mat_step3) - np.testing.assert_array_equal(inv_tf_mat_step1.T, inv_tf_mat_step1) - np.testing.assert_array_equal(inv_tf_mat_step3.T, inv_tf_mat_step3) - - ntt_result = ntt.ntt_negacyclic_tpu_algorithm( - test_in, q, psi, r, c, tf_mat_step1, coef_step2, tf_mat_step3 - ) - print("Forward negacyclic NTT of x:", ntt_result) - - x_recovered = ntt.intt_negacyclic_tpu_algorithm( - ntt_result, - q, - psi, - r, - c, - inv_tf_mat_step1, - inv_coef_step2, - inv_tf_mat_step3, - ) - print("Recovered x from inverse negacyclic NTT:", x_recovered) - - self.assertEqual(test_in, x_recovered) - - @parameterized.named_parameters(*TEST_PARAMS) - def test_ntt_layout_invariant_batch( - self, - q, - batch, - r, - c, - test_in, - ): - print("Test ntt_layout_invariant_batch") - s = 2 * math.ceil(math.log2(q)) - m = math.floor(2**s / q) - n = r * c - omega = ntt.nth_primitive_root(n, q) - psi = ntt.compute_psi(omega, n, q) - # psi_inv = pow(psi, -1, q) - - omega_col = pow(omega, c, q) - omega_row = pow(omega, r, q) - tf_mat_step1 = jnp.array( - ntt.gen_twiddle_matrix(r, r, q, omega_col), dtype=int - ) - coef_step2 = jnp.array(ntt.gen_twiddle_matrix(r, c, q, omega), dtype=int) - tf_mat_step3 = jnp.array( - ntt.gen_twiddle_matrix(c, c, q, omega_row), dtype=int - ) - - np.testing.assert_array_equal(tf_mat_step1.T, tf_mat_step1) - np.testing.assert_array_equal(tf_mat_step3.T, tf_mat_step3) - - tf_mat_bat_step1 = ntt.hpmatmul_offline_compile_bat( - tf_mat_step1.astype(jnp.uint32), q - ) - coef_step2 = coef_step2.astype(jnp.uint32) - tf_mat_bat_step3 = ntt.hpmatmul_offline_compile_bat( - tf_mat_step3.astype(jnp.uint32), q - ) - - tf_step1 = tf_mat_bat_step1.astype(jnp.uint8) - tf_step3 = tf_mat_bat_step3.astype(jnp.uint8) - assert tf_step1.shape == (r, r, 4, 4) - assert coef_step2.shape == (r, c) - assert tf_step3.shape == (c, c, 4, 4) - - if c == r: - np.testing.assert_array_equal(tf_mat_step1, tf_mat_step3) - np.testing.assert_array_equal(tf_mat_step1.T, tf_mat_step1) - np.testing.assert_array_equal(tf_mat_step3.T, tf_mat_step3) - - ntt_result = ntt.ntt_negacyclic_tpu_algorithm( - test_in, q, psi, r, c, tf_mat_step1, coef_step2, tf_mat_step3 - ) - - dut = functools.partial(ntt.ntt_layout_invariant_batch, q=q, s=s, m=m) - test_in_twisted = jnp.array( - [(test_in[i] * pow(psi, i, q)) % q for i in range(n)], jnp.uint32 - ) - test_in_twisted = test_in_twisted.reshape(batch, r, c) - result = dut(test_in_twisted, tf_step1, coef_step2, tf_step3) - result = np.array(result.T).flatten().tolist() - print([result[i] for i in range(n)]) - - print(f"input={test_in_twisted}, after NTT = {result}") - - self.assertEqual(ntt_result, result) - - jax.block_until_ready( - dut(test_in_twisted, tf_step1, coef_step2, tf_step3) - ) - - # copybara: session = xprof_session.XprofSession() - # copybara: session.start_session() - jax.block_until_ready( - dut(test_in_twisted, tf_step1, coef_step2, tf_step3) - ) - jax.block_until_ready( - dut(test_in_twisted, tf_step1, coef_step2, tf_step3) - ) - jax.block_until_ready( - dut(test_in_twisted, tf_step1, coef_step2, tf_step3) - ) - jax.block_until_ready( - dut(test_in_twisted, tf_step1, coef_step2, tf_step3) - ) - jax.block_until_ready( - dut(test_in_twisted, tf_step1, coef_step2, tf_step3) - ) - jax.block_until_ready( - dut(test_in_twisted, tf_step1, coef_step2, tf_step3) - ) - jax.block_until_ready( - dut(test_in_twisted, tf_step1, coef_step2, tf_step3) - ) - jax.block_until_ready( - dut(test_in_twisted, tf_step1, coef_step2, tf_step3) - ) - # copybara: session_id = session.end_session_and_get_session_id() - print(f"session_id: http://xprof/?session_id={session_id}") - # copybara: client = xprof_analysis_client.XprofAnalysisClient() - trace = client.get_profile_data("trace_viewer.json", session_id) - jtrace = json.loads(trace[1]) - results = [] - for e in jtrace["traceEvents"]: - if "ntt_layout_invariant_batch" in e["name"]: - results.append(e["dur"]) - print(jnp.mean(jnp.array(results[:8])), "us") - - @parameterized.named_parameters(*TEST_PARAMS_FULL) - def test_ntt_intt_layout_invariant_batch( - self, - q, - psi, - batch, - r, - c, - test_in, - ): - if math.log2(q) > 32: - print( - "Skip this test as we don't support modulus > 32, because numpy as" - " max precision as 64" - ) - return - s = 2 * math.ceil(math.log2(q)) - m = math.floor(2**s / q) - n = r * c - if psi is not None: - omega = (psi**2) % q - else: - omega = ntt.nth_primitive_root(n, q) - - omega_col = pow(omega, c, q) - omega_row = pow(omega, r, q) - tf_mat_step1 = jnp.array( - ntt.gen_twiddle_matrix(r, r, q, omega_col), dtype=int - ) - coef_step2 = jnp.array(ntt.gen_twiddle_matrix(r, c, q, omega), dtype=int) - tf_mat_step3 = jnp.array( - ntt.gen_twiddle_matrix(c, c, q, omega_row), dtype=int - ) - - inv_omega_col = pow(omega_col, -1, q) - # inverse primitive R-th root for columns - inv_omega_row = pow(omega_row, -1, q) - # inverse primitive C-th root for rows - inv_tf_mat_step3 = jnp.array( - ntt.gen_twiddle_matrix(r, r, q, inv_omega_col), dtype=int - ) - # intt needs to scale the corresponding coefficients. - inv_c = pow(c, -1, q) - inv_r = pow(r, -1, q) - inv_coef_step2_ori = jnp.array( - ntt.gen_twiddle_matrix_inv(r, c, q, omega), dtype=int - ) - inv_coef_step2 = inv_c * inv_coef_step2_ori - inv_tf_mat_step1 = jnp.array( - ntt.gen_twiddle_matrix(c, c, q, inv_omega_row), dtype=int - ) - - if c == r: - np.testing.assert_array_equal(tf_mat_step1, tf_mat_step3) - np.testing.assert_array_equal(tf_mat_step1.T, tf_mat_step1) - np.testing.assert_array_equal(tf_mat_step3.T, tf_mat_step3) - np.testing.assert_array_equal(inv_tf_mat_step1.T, inv_tf_mat_step1) - np.testing.assert_array_equal(inv_tf_mat_step3.T, inv_tf_mat_step3) - - tf_step1 = ntt.hpmatmul_offline_compile_bat( - tf_mat_step1.astype(jnp.uint32), q - ).astype(jnp.uint8) - coef_step2 = coef_step2.astype(jnp.uint32) - tf_step3 = ntt.hpmatmul_offline_compile_bat( - tf_mat_step3.astype(jnp.uint32), q - ).astype(jnp.uint8) - inv_tf_step1 = ntt.hpmatmul_offline_compile_bat( - inv_tf_mat_step1.astype(jnp.uint32), q - ).astype(jnp.uint8) - inv_coef_step2 = inv_coef_step2.astype(jnp.uint32) - inv_tf_step3 = ntt.hpmatmul_offline_compile_bat( - inv_tf_mat_step3.astype(jnp.uint32), q - ).astype(jnp.uint8) - - assert tf_step1.shape == (r, r, 4, 4) - assert coef_step2.shape == (r, c) - assert tf_step3.shape == (c, c, 4, 4) - assert inv_tf_step1.shape == (c, c, 4, 4) - assert inv_coef_step2.shape == (r, c) - assert inv_tf_step3.shape == (r, r, 4, 4) - - ntt_result = ntt.ntt_negacyclic_tpu_algorithm( - test_in, q, psi, r, c, tf_mat_step1, coef_step2, tf_mat_step3 - ) - - dut = functools.partial(ntt.ntt_layout_invariant_batch, q=q, s=s, m=m) - test_in_twisted = jnp.array( - [(test_in[i] * pow(psi, i, q)) % q for i in range(n)], jnp.uint32 - ) - test_in_twisted = test_in_twisted.reshape(batch, r, c) - result_jax = dut(test_in_twisted, tf_step1, coef_step2, tf_step3) - result = np.array(result_jax.T).flatten().tolist() - self.assertEqual(ntt_result, result) - - # INTT - x_ref = ntt.intt_negacyclic_tpu_algorithm( - ntt_result, - q, - psi, - r, - c, - inv_tf_mat_step1, - inv_coef_step2_ori, - inv_tf_mat_step3, - ) - self.assertEqual(x_ref, test_in) - - x_untwisted = ntt.intt_layout_invariant_batch( - result_jax, - inv_tf_step1, - inv_coef_step2, - inv_tf_step3, - inv_r, - q, - s, - m, - ) - psi_inv = pow(psi, -1, q) - x_untwisted = x_untwisted.flatten().tolist() - x_recovered = [ - (x_untwisted[i] * pow(psi_inv, i, q)) % q - for i in range(len(x_untwisted)) - ] - self.assertEqual(x_recovered, test_in) - - jax.block_until_ready( - dut(test_in_twisted, tf_step1, coef_step2, tf_step3) - ) - - # copybara: session = xprof_session.XprofSession() - # copybara: session.start_session() - jax.block_until_ready( - dut(test_in_twisted, tf_step1, coef_step2, tf_step3) - ) - jax.block_until_ready( - dut(test_in_twisted, tf_step1, coef_step2, tf_step3) - ) - jax.block_until_ready( - dut(test_in_twisted, tf_step1, coef_step2, tf_step3) - ) - jax.block_until_ready( - dut(test_in_twisted, tf_step1, coef_step2, tf_step3) - ) - jax.block_until_ready( - dut(test_in_twisted, tf_step1, coef_step2, tf_step3) - ) - jax.block_until_ready( - dut(test_in_twisted, tf_step1, coef_step2, tf_step3) - ) - jax.block_until_ready( - dut(test_in_twisted, tf_step1, coef_step2, tf_step3) - ) - jax.block_until_ready( - dut(test_in_twisted, tf_step1, coef_step2, tf_step3) - ) - # copybara: session_id = session.end_session_and_get_session_id() - print(f"session_id: http://xprof/?session_id={session_id}") - # copybara: client = xprof_analysis_client.XprofAnalysisClient() - trace = client.get_profile_data("trace_viewer.json", session_id) - jtrace = json.loads(trace[1]) - results = [] - for e in jtrace["traceEvents"]: - if "ntt_layout_invariant_batch" in e["name"]: - results.append(e["dur"]) - print(jnp.mean(jnp.array(results[:8])), "us") - - -if __name__ == "__main__": - absltest.main() diff --git a/jaxite/jaxite_word/profiler.py b/jaxite/jaxite_word/profiler.py new file mode 100644 index 0000000..dde1eee --- /dev/null +++ b/jaxite/jaxite_word/profiler.py @@ -0,0 +1,797 @@ +import csv +import gzip +import json +import os +import statistics +from typing import Any, Callable, Dict, List, Optional, Tuple, cast +import warnings +import jax +import jax.numpy as jnp +import pandas as pd + + +class DataFrameGenerator: + """A utility class for building pandas DataFrames from column data.""" + + def __init__(self): + """Initialize an empty DataFrameGenerator.""" + self.data: Dict[str, List[Any]] = {} + + def add_data(self, column_name: str, values: List[Any]) -> None: + """Add data to a specific column. + + Args: + column_name: Name of the column to add data to + values: List of values to add to the column + """ + if not isinstance(column_name, str): + raise ValueError('column_name must be a string') + if not isinstance(values, list): + raise ValueError('values must be a list') + + if column_name not in self.data: + self.data[column_name] = [] + self.data[column_name].extend(values) + + def add_single_value(self, column_name: str, value: Any) -> None: + """Add a single value to a specific column. + + Args: + column_name: Name of the column to add data to + value: Single value to add to the column + """ + self.add_data(column_name, [value]) + + def get_column_lengths(self) -> Dict[str, int]: + """Get the length of each column. + + Returns: + Dictionary mapping column names to their lengths + """ + return {col: len(values) for col, values in self.data.items()} + + def is_balanced(self) -> bool: + """Check if all columns have the same length. + + Returns: + True if all columns have the same length, False otherwise + """ + if not self.data: + return True + lengths = set(len(col) for col in self.data.values()) + return len(lengths) == 1 + + def to_dataframe(self, auto_balance: bool = True) -> pd.DataFrame: + """Convert the stored data to a pandas DataFrame. + + Args: + auto_balance: If True, automatically trim columns to the minimum length. + If False, raise an error if columns have different lengths. + + Returns: + pandas DataFrame with the stored data + + Raises: + ValueError: If auto_balance is False and columns have different lengths + """ + if not self.data: + return pd.DataFrame() + + if not auto_balance and not self.is_balanced(): + lengths = self.get_column_lengths() + raise ValueError(f'Columns have different lengths: {lengths}') + + # Find the minimum length among all columns + min_len = min(len(col) for col in self.data.values()) + + # Trim each column to the minimum length + trimmed_data = {k: v[:min_len] for k, v in self.data.items()} + + return pd.DataFrame(trimmed_data) + + def clear(self) -> None: + """Clear all stored data.""" + self.data.clear() + + def get_column_names(self) -> List[str]: + """Get the names of all columns. + + Returns: + List of column names + """ + return list(self.data.keys()) + + def has_column(self, column_name: str) -> bool: + """Check if a column exists. + + Args: + column_name: Name of the column to check + + Returns: + True if the column exists, False otherwise + """ + return column_name in self.data + + def merge(self, other_dataframe_generator: 'DataFrameGenerator'): + """Merge the stored data with another DataFrameGenerator. + + Args: + other_dataframe_generator: Another DataFrameGenerator to merge with + + Returns: + Merged DataFrameGenerator + """ + if not isinstance(other_dataframe_generator, DataFrameGenerator): + raise ValueError('other_dataframe_generator must be a DataFrameGenerator') + # Check if this DataFrameGenerator is empty + if not self.data: + self.data = other_dataframe_generator.data + return + # Check if the other DataFrameGenerator has the same column names + if not set(self.get_column_names()) == set( + other_dataframe_generator.get_column_names() + ): + print('The two DataFrameGenerators have different column names') + return + # raise ValueError("The two DataFrameGenerators have different column names") + # Merge the data + for column_name in other_dataframe_generator.get_column_names(): + self.add_data(column_name, other_dataframe_generator.data[column_name]) + + def get_header(self) -> List[str]: + """Get the header of the DataFrameGenerator. + + Returns: + List of column names + """ + return list(self.data.keys()) + + def get_row_dict(self, index: int) -> Dict[str, Any]: + """Get a row of the DataFrameGenerator. + + Returns: + Dictionary of column names and values + """ + return { + column_name: self.data[column_name][index] + for column_name in self.get_column_names() + } + + +class TraceParser: + + def __init__(self, trace_dir: str): + self.trace_dir = trace_dir + + def set_trace_dir(self, new_dir: str): + """Set a new trace directory for the parser.""" + self.trace_dir = new_dir + + def find_trace_file(self): + """Recursively search for the latest .trace.json.gz file in the trace_dir. + + Returns the full path to the file, or None if not found. + """ + trace_files = [] + for root, _, files in os.walk(self.trace_dir): + for file in files: + if file.endswith('.trace.json.gz'): + trace_files.append(os.path.join(root, file)) + + if not trace_files: + return None + + # Return the most recently modified file + return max(trace_files, key=os.path.getmtime) + + def read_trace_json(self): + """Finds, unzips, and reads the JSON content from the trace file. + + Returns the loaded JSON object, or None if not found or error. + """ + trace_file = self.find_trace_file() + if trace_file is None: + print('No trace file found.') + return None + try: + with gzip.open(trace_file, 'rt', encoding='utf-8') as f: + data = json.load(f) + return data + except Exception as e: + print(f'Error reading trace file: {e}') + return None + + def parse_trace_csv(self): + """Parses the trace CSV file and returns a list of trace events.""" + csv_file = os.path.join(self.trace_dir, 'trace_events.csv') + + # Read the trace JSON data + trace_data = self.read_trace_json() + if trace_data is None: + print('Failed to read trace data') + return None + + # Extract trace events + trace_events = trace_data.get('traceEvents', []) + if not trace_events: + print('No trace events found in the data') + return None + + headers = ['pid', 'tid', 'ts', 'dur', 'ph', 'name', 'args'] + # Write to CSV directly + with open(csv_file, 'w', newline='', encoding='utf-8') as f: + writer = csv.DictWriter(f, fieldnames=headers) + writer.writeheader() + for event in trace_events: + # Convert args dictionary to string if it exists + if 'args' in event: + event['args'] = json.dumps(event['args']) + else: + event['args'] = '' + + # Write the event + writer.writerow(event) + print(f'Trace events written to: {csv_file}') + + +def calculate_statistics(data: List[Any]) -> Dict[str, Any]: + """Calculate the statistics of the data. + + Args: + data: List of data + + Returns: + Dictionary containing the statistics + """ + mean_value = statistics.mean(data) + if len(data) == 1: + std_value = 0 + else: + std_value = statistics.stdev(data) + min_value = min(data) + max_value = max(data) + median_value = statistics.median(data) + return { + 'mean': mean_value, + 'std': std_value, + 'min': min_value, + 'max': max_value, + 'median': median_value, + } + + +def list_add(list1: List[Any], list2: List[Any]) -> List[Any]: + """Sum two lists element-wise. + + Args: + list1: First list to sum + list2: Second list to sum + + Returns: + List of the sum of the two lists + """ + assert len(list1) == len(list2), 'The two lists must have the same length' + return [e1 + e2 for e1, e2 in zip(list1, list2)] + + +class KernelWrapper: + + def __init__( + self, + kernel_name: str, + function_to_wrap: Callable, + input_structs: List[Tuple[Tuple[int, ...], jnp.dtype]], + mesh: Optional[jax.sharding.Mesh] = None, + input_shardings: Optional[Tuple[jax.sharding.Sharding, ...]] = None, + output_sharding: Optional[jax.sharding.Sharding] = None, + parameters: Optional[Dict[str, Any]] = {}, + enable_sharding: bool = False, + ): + self.kernel_name = kernel_name + self.callable_function = function_to_wrap + self.input_structs = input_structs + self.parameters = parameters + self.mesh = mesh + self.input_shardings = input_shardings + self.output_sharding = output_sharding + self.enable_sharding = enable_sharding + + self.jit_lower = None + self.jit_compiled_function = None + + # Compile immediately upon initialization + self._compile() + + def _compile(self): + jax_input_structs = [] + if self.enable_sharding and self.input_shardings: + for (shape, dtype), sharding in zip( + self.input_structs, self.input_shardings + ): + jax_input_structs.append( + jax.ShapeDtypeStruct(shape, dtype, sharding=sharding) + ) + else: + for shape, dtype in self.input_structs: + jax_input_structs.append(jax.ShapeDtypeStruct(shape, dtype)) + + # NOTE: Do not change the name of the function, it is used for profiling + if self.parameters: + + def compiled_kernel_function(*jax_array_inputs): + return self.callable_function( + *jax_array_inputs, parameters=self.parameters + ) + + else: + + def compiled_kernel_function(*jax_array_inputs): + return self.callable_function(*jax_array_inputs) + + if self.enable_sharding and self.mesh: + with self.mesh: + self.jit_lower = jax.jit( + jax.named_call(compiled_kernel_function, name=self.kernel_name), + in_shardings=self.input_shardings, + out_shardings=self.output_sharding, + ).lower(*jax_input_structs) + else: + self.jit_lower = jax.jit( + jax.named_call(compiled_kernel_function, name=self.kernel_name) + ).lower(*jax_input_structs) + + self.jit_compiled_function = self.jit_lower.compile() + + def get_compiled_function(self) -> Callable[..., jnp.ndarray]: + assert self.jit_compiled_function is not None, 'Kernel not compiled' + if self.enable_sharding and self.mesh: + mesh = self.mesh + + def compiled_with_mesh(*jax_array_inputs): + with mesh: + return self.jit_compiled_function(*jax_array_inputs) + + return compiled_with_mesh + return self.jit_compiled_function + + def get_input_structs(self): + return self.input_structs + + def get_kernel_name(self) -> str: + return self.kernel_name + + def shard_inputs(self, input_arrays: List[jnp.ndarray]) -> List[jnp.ndarray]: + """Place inputs on the provided sharding.""" + if self.enable_sharding and self.input_shardings: + return [ + jax.device_put(arr, sharding) + for arr, sharding in zip(input_arrays, self.input_shardings) + ] + return input_arrays + + +class Profiler: + + def __init__( + self, + output_trace_path: str, + profile_naming: str, + configuration: Optional[Dict[str, Any]] = None, + ): + self.trace_dir = output_trace_path + self.profiler_name = profile_naming + self.profile_dir = os.path.join(self.trace_dir, self.profiler_name) + if not os.path.exists(self.profile_dir): + os.makedirs(self.profile_dir) + + self.configuration = configuration or {} + self.random_seed = self.configuration.get('random_seed', 0) + self.iterations = self.configuration.get('iterations', 1) + self.save_to_file = self.configuration.get('save_to_file', True) + self.enable_sharding = self.configuration.get('enable_sharding', False) + + self.profiles: List[Dict[str, Any]] = [] + self.profile_name_list: List[str] = [] + + # Storage for results + self.storage_file = os.path.join( + self.profile_dir, f'{self.profiler_name}_results.csv' + ) + + def add_profile( + self, + name: str, + kernel_wrapper: KernelWrapper, + kernel_setting_cols: Dict[str, Any] = {}, + ): + if name in self.profile_name_list: + raise ValueError(f'Profiler name {name} already exists') + + self.profile_name_list.append(name) + + profile_folder = os.path.join(self.profile_dir, name) + if not os.path.exists(profile_folder): + os.makedirs(profile_folder) + + self.profiles.append({ + 'name': name, + 'wrapper': kernel_wrapper, + 'settings': kernel_setting_cols, + 'folder': profile_folder, + 'failed': False, + 'trace_events': None, + 'filtered_events': None, + 'stats': None, + }) + + def _get_input_arrays(self, kernel_wrapper: KernelWrapper): + def get_max_value(dtype): + if dtype == jnp.uint8: + return 128 + elif dtype == jnp.uint16: + return 32768 + elif dtype == jnp.uint32: + return 4294967295 + elif dtype == jnp.uint64: + return 4294967295 + raise ValueError(f'Unsupported dtype: {dtype}') + + random_key = jax.random.key(self.random_seed) + input_arrays = [] + for shape, dtype in kernel_wrapper.get_input_structs(): + if jnp.issubdtype(dtype, jnp.floating): + input_arrays.append(jax.random.uniform(random_key, shape, dtype)) + elif jnp.issubdtype(dtype, jnp.integer): + input_arrays.append( + jax.random.randint( + random_key, shape, 0, get_max_value(dtype), dtype + ) + ) + elif jnp.issubdtype(dtype, jnp.bool_): + input_arrays.append(jax.random.bernoulli(random_key, p=0.5, shape=shape)) + else: + raise ValueError(f'Unsupported dtype: {dtype}') + for input_array in input_arrays: + input_array.block_until_ready() + + if self.enable_sharding: + input_arrays = kernel_wrapper.shard_inputs(input_arrays) + + return input_arrays + + def profile_all_profilers(self): + for profile in self.profiles: + print(f"Profiling {profile['name']}") + try: + # Kernel wrapper is already compiled in its init + wrapper = cast(KernelWrapper, profile['wrapper']) + compiled_function = wrapper.get_compiled_function() + input_arrays = self._get_input_arrays(wrapper) + + with jax.profiler.trace(profile['folder']): + for _ in range(self.iterations): + compiled_function(*input_arrays).block_until_ready() + except Exception as e: + print(f"Error profiling {profile['name']}:\n {e}") + profile['failed'] = True + + def _parse_json_trace(self, profile): + trace_parser = TraceParser(profile['folder']) + trace_file_path = trace_parser.find_trace_file() + profile_json = trace_parser.read_trace_json() + if profile_json is None: + warnings.warn( + f"{profile['name']}: No trace events found in the data", UserWarning + ) + profile['failed'] = True + return None + trace_events = profile_json.get('traceEvents', []) + if not trace_events: + warnings.warn( + f"{profile['name']}: No trace events found in the data", UserWarning + ) + profile['failed'] = True + return None + if self.save_to_file: + # Save into the same folder as the raw trace file + output_dir = os.path.dirname(trace_file_path) + profile['output_folder'] = output_dir + with open(os.path.join(output_dir, 'trace_events.json'), 'w') as f: + json.dump(trace_events, f, indent=2) + profile['trace_events'] = trace_events + return trace_events + + def _filter_trace_events(self, profile): + trace_events = profile['trace_events'] + if trace_events is None: + return None + + def merge_filtered_events_by_name(filtered_events): + grouped = {} + for event in filtered_events: + event_name = event.get('name', 'unknown') + if ( + 'args' in event.keys() + and 'deduplicated_name' in event['args'].keys() + ): + event_name += '_' + event['args']['deduplicated_name'] + elif ( + 'custom-call' in event['name'] + and 'args' in event.keys() + and 'tf_op' in event['args'].keys() + ): + event_name += '_' + event['args']['tf_op'] + if event_name not in grouped: + grouped[event_name] = [] + grouped[event_name].append(event) + + merged_filtered_events = {} + for event_name, events in grouped.items(): + merged = events[0].copy() + merged['dur'] = [e.get('dur') for e in events if 'dur' in e] + merged['ts'] = [e.get('ts') for e in events if 'ts' in e] + merged['repeat_count'] = len(events) + merged_filtered_events[event_name] = merged + return merged_filtered_events + + filtered_events_list = [] + # Check if NVIDIA is in device kind OR CPU is used as a fallback if explicit check needed + # But generally JAX trace events differ by backend. + # Assuming typical CPU/GPU separation. + device_kind = jax.devices()[0].device_kind + + if 'NVIDIA' in device_kind: + for e in trace_events: + if 'args' in e and 'tf_op' in e['args']: + # Loosen the check for compiled_kernel_function as it might be nested differently or named differently + if 'compiled_kernel_function' in e['args'].get( + 'hlo_module', '' + ) or 'compiled_kernel_function' in e['args'].get('long_name', ''): + merged_event = False + # Try to merge with existing events + for f in filtered_events_list: + # Check if correlation_id exists before accessing it + if ( + 'correlation_id' in f['args'] + and 'correlation_id' in e['args'] + and f['args']['correlation_id'] == e['args']['correlation_id'] + and f['name'] == e['name'] + ): + f['dur'] = f['dur'] + e['dur'] + merged_event = True + if not merged_event: + filtered_events_list.append(e) + profile['filtered_events'] = merge_filtered_events_by_name( + filtered_events_list + ) + + elif 'TPU' in device_kind: + for event in trace_events: + if ( + 'pid' not in event.keys() or event['pid'] != 3 + ): # ToDo: change it into automatic PID detection based on "TPU:0". + continue + if ( + 'name' in event.keys() + and 'compiled_kernel_function' in event['name'] + and 'args' in event.keys() + ): + filtered_events_list.append(event) + elif 'args' in event.keys() and 'long_name' in event['args'].keys(): + filtered_events_list.append(event) + else: + continue + profile['filtered_events'] = merge_filtered_events_by_name( + filtered_events_list + ) + else: + # Fallback for CPU or other devices + # CPU traces might be different. Let's try to capture events related to our kernel. + for event in trace_events: + if 'name' in event and 'compiled_kernel_function' in event['name']: + filtered_events_list.append(event) + profile['filtered_events'] = merge_filtered_events_by_name( + filtered_events_list + ) + + # Always save filtered events if we have any + if self.save_to_file: + # Make sure we don't crash if profile['filtered_events'] is None + events_to_dump = ( + profile['filtered_events'] + if profile['filtered_events'] is not None + else {} + ) + with open( + os.path.join(profile['output_folder'], 'filtered_events.json'), 'w' + ) as f: + json.dump(events_to_dump, f, indent=2) + + def _calculate_profiling_statistics(self, profile): + if profile['filtered_events'] is None: + return + + repeat_count = self.iterations + kernel_duration = [0] * repeat_count + + device_kind = jax.devices()[0].device_kind + + if 'NVIDIA' in device_kind: + for event in profile['filtered_events'].values(): + if 'compiled_kernel_function' in event['args'].get('hlo_module', ''): + durations = event['dur'] + if not isinstance(durations, list): + durations = [durations] + + if len(durations) == repeat_count: + kernel_duration = list_add(kernel_duration, durations) + elif ( + len(durations) > repeat_count + and len(durations) % repeat_count == 0 + ): + # Assume sequential execution of kernels within one iteration + chunk_size = len(durations) // repeat_count + aggregated_durations = [ + sum(durations[i * chunk_size : (i + 1) * chunk_size]) + for i in range(repeat_count) + ] + kernel_duration = list_add(kernel_duration, aggregated_durations) + else: + # Fallback: just take first N or handle mismatch. + # For now, adopting CPU strategy of taking first N but this is likely under-reporting. + # Ideally log a warning. + kernel_duration = list_add( + kernel_duration, durations[:repeat_count] + ) + elif 'TPU' in device_kind: + for event in profile['filtered_events'].values(): + if 'compiled_kernel_function' in event['name']: + kernel_duration = list_add(kernel_duration, event['dur']) + else: + # CPU logic - assuming direct name match from filtered events + for event in profile['filtered_events'].values(): + # On CPU, events might be simpler + if 'compiled_kernel_function' in event.get('name', ''): + # DUR might be a single value or list depending on how it was merged + durations = event['dur'] + if not isinstance(durations, list): + durations = [durations] + + # If we have less durations than repeat_count, we might need to pad or it's a mismatch + # For now, let's just add what we have, assuming 1-to-1 or aggregated + if len(durations) == repeat_count: + kernel_duration = list_add(kernel_duration, durations) + elif len(durations) > repeat_count: + # Take first N + kernel_duration = list_add( + kernel_duration, durations[:repeat_count] + ) + else: + # Append 0s? Or just take what we have + padded = durations + [0] * (repeat_count - len(durations)) + kernel_duration = list_add(kernel_duration, padded) + + profile['stats'] = { + 'kernel_all': kernel_duration, + } + + def post_process_all_profilers(self): + for profile in self.profiles: + if profile['failed']: + continue + + events = self._parse_json_trace(profile) + if events is None: + continue + + self._filter_trace_events(profile) + self._calculate_profiling_statistics(profile) + + self.write_results() + + def get_profiling_dataframe_generator_all_profilers(self): + df_generator = DataFrameGenerator() + for profile in self.profiles: + if profile['failed'] or profile['stats'] is None: + continue + + p_df_gen = DataFrameGenerator() + p_df_gen.add_single_value( + 'operation_name', profile['wrapper'].get_kernel_name() + ) + + for key, value in profile['settings'].items(): + p_df_gen.add_single_value(key, value) + + all_kernel_duration = profile['stats']['kernel_all'] + for i, duration in enumerate(all_kernel_duration): + p_df_gen.add_single_value(f'sample_{i}', duration) + + df_generator.merge(p_df_gen) + return df_generator + + def write_results(self): + storage_dataframe_generator = ( + self.get_profiling_dataframe_generator_all_profilers() + ) + # Check if file exists to determine if we need to write header + file_exists = os.path.exists(self.storage_file) + mode = 'a' if file_exists else 'w' + header = not file_exists + storage_dataframe_generator.to_dataframe().to_csv( + self.storage_file, mode=mode, header=header, index=False + ) + print( + storage_dataframe_generator.to_dataframe().to_csv() + ) # Need to see the content of the file in terminal as Google does not have file system + print(f'Results written to: {self.storage_file}') + + +def collect_logs(root_dir='.', output_csv_name='all_logs_collected'): + """Collects all CSV files found under directories named 'log' + + and aggregates them into a single CSV file. + Handles varying headers by taking the union of all found columns. + """ + all_files = [] + + # Fieldnames set to collect all unique columns + all_fieldnames = set() + # To preserve some order, we can use a list and add new ones as we see them + ordered_fieldnames = [] + + # First pass: identify files and collect all possible fieldnames + for dirpath, dirnames, filenames in os.walk(root_dir): + path_parts = dirpath.split(os.sep) + if 'log' in path_parts: + for file in filenames: + if file.endswith('.csv'): + full_path = os.path.join(dirpath, file) + all_files.append(full_path) + try: + with open(full_path, 'r', newline='') as csvfile: + reader = csv.reader(csvfile) + try: + header = next(reader) + for h in header: + if h not in all_fieldnames: + all_fieldnames.add(h) + ordered_fieldnames.append(h) + except StopIteration: + # Empty file + pass + except Exception as e: + print(f'Error reading header of {full_path}: {e}') + + if not all_files: + print('No CSV files found.') + return + + print(f'Found {len(all_files)} CSV files.') + print(f'Unified collected columns: {ordered_fieldnames}') + + output_file = os.path.join(root_dir, f'{output_csv_name}.csv') + total_rows = 0 + + try: + with open(output_file, 'w', newline='') as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=ordered_fieldnames) + writer.writeheader() + + for full_path in all_files: + try: + with open(full_path, 'r', newline='') as infile: + reader = csv.DictReader(infile) + # The DictReader uses the file's own header mapping + # We just iterate and write to the master dict writer + for row in reader: + writer.writerow(row) + total_rows += 1 + except Exception as e: + print(f'Error processing {full_path}: {e}') + + print(f'Saved aggregated logs to {os.path.abspath(output_file)}') + print(f'Total rows collected: {total_rows}') + + except Exception as e: + print(f'Error writing output file: {e}') diff --git a/jaxite/jaxite_ckks/rns.py b/jaxite/jaxite_word/rns.py similarity index 95% rename from jaxite/jaxite_ckks/rns.py rename to jaxite/jaxite_word/rns.py index e3d63c4..a748ce8 100644 --- a/jaxite/jaxite_ckks/rns.py +++ b/jaxite/jaxite_word/rns.py @@ -6,10 +6,9 @@ """ import dataclasses -from typing import Any import jax.numpy as jnp -from jaxite.jaxite_ckks import rns_utils +import jaxite.jaxite_word.util as util def _mod_exp(x: int, n: int, q: int) -> int: @@ -29,7 +28,7 @@ def _primitive_root(m: int, q: int): Note: q-1 must be divisible by m for the primitive roots to exist. Also assume m is a power of two. """ - if not rns_utils.is_power_of_two(m): + if not util.is_power_of_two(m): raise ValueError('`m` must be a power of two.') if (q - 1) % m != 0: raise ValueError('`q - 1` must be divisible by m.') @@ -66,39 +65,39 @@ class Ntt: psis_inv_bitrev: list[int] = dataclasses.field(init=False) def __post_init__(self): - if not rns_utils.is_power_of_two(self.n): + if not util.is_power_of_two(self.n): raise ValueError('`n` must be a power of two.') if self.q % (2 * self.n) != 1: raise ValueError('`q - 1` must be divisible by 2N.') n = self.n q = self.q - self.n_inv_mod_q = rns_utils.inverse_mod(n, q) + self.n_inv_mod_q = util.modinv(n, q) # Generating the powers of primitive root psi to be used in Cooley-Tukey and # Gentleman-Sande. psi = _primitive_root(2 * n, q) self.psis_bitrev = [_mod_exp(psi, i, q) for i in range(n)] self.psis_inv_bitrev = list(self.psis_bitrev) - rns_utils.bit_reversal_array(self.psis_bitrev) + self.psis_bitrev = util.bit_reverse_array(self.psis_bitrev) self.psis_inv_bitrev = ( self.psis_inv_bitrev[:1:] + self.psis_inv_bitrev[:0:-1] ) neg_psi_inv = self.psis_inv_bitrev[1] # psi^(n-1) = -psi^{-1} mod q psi_inv = (-neg_psi_inv) % q # psi^{-1} mod q - rns_utils.bit_reversal_array(self.psis_inv_bitrev) + self.psis_inv_bitrev = util.bit_reverse_array(self.psis_inv_bitrev) self.psis_inv_bitrev[0] = (self.psis_inv_bitrev[0] * psi_inv) % q for i in range(1, n): self.psis_inv_bitrev[i] = (self.psis_inv_bitrev[i] * neg_psi_inv) % q def forward(self, coeffs: list[int]) -> None: """Forward NTT.""" - self._iterative_cooley_tukey(coeffs, rns_utils.num_bits(len(coeffs))) + self._iterative_cooley_tukey(coeffs, util.num_bits(len(coeffs))) def backward(self, coeffs: list[int]) -> None: """Backward NTT (normalized).""" - self._iterative_gentleman_sande(coeffs, rns_utils.num_bits(len(coeffs))) + self._iterative_gentleman_sande(coeffs, util.num_bits(len(coeffs))) for i in range(len(coeffs)): coeffs[i] = (coeffs[i] * self.n_inv_mod_q) % self.q diff --git a/jaxite/jaxite_ckks/rns_test.py b/jaxite/jaxite_word/rns_test.py similarity index 66% rename from jaxite/jaxite_ckks/rns_test.py rename to jaxite/jaxite_word/rns_test.py index a71d98b..1343c03 100644 --- a/jaxite/jaxite_ckks/rns_test.py +++ b/jaxite/jaxite_word/rns_test.py @@ -1,25 +1,23 @@ """Tests for RnsPolynomial.""" import random - -from jaxite.jaxite_ckks import rns -from jaxite.jaxite_ckks import rns_utils -import parameterized - from absl.testing import absltest -from absl.testing import parameterized as parameterized_test +from absl.testing import parameterized +import jaxite.jaxite_word.rns as rns +import jaxite.jaxite_word.util as util -@parameterized.parameterized_class([ - {"degree": 8, "moduli": [12289]}, - {"degree": 16, "moduli": [12289, 65537]}, - {"degree": 1024, "moduli": [12289, 65537]}, -]) -class RnsPolynomialTest(absltest.TestCase): - def setUp(self): - super().setUp() - self.ntt_params = [rns.Ntt(self.degree, modulus) for modulus in self.moduli] +@parameterized.named_parameters([ + {"testcase_name": "deg8_single_mod", "degree": 8, "moduli": [12289]}, + {"testcase_name": "deg16_two_mod", "degree": 16, "moduli": [12289, 65537]}, + { + "testcase_name": "deg1024_two_mod", + "degree": 1024, + "moduli": [12289, 65537], + }, +]) +class RnsPolynomialTest(parameterized.TestCase): def _random_coeffs(self, degree: int, modulus: int) -> list[int]: return [random.randint(0, modulus - 1) for _ in range(degree)] @@ -30,9 +28,9 @@ def _random_rns_polynomial( coeffs = [self._random_coeffs(degree, modulus) for modulus in moduli] return rns.RnsPolynomial(degree, moduli, coeffs, is_ntt=is_ntt) - def test_ntt(self): - ntt = rns.Ntt(self.degree, self.moduli[0]) - coeffs = self._random_coeffs(self.degree, self.moduli[0]) + def test_ntt(self, degree, moduli): + ntt = rns.Ntt(degree, moduli[0]) + coeffs = self._random_coeffs(degree, moduli[0]) # NTT^-1( NTT( coeffs )) should be the same as coeffs. evals = list(coeffs) ntt.forward(evals) @@ -40,13 +38,13 @@ def test_ntt(self): ntt.backward(coeffs_back) self.assertEqual(coeffs, coeffs_back) - def test_iterative_cooley_tukey(self): + def test_iterative_cooley_tukey(self, degree, moduli): # Skip large degrees as we compute the expected results in O(n^2). - if self.degree >= 32: + if degree >= 32: return - n = self.degree - q = self.moduli[0] + n = degree + q = moduli[0] ntt = rns.Ntt(n, q) psi = rns._primitive_root(2 * n, q) coeffs = self._random_coeffs(n, q) @@ -54,25 +52,24 @@ def test_iterative_cooley_tukey(self): # evaluation form of the polynomial has coefficients c'_i, i = 0..n-1, for # c'_i = sum(psi^j * coeffs[j] * psi^(2i*j), j = 0..n-1) expected_ntt_coeffs = [ - sum([psi ** j * coeffs[j] * psi ** (2 * i * j) % q for j in range(n)]) - % q + sum([psi**j * coeffs[j] * psi ** (2 * i * j) % q for j in range(n)]) % q for i in range(n) ] - rns_utils.bit_reversal_array(expected_ntt_coeffs) + expected_ntt_coeffs = util.bit_reverse_array(expected_ntt_coeffs) ntt_coeffs = coeffs.copy() - ntt._iterative_cooley_tukey(ntt_coeffs, rns_utils.num_bits(len(coeffs))) + ntt._iterative_cooley_tukey(ntt_coeffs, util.num_bits(len(coeffs))) self.assertEqual(ntt_coeffs, expected_ntt_coeffs) - def test_iterative_gentleman_sande(self): + def test_iterative_gentleman_sande(self, degree, moduli): # Skip large degrees as we compute the expected results in O(n^2). - if self.degree >= 32: + if degree >= 32: return - n = self.degree - q = self.moduli[0] + n = degree + q = moduli[0] ntt = rns.Ntt(n, q) psi = rns._primitive_root(2 * n, q) - psi_inv = rns_utils.inverse_mod(psi, q) + psi_inv = util.modinv(psi, q) ntt_coeffs = self._random_coeffs(n, q) # Since we want to compute negacyclic convolution in Z[X]/(q, X^n+1), the # coefficient form of the polynomial has coefficients c_i, i = 0..n-1, for @@ -84,15 +81,15 @@ def test_iterative_gentleman_sande(self): for i in range(n) ] coeffs_bitrev = ntt_coeffs.copy() - rns_utils.bit_reversal_array(coeffs_bitrev) + coeffs_bitrev = util.bit_reverse_array(coeffs_bitrev) ntt._iterative_gentleman_sande( - coeffs_bitrev, rns_utils.num_bits(len(ntt_coeffs)) + coeffs_bitrev, util.num_bits(len(ntt_coeffs)) ) self.assertEqual(coeffs_bitrev, expected_coeffs) - def test_rns_polynomial_addition(self): - poly0 = self._random_rns_polynomial(self.degree, self.moduli, is_ntt=False) - poly1 = self._random_rns_polynomial(self.degree, self.moduli, is_ntt=False) + def test_rns_polynomial_addition(self, degree, moduli): + poly0 = self._random_rns_polynomial(degree, moduli, is_ntt=False) + poly1 = self._random_rns_polynomial(degree, moduli, is_ntt=False) # First compute a + b in the coefficient form. poly_sum0 = poly0 + poly1 @@ -100,38 +97,40 @@ def test_rns_polynomial_addition(self): # Then compute a + b in the NTT form. The result (once converted back to the # coefficient form) should be the same. - poly0.to_ntt_form(self.ntt_params) - poly1.to_ntt_form(self.ntt_params) + ntt_params = [rns.Ntt(degree, modulus) for modulus in moduli] + poly0.to_ntt_form(ntt_params) + poly1.to_ntt_form(ntt_params) assert poly0.is_ntt assert poly1.is_ntt poly_sum1 = poly0 + poly1 assert poly_sum1.is_ntt - poly_sum1.to_coeffs_form(self.ntt_params) + poly_sum1.to_coeffs_form(ntt_params) self.assertEqual(poly_sum1, poly_sum0) - def test_rns_polynomial_negation(self): - zero_coeffs = [[0] * self.degree for _ in range(len(self.moduli))] - zero = rns.RnsPolynomial(self.degree, self.moduli, zero_coeffs) + def test_rns_polynomial_negation(self, degree, moduli): + zero_coeffs = [[0] * degree for _ in range(len(moduli))] + zero = rns.RnsPolynomial(degree, moduli, zero_coeffs) - poly0 = self._random_rns_polynomial(self.degree, self.moduli, is_ntt=False) + poly0 = self._random_rns_polynomial(degree, moduli, is_ntt=False) poly0_neg = -poly0 assert not poly0_neg.is_ntt poly0_sum = poly0 + poly0_neg assert not poly0_sum.is_ntt self.assertEqual(poly0_sum, zero) - poly1 = self._random_rns_polynomial(self.degree, self.moduli, is_ntt=True) + poly1 = self._random_rns_polynomial(degree, moduli, is_ntt=True) poly1_neg = -poly1 assert poly1_neg.is_ntt poly1_sum = poly1 + poly1_neg assert poly1_sum.is_ntt - poly1_sum.to_coeffs_form(self.ntt_params) + ntt_params = [rns.Ntt(degree, modulus) for modulus in moduli] + poly1_sum.to_coeffs_form(ntt_params) self.assertEqual(poly1_sum, zero) - def test_rns_polynomial_multiplication(self): - a = self._random_rns_polynomial(self.degree, self.moduli, is_ntt=True) - b = self._random_rns_polynomial(self.degree, self.moduli, is_ntt=True) - c = self._random_rns_polynomial(self.degree, self.moduli, is_ntt=True) + def test_rns_polynomial_multiplication(self, degree, moduli): + a = self._random_rns_polynomial(degree, moduli, is_ntt=True) + b = self._random_rns_polynomial(degree, moduli, is_ntt=True) + c = self._random_rns_polynomial(degree, moduli, is_ntt=True) # First we compute (a + b) * c ab = a + b @@ -144,44 +143,45 @@ def test_rns_polynomial_multiplication(self): assert acbc.is_ntt # Check (a + b) * c = a * c + b * c) self.assertEqual(abc, acbc) - - -class RnsNegativeTest(parameterized_test.TestCase): + + +class RnsNegativeTest(parameterized.TestCase): """Testing negative cases for RNS implementation.""" - + def setUp(self): super().setUp() self.degree = 8 self.moduli = [12289] - - @parameterized_test.named_parameters( + @parameterized.named_parameters( { - 'testcase_name': 'n_is_zero', - 'invalid_n': 0, + "testcase_name": "n_is_zero", + "invalid_n": 0, }, { - 'testcase_name': 'n_is_odd', - 'invalid_n': 7, + "testcase_name": "n_is_odd", + "invalid_n": 7, }, - ) + ) def test_create_ntt_with_invalid_n(self, invalid_n): with self.assertRaises(ValueError): rns.Ntt(invalid_n, self.moduli[0]) - + def test_create_ntt_with_invalid_q(self): - invalid_q = 2 * self.degree # set q = 2*N which isn't NTT-friendly + invalid_q = 2 * self.degree # set q = 2*N which isn't NTT-friendly with self.assertRaises(ValueError): rns.Ntt(self.degree, invalid_q) - + def test_add_polynomials_with_different_degrees(self): coeffs0 = [[0] * self.degree for _ in range(len(self.moduli))] coeffs1 = [[0] * (self.degree * 2) for _ in range(len(self.moduli))] poly0 = rns.RnsPolynomial(self.degree, self.moduli, coeffs0, is_ntt=False) - poly1 = rns.RnsPolynomial(self.degree * 2, self.moduli, coeffs1, is_ntt=False) + poly1 = rns.RnsPolynomial( + self.degree * 2, self.moduli, coeffs1, is_ntt=False + ) with self.assertRaises(ValueError): poly0 + poly1 - + def test_add_polynomials_with_incompatible_coeffs(self): coeffs0 = [[0] * self.degree for _ in range(len(self.moduli))] coeffs1 = [[0] * self.degree for _ in range(len(self.moduli) + 1)] @@ -189,7 +189,7 @@ def test_add_polynomials_with_incompatible_coeffs(self): poly1 = rns.RnsPolynomial(self.degree, self.moduli, coeffs1, is_ntt=False) with self.assertRaises(ValueError): poly0 + poly1 - + def test_add_polynomials_with_different_moduli(self): moduli0 = self.moduli moduli1 = moduli0 + [65537] @@ -199,22 +199,21 @@ def test_add_polynomials_with_different_moduli(self): poly1 = rns.RnsPolynomial(self.degree, moduli1, coeffs1, is_ntt=False) with self.assertRaises(ValueError): poly0 + poly1 - + def test_add_polynomials_with_different_forms(self): coeffs = [[0] * self.degree for _ in range(len(self.moduli))] poly0 = rns.RnsPolynomial(self.degree, self.moduli, coeffs, is_ntt=False) poly1 = rns.RnsPolynomial(self.degree, self.moduli, coeffs, is_ntt=True) with self.assertRaises(ValueError): poly0 + poly1 - + def test_multiply_polynomials_in_coefficient_form(self): coeffs = [[0] * self.degree for _ in range(len(self.moduli))] poly0 = rns.RnsPolynomial(self.degree, self.moduli, coeffs, is_ntt=False) poly1 = rns.RnsPolynomial(self.degree, self.moduli, coeffs, is_ntt=False) with self.assertRaises(ValueError): poly0 * poly1 - - + if __name__ == "__main__": absltest.main() diff --git a/jaxite/jaxite_word/sub.py b/jaxite/jaxite_word/sub.py deleted file mode 100644 index 3527a98..0000000 --- a/jaxite/jaxite_word/sub.py +++ /dev/null @@ -1,56 +0,0 @@ -"""TPU kernels for Evaluation of the CKKS algorithm.""" - -import jax -import jax.numpy as jnp - - -def jax_sub(value_a: jax.Array, value_b: jax.Array, modulus_list: jax.Array): - """This function processes all degree of the two input polynomials in parallel using multi-trheading. - - Assuming the input data type is jax array of shape (n, k, d) where - n: Number of polynomials in the ciphertext - k: The number of limbs - d: The degree of the polynomials - - Args: - value_a: the first operand of the subtraction. - value_b: the second operand of the subtraction. - modulus_list: the list of moduli for each degree. - - Returns: - The result of the subtraction. - """ - num_elements, _, degree = value_a.shape - modulus_broadcast = jnp.tile( - modulus_list[None, :, None], (num_elements, 1, degree) - ) - result = value_a - value_b - result_mod_back = modulus_broadcast + result - return jnp.where( - value_a > value_b, result, result_mod_back - ) # jnp.mod(value_a + value_b, modulus_broadcast) - - -def vmap_sub(value_a: jax.Array, value_b: jax.Array, modulus_list: jax.Array): - """This function processes all degree of the two input polynomials in SIMD using jax.vmap. - - Assuming the input data type is jax array. - - Args: - value_a: the first operand of the subtraction. - value_b: the second operand of the subtraction. - modulus_list: the list of moduli for each degree. - - Returns: - The result of the subtraction. - """ - num_elements, num_towers, degree = value_a.shape - modulus_broadcast = jnp.tile( - modulus_list[None, :, None], (num_elements, 1, degree) - ) - - def chunk_wise_subtract(value_a, value_b, mod): - result = value_a - value_b - return jnp.where(value_a > value_b, result, result + mod) - - return jax.vmap(chunk_wise_subtract)(value_a, value_b, modulus_broadcast) diff --git a/jaxite/jaxite_word/sub_test.py b/jaxite/jaxite_word/sub_test.py deleted file mode 100644 index 287dac3..0000000 --- a/jaxite/jaxite_word/sub_test.py +++ /dev/null @@ -1,112 +0,0 @@ -"""A module for operations on test CKKS evaluation kernels including. - -- Modsub -- HESub -""" - -from concurrent import futures -from typing import Any, Callable - -import jax -import jax.numpy as jnp -from jaxite.jaxite_word import sub - -from absl.testing import absltest -from absl.testing import parameterized - - -ProcessPoolExecutor = futures.ProcessPoolExecutor - -jax.config.update("jax_enable_x64", True) -jax.config.update("jax_traceback_filtering", "off") - - -class CKKSEvalSubTest(parameterized.TestCase): - """A base class for running bootstrap tests.""" - - def __init__(self, *args, **kwargs): - super(CKKSEvalSubTest, self).__init__(*args, **kwargs) - self.debug = False # dsiable it from printing the test input values - self.modulus_element_0_tower_0 = 1152921504606748673 # 60 (k=60->2k=120) - self.modulus_element_0_tower_1 = 268664833 # 28 (k=28->2k=56) - self.modulus_element_0_tower_2 = 557057 # 19 (k=19->2k=38) - self.random_key = jax.random.key(0) - self.in_c1 = [ - [761974115069642497, 186812814, 396780], - [1119697542422587247, 195711320, 415240], - ] - self.in_c2 = [ - [723287396072165360, 91967352, 112274], - [251652059326221653, 111494737, 534294], - ] - self.refer_sub_result = [ - [38686718997477137, 94845462, 284506], - [868045483096365594, 84216583, 438003], - ] - - self.random_key = jax.random.key(0) - - def random(self, shape, modulus_list, dtype=jnp.int32): - assert len(modulus_list) == shape[1] - - return jnp.concatenate( - [ - jax.random.randint( - self.random_key, - shape=(shape[0], 1, shape[2]), - minval=0, - maxval=bound, - dtype=dtype, - ) - for bound in modulus_list - ], - axis=1, - ) - - @parameterized.named_parameters( - dict( - testcase_name="jax_sub", - test_target=sub.jax_sub, - modulus_list=[1152921504606748673, 268664833, 557057], - shape=(2, 3, 16384), # number of elements, number of towers, degree - ), - dict( - testcase_name="vmap_sub", - test_target=sub.vmap_sub, - modulus_list=[1152921504606748673, 268664833, 557057], - shape=(2, 3, 16384), # number of elements, number of towers, degree - ), - ) - def test_sub( - self, - test_target: Callable[[Any, Any, Any], Any], - modulus_list=jax.Array, - shape=tuple[int, int, int], - ): - """This function tests the sub function using Python native integer data type with arbitrary precision. - - This test finishes in 1.05 second. - - Args: - test_target: The function to test. - modulus_list: A jax.Array of integers. - shape: A tuple of integers representing the shape of the input arrays. - """ - # Only test a single element to save comparison time, - # Correctness-wise, it's sufficient for sub. - value_a = self.random(shape, modulus_list, dtype=jnp.uint64) - value_b = self.random(shape, modulus_list, dtype=jnp.uint64) - for i in range(shape[0]): - for j in range(shape[1]): - value_a = value_a.at[i, j, 0].set(self.in_c1[i][j]) - value_b = value_b.at[i, j, 0].set(self.in_c2[i][j]) - assert value_a.shape == shape - assert value_b.shape == shape - modulus_list = jnp.array(modulus_list, dtype=jnp.uint64) - refer_sub_result = jnp.array(self.refer_sub_result, dtype=jnp.uint64) - result = test_target(value_a, value_b, modulus_list) - self.assertEqual(result[:, :, 0].all(), refer_sub_result.all()) - - -if __name__ == "__main__": - absltest.main() diff --git a/jaxite/jaxite_word/util.py b/jaxite/jaxite_word/util.py index 7bfd1d6..d6193c5 100644 --- a/jaxite/jaxite_word/util.py +++ b/jaxite/jaxite_word/util.py @@ -5,128 +5,287 @@ import math import jax +import jax.sharding as shd +import re +import os +import json +import gzip import jax.numpy as jnp +from typing import Any, Callable, List, Tuple, Union +import copy gcd = math.gcd -# Always Fixed Parameters -BASE = 16 -BASE_TYPE = jnp.uint16 # this type must match the BASE, i.e. jnp.uint -U16_MASK = 0xFFFF -U32_MASK = 0xFFFFFFFF -U16_CHUNK_SHIFT_BITS = 16 -U32_CHUNK_SHIFT_BITS = 3 -MODULUS_DEFAULT = 1152921504606748673 -MU_DEFAULT = 1152921504606945279 - -# Workload Dependent Parameters -NUM_ELEMENTS = 2 -NUM_TOWERS = 3 -NUM_DEGREE = 65536 -U32_CHUNK_NUM_DEFAULT = 2 -U16_CHUNK_NUM_DEFAULT = 4 -U8_CHUNK_NUM_DEFAULT = 8 -U32_CHUNK_NUM_U32 = 1 -U16_CHUNK_NUM_U32 = 2 -U8_CHUNK_NUM_U32 = 4 - -MODULUS_LIST = (1152921504606748673, 268664833, 557057) -MU_LIST = (1152921504606945279, 268206274, 493446) - -MODULUS_ARRAY_ALL = ((32769, 65534, 65535, 4095), (32769, 4099), (32769, 8)) -MU_ARRAY_ALL = ( - (32767, 1, 0, 4096), - (32962, 4092), - (34694, 7), -) - -BARRETT_SHIFT_U8_ALL = (15, 7, 4.75) -BARRETT_SHIFT_U16_ALL = (7.5, 3.5, 2.375) -U32_CHUNK_NUM_ALL = (2, 1, 1) -U16_CHUNK_NUM_ALL = (4, 2, 2) -U8_CHUNK_NUM_ALL = (8, 4, 4) - -## modulus set 1 -# MODULUS_64 = 1152921504606748673 -# MU_64 = 1152921504606945279 -# BARRETT_SHIFT_U8 = 15 -# BARRETT_SHIFT_U16 = 7.5 -# MODULUS_ARRAY = (32769, 65534, 65535, 4095) -# MU_ARRAY = (32767, 1, 0, 4096) -# U32_CHUNK_NUM = 2 -# U16_CHUNK_NUM = 4 -# U8_CHUNK_NUM = 8 - -## modulus set 2 -MODULUS_32 = 268664833 -MU_32 = 268206274 -BARRETT_SHIFT_U8 = 7 -BARRETT_SHIFT_U16 = 3.5 -MODULUS_ARRAY = (32769, 4099) -MU_ARRAY = (32962, 4092) -U32_CHUNK_NUM = 1 -U16_CHUNK_NUM = 2 -U8_CHUNK_NUM = 4 - -## modulus set 3 -# MODULUS_64 = 557057 -# MU_64 = 493446 -# BARRETT_SHIFT_U8 = 4.75 -# BARRETT_SHIFT_U16 = 2.375 -# MODULUS_ARRAY = (32769, 8) -# MU_ARRAY = (34694, 7) - -# NTT Test -NTT_BATCH_SIZE = 128 -NTT_N1 = 64 -NTT_N2 = 128 -NTT_DEGREE = 8192 - -# Lazy Reduction -U16_EXT_CHUNK_NUM = 5 - - -def int_to_array( - python_int, base=BASE, dtype=jnp.uint16, array_size=U16_CHUNK_NUM -): - """Chunk decompose a Python integer into a JAX array of fixed dtype and fixed size. +# Capture existing profile run directories so we can identify the new one. +profile_root = os.path.join("./log/xprof", "plugins", "profile") +try: + pre_existing_dirs = set(os.listdir(profile_root)) if os.path.isdir(profile_root) else set() +except Exception: + pre_existing_dirs = set() + +#################################### +# Utility Functions +#################################### +def _square_like_mesh_shape(device_count: int) -> Tuple[int, int]: + """Return a near-square 2D mesh shape that covers all available devices.""" + if device_count <= 0: + raise ValueError("At least one device is required to build a mesh.") + sqrt_devices = math.isqrt(device_count) + for dim0 in range(sqrt_devices, 0, -1): + if device_count % dim0 == 0: + return dim0, device_count // dim0 + return 1, device_count + + +def create_sharding(): + """Create default batch and replicated shardings for the current device mesh.""" + available_devices = jax.devices() + if not available_devices: + raise RuntimeError("No devices available for sharding test.") + if len(available_devices) == 8: + mesh_shape = (2, 4) + elif len(available_devices) == 4: + mesh_shape = (2, 2) + elif len(available_devices) == 2: + mesh_shape = (2, 1) + else: + mesh_shape = (1, 1) + + mesh = jax.make_mesh(mesh_shape, ('x', 'y')) + shd.set_mesh(mesh) + + partition_spec = jax.sharding.PartitionSpec + return mesh, partition_spec + + +def num_bits(x: int) -> int: + """Returns the number of bits in x.""" + return x.bit_length() - 1 + + +def is_power_of_two(x: int) -> bool: + """Returns True if x is a power of two.""" + return x > 0 and (x & (x - 1)) == 0 + + +def to_tuple(a): + """Create to convert numpy array into tuple.""" + try: + return tuple(to_tuple(i) for i in a) + except TypeError: + return a + + +def slice_first_k_along_axis0(arrays, k): + """ + Given an iterable of array-like or sequence objects, return a tuple where + each element is the slice of the original object taking the first k entries + along axis 0 (i.e., obj[:k]). + + Example: + (s_tuple, s_w_tuple, w_tuple, m_tuple) -> + (s_tuple[:k], s_w_tuple[:k], w_tuple[:k], m_tuple[:k]) + """ + return tuple(arr[:k] for arr in arrays) + + +def slice_k_to_end_along_axis0(arrays, k): + """ + Given an iterable of array-like or sequence objects, return a tuple where + each element is the slice of the original object taking the k to end entries + along axis 0 (i.e., obj[k:]). + + Example: + (s_tuple, s_w_tuple, w_tuple, m_tuple) -> + (s_tuple[k:], s_w_tuple[k:], w_tuple[k:], m_tuple[k:]) + """ + return tuple(arr[k:] for arr in arrays) + + +def slice_kth_along_axis0(arrays, k): + """ + Given an iterable of array-like or sequence objects, return a tuple where + each element is the slice of the original object taking the first k entries + along axis 0 (i.e., obj[k]). + + Example: + (s_tuple, s_w_tuple, w_tuple, m_tuple) -> + (s_tuple[k], s_w_tuple[k], w_tuple[k], m_tuple[k]) + """ + return tuple(arr[k] for arr in arrays) + + +def slice_0_to_k0_to_k1_along_axis0(arrays, k0, k1): + """ + Given an iterable of array-like or sequence objects, return a tuple where + each element is the slice of the original object from k0 to k1 along axis 0 (i.e., obj[k0:k1]). + + Example: + (s_tuple, s_w_tuple, w_tuple, m_tuple) -> + (s_tuple[k0:k1], s_w_tuple[k0:k1], w_tuple[k0:k1], m_tuple[k0:k1]) + """ + if isinstance(arrays[0], jnp.ndarray): + return tuple(jnp.concatenate([x[:k0], x[k1:]]) for x in arrays) + elif isinstance(arrays, tuple): + return tuple([arr[:k0] + arr[k1:] for arr in arrays]) + elif isinstance(arrays, list): + return [arr[:k0] + arr[k1:] for arr in arrays] + else: + raise ValueError(f"Unsupported type: {type(arrays)}") + + +def slice_k0_to_k1_axis0(arrays, k0, k1): + """ + Given an iterable of array-like or sequence objects, return a tuple where + each element is the slice of the original object from k0 to k1 along axis 0 (i.e., obj[k0:k1]). + + Example: + (s_tuple, s_w_tuple, w_tuple, m_tuple) -> + (s_tuple[k0:k1], s_w_tuple[k0:k1], w_tuple[k0:k1], m_tuple[k0:k1]) + """ + return tuple(arr[k0:k1] for arr in arrays) + +#################################### +# Math Functions +#################################### +def extended_gcd(a, b): + """Return a tuple of (g, x, y) such that a*x + b*y = g = gcd(a, b).""" + if b == 0: + return (a, 1, 0) + else: + g, x, y = extended_gcd(b, a % b) + return (g, y, x - (a // b) * y) + + +def modinv_manual(x, q): + """Returns the inverse of x mod q.""" + g, x, _ = extended_gcd(x, q) + if g != 1: + raise Exception(f'Modular inverse does not exist for {x} modulo {q}') + else: + return x % q + + +def modinv(x: int, q: int) -> int: + """Returns the inverse of x mod q.""" + return int(pow(x, -1, q)) + + +def prime_factors(n): + """Return the set of prime factors of n.""" + factors = set() + # Divide out factors of 2 + while n % 2 == 0: + factors.add(2) + n //= 2 + # Check odd factors from 3 to sqrt(n) + p = 3 + while p**2 <= n: + while n % p == 0: + factors.add(p) + n //= p + p += 2 + if n > 1: + factors.add(n) + return factors - Args: - python_int: The Python integer to convert. - base: The base of the integer representation of the coordinates. - dtype: The data type of the JAX array. If None, the data type will be - automatically determined based on the base. - array_size: The size of the JAX array. If None, the array will have the - minimum size necessary to store the integer. - Note that: the default parameter is only for 384-bit data. +def find_generator(q): + """Find a primitive root modulo q. + + Args: + q (int): The prime modulus. Returns: - A JAX array representing the integer. + A generator of GF(q)^*. + + Raises: + ValueError: If no generator is found, indicating q is not prime. """ - mask = (1 << base) - 1 - # Chunk Decomposition - elements = [] - while python_int > 0: - elements.append(python_int & mask) # Extract the lower bits - python_int >>= base # Shift to remove the extracted bits + phi = q - 1 + factors = prime_factors(phi) - # we pad or trim the result to match the desired size - if array_size is not None: - assert array_size >= len(elements) - elements = elements[:array_size] + [0] * (array_size - len(elements)) + # Test candidates from 2 to q-1. + for g in range(2, q): + is_generator = all(pow(g, phi // p, q) != 1 for p in factors) + if is_generator: + return g + raise ValueError("No generator found, check that q is prime.") - return jnp.array(elements, dtype=dtype) +#################################### +# Parameters Generation +#################################### +def root_of_unity(m: int, q: int) -> Union[complex, float, int]: + """Canonical primitive m-th root of unity modulo q that **works with NTT**. -def array_to_int(jax_array: jax.Array, base) -> int: - """Converts a JAX array to a single Python integer.""" - result = 0 + Args: + m (int): The order of the root of unity. + q (int): The prime modulus. - for i, elem in enumerate(jax_array): - result |= int(elem) << (i * base) + Returns: + int: The canonical primitive m-th root of unity modulo q. - return result + Usage: + root_of_unity(16, 134219681) # This works with NTT. + computed_psi = [root_of_unity(m, q) for q in original_modulus] + """ + assert (q - 1) % m == 0, "q-1 must be divisible by m" + # Step 1: multiplicative generator of Z_q^* + g = find_generator(q) + # Step 2: raise to (q-1)/m to get an m-th root candidate + r = pow(g, (q - 1) // m, q) + # Step 3: among r^k with gcd(k,m)=1, pick the minimal value whose order is exactly m + # For m=2^t, order check is psi^(m/2) == q-1 (i.e., == -1 mod q) + candidates = [] + half = m // 2 + for k in range(1, m): + if gcd(k, m) != 1: + continue + psi = pow(r, k, q) + if pow(psi, half, q) == q - 1 and pow(psi, m, q) == 1: + candidates.append(psi) + assert candidates, "No primitive m-th root found" + return min(candidates) + + +def any_primitive_root_of_unity(n, q): + """Canonical primitive m-th root of unity modulo q that **may not work with NTT**. + + Args: + m (int): The order of the root of unity. + q (int): The prime modulus. + + Returns: + int: The canonical primitive m-th root of unity modulo q. + + Usage: + root_of_unity(16, 134219681) # This may not work with NTT. + computed_psi = [root_of_unity(m, q) for q in original_modulus] + """ + if (q - 1) % n != 0: + raise ValueError( + "n must divide q-1 for a primitive n-th root of unity to exist." + ) + + # Find a generator g of GF(q)^* (a primitive element). + g = find_generator(q) + # Compute omega = g^((q-1)/n) mod q. + exponent = (q - 1) // n + omega = pow(g, exponent, q) + + # Optional: Verify that omega is indeed of order n. + if pow(omega, n, q) != 1: + raise ValueError("Something went wrong: omega^n != 1") + # Check that no smaller positive exponent gives 1. + for d in range(1, n): + if n % d == 0 and pow(omega, d, q) == 1: + raise ValueError( + "Found an exponent d < n with omega^d == 1, so omega is not" + " primitive." + ) + + return omega def compute_barrett_mu(modulus): @@ -149,45 +308,792 @@ def compute_barrett_mu(modulus): return barrett_mu, k_val -def int_list_to_jax_array(int_list, base=BASE, array_size=U16_CHUNK_NUM): - """Converts a (potentially multi-dimensional) list of integers to a JAX array.""" +def compute_QHatInvModq_QHatModp(original_moduli, target_moduli, perf_test=False): + """ + Given a list of moduli original_moduli, compute QHatInvModq. + Input: + - original_moduli (list[int]): + The list of primes (moduli) defining the original CRT basis (Q). + - target_moduli (list[int]): + The list of primes (moduli) defining the target CRT basis (P). + + For each modulus q_i, compute: + - Qhat_i = Q // q_i + - QHatInvModq[i] = modular inverse of Qhat_i modulo q_i + - QHatModp: Precomputed Q̂ modulo each prime in P. Used in approximate basis switching. + """ + if perf_test: + sizeP = len(original_moduli) + sizeQ = len(target_moduli) + # Random arrays with matching shapes/dtypes + PInvModq = random_parameters((sizeQ,), target_moduli, dtype=jnp.uint32).tolist() + QHatInvModq = random_parameters((sizeP,), target_moduli, dtype=jnp.uint32).tolist() + QHatModp = random_parameters((sizeP, sizeQ), [min(target_moduli + original_moduli)], dtype=jnp.uint32).tolist() + return to_tuple((QHatInvModq, QHatModp)) + else: + Q = 1 + for qi in original_moduli: + Q *= qi + + QHatInvModq = [] + QHat = [] + for qi in original_moduli: + Qhat_i = Q // qi + inv = modinv(Qhat_i, qi) + QHat.append(Qhat_i) + QHatInvModq.append(inv) + + QHatModp = [] + for i in range(len(original_moduli)): + QHatModp_sgl = [] + for j in range(len(target_moduli)): + QHatModp_sgl.append(QHat[i] % target_moduli[j]) + QHatModp.append(QHatModp_sgl) + + return QHatInvModq, QHatModp + + +def approx_mod_down_control_generation(current_moduli, target_moduli, perf_test=False): + if perf_test: + PInvModq = random_parameters((len(target_moduli),), target_moduli, dtype=jnp.uint32).tolist() + else: + P = 1 + for moduli in current_moduli: + P *= moduli + PInvModq = [modinv(P, q) for q in target_moduli] + overall_moduli = current_moduli + target_moduli + QHatInvModq, QHatModp = compute_QHatInvModq_QHatModp(current_moduli, target_moduli, perf_test=perf_test) + + return PInvModq, len(overall_moduli) - len(target_moduli), len(overall_moduli) - len(current_moduli), QHatInvModq, QHatModp + + +def compute_powers_of_psi(ring_dim, moduli, perf_test=False): + """Computes powers of psi for the given moduli.""" + if perf_test: + return random_parameters((len(moduli), ring_dim), moduli, dtype=jnp.uint64) + else: + psi = [root_of_unity(2 * ring_dim, q) for q in moduli] + return jnp.array( + [ + [pow(psi[idx], i, moduli[idx]) for i in range(ring_dim)] + for idx in range(len(moduli)) + ], + jnp.uint64, + ) + + +def is_prime_deterministic(n): + """ + Deterministic primality test for n < 2^64. + Uses Trial Division for speed + Deterministic Miller-Rabin for correctness. + """ + if n < 2: return False + if n == 2 or n == 3: return True + if n % 2 == 0: return False + + # 1. SPEED OPTIMIZATION: Trial Division + # Check divisibility by small primes to fail fast on obvious composites. + # This filters out ~85% of candidates without expensive modular exponentiation. + small_primes = [3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53] + for p in small_primes: + if n == p: return True + if n % p == 0: return False + + # 2. DETERMINISTIC MILLER-RABIN + # For n < 2^64, verifying these specific bases guarantees primality. + # No randomness involved. + d = n - 1 + s = 0 + while d % 2 == 0: + d //= 2 + s += 1 + + # Bases required for deterministic check up to 2^64 + bases = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37] + + for a in bases: + if a >= n: break + + x = pow(a, d, n) + if x == 1 or x == n - 1: + continue + + for _ in range(s - 1): + x = pow(x, 2, n) + if x == n - 1: + break + else: + return False # Composite + + return True # Prime + + +def find_moduli_ntt(total_number, precision, ntt_length): + """ + Deterministically finds the largest valid NTT moduli. + + Args: + total_number: Number of moduli to find. + precision: Bit-width (e.g., 60 for < 2^60). + ntt_length: The required N-th root of unity (e.g., 1024). + """ + overall_moduli = [] + + # Upper bound + limit = 2**precision + + # Start search from the largest possible k + # P = k * ntt_length + 1 + k = (limit - 1) // ntt_length + + while len(overall_moduli) < total_number and k > 0: + candidate_p = k * ntt_length + 1 + + # Check candidate + if is_prime_deterministic(candidate_p): + overall_moduli.append(candidate_p) + + k -= 1 + + return overall_moduli - def recursive_convert(lst): - if isinstance(lst, list): - return [recursive_convert(item) for item in lst] - else: - return int_to_array(lst, base, array_size=array_size) - result = recursive_convert(int_list) - return jnp.array(result, dtype=jnp.uint16) +def gamma_beta_calculation(moduli_list, perf_test=False): + if perf_test: + # Shapes: gammas: (len(moduli_list)-1,), betas: (len(moduli_list)-1,) + assert len(moduli_list) > 1, "moduli_list must have at least 2 moduli" + gamma_rand = random_parameters((len(moduli_list)-1,), moduli_list[:-1], dtype=jnp.uint64) + beta_rand = random_parameters((len(moduli_list)-1,), moduli_list[:-1], dtype=jnp.uint64) + return jnp.array(gamma_rand, jnp.uint64), jnp.array(beta_rand, jnp.uint64) + # Compute Q as the product of the moduli for the remaining towers. + Q = 1 + for m in moduli_list[:-1]: + Q *= m + num_towers = len(moduli_list) + q_l = moduli_list[-1] + # Compute Q_inv_mod_ql: the inverse of Q modulo q_l. + Q_inv_mod_ql = modinv(Q, q_l) -def jax_array_to_int_list(jax_array, base): - """Converts a (potentially multi-dimensional) JAX array into a nested list of integers. + # Compute gamma_common such that: + # Q * Q_inv_mod_ql = 1 + gamma_common * q_l. + # Hence, gamma_common = (Q * Q_inv_mod_ql - 1) // q_l. + gamma_common = (Q * Q_inv_mod_ql - 1) // q_l - The function recursively traverses the array until it reaches a 1D vector, - then applies `array_to_int` to convert that vector into an integer. + # For each remaining tower compute gamma_i and beta_i. + gammas = [] + betas = [] + for i in range(num_towers - 1): + mod_i = moduli_list[i] + gamma_i = gamma_common % mod_i + beta_i = modinv(q_l, mod_i) + gammas.append(gamma_i) + betas.append(beta_i) + return jnp.array(gammas, jnp.uint64), jnp.array(betas, jnp.uint64) + +#################################### +# Random Functions +#################################### +def random_batched_ciphertext(shape, modulus_list, dtype=jnp.int32): + assert len(modulus_list) == shape[-1] + random_key = jax.random.key(0) + return jnp.concatenate( + [ + jax.random.randint( + random_key, + shape=(shape[0], shape[1], shape[2], 1), + minval=0, + maxval=bound, + dtype=dtype, + ) + for bound in modulus_list + ], + axis=3, + ) + + +def random_ciphertext(shape, modulus_list, dtype=jnp.int32): + assert len(modulus_list) == shape[-1] + random_key = jax.random.key(0) + return jnp.concatenate( + [ + jax.random.randint( + random_key, + shape=(shape[0],shape[1], 1), + minval=0, + maxval=bound, + dtype=dtype, + ) + for bound in modulus_list + ], + axis=2, + ) + + +def random_parameters(shape, modulus_list, dtype=jnp.int32): + random_key = jax.random.key(0) + min_modulus = 2**127 + for modulus in modulus_list: + if modulus < min_modulus: + min_modulus = modulus + return jax.random.randint(random_key, shape=shape, minval=0, maxval=min_modulus-1, dtype=dtype) + + +#################################### +# Parse Functions +#################################### +def parse_ciphertext_string(input_str): + """ + Parses the input string into two objects: + - data: a list of element groups, each a list of evaluations (list of lists of numbers). + Shape: (num_element, num_eval, num_numbers) + - modulus: a one-dimensional list of modulus values corresponding to each evaluation index. + All element groups are assumed to share the same modulus per evaluation. + + Parameters: + input_str (str): The string containing the input data. + + Returns: + tuple: (data, modulus) as described. + """ + data = [] + global_modulus = [] # This will store the modulus once per evaluation index. + + + # Process the input line by line. + for line in input_str.strip().splitlines(): + line = line.strip() + + # Check for an "Element" header. + if line.startswith("Element"): + # Start a new element group. + current_data_group = [] + data.append(current_data_group) + + # Check if there is extra content on the same line after the header. + header_match = re.match(r'^Element\s+\d+:\s*(.*)', line) + if header_match: + remainder = header_match.group(1).strip() + if remainder: + # Process an evaluation if it appears on the same line. + if eval_match := re.match(r'^(\d+):\s*EVAL:\s*\[(.*?)\]\s*modulus:\s*(\d+)', remainder): + numbers_str = eval_match.group(2) + mod_val = int(eval_match.group(3)) + numbers = [int(num) for num in numbers_str.split()] + current_data_group.append(numbers) + # For the first element group, record the modulus; otherwise, check consistency. + eval_idx = len(current_data_group) - 1 + if len(data) == 1: + global_modulus.append(mod_val) + else: + if eval_idx < len(global_modulus) and global_modulus[eval_idx] != mod_val: + raise ValueError(f"Inconsistent modulus at evaluation index {eval_idx}") + + # Otherwise, check if the line is an evaluation line. + elif eval_match := re.match(r'^(\d+):\s*EVAL:\s*\[(.*?)\]\s*modulus:\s*(\d+)', line): + numbers_str = eval_match.group(2) + mod_val = int(eval_match.group(3)) + numbers = [int(num) for num in numbers_str.split()] + # Holds the current element's data evaluations. + current_data_group = [] + current_data_group.append(numbers) + eval_idx = len(current_data_group) - 1 + # For the first element group, record the modulus; for subsequent groups, check consistency. + if len(data) == 1: + global_modulus.append(mod_val) + else: + if eval_idx < len(global_modulus) and global_modulus[eval_idx] != mod_val: + raise ValueError(f"Inconsistent modulus at evaluation index {eval_idx}") + + return data, global_modulus + + +#################################### +# Bit Reverse Functions +#################################### +def bit_reverse(x, bits): + """Compute the bit-reversal of integer x with the given number of bits.""" + result = 0 + for i in range(bits): + if (x >> i) & 1: # if i-th bit of x is 1 + result |= 1 << (bits - 1 - i) # set the corresponding reversed bit + return result + + +def bit_reverse_array(in_tower): + x = copy.deepcopy(in_tower) + bits = len(x).bit_length() - 1 + for i in range(len(x)): + j = bit_reverse(i, bits) + if i < j: + x[i], x[j] = x[j], x[i] + return x + + +def bit_reverse_indices(n: int) -> jnp.ndarray: + """ + Compute an array rev_idx of shape (n,) such that rev_idx[i] is the bit-reversal + of i over log2(n) bits. + """ + bits = int(math.log2(n)) + idx = jnp.arange(n) + # build the reversed index by summing shifted bits + rev = sum( + ((idx >> i) & 1) << (bits - 1 - i) + for i in range(bits) + ) + return rev + + + +#################################### +# Automorphism Functions +#################################### +def precompute_auto_map(n: int, k: int) -> List[int]: + m = n << 1 # cyclOrder + logm = int(round(math.log2(m))) + logn = int(round(math.log2(n))) + + precomp: List[int] = [0] * n + for j in range(n): + j_tmp = (j << 1) + 1 + t = j_tmp * k + # ((t % m) >> 1) but written to mirror the C++ bit ops exactly + idx = (t - ((t >> logm) << logm)) >> 1 + + j_rev = bit_reverse(j, logn) + idx_rev = bit_reverse(idx, logn) + precomp[j_rev] = idx_rev + + return precomp + + +def find_automorphism_index_2n_complex(i: int, m: int) -> int: + """Python translation of nbtheory2.cpp FindAutomorphismIndex2nComplex (243-263). + + Mirrors the C++ logic including early exits, power-of-two validation, and + modulus via bitmask for m being a power of two. + """ + if i == 0: + return 1 + if i == (m - 1): + return int(i) + + if not is_power_of_two(m): + raise ValueError("m should be a power of two.") + + # Conjugation automorphism generator + g0 = pow(5, -1, m) if i < 0 else 5 + g = g0 + i_unsigned = abs(i) + mask = m - 1 + for _ in range(1, i_unsigned): + # Equivalent to (g * g0) % m since m is a power of two + g = (g * g0) & mask + return int(g) + + +#################################### +# Number Theory Transformation +# Negacyclic NTT is used in CKKS +#################################### +def ntt_bit_reverse(a, q, omega): + """Compute cyclic Number Theoretic Transform of array a modulo q using a given primitive omega of unity.""" + n = len(a) + # Ensure that omega^n ≡ 1 (mod q) and n divides q-1 for validity. + # (This should be true if omega is a correct n-th omega of unity.) + # Bit-reverse the input array indices + bits = n.bit_length() - 1 # number of bits needed for indexes 0..n-1 + for i in range(n): + j = bit_reverse(i, bits) + if i < j: + a[i], a[j] = a[j], a[i] # swap to achieve bit-reversed order + # Cooley-Tukey iterative FFT (NTT) + length = 2 + while length <= n: + # Compute twiddle factor step: use omega^(n/length) as the increment + w_m = pow(omega, n // length, q) + half = length // 2 + for i in range(0, n, length): # loop over sub-FFT blocks + w = 1 + for j in range(i, i + half): # loop within each block + u = a[j] + v = a[j + half] * w % q # multiply by current twiddle factor + a[j] = (u + v) % q # butterfly: combine top part + a[j + half] = (u - v) % q # butterfly: combine bottom part + w = w * w_m % q # advance twiddle factor for next element + length *= 2 + return a + + +def intt_bit_reverse(a, q, omega): + """Compute the Inverse Number Theoretic Transform of array a modulo p using the given primitive root.""" + n = len(a) + inv_root = pow(omega, -1, q) # modular inverse of root + # Decimation-in-frequency (Gentleman-Sande) butterfly operations + length = n + while length >= 2: + w_m = pow(inv_root, n // length, q) + half = length // 2 + for i in range(0, n, length): + w = 1 + for j in range(i, i + half): + u = a[j] + v = a[j + half] + a[j] = (u + v) % q # combine pairs (top value) + a[j + half] = ( + ((u - v) % q) * w % q + ) # combine pairs (bottom), then multiply by twiddle + w = w * w_m % q # advance twiddle factor + length //= 2 + # Bit-reverse the result (to invert the initial bit-reversal + # permutation in NTT) + bits = n.bit_length() - 1 + for i in range(n): + j = bit_reverse(i, bits) + if i < j: + a[i], a[j] = a[j], a[i] + # Divide by n (multiply by n^{-1} mod p) to finish the inverse transform + inv_n = pow(n, -1, q) + for i in range(n): + a[i] = a[i] * inv_n % q + return a + + +def ntt_negacyclic_bit_reverse(a, q, psi): + """Compute the negacyclic NTT of array a (length n) modulo q. Args: - jax_array: The JAX array to convert. - base: The base of the integer representation. + a: list (or 1D array) of integers (length n). + q: prime modulus. + psi: an element in GF(q) such that psi^(2*n) = 1 and psi^n = -1 mod q. + (That is, psi is a primitive 2n-th root of unity; note that then ω = + psi^2 + is a primitive n-th root of unity.) + rows: Number of rows in the matrix. + cols: Number of columns in the matrix. Returns: - A nested list of integers. + The negacyclic NTT of a. + + Process: + 1. Pre-twist: multiply each coefficient a[i] by psi^i. + 2. Compute the vanilla NTT (for example, using ntt_bit_reverse) with ω = + psi^2. """ - if jax_array.ndim == 1: - return array_to_int(jax_array, base) + n = len(a) + # Check that psi^n = -1 mod q. + if pow(psi, n, q) != q - 1: + raise ValueError( + "psi is not a valid 2n-th root of unity for negacyclic NTT (psi^n must" + " equal -1 mod q)." + ) + + # Pre-twisting: multiply a[i] by psi^i. + a_twisted = [(a[i] * pow(psi, i, q)) % q for i in range(n)] + + # Compute vanilla NTT using ω = psi². + omega = pow(psi, 2, q) + + return ntt_bit_reverse(a_twisted.copy(), q, omega) + + +def intt_negacyclic_bit_reverse(a, q, psi): + """Compute the inverse negacyclic NTT of array a (length n) modulo q. + + Args: + a : list (or 1D array) of integers (length n) in the negacyclic evaluation + domain. + q : prime modulus. + psi : an element in GF(q) such that psi^(2*n) = 1 and psi^n = -1 mod q. + (That is, psi is a primitive 2n-th root of unity; note that then ω = + psi^2 + is a primitive n-th root of unity.) + Returns: + The original input vector (i.e. the inverse transform). + + Process: + 1. Compute the inverse vanilla NTT using ω = psi². + 2. Post-twist: multiply the result by psi^(–i) for coefficient index i. + """ + n = len(a) + omega = pow(psi, 2, q) + + # Compute the inverse vanilla NTT. + a_inv = intt_bit_reverse(a.copy(), q, omega) + + # Post-twisting: multiply a_inv[i] by psi^(–i). + psi_inv = pow(psi, -1, q) + return [(a_inv[i] * pow(psi_inv, i, q)) % q for i in range(n)] + + +#################################### +# Precision Lowering Functions (outside Google) +#################################### +def chunk_decomposition(x, chunkwidth=8): + """Precision-level data conversion. + + Args: + x: The input data. + chunkwidth: The chunkwidth. + + Returns: + The decomposed data. + """ + dtype = jnp.uint8 + if chunkwidth == 16: + dtype = jnp.uint16 + elif chunkwidth == 32: + dtype = jnp.uint32 + + elements = [] + mask = (1 << chunkwidth) - 1 + # Mask to extract the lower bits (e.g., 32 bits -> 0xFFFFFFFF) + + # Extract each element from the integer + while x > 0: + elements.append(x & mask) # Extract the lower bits + x >>= chunkwidth # Shift to remove the extracted bits + + # Convert the list to a JAX array + return jnp.array(elements, dtype=dtype) + + +#################################### +# Performance Profiler Functions (outside Google) +#################################### +def dump_hlo_from_lowered(lowered: Any, out_dir: str, out_filename: str) -> str: + """ + Extract HLO (or StableHLO) textual IR from a lowered computation and write it to a file. + Returns the full output file path. + """ + # Prefer XLA HLO; fall back to StableHLO if necessary + try: + ir_obj = lowered.compiler_ir(dialect="hlo") + # Handle XLA computation objects and MLIR modules + if hasattr(ir_obj, "as_hlo_text"): + hlo_text = ir_obj.as_hlo_text() + elif hasattr(ir_obj, "operation"): + try: + hlo_text = ir_obj.operation.get_asm(enable_debug_info=True) + except Exception: + hlo_text = ir_obj.operation.get_asm() + else: + try: + hlo_text = ir_obj.as_text() + except Exception: + hlo_text = str(ir_obj) + except Exception: + ir_obj = lowered.compiler_ir(dialect="stablehlo") + if hasattr(ir_obj, "operation"): + try: + hlo_text = ir_obj.operation.get_asm(enable_debug_info=True) + except Exception: + hlo_text = ir_obj.operation.get_asm() + else: + try: + hlo_text = ir_obj.as_text() + except Exception: + hlo_text = str(ir_obj) + os.makedirs(out_dir, exist_ok=True) + out_path = os.path.join(out_dir, out_filename) + with open(out_path, "w") as f: + f.write(hlo_text) + return out_path + + +def dump_llo_from_lowered(lowered: Any, out_dir: str, out_filename: str) -> str: + """ + Extract a lower-level XLA IR (LMHLO / LLVM if available) from a lowered computation + and write it to a file. Returns the full output file path. + """ + ir_text = None + dialect_candidates = ["llvm", "lmhlo", "mhlo"] + # Try both precompiled lowered and compiled executable views + sources = [lowered] + try: + compiled = lowered.compile() + sources.append(compiled) + except Exception: + pass + for source in sources: + for dialect in dialect_candidates: + try: + ir_obj = source.compiler_ir(dialect=dialect) + if hasattr(ir_obj, "operation"): + try: + ir_text = ir_obj.operation.get_asm(enable_debug_info=True) + except Exception: + ir_text = ir_obj.operation.get_asm() + else: + # Some dialects (e.g., llvm) may expose textual interfaces differently + if hasattr(ir_obj, "as_text"): + ir_text = ir_obj.as_text() + else: + ir_text = str(ir_obj) + if ir_text and len(ir_text) > 0: + break + except Exception: + continue + if ir_text: + break + if ir_text is None: + # As a last resort, try to stringify generic compiler_ir without dialect hints + try: + generic = lowered.compiler_ir() + if hasattr(generic, "operation"): + ir_text = generic.operation.get_asm(enable_debug_info=True) + elif hasattr(generic, "as_text"): + ir_text = generic.as_text() + else: + ir_text = str(generic) + except Exception: + raise RuntimeError("Unable to extract LLO/LMHLO/LLVM IR from lowered/compiled computation.") + os.makedirs(out_dir, exist_ok=True) + out_path = os.path.join(out_dir, out_filename) + with open(out_path, "w") as f: + f.write(ir_text) + return out_path + + +def profile_jax_functions_xprof( + tasks: List[Tuple[Callable[..., Any], Tuple[Any, ...]]], + profile_name: str = "jax_profile", + kernel_name: str = "kernel_name", +): + """Profiles a list of JAX functions. + + Args: + tasks: A list of tuples, where each tuple contains a JAX function and its + arguments. + profile_name: The name of the profile, written in log/xprof/plugins/profile. + kernel_name: The name of the kernel, used to find the latency of the kernel in the trace file. + Usage: + tasks = [ + (jit_pdul_barrett_xyzz_pack, (point_a_jax,)), + ] + profile_name = "jit_pdul_barrett_xyzz_pack" + profile_jax_functions(tasks, profile_name, kernel_name="jit_pdul_barrett_xyzz_pack") + """ + latency = 0 + n = 1 # number of times running the kernel + final_folder_name = os.path.join(profile_root, profile_name) + options = jax.profiler.ProfileOptions() + options.python_tracer_level = 3 + options.host_tracer_level = 3 # https://docs.jax.dev/en/latest/profiling.html#general-options + options.advanced_configuration = {"tpu_trace_mode" : "TRACE_COMPUTE_AND_SYNC", "tpu_num_chips_to_profile_per_task" : 4} + + repo_root = os.path.dirname(__file__) + xprof_dir = os.path.join(repo_root, "log/xprof") + with jax.profiler.trace(xprof_dir): + # Launch all JAX computations + results = [] + for func, args_tuple in tasks: + result = func(*args_tuple) + results.append(result) + + # Wait for all computations launched in the loop to complete + if results: + jax.block_until_ready(results) + + # Rename the newly created timestamped directory to the designated profile_name. + try: + if os.path.isdir(profile_root): + post_dirs = set(os.listdir(profile_root)) + created_dirs = [d for d in (post_dirs - pre_existing_dirs) if os.path.isdir(os.path.join(profile_root, d))] + + target_dir = None + if created_dirs: + # Choose the most recently modified among the newly created ones. + target_dir = max(created_dirs, key=lambda d: os.path.getmtime(os.path.join(profile_root, d))) + else: + # Fallback: pick the most recent dir in case set diff failed (e.g., pre list failed). + all_dirs = [d for d in post_dirs if os.path.isdir(os.path.join(profile_root, d))] + if all_dirs: + target_dir = max(all_dirs, key=lambda d: os.path.getmtime(os.path.join(profile_root, d))) + + if target_dir: + + # Avoid overwriting existing destination; add numeric suffix if necessary. + if os.path.exists(final_folder_name): + suffix = 1 + while os.path.exists(f"{final_folder_name}_{suffix}"): + suffix += 1 + final_folder_name = f"{final_folder_name}_{suffix}" + os.rename(os.path.join(profile_root, target_dir), final_folder_name) + except Exception as e: + print(f"Profile rename failed: {e}") + + # Read the trace file and print the latency of the kernel + # Find the file that ends with 'trace.json.gz' in the destination directory + trace_file = None + if os.path.exists(final_folder_name): + for fname in os.listdir(final_folder_name): + if fname.endswith("trace.json.gz"): + trace_file = os.path.join(final_folder_name, fname) + break + if trace_file: + with gzip.open(trace_file, 'rt') as f: + jtrace = json.loads(f.read()) + if jtrace: + if "NVIDIA" in jax.devices()[0].device_kind: + for e in jtrace["traceEvents"]: + if 'args' in e and 'tf_op' in e['args']: + if kernel_name in e['args']["hlo_module"]: + latency += e['dur'] + elif "TPU" in jax.devices()[0].device_kind: + pid = 999999 # an invalid PID + for e in jtrace["traceEvents"]: + if 'args' in e and 'name' in e['args'] and 'TPU:0' in e['args']['name']: + pid = e['pid'] + if 'args' in e and 'tf_op' in e['args'] and kernel_name in e['args']['tf_op']: + if e['pid'] == pid: + latency += e['dur'] + if 'args' in e and 'hlo_category' in e['args'] and 'copy' in e['args']['hlo_category']: + if e['pid'] == pid: + latency += e['dur'] + else: + print(f"Trace file not found: {trace_file}") else: - return [jax_array_to_int_list(sub_array, base) for sub_array in jax_array] + print(f"Final folder name not found: {final_folder_name}") + return latency -def random_list(shape, max_val, dtype=jnp.int32): - return jax.random.randint( - jax.random.key(0), shape=shape, minval=0, maxval=max_val, dtype=dtype - ).tolist() +# paper full case evaluation. +original_moduli_51_limbs = [1073753729, 1073738977, 1073753281, 1073739041, 1073753089, 1073747137, 1073752417, 1073739169, 1073745697, 1073739361, 1073752129, 1073746337, 1073748737, 1073746529, 1073748289, 1073747393, 1073749889, 1073748449, 1073751713, 1073749153, 1073750593, 1073749409, 1073751521, 1073750017, 1073751169, 1073750497, 1073751073, 1073750113, 1073750849, 1073739617, 1073746273, 1073745473, 1073745889, 1073742881, 1073745377, 1073739649, 1073745121, 1073741953, 1073744993, 1073739937, 1073744417, 1073742913, 1073744257, 1073742113, 1073743457, 1073742209, 1073743393, 1073740609, 1073742721, 1073741441, 1073741857, 524353] +original_psi_51_limbs = [1093151, 90892563, 108899655, 56634236, 235160291, 12265314, 191995239, 21404433, 40083131, 3916344, 113671079, 34500367, 61894143, 20463380, 13205216, 60050555, 145308815, 87067229, 10533116, 133048918, 13697511, 47895671, 14807533, 10994638, 25005605, 44429319, 77617905, 22756112, 21182116, 46947055, 41148497, 163086225, 60397627, 176334344, 30766686, 77429283, 67466901, 67653750, 4536048, 135444559, 63788661, 110966687, 9716122, 12174708, 49591386, 81862273, 51874541, 12155428, 60746932, 68809976, 28870916, 19017] +extend_moduli_51_limbs = [1152921504606845473, 1152921504606844513, 1152921504606844417, 1152921504606844289, 1152921504606843233, 1152921504606843073, 1152921504606842753, 1152921504606841793, 1152921504606841441, 1152921504606840929] -def random_array(shape, max_val, dtype=jnp.int32): - return jax.random.randint( - jax.random.key(0), shape=shape, minval=0, maxval=max_val, dtype=dtype - ) +NTT_PARAMETERS_BY_DEGREE = { + 16: { + "moduli": [1073759809, 1073759041, 1073759777, 1073758337, 1073759329, 1073758849, 1073759233, 1073738273, 1073754113, 1073738753, 1073753729, 1073738977, 1073753281, 1073739041, 1073753089, 1073747137, 1073752417, 1073739169, 1073745697, 1073739361, 1073752129, 1073746337, 1073748737, 1073746529, 1073748289, 1073747393, 1073749889, 1073748449, 1073751713, 1073749153, 1073750593, 1073749409, 1073751521, 1073750017, 1073751169, 1073750497, 1073751073, 1073750113, 1073750849, 1073739617, 1073746273, 1073745473, 1073745889, 1073742881, 1073745377, 1073739649, 1073745121, 1073741953, 1073744993, 1073739937, 1073744417, 1073742913, 1073744257, 1073742113, 1073743457, 1073742209, 1073743393, 1073740609, 1073742721, 1073741441, 1073741857, 524353], + "root_of_unity": [149761193, 17168328, 145519847, 68042513, 3491826, 21109149, 48183983, 49547540, 15369996, 12935385, 1093151, 90892563, 108899655, 56634236, 235160291, 12265314, 191995239, 21404433, 40083131, 3916344, 113671079, 34500367, 61894143, 20463380, 13205216, 60050555, 145308815, 87067229, 10533116, 133048918, 13697511, 47895671, 14807533, 10994638, 25005605, 44429319, 77617905, 22756112, 21182116, 46947055, 41148497, 163086225, 60397627, 176334344, 30766686, 77429283, 67466901, 67653750, 4536048, 135444559, 63788661, 110966687, 9716122, 12174708, 49591386, 81862273, 51874541, 12155428, 60746932, 68809976, 28870916, 19017], + }, + 4096: { + "moduli": [268730369, 268689409, 268361729, 268582913, 268369921, 268460033, 557057, 1152921504606830593, 1152921504606748673], + "root_of_unity": [8801, 19068, 58939, 11033, 62736, 77090, 474, 116777451583545, 271802498405390], + }, + 8192: { + "moduli": [269402113, 268091393, 268730369, 268271617, 269221889, 268664833, 268861441, 268369921, 268582913, 557057, 1152921504606830593, 1152921504606748673], + "root_of_unity": [18987, 2826, 1678, 18925, 2446, 31335, 40892, 65274, 15787, 268, 25959043411404, 100406242475323], + }, + 16384: { + "moduli": [274726913, 272760833, 274628609, 267059201, 270499841, 267550721, 270237697, 267943937, 268861441, 268042241, 268730369, 268238849, 269844481, 268271617, 269221889, 268369921, 268664833, 557057, 1152921504606748673, 1152921504606683137, 1152921504606584833], + "root_of_unity": [9358, 15613, 1976, 5381, 15236, 9622, 5177, 2469, 792, 63914, 9742, 12308, 3704, 7216, 7564, 10360, 2023, 19, 62213374832584, 212089012217363, 92166579128688], + }, + 65536: { + "moduli": [384040961, 376569857, 371458049, 375521281, 371589121, 383778817, 377880577, 379453441, 323092481, 351797249, 349962241, 351404033, 260702209, 308150273, 304742401, 307888129, 302776321, 306708481, 304218113, 347996161, 319291393, 347078657, 323223553, 337248257, 323878913, 336855041, 329515009, 332660737, 329777153, 335413249, 325844993, 330301441, 327548929, 332267521, 328728577, 344850433, 336068609, 340000769, 261488641, 302252033, 297664513, 299499521, 261881857, 295305217, 263323649, 277086209, 263454721, 292159489, 279838721, 291373057, 284950529, 290455553, 281935873, 285474817, 283508737, 288882689, 264634369, 276430849, 270532609, 274726913, 272760833, 276037633, 265420801, 270794753, 268042241, 269221889, 786433], + "root_of_unity": [1197, 4622, 9335, 5748, 719, 1497, 2281, 3163, 3548, 80, 6577, 4942, 435, 3498, 316, 4503, 1433, 5766, 440, 2739, 1792, 13, 545, 7539, 7418, 7033, 32540, 1301, 4354, 16962, 10301, 289, 4195, 3322, 1005, 1747, 13384, 7659, 2200, 1035, 2142, 6961, 2774, 910, 43, 1949, 4343, 6648, 787, 2879, 4743, 563, 3385, 5648, 5875, 9494, 2122, 852, 6279, 1335, 712, 2017, 929, 142, 5274, 3264, 8], + }, +} + +moduli_28_list = { + degree: params["moduli"] + for degree, params in NTT_PARAMETERS_BY_DEGREE.items() +} + +roof_of_unity = { + degree: params["root_of_unity"] + for degree, params in NTT_PARAMETERS_BY_DEGREE.items() +} diff --git a/jaxite_ckks/bconv.py b/jaxite_ckks/bconv.py new file mode 100644 index 0000000..4c50dd6 --- /dev/null +++ b/jaxite_ckks/bconv.py @@ -0,0 +1,17 @@ +"""TODO: jianmingt - DO NOT SUBMIT without one-line documentation for bconv. + +TODO: jianmingt - DO NOT SUBMIT without a detailed description of bconv. +""" + +from collections.abc import Sequence + +from absl import app + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + +if __name__ == "__main__": + app.run(main) diff --git a/jaxite_ckks/ciphertext.py b/jaxite_ckks/ciphertext.py new file mode 100644 index 0000000..7828149 --- /dev/null +++ b/jaxite_ckks/ciphertext.py @@ -0,0 +1,17 @@ +"""TODO: jianmingt - DO NOT SUBMIT without one-line documentation for ciphertext. + +TODO: jianmingt - DO NOT SUBMIT without a detailed description of ciphertext. +""" + +from collections.abc import Sequence + +from absl import app + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + +if __name__ == "__main__": + app.run(main) diff --git a/jaxite_ckks/ckks_ctx.py b/jaxite_ckks/ckks_ctx.py new file mode 100644 index 0000000..fbdb2df --- /dev/null +++ b/jaxite_ckks/ckks_ctx.py @@ -0,0 +1,17 @@ +"""TODO: jianmingt - DO NOT SUBMIT without one-line documentation for ckks_ctx. + +TODO: jianmingt - DO NOT SUBMIT without a detailed description of ckks_ctx. +""" + +from collections.abc import Sequence + +from absl import app + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + +if __name__ == "__main__": + app.run(main) diff --git a/jaxite_ckks/key_gen.py b/jaxite_ckks/key_gen.py new file mode 100644 index 0000000..06a6065 --- /dev/null +++ b/jaxite_ckks/key_gen.py @@ -0,0 +1,17 @@ +"""TODO: jianmingt - DO NOT SUBMIT without one-line documentation for key_gen. + +TODO: jianmingt - DO NOT SUBMIT without a detailed description of key_gen. +""" + +from collections.abc import Sequence + +from absl import app + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + +if __name__ == "__main__": + app.run(main) diff --git a/jaxite_ckks/ntt_mm.py b/jaxite_ckks/ntt_mm.py new file mode 100644 index 0000000..57c9038 --- /dev/null +++ b/jaxite_ckks/ntt_mm.py @@ -0,0 +1,17 @@ +"""TODO: jianmingt - DO NOT SUBMIT without one-line documentation for ntt_mm. + +TODO: jianmingt - DO NOT SUBMIT without a detailed description of ntt_mm. +""" + +from collections.abc import Sequence + +from absl import app + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + +if __name__ == "__main__": + app.run(main) diff --git a/jaxite_ckks/ntt_mm_test.py b/jaxite_ckks/ntt_mm_test.py new file mode 100644 index 0000000..712bbae --- /dev/null +++ b/jaxite_ckks/ntt_mm_test.py @@ -0,0 +1,16 @@ +"""TODO: jianmingt - DO NOT SUBMIT without either providing a detailed docstring or +removing it altogether. +""" + +from jaxite.jaxite_ckks import ntt_mm +from absl.testing import absltest + + +class NttMmTest(absltest.TestCase): + + def test_give_me_a_name(self): + pass + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxite_ckks/profiler.py b/jaxite_ckks/profiler.py new file mode 100644 index 0000000..f02b784 --- /dev/null +++ b/jaxite_ckks/profiler.py @@ -0,0 +1,17 @@ +"""TODO: jianmingt - DO NOT SUBMIT without one-line documentation for profiler. + +TODO: jianmingt - DO NOT SUBMIT without a detailed description of profiler. +""" + +from collections.abc import Sequence + +from absl import app + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + +if __name__ == "__main__": + app.run(main)