diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml
index 7dbb06f..20daa31 100644
--- a/.github/workflows/build_and_test.yml
+++ b/.github/workflows/build_and_test.yml
@@ -27,13 +27,14 @@ jobs:
- name: "Run `bazel build`"
run: |
- bazel build //...
+ bazel --host_jvm_args=-Xmx32g build --jobs=8 //...
- name: "Run `bazel test`"
run: |
- bazel test \
+ bazel --host_jvm_args=-Xmx32g test \
--test_output=errors \
--test_size_filters=small \
--test_timeout=1800 \
--experimental_ui_max_stdouterr_bytes=10485760 \
+ --jobs=8 \
//...
diff --git a/.github/workflows/periodic_tests.yml b/.github/workflows/periodic_tests.yml
index a9eef38..935e159 100644
--- a/.github/workflows/periodic_tests.yml
+++ b/.github/workflows/periodic_tests.yml
@@ -14,13 +14,14 @@ jobs:
- name: "Run `bazel build`"
run: |
- bazel build //...
+ bazel --host_jvm_args=-Xmx32g build --jobs=8 //...
- name: "Run `bazel test`"
run: |
- bazel test \
+ bazel --host_jvm_args=-Xmx32g test \
--test_output=errors \
--test_size_filters=medium,large \
--test_timeout=3600 \
--experimental_ui_max_stdouterr_bytes=10485760 \
+ --jobs=8 \
//...
diff --git a/BUILD b/BUILD
index 7ce857c..82fa23a 100644
--- a/BUILD
+++ b/BUILD
@@ -78,6 +78,10 @@ py_library(
"jaxite_ckks/*",
],
),
+ data = [
+ "jaxite_ec/configurations.toml",
+ # "@jaxite//jaxite_ec/c_kernels:distribution.so",
+ ],
visibility = [":internal"],
deps = [
":jaxite_ckks",
@@ -88,7 +92,10 @@ py_library(
"@jaxite_deps//jaxlib",
# copybara: jax/experimental:pallas_lib
# copybara: jax/experimental:pallas_tpu
+ "@jaxite//jaxite_ec/c_kernels:build",
"@jaxite_deps//numpy",
+ # copybara: pandas
+ # copybara: toml
],
)
@@ -142,13 +149,21 @@ tpu_test(
)
tpu_test(
- name = "jaxite_ec_finite_field_test",
- size = "large",
- timeout = "moderate",
+ name = "ec_finite_field_test",
srcs = ["jaxite_ec/finite_field_test.py"],
- python_version = "PY3",
- shard_count = 3,
- srcs_version = "PY3ONLY",
+ deps = [
+ ":jaxite",
+ "@abseil-py//absl/testing:absltest",
+ "@abseil-py//absl/testing:parameterized",
+ "@jaxite_deps//jax",
+ "@jaxite_deps//jaxlib",
+ "@jaxite_deps//numpy",
+ ],
+)
+
+tpu_test(
+ name = "ec_finite_field_perf_test",
+ srcs = ["jaxite_ec/finite_field_perf_test.py"],
deps = [
":jaxite",
# copybara: xprof_analysis_client # buildcleaner: keep
@@ -158,57 +173,44 @@ tpu_test(
"@jaxite_deps//jax",
"@jaxite_deps//jaxlib",
"@jaxite_deps//numpy",
+ # copybara: toml
],
)
tpu_test(
- name = "msm_test",
- size = "large",
- timeout = "eternal",
- srcs = [
- "jaxite_ec/msm_test.py",
+ name = "elliptic_curve_test",
+ srcs = ["jaxite_ec/elliptic_curve_test.py"],
+ deps = [
+ ":jaxite",
+ "@abseil-py//absl/testing:absltest",
+ "@abseil-py//absl/testing:parameterized",
+ "@jaxite_deps//jax",
+ "@jaxite_deps//jaxlib",
+ "@jaxite_deps//numpy",
+ # copybara: toml
],
- data = [
- "jaxite_ec/test_case/t1/zprize_msm_curve_377_bases_dim_1_seed_0.csv",
- "jaxite_ec/test_case/t1/zprize_msm_curve_377_res_dim_1_seed_0.csv",
- "jaxite_ec/test_case/t1/zprize_msm_curve_377_scalars_dim_1_seed_0.csv",
- "jaxite_ec/test_case/t1024/zprize_msm_curve_377_bases_dim_1024_seed_0.csv",
- "jaxite_ec/test_case/t1024/zprize_msm_curve_377_res_dim_1024_seed_0.csv",
- "jaxite_ec/test_case/t1024/zprize_msm_curve_377_scalars_dim_1024_seed_0.csv",
- "jaxite_ec/test_case/t2/zprize_msm_curve_377_bases_dim_2_seed_0.csv",
- "jaxite_ec/test_case/t2/zprize_msm_curve_377_res_dim_2_seed_0.csv",
- "jaxite_ec/test_case/t2/zprize_msm_curve_377_scalars_dim_2_seed_0.csv",
- "jaxite_ec/test_case/t4/zprize_msm_curve_377_bases_dim_4_seed_0.csv",
- "jaxite_ec/test_case/t4/zprize_msm_curve_377_res_dim_4_seed_0.csv",
- "jaxite_ec/test_case/t4/zprize_msm_curve_377_scalars_dim_4_seed_0.csv",
- "jaxite_ec/test_case/t8/zprize_msm_curve_377_bases_dim_8_seed_0.csv",
- "jaxite_ec/test_case/t8/zprize_msm_curve_377_res_dim_8_seed_0.csv",
- "jaxite_ec/test_case/t8/zprize_msm_curve_377_scalars_dim_8_seed_0.csv",
- ],
- python_version = "PY3",
- shard_count = 3,
- srcs_version = "PY3ONLY",
+)
+
+tpu_test(
+ name = "elliptic_curve_perf_test",
+ srcs = ["jaxite_ec/elliptic_curve_perf_test.py"],
deps = [
":jaxite",
# copybara: xprof_analysis_client # buildcleaner: keep
# copybara: xprof_session # buildcleaner: keep
- # copybara: resources
"@abseil-py//absl/testing:absltest",
"@abseil-py//absl/testing:parameterized",
"@jaxite_deps//jax",
"@jaxite_deps//jaxlib",
"@jaxite_deps//numpy",
+ # copybara: toml
],
)
tpu_test(
- name = "elliptic_curve_test",
- size = "large",
- timeout = "long",
- srcs = ["jaxite_ec/elliptic_curve_test.py"],
- python_version = "PY3",
- shard_count = 16,
- srcs_version = "PY3ONLY",
+ name = "multiscalar_multiplication_test",
+ srcs = ["jaxite_ec/multiscalar_multiplication_test.py"],
+ data = glob(["jaxite_ec/data/t1024/*.csv"]),
deps = [
":jaxite",
# copybara: xprof_analysis_client # buildcleaner: keep
@@ -218,6 +220,50 @@ tpu_test(
"@jaxite_deps//jax",
"@jaxite_deps//jaxlib",
"@jaxite_deps//numpy",
+ # copybara: toml
+ ],
+)
+
+tpu_test(
+ name = "multiscalar_multiplication_perf_test",
+ srcs = ["jaxite_ec/multiscalar_multiplication_perf_test.py"],
+ deps = [
+ ":jaxite",
+ # copybara: xprof_analysis_client # buildcleaner: keep
+ # copybara: xprof_session # buildcleaner: keep
+ "@abseil-py//absl/testing:absltest",
+ "@abseil-py//absl/testing:parameterized",
+ "@jaxite_deps//jax",
+ "@jaxite_deps//jaxlib",
+ "@jaxite_deps//numpy",
+ ],
+)
+
+tpu_test(
+ name = "number_theory_transform_test",
+ srcs = ["jaxite_ec/number_theory_transform_test.py"],
+ deps = [
+ ":jaxite",
+ "@abseil-py//absl/testing:absltest",
+ "@abseil-py//absl/testing:parameterized",
+ "@jaxite_deps//jax",
+ "@jaxite_deps//jaxlib",
+ "@jaxite_deps//numpy",
+ ],
+)
+
+tpu_test(
+ name = "number_theory_transform_perf_test",
+ size = "large",
+ timeout = "eternal",
+ srcs = ["jaxite_ec/number_theory_transform_perf_test.py"],
+ deps = [
+ ":jaxite",
+ "@abseil-py//absl/testing:absltest",
+ "@abseil-py//absl/testing:parameterized",
+ "@jaxite_deps//jax",
+ "@jaxite_deps//jaxlib",
+ "@jaxite_deps//numpy",
],
)
diff --git a/MODULE.bazel b/MODULE.bazel
index 3919bc0..9605595 100644
--- a/MODULE.bazel
+++ b/MODULE.bazel
@@ -4,6 +4,7 @@ bazel_dep(name = "bazel_skylib", version = "1.9.0")
bazel_dep(name = "rules_license", version = "1.0.0")
bazel_dep(name = "abseil-py", version = "2.1.0")
bazel_dep(name = "rules_python", version = "1.5.1")
+bazel_dep(name = "rules_cc", version = "0.0.17")
# Hermetic python setup
python = use_extension("@rules_python//python/extensions:python.bzl", "python")
diff --git a/jaxite_ec/Morph_logo.png b/jaxite_ec/Morph_logo.png
new file mode 100644
index 0000000..5890424
Binary files /dev/null and b/jaxite_ec/Morph_logo.png differ
diff --git a/jaxite_ec/README.md b/jaxite_ec/README.md
new file mode 100644
index 0000000..02a5d07
--- /dev/null
+++ b/jaxite_ec/README.md
@@ -0,0 +1,227 @@
+
+
+
+
+
+
+
+TPU-accelerated, free, immediate, fast, and cheap HE serving for everyone
+
+
+| paper |
+code |
+tutorial |
+
+
+🔥 We have delivered a tutorial at ASPLOS'26 to help you get started with MORPH. Please visit [CPA_tutorial](https://efficientppml.github.io/CROSS_Tutorial/) to learn more.
+For questions, please drop an email to our community [email](cpacommunity@googlegroups.com).
+
+---
+
+# MORPH: Enable AI Accelerator for Zero Knowledge Proof
+[](./LICENSE)
+
+# 1. What is MORPH?
+MORPH is the first project to enable AI Accelerator, such as Google TPUs, to accelerate Zero Knowledge Proof Primitives (Multi-scalar Multiplication and Number Theory Transformation) and achieves the State-of-the-art (SotA) throughput and energy efficiency (performance per watt). Together with [CROSS](https://github.com/EfficientPPML/CROSS), they enable AI ASICs to be SotA throughput machine for cryptography primitive with wide-range precision.
+
+
+
+It features
+- MXU Lazy Modular Reduction: bringing quadratic high-precision modular reduction down to linear operation.
+
+
+
+- dataflow optimization for MSM and NTT. Details in the [paper](https://arxiv.org/abs/2604.17808).
+
+This branch (`asplos`) contains demo scripts for profiling and comparing the two core workloads.
+
+## Project Structure
+
+```
+├── finite_field_context.py # Finite field arithmetic (MORPH & CROSS backends)
+├── elliptic_curve_context.py # Elliptic curve point arithmetic
+├── multiscalar_multiplication_context.py # Multi-scalar multiplication (MSM)
+├── number_theory_transform_context.py # Number Theoretic Transform (NTT)
+├── utils.py # JAX kernel utilities, number theory helpers
+├── profiler.py # Trace parsing and kernel profiling
+├── configurations.toml # Curve parameters (BLS12-377)
+├── c_kernels/ # Custom C kernels for TPU acceleration
+├── deployments/ # Serialized compiled JAX kernels
+```
+All functions have `_test.py` and `_perf_test.py` for correctness and performance testing.
+
+
+## Key Concepts
+
+| Concept | Description |
+|---------|-------------|
+| **DRNS (Double RNS)** | Residue Number System representation enabling efficient large-integer modular arithmetic on TPU |
+| **MORPH** | Alternative modular multiplication backend using chunk-based representation |
+| **MSM** | Multi-scalar multiplication — computing $\sum_i s_i \cdot P_i$ over elliptic curve points |
+| **Bucket Accumulation** | MSM decomposition strategy: scalars are sliced into windows, points accumulated into buckets per window |
+| **Compiled Kernels** | Pre-compiled JAX/C kernels stored in `deployments/` for fast TPU execution |
+| **Sharding** | Distribution of computation aMORPH TPU cores |
+
+
+# 2. Environment Setup
+
+Inside TPU VM, please do following setup to configure the environment.
+
+Step 1: install miniconda
+```
+wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
+chmod +x ./Miniconda3-latest-Linux-x86_64.sh
+./Miniconda3-latest-Linux-x86_64.sh
+```
+
+Step 2: create environment and install required packages
+```
+source ~/.bashrc
+conda create --name jaxite python=3.13
+conda activate jaxite
+pip install -U "jax[tpu]"
+pip install xprof
+pip install absl-py
+pip install toml
+pip install gdown
+pip install pandas
+pip install gmpy2
+```
+
+Step 3: Install the C++ toolchain for the MSM C kernel.
+
+The MSM path uses a CPU C kernel (`c_kernels/distribution.cpp`) that is
+compiled on the first import of `multiscalar_multiplication_context`. You need
+a host `g++` with OpenMP, and the conda env's bundled `libstdc++` must be
+recent enough to satisfy the symbols emitted by that compiler. On modern Ubuntu
+(g++ 13+) this means `GLIBCXX_3.4.32`, which the conda default
+`libstdcxx-ng 11.2.0` does **not** ship — so install a newer one from
+`conda-forge`:
+```
+sudo apt-get install -y g++ # skip if already installed
+conda install -n jaxite -c conda-forge 'libstdcxx-ng>=13' 'libgcc-ng>=13'
+```
+If you see `OSError: ... libstdc++.so.6: version 'GLIBCXX_3.4.XX' not found`
+when importing `multiscalar_multiplication_context`, the conda libstdc++ is
+older than your system `g++` — re-run the `conda install` line above (or
+`LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libstdc++.so.6` for a one-shot
+workaround).
+
+Step 4: Download the reference data
+```
+mkdir -p data && gdown 1aJhANlS8hWrjSt9j0nBKoFRBoZh0W1aa -O data/data.tar.gz && cd data && tar -xvf data.tar.gz
+```
+
+Step 5 (optional): Pre-build the MSM C kernel.
+
+The C kernel is compiled automatically by `c_kernels/build.py` on the first
+import of `multiscalar_multiplication_context`, so no separate build step is
+required. To pre-build (or force a rebuild) ahead of time:
+```
+python -m c_kernels.build # build if missing or stale
+python -m c_kernels.build --force # always rebuild
+```
+The compiler defaults to `g++` with `-std=c++17 -fopenmp -O2 -fPIC -shared`
+plus `-I/include`. Override via the `CXX` and `CXXFLAGS` env vars
+(e.g. point `CXX` at conda's `gxx_linux-64` to keep the build inside the env).
+
+# 3. TPU Setup
+The code is optimized for TPU execution, but it also runs on NVIDIA GPU and CPU for functional preview (not optimized for these devices).
+
+- Step 1: Create a Google Project [tutorial](https://cloud.google.com/appengine/docs/standard/nodejs/building-app/creating-project).
+
+Obtain the name of the project as and **Google Project ID** from the created project.
+
+- Step 2: Apply for the Tree-tier TPU trail for 30 days[TRC](https://sites.research.google/trc/about/)
+
+Once submitted the request, an email will be shot to you within one day, where there is a link to fill in a survey with your **Google project ID**.
+
+- Step 3: Launch TPU VM.
+You could do it over GUI or gcloud cli (in your local machine) to create a TPU VM. I give the gcloud cli as it works for all generations (>=v4) of TPUs.
+
+For TPUv6e,
+```bash
+gcloud config set project
+gcloud config set compute/zone us-east1-d
+gcloud alpha compute tpus queued-resources create --node-id= \
+ --zone=us-east1-d \
+ --accelerator-type=v6e-1 \
+ --runtime-version=v2-alpha-tpuv6e \
+ --provisioning-model=spot
+```
+
+Note that TPUv5e and TPUv6e could only work with provisioning-model as spot, because they are popular resources, and Google cloud can preempt it if there are tasks with higher priority requiring these resources. But you could get a long-term active TPUv4 VM as it's less demanding by other tasks.
+
+- Step 4: Setup Remote SSH (VSCode or Cursor) to TPU VM
+Once the requested TPU vm is up and running as shown in Google console, you could use gcloud to forward the SSH port of the remote machine to a port of local machine and setup VSCode remote ssh.
+
+You need to first setup local ssh key to Google's compute engine, following [link](https://cloud.google.com/compute/docs/connect/create-ssh-keys#gcloud). After your follow the instructions on the page, the ssh key will be dumped here `/.ssh/google_compute_engine`.
+
+
+```bash
+gcloud compute tpus tpu-vm ssh @ -- -L 9009:localhost:22
+```
+Where 9009 is the port of local machine, while 22 is the SSH port of the TPU vm.
+
+After you set it up, you could configure VSCode to use the remote SSH package [link](https://code.visualstudio.com/docs/remote/ssh) to remotely access into TPUvm.
+```bash
+Host tpu-vm
+ User
+ HostName localhost
+ Port 9009
+ IdentityFile /.ssh/google_compute_engine
+```
+
+After this, you should follow the steps on [link](https://code.visualstudio.com/docs/remote/ssh) to log into TPU VM.
+
+# 4. Ready to Play?
+
+Run functional correctness tests for both NTT and MSM:
+```
+python3 number_theory_transform_test.py
+python3 multiscalar_multiplication_test.py
+```
+
+Run performance tests for both NTT and MSM:
+
+```
+python3 number_theory_transform_perf_test.py
+python3 multiscalar_multiplication_perf_test.py
+```
+
+Notes:
+- The first MSM test run auto-compiles `c_kernels/distribution.cpp` into
+ `c_kernels/distribution.so` (a few seconds). Subsequent runs reuse the cached
+ `.so` and rebuild only when the source is newer.
+- The first run of each test also JIT-compiles JAX kernels; expect a longer
+ first iteration that is then cached under `deployments/`.
+- Performance tests assume the reference data from Step 4 is present under
+ `./data/`.
+
+# 5. Call for Actions
+Our mission is to build an open-sourced SoTA library for the community.
+- If you find this repository helpful, please consider giving it a star :)
+- For any questions, please feel free to open an issue.
+- For any suggestions or new features, please feel free to open a pull request.
+
+# Contact
+- Jianming Tong, Georgia Institute of Technology / Google, jianming.tong@gatech.edu/jianmingt@google.com
+- Jingtian Dang, Georgia Institute of Technology, dangjingtian@gatech.edu
+- Tushar Krishna, Georgia Institute of Technology, tushar@ece.gatech.edu
+
+
+# Citation
+
+```
+@inproceedings{tong2025MORPH,
+author = {Jianming Tong and Jingtian Dang and Simon Langowski and Tianhao Huang and Asra Ali and Jeremy Kun and Srini Devadas and Tushar Krishna},
+title = {MORPH: Enabling AI ASICs for Zero Knowledge Proof},
+year = {2026},
+publisher = {IEEE Press},
+booktitle = {Proceedings of the 63nd Annual ACM/IEEE Design Automation Conference},
+location = {Los Angeles, California, United States},
+series = {DAC '26}
+}
+```
+
+Enjoy! :D
diff --git a/jaxite_ec/__init__.py b/jaxite_ec/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/jaxite_ec/algorithm/big_integer.py b/jaxite_ec/algorithm/big_integer.py
deleted file mode 100644
index d79b763..0000000
--- a/jaxite_ec/algorithm/big_integer.py
+++ /dev/null
@@ -1,169 +0,0 @@
-"""Big integer classes for jaxite_ec."""
-
-import gmpy2
-
-
-class GMPBigInteger:
- """A class representing a big integer using gmpy2.
-
- This class provides basic arithmetic operations for big integers using gmpy2.
- """
-
- def __init__(self, value) -> None:
- if isinstance(value, (int, gmpy2.mpz)):
- self.value = gmpy2.mpz(value)
- elif isinstance(value, GMPBigInteger):
- self.value = value.value
- else:
- raise TypeError("Unsupported type for GMPBigInteger initialization")
-
- def __add__(self, other):
- if isinstance(other, (GMPBigInteger, int)):
- return GMPBigInteger(
- self.value
- + gmpy2.mpz(
- other.value if isinstance(other, GMPBigInteger) else other
- )
- )
- return NotImplemented
-
- def __sub__(self, other):
- if isinstance(other, (GMPBigInteger, int)):
- return GMPBigInteger(
- self.value
- - gmpy2.mpz(
- other.value if isinstance(other, GMPBigInteger) else other
- )
- )
- return NotImplemented
-
- def __mul__(self, other):
- if isinstance(other, (GMPBigInteger, int)):
- return GMPBigInteger(
- self.value
- * gmpy2.mpz(
- other.value if isinstance(other, GMPBigInteger) else other
- )
- )
- return NotImplemented
-
- def __truediv__(self, other):
- if isinstance(other, (GMPBigInteger, int)):
- if (
- gmpy2.mpz(other.value if isinstance(other, GMPBigInteger) else other)
- == 0
- ):
- raise ZeroDivisionError("division by zero")
- return GMPBigInteger(
- self.value
- // gmpy2.mpz(
- other.value if isinstance(other, GMPBigInteger) else other
- )
- )
- return NotImplemented
-
- def __mod__(self, other):
- if isinstance(other, (GMPBigInteger, int)):
- return GMPBigInteger(
- self.value
- % gmpy2.mpz(
- other.value if isinstance(other, GMPBigInteger) else other
- )
- )
- return NotImplemented
-
- def __eq__(self, other):
- if isinstance(other, (GMPBigInteger, int)):
- return self.value == gmpy2.mpz(
- other.value if isinstance(other, GMPBigInteger) else other
- )
- return NotImplemented
-
- def __ne__(self, other):
- if isinstance(other, (GMPBigInteger, int)):
- return self.value != gmpy2.mpz(
- other.value if isinstance(other, GMPBigInteger) else other
- )
- return NotImplemented
-
- def __lt__(self, other):
- if isinstance(other, (GMPBigInteger, int)):
- return self.value < gmpy2.mpz(
- other.value if isinstance(other, GMPBigInteger) else other
- )
- return NotImplemented
-
- def __le__(self, other):
- if isinstance(other, (GMPBigInteger, int)):
- return self.value <= gmpy2.mpz(
- other.value if isinstance(other, GMPBigInteger) else other
- )
- return NotImplemented
-
- def __gt__(self, other):
- if isinstance(other, (GMPBigInteger, int)):
- return self.value > gmpy2.mpz(
- other.value if isinstance(other, GMPBigInteger) else other
- )
- return NotImplemented
-
- def __ge__(self, other):
- if isinstance(other, (GMPBigInteger, int)):
- return self.value >= gmpy2.mpz(
- other.value if isinstance(other, GMPBigInteger) else other
- )
- return NotImplemented
-
- def __pow__(self, exponent, modulus=None):
- if isinstance(exponent, (GMPBigInteger, int, gmpy2.mpz)):
- if isinstance(exponent, GMPBigInteger):
- exponent = gmpy2.mpz(exponent.value)
- if isinstance(modulus, GMPBigInteger):
- modulus = gmpy2.mpz(modulus.value)
- if modulus is None:
- return GMPBigInteger(self.value**exponent)
- else:
- return GMPBigInteger(gmpy2.powmod(self.value, exponent, modulus))
- else:
- print(type(exponent))
- raise TypeError("Exponent must be an integer")
-
- def __lshift__(self, shift):
- """Left shift operator (<<)."""
- if isinstance(shift, GMPBigInteger):
- shift = shift.value
- return GMPBigInteger(self.value << shift)
-
- def __rshift__(self, shift):
- """Right shift operator (>>)."""
- if isinstance(shift, GMPBigInteger):
- shift = shift.value
- return GMPBigInteger(self.value >> shift)
-
- def __and__(self, other):
- if isinstance(other, (GMPBigInteger, int)):
- return GMPBigInteger(
- self.value
- & gmpy2.mpz(
- other.value if isinstance(other, GMPBigInteger) else other
- )
- )
- return NotImplemented
-
- def ceil_log2(self):
- """Calculate the base-2 logarithm of the GMPBigInteger."""
- if self.value <= 0:
- raise ValueError("log2 is only defined for positive integers")
- return GMPBigInteger(gmpy2.ceil(gmpy2.log2(self.value)))
-
- def __int__(self):
- return int(self.value)
-
- def __str__(self):
- return str(self.value)
-
- def __repr__(self):
- return f"GMPBigInteger({self.value})"
-
- def hex_value_str(self) -> str:
- return hex(self.value)
diff --git a/jaxite_ec/algorithm/config_file.py b/jaxite_ec/algorithm/config_file.py
deleted file mode 100644
index 0233520..0000000
--- a/jaxite_ec/algorithm/config_file.py
+++ /dev/null
@@ -1,46 +0,0 @@
-"""Jaxite EC algorithm configuration file."""
-
-config_BLS12_377 = {
- # A small prime field for simplicity
- 'prime': 0x01AE3A4617C510EAC63B05C06CA1493B1A22D9F300F5138F1EF3622FBA094800170B5D44300000008508C00000000001,
- 'order': 0x12AB655E9A2CA55660B44D1E5C37B00159AA76FED00000010A11800000000001,
- 'a': 0, # Coefficient a = 0
- 'b': 1, # Coefficient b = 1
- 'generator': [
- 0x008848DEFE740A67C8FC6225BF87FF5485951E2CAA9D41BB188282C8BD37CB5CD5481512FFCD394EEAB9B16EB21BE9EF,
- 0x01914A69C5102EFF1F674F5D30AFEEC4BD7FB348CA3E52D96D182AD44FB82305C2FE3D3634A9591AFD82DE55559C8EA6,
- ],
-}
-
-# https://docs.rs/ark-ed-on-bls12-377/latest/ark_ed_on_bls12_377/index.html
-config_BLS12_377_t = {
- 'prime': (
- 258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177
- ),
- 'order': (
- 8444461749428370424248824938781546531375899335154063827935233455917409239041
- ),
- 'a': -1,
- 'd': (
- 122268283598675559488486339158635529096981886914877139579534153582033676785385790730042363341236035746924960903179
- ),
- 'alpha': -1,
- 'b': 1,
- 's': (
- 10189023633222963290707194929886294091415157242906428298294512798502806398782149227503530278436336312243746741931
- ),
- 'MA': (
- 228097355113300204138531148905234651262148041026195375645000724271212049151994375092458297304264351187709081232384
- ),
- 'MB': (
- 10189023633222963290707194929886294091415157242906428298294512798502806398782149227503530278436336312243746741931
- ),
- 't': (
- 23560188534917577818843641916571445935985386319233886518929971599490231428764380923487987729215299304184915158756
- ),
- # 't': 235104237478051516191809091778322087600408126435680774020954291067230236919576441851480900410358060820255406299421,
- 'generator': [
- 71222569531709137229370268896323705690285216175189308202338047559628438110820800641278662592954630774340654489393,
- 6177051365529633638563236407038680211609544222665285371549726196884440490905471891908272386851767077598415378235,
- ],
-}
diff --git a/jaxite_ec/algorithm/elliptic_curve.py b/jaxite_ec/algorithm/elliptic_curve.py
deleted file mode 100644
index 6b649d9..0000000
--- a/jaxite_ec/algorithm/elliptic_curve.py
+++ /dev/null
@@ -1,775 +0,0 @@
-"""Elliptic curve coordinate systems and points."""
-
-import abc
-import copy
-import enum
-from typing import Dict, Generic, Iterable, List, Optional, TypeVar, Union
-
-from jaxite.jaxite_ec.algorithm import big_integer
-from jaxite.jaxite_ec.algorithm import finite_field
-
-
-BigInt = big_integer.GMPBigInteger
-
-abstractmethod = abc.abstractmethod
-ABC = abc.ABC
-Auto = enum.auto
-Enum = enum.Enum
-T = TypeVar('T')
-FieldEle = finite_field.FiniteFieldElement
-
-
-class CoordinateSystemType(Enum):
- """Enum to represent different types of coordinate systems for elliptic curves.
-
- AFFINE: The affine coordinate system, the standard coordinate system for
- elliptic curves: https://www.hyperelliptic.org/EFD/g1p/auto-shortw.html
-
- PROJECTIVE: The projective coordinate system,
- https://www.hyperelliptic.org/EFD/g1p/auto-shortw-projective.html#addition-madd-1998-cmo
-
- XYZZ: The XYZZ coordinate system:
- https://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html
- """
-
- NONE = Auto()
- WEIERSTRASS_AFFINE = Auto()
- WEIERSTRASS_PROJECTIVE = Auto()
- WEIERSTRASS_XYZZ = Auto()
-
-
-class ECPoint(Generic[T]):
- """Represents a point in an elliptic curve coordinate system."""
-
- def __init__(
- self,
- coordinates: List[T],
- coordinate_system: 'EllipticCurveCoordinateSystem',
- zero: bool = False,
- ) -> None:
- self.coordinate_system = coordinate_system
- self.zero = zero
- self.coordinates = self.coordinate_system.generate_formal_coordinates(
- coordinates, zero
- )
- self.type = coordinate_system.get_type()
-
- def __getitem__(self, index: Union[int, slice]):
- return self.coordinates[index]
-
- def __setitem__(
- self, index: Union[int, slice], value: Union[T, Iterable[T]]
- ) -> None:
- self.coordinates[index] = value # type: ignore
-
- def __eq__(self, other: 'ECPoint') -> bool:
- if not isinstance(other, ECPoint):
- return NotImplemented
-
- return self.coordinates == other.coordinates
-
- def __add__(self, other: 'ECPoint') -> 'ECPoint':
- return self.coordinate_system.point_add(self, other)
-
- def __lshift__(self, shift) -> 'ECPoint':
- if self.zero:
- return self.copy()
- return self.coordinate_system.point_lshift(self, shift)
-
- def is_zero(self):
- return self.zero
-
- def get_type(self) -> CoordinateSystemType:
- return self.type
-
- def set_type(self, cs_type: CoordinateSystemType):
- self.type = cs_type
-
- def set_coordinate_system(
- self, coordinate_system: 'EllipticCurveCoordinateSystem'
- ):
- self.coordinate_system = coordinate_system
- self.type = coordinate_system.get_type()
-
- def append(self, coordinate: T):
- assert self.coordinates is not None
- self.coordinates.append(coordinate)
-
- def copy(self):
- obj = copy.copy(self)
- if not self.is_zero():
- obj.coordinates = self.coordinate_system.generate_formal_coordinates(
- self.coordinates, self.zero
- )
- return obj
-
- def convert_to_affine(self):
- return self.coordinate_system.convert_to_affine(self)
-
- def __str__(self) -> str:
- if self.is_zero():
- return 'Point, O'
- coord_strs = [coord.hex_value_str() for coord in self.coordinates] # pytype: disable=attribute-error
- return 'Point, ' + ', '.join(coord_strs)
-
-
-class EllipticCurveCoordinateSystem(ABC, Generic[T]):
- """Abstract base class for elliptic curve coordinate systems."""
-
- def __init__(self, config: Dict[str, Union[int, List[int]]]) -> None:
- """Initialize the coordinate system.
-
- Args:
- config: A dictionary containing the configuration of the coordinate
- system. The dictionary should contain the following keys: - ff_zero: The
- value "zero" in the finite field. - ff_one: The value "one" in the
- finite field. - type: The type of the coordinate system.
- """
- super().__init__()
- self.config = config
- self.ff_zero: FieldEle = FieldEle(0, config['prime'])
- self.ff_one: FieldEle = FieldEle(1, config['prime'])
- self.type = CoordinateSystemType.NONE
-
- @abstractmethod
- def generate_formal_coordinates(
- self, coordinates: List[Union[int, FieldEle]], zero: bool
- ) -> List[T]:
- pass
-
- def get_type(self):
- return self.type
-
- def generate_point(
- self,
- coordinates: Optional[Union[List[T], List[int]]] = None,
- zero: bool = False,
- ) -> ECPoint[T]:
- return ECPoint[T](coordinates, self, zero)
-
- @abstractmethod
- def point_add(self, point_a: ECPoint[T], point_b: ECPoint[T]):
- pass
-
- @abstractmethod
- def point_lshift(self, point_a: ECPoint[T], index):
- pass
-
- @abstractmethod
- def convert_to_affine(self, point_a: ECPoint[T]):
- pass
-
-
-class ECCSWeierstrass(EllipticCurveCoordinateSystem[FieldEle]):
- """Base class for Weierstrass coordinate systems."""
-
- def __init__(self, config: Dict[str, Union[int, List[int]]]) -> None:
- """Initialize the EC Weierstrass coordinate systems.
-
- Weierstrass: y^2 = x^3 + a*x + b
-
- Args:
- config: A dictionary containing the configuration of the coordinate
- system. The dictionary should contain the following keys: - generator: A
- special point "G" in the Elliptic Curve. - order: order * generator == 0
- in the Elliptic Curve. - prime: The modulus of the Elliptic Curve. - a:
- The parameter a of the Elliptic Curve. - b: The parameter b of the
- Elliptic Curve.
- """
- super().__init__(config)
- self.prime = BigInt(config['prime'])
- self.order = BigInt(config['order'])
-
- self.generator: List[FieldEle] = []
- for coordinate in config['generator']:
- element = self.ff_zero.copy(coordinate)
- self.generator.append(element)
- self.a = self.ff_zero.copy(config['a'])
- self.b = self.ff_zero.copy(config['b'])
-
- def generate_formal_coordinates(
- self, coordinates: List[Union[int, FieldEle]], zero: bool = False
- ):
- if zero:
- return None
- formal_coordinate: List[FieldEle] = []
- for coordinate in coordinates:
- if isinstance(coordinate, FieldEle):
- formal_coordinate.append(coordinate)
- else:
- formal_coordinate.append(self.ff_zero.copy(coordinate))
- return formal_coordinate
-
-
-class ECCSWeierstrassAffine(ECCSWeierstrass):
- """Weierstrass affine coordinate system.
-
- It has unique point add calculation. Details pleaase refer to:
- AFFINE: The affine coordinate system, the standard coordinate system for
- elliptic curves: https://www.hyperelliptic.org/EFD/g1p/auto-shortw.html
- """
-
- def __init__(self, config: Dict[str, Union[int, List[int]]]) -> None:
- super().__init__(config)
- self.type = CoordinateSystemType.WEIERSTRASS_AFFINE
-
- def add_general(self, point_a: ECPoint[FieldEle], point_b: ECPoint[FieldEle]):
- slope = (point_b[1] - point_a[1]) / (point_b[0] - point_a[0])
- cx = (slope * slope) - point_a[0] - point_b[0]
- cy = slope * (point_a[0] - cx) - point_a[1]
- return ECPoint[FieldEle]([cx, cy], self)
-
- def double_general(self, point_a: ECPoint[FieldEle]):
- x1 = point_a[0]
- y1 = point_a[1]
- slope = (x1 * x1 * 3 + self.a) / (y1 * 2)
- cx = slope * slope - x1 - x1
- cy = slope * (x1 - cx) - y1
- return ECPoint[FieldEle]([cx, cy], self)
-
- def point_add(self, point_a: ECPoint[FieldEle], point_b: ECPoint[FieldEle]):
- if point_b.is_zero():
- return point_a.copy()
- elif point_a.is_zero():
- return point_b.copy()
-
- if point_a == point_b:
- result = self.double_general(point_a)
- else:
- result = self.add_general(point_a, point_b)
-
- return result
-
- def point_lshift(self, point_a: ECPoint[FieldEle], shift: int):
- if point_a.is_zero():
- return point_a.copy()
-
- for _ in range(shift):
- point_a = self.double_general(point_a)
- return point_a
-
- def convert_to_affine(self, point_a: ECPoint[FieldEle]) -> ECPoint[FieldEle]:
- return point_a
-
- def generate_formal_coordinates(
- self, coordinates: List[Union[int, FieldEle]], zero: bool = False
- ) -> Optional[List[FieldEle]]:
- if zero:
- return None
- coordinate_length = len(coordinates)
- assert coordinate_length == 2 or coordinate_length == 3
- formal_coordinates: List[FieldEle] = []
- for coordinate in coordinates:
- if isinstance(coordinate, FieldEle):
- formal_coordinates.append(coordinate)
- else:
- formal_coordinates.append(self.ff_zero.copy(coordinate))
- if coordinate_length == 2:
- assert len(formal_coordinates) == 3
- return formal_coordinates
-
-
-class ECCSWeierstrassProjective(ECCSWeierstrass):
- """Weierstrass projective coordinate system.
-
- PROJECTIVE: The projective coordinate system,
- https://www.hyperelliptic.org/EFD/g1p/auto-shortw-projective.html#addition-madd-1998-cmo
- """
-
- def __init__(self, config: Dict[str, Union[int, List[int]]]) -> None:
- super().__init__(config)
- self.prime = BigInt(config['prime'])
- self.order = BigInt(config['order'])
- self.type = CoordinateSystemType.WEIERSTRASS_PROJECTIVE
-
- self.generator: List[FieldEle] = []
- for coordinate in config['generator']:
- element = self.ff_zero.copy(coordinate)
- self.generator.append(element)
-
- self.a = self.ff_zero.copy(config['a'])
- self.b = self.ff_zero.copy(config['b'])
-
- def add_z2_eq_1(self, point_a: ECPoint[FieldEle], point_b: ECPoint[FieldEle]):
- if point_b[2] == self.ff_one:
- x1, y1, z1 = point_a
- x2, y2, _ = point_b
- elif point_a[2] == self.ff_one:
- x1, y1, z1 = point_b
- x2, y2, _ = point_a
- else:
- raise NotImplementedError
-
- u = y2 * z1 - y1
- uu = u * u
- v = x2 * z1 - x1
- vv = v * v
- vvv = v * vv
- r = vv * x1
- a = uu * z1 - vvv - (r + r)
- x3 = v * a
- y3 = u * (r - a) - vvv * y1
- z3 = vvv * z1
- return ECPoint[FieldEle]([x3, y3, z3], self)
-
- def add_general(self, point_a: ECPoint[FieldEle], point_b: ECPoint[FieldEle]):
- x1, y1, z1 = point_a
- x2, y2, z2 = point_b
-
- b3 = self.b * self.ff_zero.copy(3)
- a = self.a
-
- # Perform the operations
- t0 = x1 * x2
- t1 = y1 * y2
- t2 = z1 * z2
- t3 = (x1 + y1) * (x2 + y2)
- t4 = t0 + t1
- t3 = t3 - t4
- t4 = (x1 + z1) * (x2 + z2)
- t5 = t0 + t2
- t4 = t4 - t5
- t5 = (y1 + z1) * (y2 + z2)
- x3 = t1 + t2
- t5 = t5 - x3
- z3 = a * t4
- x3 = b3 * t2
- z3 = x3 + z3
- x3 = t1 - z3
- z3 = t1 + z3
- y3 = x3 * z3
- t1 = t0 + t0
- t1 = t1 + t0
- t2 = a * t2
- t4 = b3 * t4
- t1 = t1 + t2
- t2 = t0 - t2
- t2 = a * t2
- t4 = t4 + t2
- t0 = t1 * t4
- y3 = y3 + t0
- t0 = t5 * t4
- x3 = t3 * x3
- x3 = x3 - t0
- t0 = t3 * t1
- z3 = t5 * z3
- z3 = z3 + t0
-
- return ECPoint[FieldEle]([x3, y3, z3], self)
-
- def double_general(self, point_a: ECPoint):
- x1, y1, z1 = point_a
- a = self.a
- ff2 = self.ff_zero.copy(2)
-
- # Perform the operations based on the pseudocode
- xx = x1 * x1
- zz = z1 * z1
- w = a * zz + xx * self.ff_zero.copy(3)
- s = y1 * z1 * ff2
- ss = s * s
- sss = s * ss
- r = y1 * s
- rr = r * r
- b = (x1 + r) * (x1 + r) - xx - rr
- h = w * w - b * ff2
- x3 = h * s
- y3 = w * (b - h) - rr * ff2
- z3 = sss
-
- return ECPoint([x3, y3, z3], self)
-
- def point_add(
- self, point_a: ECPoint[FieldEle], point_b: ECPoint[FieldEle]
- ) -> ECPoint[FieldEle]:
- if point_b.is_zero():
- return point_a.copy()
- elif point_a.is_zero():
- return point_b.copy()
-
- if point_a == point_b:
- result = self.double_general(point_a)
- elif point_a[2] == self.ff_one or point_b[2] == self.ff_one:
- result = self.add_z2_eq_1(point_a, point_b)
- else:
- result = self.add_general(point_a, point_b)
- return result
-
- def point_lshift(self, point: ECPoint[FieldEle], shift: int):
- if point.is_zero():
- return point.copy()
-
- for _ in range(shift):
- point = self.double_general(point)
- return point
-
- def convert_from_affine(
- self, point_a: ECPoint[FieldEle]
- ) -> ECPoint[FieldEle]:
- assert point_a.get_type() == self.type
- new_point = point_a.copy()
- if new_point.coordinates is not None:
- new_point.coordinates.clear()
- z_invert = self.ff_one / point_a[2]
- new_point.append(point_a[0] * z_invert)
- new_point.append(point_a[1] * z_invert)
- new_point.append(self.ff_one)
- new_point.set_type(CoordinateSystemType.WEIERSTRASS_AFFINE)
- return new_point
- else:
- return new_point
-
- def convert_to_affine(self, point: ECPoint) -> ECPoint:
- assert point.get_type() == self.type
- new_point = point.copy()
- if new_point.coordinates is not None:
- new_point.coordinates.clear()
- z_invert = self.ff_one / point[2]
- new_point.append(point[0] * z_invert)
- new_point.append(point[1] * z_invert)
- new_point.append(self.ff_one)
- new_point.set_type(CoordinateSystemType.WEIERSTRASS_AFFINE)
- return new_point
- else:
- return new_point
-
-
-class ECCSWeierstrassXYZZ(ECCSWeierstrass):
- """Weierstrass XYZZ coordinate system."""
-
- def __init__(self, config: Dict[str, Union[int, List[int]]]) -> None:
- super().__init__(config)
- self.type = CoordinateSystemType.WEIERSTRASS_XYZZ
-
- def generate_formal_coordinates(
- self, coordinates: List[Union[int, FieldEle]], zero: bool = False
- ) -> Optional[List[FieldEle]]:
- if zero:
- return None
- coordinate_length = len(coordinates)
- assert coordinate_length == 2 or coordinate_length == 4
- formal_coordinates: List[FieldEle] = []
- for coordinate in coordinates:
- if isinstance(coordinate, FieldEle):
- formal_coordinates.append(coordinate)
- else:
- formal_coordinates.append(self.ff_zero.copy(coordinate))
- if coordinate_length == 2:
- formal_coordinates.append(self.ff_one.copy())
- formal_coordinates.append(self.ff_one.copy())
- assert len(formal_coordinates) == 4
- return formal_coordinates
-
- def add_z2_eq_1(self, point_a: ECPoint[FieldEle], point_b: ECPoint[FieldEle]):
- """Add the general coordinates of two points in the XYZZ coordinate system.
-
- This function is not implemented for the XYZZ coordinate system.
-
- Args:
- point_a: The first point.
- point_b: The second point.
-
- Returns:
- The added general coordinates of the two points.
- """
- raise NotImplementedError
- # return ECPoint[FieldEle]([], self)
-
- def add_general(self, point_a: ECPoint[FieldEle], point_b: ECPoint[FieldEle]):
- x1, y1, zz1, zzz1 = point_a
- x2, y2, zz2, zzz2 = point_b
-
- u1 = x1 * zz2
- u2 = x2 * zz1
- s1 = y1 * zzz2
- s2 = y2 * zzz1
- p = u2 - u1
- r = s2 - s1
- pp = p * p
- ppp = p * pp
- q = u1 * pp
- x3 = r * r - ppp - (q + q)
- y3 = r * (q - x3) - s1 * ppp
- zz3 = zz1 * zz2 * pp
- zzz3 = zzz1 * zzz2 * ppp
-
- return ECPoint[FieldEle]([x3, y3, zz3, zzz3], self)
-
- def double_general(self, point_a: ECPoint[FieldEle]):
- """Double the general coordinates of a point in the XYZZ coordinate system.
-
- Args:
- point_a: The point to double the general coordinates of.
-
- Returns:
- The doubled general coordinates of the point.
- """
- x1, y1, zz1, zzz1 = point_a
- a = self.a
-
- # Perform the operations based on the pseudocode
- u = y1 + y1
- v = u * u
- w = u * v
- s = x1 * v
- x1_sq = x1 * x1
- m = (x1_sq + x1_sq + x1_sq) + a * (zz1 * zz1)
- x3 = (m * m) - (s + s)
- y3 = m * (s - x3) - w * y1
- zz3 = v * zz1
- zzz3 = w * zzz1
-
- return ECPoint[FieldEle]([x3, y3, zz3, zzz3], self)
-
- def point_add(
- self, point_a: ECPoint[FieldEle], point_b: ECPoint[FieldEle]
- ) -> ECPoint[FieldEle]:
- if point_b.is_zero():
- return point_a.copy()
- elif point_a.is_zero():
- return point_b.copy()
-
- if point_a == point_b:
- result = self.double_general(point_a)
- else:
- result = self.add_general(point_a, point_b)
- return result
-
- def point_lshift(self, point_a: ECPoint[FieldEle], shift: int):
- if point_a.is_zero():
- return point_a.copy()
-
- for _ in range(shift):
- point_a = self.double_general(point_a)
- return point_a
-
- def convert_from_affine(
- self, point_a: ECPoint[FieldEle]
- ) -> ECPoint[FieldEle]:
- assert (
- point_a.coordinate_system.get_type()
- == CoordinateSystemType.WEIERSTRASS_AFFINE
- )
- new_point = point_a.copy()
- new_point.append(self.ff_one)
- new_point.append(self.ff_one)
- new_point.set_coordinate_system(self)
- return new_point
-
- def convert_to_affine(self, point_a: ECPoint[FieldEle]) -> ECPoint[FieldEle]:
- assert point_a.get_type() == self.type
- new_point = point_a.copy()
- if new_point.coordinates is not None:
- new_point.coordinates.clear()
- a = self.ff_one / point_a[3]
- b = point_a[2] * a
- b_sq = b * b
- new_point.append(point_a[0] * b_sq)
- new_point.append(point_a[1] * a)
- new_point.append(self.ff_one)
- new_point.append(self.ff_one)
- new_point.set_type(CoordinateSystemType.WEIERSTRASS_AFFINE)
- return new_point
- else:
- return new_point
-
-
-class ECCSTwistedEdwardsExtended(ECCSWeierstrass):
- """Twisted Edwards Extended coordinate system."""
-
- def __init__(self, config: Dict[str, Union[int, List[int]]]) -> None:
- super().__init__(config)
- self.type = CoordinateSystemType.WEIERSTRASS_PROJECTIVE
- self.a = self.ff_zero.copy(config['a'])
- self.d = self.ff_zero.copy(config['d'])
- self.alpha = self.ff_zero.copy(config['alpha'])
- self.s = self.ff_zero.copy(config['s'])
- self.ma = self.ff_zero.copy(config['MA'])
- self.mb = self.ff_zero.copy(config['MB'])
- self.t = self.ff_zero.copy(config['t'])
- self.k = self.d + self.d
-
- def generate_point(
- self,
- coordinates: Optional[Union[List[T], List[int]]] = None,
- zero: bool = False,
- twist: bool = True,
- ) -> ECPoint[T]:
- if twist:
- ff_coordinates = []
- for coordinate in coordinates:
- if isinstance(coordinate, FieldEle):
- ff_coordinates.append(coordinate)
- else:
- ff_coordinates.append(self.ff_zero.copy(coordinate))
- coordinates = self.twist(ff_coordinates)
- return ECPoint[T](coordinates, self, zero)
-
- def generate_formal_coordinates(
- self,
- coordinates: List[Union[int, FieldEle]],
- zero: bool = False,
- ) -> Optional[List[FieldEle]]:
- if zero:
- coordinates = self.zero()
- coordinate_length = len(coordinates)
- assert coordinate_length in {2, 4}
- formal_coordinates: List[FieldEle] = []
- for coordinate in coordinates:
- if isinstance(coordinate, FieldEle):
- formal_coordinates.append(coordinate)
- else:
- formal_coordinates.append(self.ff_zero.copy(coordinate))
-
- if coordinate_length == 2:
- formal_coordinates.append(self.ff_one.copy())
- formal_coordinates.append(formal_coordinates[0] * formal_coordinates[1])
-
- return formal_coordinates
-
- def point_add(self, point_a: ECPoint[FieldEle], point_b: ECPoint[FieldEle]):
- # https://www.hyperelliptic.org/EFD/g1p/data/twisted/extended/addition/madd-2008-hwcd
- # copied from arkworks_bls12_377
- x1, y1, z1, t1 = point_a
- x2, y2, z2, t2 = point_b # 0, 1, 1, 0
- a = x1 * x2 # 0
- b = y1 * y2 # y1
- c = self.d * t1 * t2 # 0
- d = z1 * z2 # z1
- # h = b - (self.a * a) # self.a = -1 in standard form
- h = b + a # y1
- # karatsuba
- e = (x1 + y1) * (x2 + y2) - h # (x1 + y1) - y1 = x1
- f = d - c # z1
- g = d + c # z1
- x3 = e * f # x1 * z1
- y3 = g * h # y1 * z1
- z3 = f * g # z1 * z1
- t3 = e * h # x1 * y1
-
- return ECPoint[FieldEle]([x3, y3, z3, t3], self)
-
- def double_general(self, point_a: ECPoint[FieldEle]):
- # https://www.hyperelliptic.org/EFD/g1p/data/twisted/extended/doubling/dbl-2008-hwcd
- # copied from arkworks_bls12_377
- x1, y1, z1, _ = point_a
- a = x1 * x1
- b = y1 * y1
- ct = z1 * z1
- # d = self.a * a
- # d = self.ff_zero - a
- h = self.ff_zero - a - b
- et = x1 + y1
- e = (et * et) + h
- g = b - a
- f = g - ct - ct
- x3 = e * f
- y3 = g * h
- z3 = f * g
- t3 = e * h
- return ECPoint[FieldEle]([x3, y3, z3, t3], self)
-
- def point_lshift(self, point_a: ECPoint[FieldEle], shift: int):
- if point_a.is_zero():
- return point_a.copy()
-
- for _ in range(shift):
- point_a = self.double_general(point_a)
- return point_a
-
- def twist(self, point_a: List[FieldEle]):
- # https://en.wikipedia.org/wiki/Montgomery_curve
- x, y = point_a
- # Convert to montgomery
- xm = self.s * (x - self.alpha)
- ym = self.s * y
- # Convert to edwards
- if ym == self.ff_zero:
- return None
- xt = xm / ym
-
- yt_denom = xm + self.ff_one
- if yt_denom == self.ff_zero:
- return None
- yt = (xm - self.ff_one) / (yt_denom)
-
- xt = xt * self.t
- return [xt, yt]
-
- def untwist(self, point_a: List[FieldEle]):
- xt, yt = point_a
- xt = xt / self.t
- # print("ut", xt, yt)
- # Convert to montgomery
- xm = (self.ff_one + yt) / ((self.ff_one - yt))
- ym = (self.ff_one + yt) / ((self.ff_one - yt) * xt)
- # print("um", xm, ym)
- # Convert to weierstrass
- three = self.ff_zero.copy(3)
- x = (xm / self.mb) + (self.ma / (three * self.mb))
- y = ym / self.mb
- # print("uw", x, y)
- return [x, y]
-
- def convert_from_affine(self, point_a):
- # apply twist?
- return NotImplementedError
-
- def convert_to_twisted_edwards_affine(
- self, point_a: ECPoint[FieldEle]
- ) -> ECPoint[FieldEle]:
- new_point = point_a.copy()
- if new_point.coordinates is not None:
- new_point.coordinates.clear()
- x, y, z = point_a[:3]
- inv_z = self.ff_one / z
- x2 = x * inv_z
- y2 = y * inv_z
- new_point.append(x2)
- new_point.append(y2)
- new_point.append(self.ff_one)
- new_point.append(self.ff_one)
- return new_point
-
- def convert_to_affine(self, point_a: ECPoint[FieldEle]) -> ECPoint[FieldEle]:
- new_point = self.convert_to_twisted_edwards_affine(point_a)
- twisted_coords = new_point[:2]
- new_point[:2] = self.untwist(twisted_coords)
- return new_point
-
- def zero(self):
- return [self.ff_zero, self.ff_one, self.ff_one, self.ff_zero]
-
- def twist_int_coordinates(self, coordinates: List[int]) -> List[int]:
- ff_coordinates = [
- self.ff_zero.copy(coordinate) for coordinate in coordinates
- ]
- ff_coordinates = self.twist(ff_coordinates)
- if ff_coordinates is None:
- return [0, 1, 1, 0] # zero
- ff_coordinates.append(self.ff_one.copy())
- ff_coordinates.append(ff_coordinates[0] * ff_coordinates[1])
- return [int(coord.value) for coord in ff_coordinates]
-
- def twist_projective(self, point_a):
- x, y, z = point_a
- xm = self.s * (x - self.alpha)
- ym = self.s * y
- # Convert to edwards
- # Common denominator
- zt = z * ym * (xm + self.ff_one)
- xt = xm * (xm + self.ff_one)
- yt = (xm - self.ff_one) * ym
-
- xt = xt * self.t
- return [xt, yt, zt]
-
-
-class TwistedInvertedXYZ(ECCSWeierstrass):
-
- def convert_from_affine(self, point_a):
- x, y = point_a
- return [y, x, x * y]
-
- # Looks like the cheapest strongly unified three coordinate option
- # https://www.hyperelliptic.org/EFD/g1p/data/twisted/inverted/addition/add-2008-bbjlp
diff --git a/jaxite_ec/algorithm/finite_field.py b/jaxite_ec/algorithm/finite_field.py
deleted file mode 100644
index 6a3ea36..0000000
--- a/jaxite_ec/algorithm/finite_field.py
+++ /dev/null
@@ -1,304 +0,0 @@
-"""Finite field elements for elliptic curve cryptography."""
-
-import copy
-import math
-
-from jaxite.jaxite_ec.algorithm import big_integer
-
-BigInt = big_integer.GMPBigInteger
-
-
-class FiniteFieldElement:
- """Finite field element for elliptic curve cryptography."""
-
- def __init__(self, value, prime):
- if isinstance(value, BigInt):
- self.value = value
- else:
- self.value = BigInt(value)
-
- if isinstance(prime, BigInt):
- self.prime = prime
- else:
- self.prime = BigInt(prime)
-
- if self.value < 0 or self.value >= self.prime:
- raise ValueError(f"Value {self.value} not in range 0 to {self.prime - 1}")
-
- def set_value(self, value):
- """Set the value of the finite field element, with validation."""
- if isinstance(value, BigInt):
- new_value = value
- else:
- new_value = BigInt(value)
-
- if new_value < BigInt(0) or new_value >= self.prime:
- raise ValueError(f"Value {new_value} not in range 0 to {self.prime - 1}")
-
- self.value = new_value
-
- def get_value(self) -> int:
- return int(self.value)
-
- def get_prime(self):
- return copy.deepcopy(self.prime)
-
- def copy(self, value=None, transform=False, reduction=False):
- """Create a deep copy of the current finite field element.
-
- transform and reduction are only place holders for unified API.
-
- Args:
- value (int): The value of the new finite field element. If None, the
- value of the current finite field element is used.
- transform (bool): Whether to transform the value of the new finite field
- element.
- reduction (bool): Whether to reduce the value of the new finite field
- element.
-
- Returns:
- FiniteFieldElement: A deep copy of the current finite field element.
- """
- obj = copy.copy(self)
- if not value:
- obj.value = BigInt(int(self.value))
- else:
- obj.value = BigInt(value)
- return obj
-
- def __add__(self, other):
- if self.prime != other.prime:
- raise ValueError("Cannot add two numbers in different Fields")
- result = (self.value + other.value) % self.prime
- return FiniteFieldElement(result, self.prime.value)
-
- def __sub__(self, other):
- if self.prime != other.prime:
- raise ValueError("Cannot subtract two numbers in different Fields")
- result = (self.value - other.value) % self.prime
- return FiniteFieldElement(result, self.prime.value)
-
- def __mul__(self, other):
- if self.prime != other.prime:
- raise ValueError("Cannot multiply two numbers in different Fields")
- result = (self.value * other.value) % self.prime
- return FiniteFieldElement(result, self.prime.value)
-
- def __truediv__(self, other):
- if self.prime != other.prime:
- raise ValueError("Cannot divide two numbers in different Fields")
- # Use Fermat's Little Theorem to find the inverse: a^(p-1) ≡ 1 (mod p)
- inverse = other.value.__pow__(self.prime.value - 2, self.prime.value)
- result = (self.value * inverse) % self.prime
- return FiniteFieldElement(result, self.prime.value)
-
- def __pow__(self, exponent):
- result = self.value.__pow__(exponent, self.prime.value)
- return FiniteFieldElement(result.value, self.prime.value)
-
- def __eq__(self, other):
- return self.value == other.value and self.prime == other.prime
-
- def __str__(self):
- return f"FieldElement_{self.prime.value}({self.value.value})"
-
- def __repr__(self):
- return (
- f"FiniteFieldElement(value={self.value.value},"
- f" prime={self.prime.value})"
- )
-
- def __hex__(self):
- return hex(int(self.value.value))
-
- def hex_value_str(self) -> str:
- return self.value.hex_value_str()
-
-
-class FiniteFieldElementBarrett(FiniteFieldElement):
- """Finite field element for elliptic curve cryptography using Barrett reduction."""
-
- def __init__(self, value, prime, k=None):
- super().__init__(value, prime)
- if k == None:
- if isinstance(prime, BigInt):
- self.two_k = prime.ceil_log2() * 2
- else:
- self.two_k = BigInt(math.ceil(math.log2(prime))) * 2
- else:
- self.two_k = BigInt(2 * k)
- self.mu = BigInt(2) ** self.two_k / prime
-
- def barrett_reduction(self, x):
- # q = (x * mu) >> 2k
- q = (x * self.mu) >> self.two_k
- # r = x - q * prime
- r = x - q * self.prime
- if r >= self.prime:
- r -= self.prime
- return r
-
- def __add__(self, other):
- if self.prime != other.prime:
- raise ValueError("Cannot add two numbers in different Fields")
- result = self.value + other.value
- if result > self.prime:
- result -= self.prime
- new_instance = self.copy()
- new_instance.value = result
- return new_instance
-
- def __sub__(self, other):
- if self.prime != other.prime:
- raise ValueError("Cannot subtract two numbers in different Fields")
- if self.value < other.value:
- result = self.value + self.prime - other.value
- else:
- result = self.value - other.value
- new_instance = self.copy()
- new_instance.value = result
- return new_instance
-
- def __mul__(self, other):
- if self.prime != other.prime:
- raise ValueError("Cannot multiply two numbers in different Fields")
- result = self.value * other.value
- reduced_result = self.barrett_reduction(result)
- new_instance = self.copy()
- new_instance.value = reduced_result
- return new_instance
-
- def __truediv__(self, other):
- if self.prime != other.prime:
- raise ValueError("Cannot divide two numbers in different Fields")
- inverse = other.value.__pow__(self.prime.value - 2, self.prime.value)
- result = self.value * inverse
- reduced_result = self.barrett_reduction(result)
- new_instance = self.copy()
- new_instance.value = reduced_result
- return new_instance
-
- def copy(self, value=None, transform=False, reduction=False):
- """Create a deep copy of the current finite field element."""
- obj = copy.copy(self)
- if not value:
- obj.value = BigInt(self.value)
- else:
- if reduction:
- value = self.barrett_reduction(BigInt(value))
- obj.value = BigInt(value)
- return obj
-
-
-class FiniteFieldElementMontgomery(FiniteFieldElement):
- """Finite field element for elliptic curve cryptography using Montgomery reduction."""
-
- def __init__(self, value, prime, k=None):
- super().__init__(value, prime)
- if not k:
- if isinstance(prime, BigInt):
- self.k = prime.ceil_log2()
- else:
- self.k = BigInt(math.ceil(math.log2(prime)))
- else:
- self.k = BigInt(k)
-
- self.r = BigInt(2) ** self.k
- # self.r_inverse = (self.r ** (self.prime - 2)) % self.prime
- self.r_inverse = self.r.__pow__(self.prime - 2, self.prime)
- self.n_prime = (self.r * self.r_inverse - 1) / self.prime
- self.r_mask = self.r - 1
- self.value = self.montgomeryize(self.value)
- self.montgomeryized = True
- self.one_bar = self.montgomeryize(BigInt(1))
-
- def montgomery_reduction(self, x):
- m = ((x & self.r_mask) * self.n_prime) & self.r_mask
- u = (x + m * self.prime) >> self.k
- if u >= self.prime:
- u -= self.prime
- return u
-
- def montgomeryize(self, x):
- x_bar = (x * self.r) % self.prime
- return x_bar
-
- def de_montgomeryize(self, x_bar):
- x = self.montgomery_reduction(x_bar)
- return x
-
- def change_montgomery_form(self):
- if self.montgomeryized:
- self.value = self.de_montgomeryize(self.value)
- self.montgomeryized = False
- else:
- self.value = self.montgomeryize(self.value)
- self.montgomeryized = True
- return self
-
- def __add__(self, other):
- if self.prime != other.prime:
- raise ValueError("Cannot add two numbers in different Fields")
- assert other.montgomeryized
- result = self.value + other.value
- if result > self.prime:
- result -= self.prime
- new_instance = self.copy()
- new_instance.value = result
- return new_instance
-
- def __sub__(self, other):
- if self.prime != other.prime:
- raise ValueError("Cannot subtract two numbers in different Fields")
- assert other.montgomeryized
- if self.value < other.value:
- result = self.value + self.prime - other.value
- else:
- result = self.value - other.value
- new_instance = self.copy()
- new_instance.value = result
- return new_instance
-
- def __mul__(self, other):
- if self.prime != other.prime:
- raise ValueError("Cannot multiply two numbers in different Fields")
- assert other.montgomeryized
- result = self.value * other.value
- reduced_result = self.montgomery_reduction(result)
- new_instance = self.copy()
- new_instance.value = reduced_result
- return new_instance
-
- def __truediv__(self, other):
- if self.prime != other.prime:
- raise ValueError("Cannot divide two numbers in different Fields")
-
- if other.montgomeryized:
- other_value = self.de_montgomeryize(other.value)
- else:
- other_value = other.value
-
- inverse = other_value.__pow__(self.prime.value - 2, self.prime.value)
-
- inverse_bar = self.montgomeryize(inverse)
-
- result = self.value * inverse_bar
- reduced_result = self.montgomery_reduction(result)
- new_instance = self.copy()
- new_instance.value = reduced_result
- return new_instance
-
- def copy(self, value=None, transform=False, reduction=False):
- """Create a deep copy of the current finite field element."""
- obj = copy.copy(self)
- if not value:
- obj.value = BigInt(self.value)
- else:
- if reduction:
- obj.value = self.montgomery_reduction(BigInt(value))
- elif transform:
- obj.value = self.montgomeryize(BigInt(value))
- else:
- obj.value = BigInt(value)
-
- return obj
diff --git a/jaxite_ec/algorithm/lazy_reduction.py b/jaxite_ec/algorithm/lazy_reduction.py
deleted file mode 100644
index a13b093..0000000
--- a/jaxite_ec/algorithm/lazy_reduction.py
+++ /dev/null
@@ -1,91 +0,0 @@
-"""This file implements the lazy reduction algorithm."""
-
-import random
-
-import numpy as np
-
-randint = random.randint
-
-
-def byte_len(x: int):
- return (x.bit_length() + 7) // 8
-
-
-def chunk_decomposition(x, length=-1):
- if length == -1:
- length = byte_len(x)
- return [(x >> (8 * i)) & 255 for i in range(length)]
-
-
-# Input: a_list: (t+2) byte numbers for a
-# b_list: (t+2) byte numbers for b
-# p: modulus
-# Output: c_list: (t_2) byte numbers for (ab % p)
-def lazy_reduction_via_matrix(a_list, b_list, p):
- """Performs lazy reduction via matrix multiplication.
-
- Args:
- a_list: (t+2) byte numbers for a
- b_list: (t+2) byte numbers for b
- p: modulus
-
- Returns:
- c_list: (t_2) byte numbers for (ab % p)
- """
- t = len(chunk_decomposition(p))
- # Precomputation
- lazy_mat = np.zeros((t + 4, t), dtype=np.uint8)
- for i in range(t + 4):
- val = (256 ** (t + i)) % p
- lazy_mat[i, :] = chunk_decomposition(val, t)
-
- # Begin computation
- assert len(a_list) == len(b_list)
- batch_size = len(a_list)
- batch_mat = np.zeros((batch_size, t + 4), dtype=np.uint8)
- standard_product = [a_list[i] * b_list[i] for i in range(batch_size)]
- standard_product_low = [s & (256**t - 1) for s in standard_product]
- standard_product_high = [s >> (8 * t) for s in standard_product]
- # Matrix packing
- for i in range(batch_size):
- batch_mat[i] = chunk_decomposition(standard_product_high[i], t + 4)
- # Matrix product
- # Upcast to get proper accumulators
- reduced = np.matmul(batch_mat.astype(np.uint32), lazy_mat.astype(np.uint32))
- c_list = []
- # Recombine into integers
- for i in range(batch_size):
- val = 0
- for j in reversed(range(t)):
- val *= 256
- val += int(reduced[i][j])
- c_list.append(val)
- # Add in standard_product_low; this could be done in u8 form before the
- # carry-add chain above, perhaps from the raw toeplitz output.
- # Since we are using a redundant form the upper part doesn't have to be
- # accurate.
- c_list = [c_list[i] + standard_product_low[i] for i in range(batch_size)]
- return c_list
-
-
-def main():
- """This test case check the functionality of the lazy reduction algorithm."""
- p = randint(2**381, 2**384)
- batch_size = 16
- bound = p * 256 * 256
- # a and b are both < bound
- a_list = [randint(0, bound) for _ in range(batch_size)]
- b_list = [randint(0, bound) for _ in range(batch_size)]
- c_list = lazy_reduction_via_matrix(a_list, b_list, p)
- for i in range(batch_size):
- # each output is congruent modulo p
- assert c_list[i] % p == (a_list[i] * b_list[i]) % p
- # each output is < bound
- assert c_list[i] < bound
- print(a_list)
- print(b_list)
- print(c_list)
-
-
-if __name__ == "__main__":
- main()
diff --git a/jaxite_ec/algorithm/rns_mod_reduce.py b/jaxite_ec/algorithm/rns_mod_reduce.py
deleted file mode 100644
index 910d312..0000000
--- a/jaxite_ec/algorithm/rns_mod_reduce.py
+++ /dev/null
@@ -1,419 +0,0 @@
-"""Implementation of RNS modular reduction."""
-
-import math
-import random
-
-import jax
-import jax.numpy as jnp
-
-randint = random.randint
-
-
-jax.config.update("jax_enable_x64", True)
-chunk_dtype = jnp.uint16
-mul_res_dtype = jnp.uint32
-
-
-def to_tuple(a):
- """Create to convert numpy array into tuple."""
- try:
- return tuple(to_tuple(i) for i in a)
- except TypeError:
- return a
-
-
-def find_moduli(total_modulus, precision):
- """Find moduli for RNS."""
- initial_moduli = 2**precision
- overall_moduli = []
- constant_offset_list = []
- overall_modulus = 1
- for i in range(2 ** (precision >> 1) - 1):
- cur_moduli = initial_moduli - i
- if math.gcd(cur_moduli, overall_modulus) == 1:
- overall_moduli.append(cur_moduli)
- constant_offset_list.append(i)
- overall_modulus *= cur_moduli
- if overall_modulus > total_modulus:
- return overall_moduli, constant_offset_list
-
- # Find 2**15 - v too
- initial_moduli = 2 ** (precision - 1)
- if overall_modulus < total_modulus:
- for i in range(2 ** (precision >> 1) - 1):
- cur_moduli = initial_moduli - i
- if math.gcd(cur_moduli, overall_modulus) == 1:
- overall_moduli.append(cur_moduli)
- constant_offset_list.append(i << 1)
- overall_modulus *= cur_moduli
- if overall_modulus > total_modulus:
- return overall_moduli, constant_offset_list
-
- return overall_moduli, constant_offset_list
-
-
-def hardware_friendly_mod_reduce(x, moduli_t):
- """Convert input value x into the RNS form.
-
- Args:
- x: Input value as a jnp.uint32 array.
- moduli_t: List of hardware friendly moduli_t.
-
- Returns:
- The RNS representation of x as a jnp.uint16 array.
- """
- assert x.dtype == jnp.uint32
- x_h = (x >> 16) & 0xFFFF
- x_l = x & 0xFFFF
- x_reduce = x_h * moduli_t + x_l
-
- x_h_sec = (x_reduce >> 16) & 0xFFFF
- x_l_sec = x_reduce & 0xFFFF
- x_reduce_sec = x_h_sec * moduli_t + x_l_sec
-
- x_h_third = (x_reduce_sec >> 16) & 0xFFFF
- x_l_third = x_reduce_sec & 0xFFFF
- x_reduce_third = x_h_third * moduli_t + x_l_third
- return x_reduce_third.astype(jnp.uint16)
-
-
-def to_rns(x, moduli):
- return [x % m for m in moduli]
-
-
-def rns_reconstruct(x, overall_moduli, icrt_factors):
- x_reconstruct = []
- for value_rns in x:
- big_int = 0
- for idx, tower in enumerate(value_rns):
- big_int += int(tower) * int(icrt_factors[idx])
- x_reconstruct.append(big_int % overall_moduli)
- return x_reconstruct
-
-
-def rns_icrt_factors_compute(modulus, moduli):
- precomputed = []
- for m in moduli:
- rest = modulus // m # 0 mod all the other moduli
- inverse = pow(rest % m, -1, m) # factor to make 1 mod this moduli
- icrt_val = (rest * inverse) % modulus # combine
- precomputed.append(icrt_val)
- return precomputed
-
-
-def rns_coefficients_precompute(
- icrt_factors,
- overall_moduli,
- num_bytes,
- moduli_precision,
- overall_modulus,
- q,
-):
- """Precompute RNS coefficients.
-
- Args:
- icrt_factors: Precomputed inverse CRT factors.
- overall_moduli: Array of moduli.
- num_bytes: Number of bytes.
- moduli_precision: Precision of the moduli.
- overall_modulus: Overall modulus.
- q: Target modulus.
-
- Returns:
- Precomputed RNS coefficients and correction coefficients.
- """
- num_residues = len(overall_moduli)
- # icrt_factors_byteshifted -- (num_residues, num_bytes)
- icrt_factors_byteshifted = [
- [
- (((1 << (8 * pre_id)) * factor) % overall_modulus)
- for pre_id in range(num_bytes)
- ]
- for factor in icrt_factors
- ]
- # icrt_factors_byteshifted_modq -- (num_residues, num_bytes)
- icrt_factors_byteshifted_modq = [
- [(chunk % q) for chunk in factors] for factors in icrt_factors_byteshifted
- ]
- # icrt_factors_byteshifted_modq_rns
- # (num_residues, num_bytes, num_residues) [Convert each byte range into RNS]
- icrt_factors_byteshifted_modq_rns = [
- [to_rns(chunk, overall_moduli) for chunk in factors]
- for factors in icrt_factors_byteshifted_modq
- ]
-
- rns_mat = jnp.array(
- icrt_factors_byteshifted_modq_rns, dtype=jnp.uint16
- ).reshape(-1, num_residues)
-
- # calculate quotient estimation
- fix_point = 1 << moduli_precision
-
- shifted_quotient_estimations = []
- for factors in icrt_factors_byteshifted:
- for chunk in factors:
- shifted_quotient_estimations.append(
- [math.ceil((chunk * fix_point) / overall_modulus)]
- )
- sqe_mat = jnp.array(shifted_quotient_estimations, dtype=jnp.uint16)
-
- cor_mat = jnp.array(
- [to_rns(-overall_modulus % q, overall_moduli)], dtype=jnp.uint16
- )
-
- # Convert rns_mat and sqe_mat into various bytes.
- rns_mat_u8 = jax.lax.bitcast_convert_type(rns_mat, jnp.uint8)
- seq_mat_u8 = jax.lax.bitcast_convert_type(sqe_mat, jnp.uint8)
- rns_stack_mat_u8 = jnp.hstack((
- rns_mat_u8[..., 0],
- seq_mat_u8[..., 0],
- rns_mat_u8[..., 1],
- seq_mat_u8[..., 1],
- ))
- return to_tuple(rns_stack_mat_u8.tolist()), to_tuple(cor_mat.tolist())
-
-
-def rns_mod_reduce(
- data_a_rns,
- data_b_rns,
- moduli,
- moduli_t,
- rns_stack_mat_u8,
- cor_mat,
- icrt_factors,
- overall_modulus,
-):
- """Performs RNS modular reduction.
-
- Args:
- data_a_rns: First input in RNS form.
- data_b_rns: Second input in RNS form.
- moduli: Array of moduli.
- moduli_t: Array of constant offsets.
- rns_stack_mat_u8: Precomputed RNS coefficients.
- cor_mat: Precomputed correction coefficients.
- icrt_factors: Precomputed inverse CRT factors.
- overall_modulus: Overall modulus.
-
- Returns:
- Result of the modular reduction.
- """
- num_residues = moduli.shape[0]
- mul_res = jnp.multiply(
- data_a_rns.astype(mul_res_dtype), data_b_rns.astype(mul_res_dtype)
- )
- mul_res_tower_red = hardware_friendly_mod_reduce(mul_res, moduli_t)
-
- # Global Modular reduction
- mul_res_glb_red = jnp.matmul(
- mul_res_tower_red.view(jnp.uint8),
- rns_stack_mat_u8,
- preferred_element_type=mul_res_dtype,
- )
-
- mul_res_glb_red_u32_l, mul_res_glb_red_u32_h = jnp.split(
- mul_res_glb_red, [num_residues + 1], axis=1
- )
- mul_res_glb_red_u32 = mul_res_glb_red_u32_l + (mul_res_glb_red_u32_h << 8)
- rns_reduce_u32, qe_u32 = jnp.split(
- mul_res_glb_red_u32, [num_residues], axis=1
- )
-
- # obtain the high 16 bits from the quotient estimation results qe_u32
- c_corrected = rns_reduce_u32 + jnp.matmul(
- qe_u32 >> moduli_precision, cor_mat, preferred_element_type=mul_res_dtype
- )
- c_corrected_reduce = hardware_friendly_mod_reduce(c_corrected, moduli_t)
-
- result_rns = rns_reconstruct(
- c_corrected_reduce.tolist(), overall_modulus, icrt_factors
- )
-
- return result_rns
-
-
-if __name__ == "__main__":
- ###########################
- # User Configured Input
- ###########################
- q = 0x01AE3A4617C510EAC63B05C06CA1493B1A22D9F300F5138F1EF3622FBA094800170B5D44300000008508C00000000001
- moduli_precision = 16
- extra_bit_to_avoid_addition_overflow = 4
- num_bytes = moduli_precision // 8
- num_residues_for_q = (
- q.bit_length() + moduli_precision - 1
- ) // moduli_precision + 1
- overall_modulus = (
- ((q * 256 * num_bytes * num_residues_for_q) ** 2)
- * extra_bit_to_avoid_addition_overflow
- * 2
- )
- assert overall_modulus == q * q * 256 * 256 * 50 * 50 * 4 * 2
- num_elements = 1024
-
- ###########################
- # Offline Precompute
- ###########################
-
- overall_moduli, constant_offset_list = find_moduli(
- overall_modulus, moduli_precision
- )
- assert overall_moduli == [
- 65536,
- 65535,
- 65533,
- 65531,
- 65527,
- 65521,
- 65519,
- 65509,
- 65503,
- 65497,
- 65491,
- 65489,
- 65479,
- 65477,
- 65473,
- 65459,
- 65449,
- 65447,
- 65437,
- 65431,
- 65423,
- 65419,
- 65413,
- 65411,
- 65407,
- 65393,
- 65383,
- 65381,
- 65371,
- 65369,
- 65363,
- 65357,
- 65353,
- 65347,
- 65339,
- 65327,
- 65323,
- 65321,
- 65311,
- 65309,
- 65293,
- 65287,
- 32761,
- 32749,
- 32743,
- 32741,
- 32719,
- 32717,
- 32713,
- 32707,
- ]
- overall_modulus = 1
- for moduli in overall_moduli:
- overall_modulus *= moduli
- overall_modulus = int(overall_modulus)
- assert len(overall_moduli) == (
- (overall_modulus.bit_length() + moduli_precision - 1) // moduli_precision
- )
- icrt_factors = rns_icrt_factors_compute(overall_modulus, overall_moduli)
-
- # hardware friendly moduli is 2**precision - t
- # moduli is the jax.array of "2**precision - t"
- # moduli_t is the jax.array of "t"
- rns_stack_mat_u8, cor_mat = rns_coefficients_precompute(
- icrt_factors,
- overall_moduli,
- num_bytes,
- moduli_precision,
- overall_modulus,
- q,
- )
- assert cor_mat == (
- (
- 57491,
- 26379,
- 24673,
- 4733,
- 47122,
- 11996,
- 11119,
- 12151,
- 45048,
- 10179,
- 3397,
- 45514,
- 12274,
- 62018,
- 4316,
- 141,
- 20271,
- 17626,
- 20758,
- 57875,
- 41612,
- 44321,
- 30081,
- 6090,
- 16501,
- 13984,
- 14909,
- 14581,
- 47918,
- 44932,
- 34016,
- 7605,
- 33574,
- 30236,
- 15843,
- 26521,
- 52723,
- 28347,
- 32242,
- 11676,
- 31854,
- 34463,
- 30291,
- 29806,
- 1344,
- 25148,
- 23069,
- 4869,
- 6178,
- 32502,
- ),
- )
- rns_stack_mat_u8 = jnp.array(rns_stack_mat_u8, dtype=jnp.uint8)
- cor_mat = jnp.array(cor_mat, dtype=jnp.uint16)
-
- moduli = jnp.array(overall_moduli, dtype=mul_res_dtype)
- moduli_t = jnp.array(constant_offset_list, dtype=chunk_dtype)
-
- ###########################
- # Generate Random Data
- ###########################
-
- random_data = [randint(0, q) for _ in range(num_elements)]
- result_ref = [val * val % q for val in random_data]
- data_rns = jnp.array(
- [to_rns(ele, overall_moduli) for ele in random_data], jnp.uint32
- )
-
- ###########################
- # Compute Modular Reduction in RNS Form
- ###########################
- # Limb-wise modular multiplication
- # (num_elements, num_residue)
- result_rns = rns_mod_reduce(
- data_rns,
- data_rns,
- moduli,
- moduli_t,
- rns_stack_mat_u8,
- cor_mat,
- icrt_factors,
- overall_modulus,
- )
- result_mod_q = [val % q for val in result_rns]
- assert result_ref == result_mod_q
diff --git a/jaxite_ec/algorithm/util.py b/jaxite_ec/algorithm/util.py
deleted file mode 100644
index cfbef87..0000000
--- a/jaxite_ec/algorithm/util.py
+++ /dev/null
@@ -1,68 +0,0 @@
-"""Utility functions for the Bukect distribution algorithm."""
-
-import numpy as np
-
-
-def bits_to_numpy_dtype(bits):
- if bits == 8:
- return np.uint8
- elif bits == 16:
- return np.uint16
- elif bits == 32:
- return np.uint32
- elif bits == 64:
- return np.uint64
- else:
- raise ValueError("Unsupported bit size. Use 8, 16, 32, or 64.")
-
-
-def int_to_array(python_int, base, array_size=None) -> np.ndarray:
- """Converts a Python integer to a JAX array.
-
- Args:
- python_int: The Python integer to convert.
- base: The number of bits per element in the integer.
- array_size: The desired size of the resulting JAX array. If provided, the
- integer will be padded or trimmed to match this size.
-
- Returns:
- A JAX array containing the elements of the Python integer.
- """
- chunks = []
- mask = (
- 1 << base
- ) - 1 # Mask to extract the lower bits (e.g., 32 bits -> 0xFFFFFFFF)
- dtype = bits_to_numpy_dtype(base)
- # Extract each element from the integer
- while python_int > 0:
- chunks.append(python_int & mask) # Extract the lower bits
- python_int >>= base # Shift to remove the extracted bits
-
- # If array_size is provided, we pad or trim the result to match the
- # desired size
- if array_size is not None:
- assert array_size >= len(chunks)
- chunks = chunks[:array_size] + [0] * (array_size - len(chunks))
-
- # Convert the list to a JAX array
- return np.array(chunks, dtype=dtype)
-
-
-def array_to_int(np_array: np.ndarray, base):
- # Initialize the result as a Python integer
- result = 0
-
- # Iterate over the elements in the array
- for i, elem in enumerate(np_array):
- # Convert each element to an int and shift it by the appropriate
- # number of bits
- result |= int(elem) << (i * base)
-
- return result
-
-
-def int_list_to_array(int_list, base, array_size=None) -> np.ndarray:
- chunked_arrays = [
- int_to_array(int_value, base, array_size) for int_value in int_list
- ]
- return np.array(chunked_arrays)
diff --git a/jaxite_ec/c_kernels/BUILD b/jaxite_ec/c_kernels/BUILD
new file mode 100644
index 0000000..f762ed2
--- /dev/null
+++ b/jaxite_ec/c_kernels/BUILD
@@ -0,0 +1,24 @@
+load("@rules_python//python:defs.bzl", "py_library")
+# load("//third_party/bazel_rules/rules_cc/cc:cc_binary.bzl", "cc_binary")
+
+package(
+ default_applicable_licenses = ["@jaxite//:license"],
+ default_visibility = ["//visibility:public"],
+)
+
+# cc_binary(
+# name = "distribution.so",
+# srcs = ["distribution.cpp"],
+# features = ["-layering_check"],
+# linkshared = 1,
+# deps = [
+# "//third_party/tensorflow/compiler/xla/ffi/api:c_api",
+# "//third_party/tensorflow/compiler/xla/ffi/api:ffi",
+# ],
+# )
+
+py_library(
+ name = "build",
+ srcs = ["build.py"],
+ visibility = ["//visibility:public"],
+)
diff --git a/jaxite_ec/c_kernels/build.py b/jaxite_ec/c_kernels/build.py
new file mode 100644
index 0000000..9b44d73
--- /dev/null
+++ b/jaxite_ec/c_kernels/build.py
@@ -0,0 +1,103 @@
+"""Python-side build driver for the distribution C kernel.
+
+:func:`ensure_distribution_kernel` compiles ``distribution.cpp`` (or returns
+the cached build) and yields the absolute path of the resulting shared
+library, so callers can ``ctypes.cdll.LoadLibrary`` it without a separate
+build step.
+
+Default flags: ``-std=c++17 -fopenmp -O2 -fPIC -shared -I/include``.
+``CXX`` and ``CXXFLAGS`` env vars override the compiler and flags.
+
+Run as a script (``python -m c_kernels.build``) to pre-build without importing
+the rest of the package.
+"""
+
+from __future__ import annotations
+
+import os
+import pathlib
+import shlex
+import subprocess
+import sys
+import threading
+
+_KERNEL_DIR = pathlib.Path(__file__).parent
+_SRC_PATH = _KERNEL_DIR / "distribution.cpp"
+_LIB_PATH = _KERNEL_DIR / "distribution.so"
+
+_DEFAULT_CXXFLAGS = ("-std=c++17", "-fopenmp", "-O2", "-fPIC", "-shared")
+
+_BUILD_LOCK = threading.Lock()
+
+
+def _jaxlib_include_dir() -> pathlib.Path:
+ import jaxlib # pytype: disable=import-error
+
+ return pathlib.Path(jaxlib.__file__).resolve().parent / "include"
+
+
+def _needs_rebuild() -> bool:
+ if not _LIB_PATH.exists():
+ return True
+ if not _SRC_PATH.exists():
+ return False
+ return _SRC_PATH.stat().st_mtime > _LIB_PATH.stat().st_mtime
+
+
+def _build() -> None:
+ cxx = os.environ.get("CXX", "g++")
+ flags_env = os.environ.get("CXXFLAGS")
+ cxxflags = shlex.split(flags_env) if flags_env else list(_DEFAULT_CXXFLAGS)
+ cmd = [
+ cxx,
+ *cxxflags,
+ f"-I{_jaxlib_include_dir()}",
+ str(_SRC_PATH),
+ "-o",
+ str(_LIB_PATH),
+ ]
+ print(
+ f"[c_kernels.build] compiling: {' '.join(shlex.quote(c) for c in cmd)}",
+ file=sys.stderr,
+ )
+ subprocess.run(cmd, check=True, cwd=_KERNEL_DIR)
+
+
+def ensure_distribution_kernel(force: bool = False) -> str:
+ """Compile :file:`distribution.cpp` on demand and return the .so path.
+
+ Args:
+ force: rebuild even if the cached library is up to date.
+
+ Returns:
+ Absolute path to ``distribution.so``.
+
+ Raises:
+ FileNotFoundError: if pre-built kernel not found in test environment.
+ """
+ if "TEST_SRCDIR" in os.environ:
+ runfiles_dir = pathlib.Path(os.environ["TEST_SRCDIR"])
+ lib_path = (
+ runfiles_dir
+ / "google3/third_party/py/jaxite/jaxite_ec/c_kernels/distribution.so"
+ )
+ if lib_path.exists():
+ return str(lib_path)
+ raise FileNotFoundError(
+ f"Pre-built kernel not found in runfiles at {lib_path}"
+ )
+
+ if "UNITTEST_ON_FORGE" in os.environ:
+ if _LIB_PATH.exists():
+ return str(_LIB_PATH)
+ raise FileNotFoundError(f"Pre-built kernel not found at {_LIB_PATH}")
+
+ with _BUILD_LOCK:
+ if force or _needs_rebuild():
+ _build()
+ return str(_LIB_PATH)
+
+
+if __name__ == "__main__":
+ path = ensure_distribution_kernel(force="--force" in sys.argv[1:])
+ print(path)
diff --git a/jaxite_ec/c_kernels/distribution.cpp b/jaxite_ec/c_kernels/distribution.cpp
new file mode 100644
index 0000000..e5644ba
--- /dev/null
+++ b/jaxite_ec/c_kernels/distribution.cpp
@@ -0,0 +1,1151 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#ifdef _OPENMP
+#include
+#endif
+#include
+
+#include "third_party/tensorflow/compiler/xla/ffi/api/c_api.h"
+#include "third_party/tensorflow/compiler/xla/ffi/api/ffi.h"
+
+namespace ffi = xla::ffi;
+#define COORDINATE_NUM 4
+#define CHUNK_NUM 32
+#define RESERVE_RATIO 2.0
+#define SCALAR_BITS 253
+#define ORDER_HIGH 0x12AB655E
+#define ORDER_LOW_BITS 224
+// #define PROFILE
+
+typedef struct {
+ uint32_t chunks[CHUNK_NUM * COORDINATE_NUM];
+} point_t;
+
+typedef struct {
+ // Parameters for the distributor
+ uint32_t window_num;
+ uint32_t regular_bucket_num;
+ uint32_t special_bucket_num;
+ uint32_t msm_length;
+ uint32_t fixed_regular_padding_size;
+ uint32_t fixed_special_padding_size;
+
+ // Inputs buffers
+ uint32_t* zero;
+ int32_t* slices_list;
+ uint32_t* points_list;
+ uint32_t* neg_points_list;
+
+ // Output buffers
+ uint32_t* regular_buckets;
+ uint32_t* special_buckets;
+} distributor_params_t;
+
+class Distributor {
+ public:
+ Distributor(uint32_t window_num, uint32_t regular_bucket_num,
+ uint32_t special_bucket_num, uint32_t msm_length)
+ : window_num_(window_num),
+ regular_bucket_num_(regular_bucket_num),
+ special_bucket_num_(special_bucket_num),
+ msm_length_(msm_length) {
+ this->reserve_ratio_ = RESERVE_RATIO;
+ this->windows_.resize(window_num);
+ this->fixed_regular_padding_size_ = 0;
+ this->fixed_special_padding_size_ = 0;
+ this->truncated = 0;
+ }
+
+ void set_slices_list(int32_t* slices_list) {
+ this->slices_list_pointers_.resize(this->window_num_);
+ for (uint32_t i = 0; i < this->window_num_; ++i) {
+ slices_list_pointers_[i] = slices_list + i * this->msm_length_;
+ }
+ }
+
+ void set_output_buffers(uint32_t* regular_buckets,
+ uint32_t* special_buckets) {
+ this->regular_buckets_ = regular_buckets;
+ this->special_buckets_ = special_buckets;
+ this->bucket_pointers_.resize(window_num_);
+ this->bucket_sizes_.resize(window_num_);
+ uint32_t regular_bucket_capacity_in_bytes =
+ fixed_regular_padding_size_ * sizeof(point_t);
+ uint32_t special_bucket_capacity_in_bytes =
+ fixed_special_padding_size_ * sizeof(point_t);
+ uint32_t regular_window_capacity_in_bytes =
+ regular_bucket_num_ * regular_bucket_capacity_in_bytes;
+ // uint32_t special_window_capacity_in_bytes = special_bucket_num_ *
+ // special_bucket_capacity_in_bytes;
+
+ // Initialize the bucket pointers and sizes
+ for (uint32_t i = 0; i < window_num_ - 1; ++i) {
+ this->bucket_pointers_[i].resize(regular_bucket_num_);
+ this->bucket_sizes_[i].resize(regular_bucket_num_);
+ for (uint32_t j = 0; j < this->bucket_pointers_[i].size(); ++j) {
+ this->bucket_sizes_[i][j] = 0;
+ this->bucket_pointers_[i][j] = reinterpret_cast(
+ reinterpret_cast(regular_buckets_) +
+ i * regular_window_capacity_in_bytes +
+ j * regular_bucket_capacity_in_bytes);
+ }
+ }
+ // For the last window, which has special buckets
+ this->bucket_pointers_[window_num_ - 1].resize(special_bucket_num_);
+ this->bucket_sizes_[window_num_ - 1].resize(special_bucket_num_);
+ for (uint32_t j = 0; j < this->bucket_pointers_[window_num_ - 1].size();
+ ++j) {
+ this->bucket_sizes_[window_num_ - 1][j] = 0;
+ this->bucket_pointers_[window_num_ - 1][j] = reinterpret_cast(
+ reinterpret_cast(special_buckets_) +
+ j * special_bucket_capacity_in_bytes);
+ }
+ }
+
+ void set_points_list(uint32_t* points_list) {
+ this->points_list_ = reinterpret_cast(points_list);
+ }
+
+ void set_neg_points_list(uint32_t* neg_points_list) {
+ this->neg_points_list_ = reinterpret_cast(neg_points_list);
+ }
+
+ void set_zeros(uint32_t* zero) {
+ this->zero_ = reinterpret_cast(zero);
+ }
+
+ void set_fixed_padding_size(uint32_t regular_size, uint32_t special_size) {
+ this->fixed_regular_padding_size_ = regular_size;
+ this->fixed_special_padding_size_ = special_size;
+ }
+
+ void set_tile_length(uint32_t tile_length) {
+ this->tile_length_ = tile_length;
+ assert(msm_length_ % tile_length == 0);
+ this->tile_num_ = msm_length_ / tile_length;
+ }
+
+ void distribute_a_window(int32_t* slices, point_t* points,
+ uint32_t bucket_num,
+ std::vector>& buckets,
+ uint32_t& max_size) {
+ // Implement the distribution logic for a single window here
+ buckets.resize(bucket_num);
+ uint32_t reserve_num =
+ static_cast(msm_length_ * reserve_ratio_ / bucket_num);
+ for (uint32_t i = 0; i < bucket_num; ++i) {
+ buckets[i].reserve(reserve_num);
+ }
+
+ for (uint32_t i = 0; i < msm_length_; ++i) {
+ if (slices[i] == 0) {
+ continue;
+ }
+ uint32_t bucket_index = slices[i] - 1;
+ buckets[bucket_index].push_back(points[i]);
+ }
+ // max_size = 0;
+ for (uint32_t i = 0; i < bucket_num; ++i) {
+ if (buckets[i].size() > max_size) {
+ max_size = buckets[i].size();
+ }
+ }
+ }
+
+ void distribute_a_window_to_buffer(int32_t* slices, point_t* points,
+ uint32_t tile_length,
+ uint32_t max_bucket_size,
+ std::vector& buckets,
+ std::vector& bucket_sizes,
+ uint32_t& overflow) {
+ for (uint32_t i = 0; i < tile_length; ++i) {
+ uint32_t prefetch_distance = 32; // Adjust prefetch distance as needed
+ if (i + prefetch_distance < tile_length) {
+ __builtin_prefetch(&slices[i + prefetch_distance], 0, 1);
+ __builtin_prefetch(&points[i + prefetch_distance], 0, 1);
+ }
+ if (slices[i] == 0) {
+ continue;
+ }
+ int32_t bucket_index = slices[i] - 1;
+ if (bucket_index >= buckets.size()) {
+ std::cerr << "Error: bucket index out of range." << std::endl;
+ continue;
+ }
+ if (bucket_sizes[bucket_index] < max_bucket_size) {
+ buckets[bucket_index][bucket_sizes[bucket_index]] = points[i];
+ bucket_sizes[bucket_index]++;
+ } else {
+ overflow++;
+ }
+ }
+ }
+
+ void distribute_a_window_to_buffer_with_sign(
+ int32_t* slices, point_t* points, point_t* neg_points,
+ uint32_t tile_length, uint32_t max_bucket_size,
+ std::vector& buckets, std::vector& bucket_sizes,
+ uint32_t& overflow) {
+ for (uint32_t i = 0; i < tile_length; ++i) {
+ if (slices[i] == 0) {
+ continue;
+ }
+ point_t* points_to_use = (slices[i] > 0) ? &points[i] : &neg_points[i];
+ int32_t bucket_index = abs(slices[i]) - 1;
+ if (bucket_index >= buckets.size()) {
+ std::cerr << "Error: bucket index out of range." << std::endl;
+ continue;
+ }
+ if (bucket_sizes[bucket_index] < max_bucket_size) {
+ buckets[bucket_index][bucket_sizes[bucket_index]] = *points_to_use;
+ bucket_sizes[bucket_index]++;
+ } else {
+ overflow++;
+ }
+ }
+ }
+
+ void pad_a_window(point_t* zero, uint32_t target_size,
+ std::vector>& buckets) {
+ for (uint32_t i = 0; i < buckets.size(); ++i) {
+ uint32_t current_size = buckets[i].size();
+ // assert(current_size <= target_size);
+ if (current_size <= target_size) {
+ uint32_t pad_num = target_size - current_size;
+ buckets[i].insert(buckets[i].end(), pad_num, *zero);
+ } else {
+ // If the current size exceeds the target size, truncate it
+ buckets[i].resize(target_size);
+ }
+ }
+ }
+
+ void pad_a_window_to_buffer(point_t* zero, uint32_t target_size,
+ std::vector& buckets,
+ std::vector& bucket_sizes) {
+ for (uint32_t i = 0; i < buckets.size(); ++i) {
+ uint32_t current_size = bucket_sizes[i];
+ // assert(current_size <= target_size);
+ if (current_size < target_size) {
+ uint32_t pad_num = target_size - current_size;
+ for (uint32_t j = 0; j < pad_num; ++j) {
+ buckets[i][current_size + j] = *zero;
+ }
+ bucket_sizes[i] += pad_num;
+ }
+ }
+ }
+
+ void distribute() {
+ uint32_t max_regular_bucket_size = 0;
+ auto start = std::chrono::high_resolution_clock::now();
+ auto end = std::chrono::high_resolution_clock::now();
+ auto duration_us =
+ std::chrono::duration_cast(end - start)
+ .count();
+ start = std::chrono::high_resolution_clock::now();
+ for (uint32_t i = 0; i < window_num_ - 1; ++i) {
+ int32_t* slices = slices_list_pointers_[i];
+ distribute_a_window(slices, points_list_, regular_bucket_num_,
+ windows_[i], max_regular_bucket_size);
+ }
+ uint32_t max_special_bucket_size = 0;
+ auto end2 = std::chrono::high_resolution_clock::now();
+ distribute_a_window(slices_list_pointers_[window_num_ - 1], points_list_,
+ special_bucket_num_, windows_[window_num_ - 1],
+ max_special_bucket_size);
+ end = std::chrono::high_resolution_clock::now();
+
+ duration_us =
+ std::chrono::duration_cast(end - start)
+ .count();
+ auto duration2_us =
+ std::chrono::duration_cast(end2 - start)
+ .count();
+ std::cout << "dist_r time: " << duration2_us << " us" << std::endl;
+ std::cout << "dist time: " << duration_us << " us" << std::endl;
+
+ this->real_regular_padding_size_ = max_regular_bucket_size;
+ this->real_special_padding_size_ = max_special_bucket_size;
+
+ if (this->fixed_regular_padding_size_ > 0) {
+ if (this->fixed_regular_padding_size_ < max_regular_bucket_size) {
+ this->truncated = 1;
+ }
+ // assert(this->fixed_regular_padding_size_ >= max_regular_bucket_size);
+ max_regular_bucket_size = this->fixed_regular_padding_size_;
+ }
+ if (this->fixed_special_padding_size_ > 0) {
+ if (this->fixed_special_padding_size_ < max_special_bucket_size) {
+ this->truncated = 1;
+ }
+ // assert(this->fixed_special_padding_size_ >= max_special_bucket_size);
+ max_special_bucket_size = this->fixed_special_padding_size_;
+ }
+
+ start = std::chrono::high_resolution_clock::now();
+ for (uint32_t i = 0; i < window_num_; ++i) {
+ uint32_t target_size = (i == window_num_ - 1) ? max_special_bucket_size
+ : max_regular_bucket_size;
+ pad_a_window(zero_, target_size, windows_[i]);
+ }
+ end = std::chrono::high_resolution_clock::now();
+ duration_us =
+ std::chrono::duration_cast(end - start)
+ .count();
+ std::cout << "pad time: " << duration_us << " us" << std::endl;
+ }
+
+ void distribute_parallel_v1() {
+ uint32_t max_regular_bucket_size = 0;
+ std::vector max_regular_bucket_sizes(window_num_ - 1, 0);
+ auto start = std::chrono::high_resolution_clock::now();
+ auto end = std::chrono::high_resolution_clock::now();
+ auto duration_us =
+ std::chrono::duration_cast(end - start)
+ .count();
+ start = std::chrono::high_resolution_clock::now();
+#pragma omp parallel for
+ for (uint32_t i = 0; i < window_num_ - 1; ++i) {
+ int32_t* slices = slices_list_pointers_[i];
+ distribute_a_window(slices, points_list_, regular_bucket_num_,
+ windows_[i], max_regular_bucket_sizes[i]);
+ }
+ auto end2 = std::chrono::high_resolution_clock::now();
+ for (uint32_t i = 0; i < window_num_ - 1; ++i) {
+ if (max_regular_bucket_sizes[i] > max_regular_bucket_size) {
+ max_regular_bucket_size = max_regular_bucket_sizes[i];
+ }
+ }
+ uint32_t max_special_bucket_size = 0;
+
+ distribute_a_window(slices_list_pointers_[window_num_ - 1], points_list_,
+ special_bucket_num_, windows_[window_num_ - 1],
+ max_special_bucket_size);
+ this->real_regular_padding_size_ = max_regular_bucket_size;
+ this->real_special_padding_size_ = max_special_bucket_size;
+ end = std::chrono::high_resolution_clock::now();
+
+ duration_us =
+ std::chrono::duration_cast(end - start)
+ .count();
+ auto duration2_us =
+ std::chrono::duration_cast(end2 - start)
+ .count();
+ std::cout << "dist_r time: " << duration2_us << " us" << std::endl;
+ std::cout << "dist time: " << duration_us << " us" << std::endl;
+
+ if (this->fixed_regular_padding_size_ > 0) {
+ if (this->fixed_regular_padding_size_ < max_regular_bucket_size) {
+ this->truncated = 1;
+ }
+ // assert(this->fixed_regular_padding_size_ >= max_regular_bucket_size);
+ max_regular_bucket_size = this->fixed_regular_padding_size_;
+ }
+ if (this->fixed_special_padding_size_ > 0) {
+ if (this->fixed_special_padding_size_ < max_special_bucket_size) {
+ this->truncated = 1;
+ }
+ // assert(this->fixed_special_padding_size_ >=
+ // max_special_bucket_size);
+ max_special_bucket_size = this->fixed_special_padding_size_;
+ }
+
+ start = std::chrono::high_resolution_clock::now();
+#pragma omp parallel for
+ for (uint32_t i = 0; i < window_num_; ++i) {
+ uint32_t target_size = (i == window_num_ - 1) ? max_special_bucket_size
+ : max_regular_bucket_size;
+ pad_a_window(zero_, target_size, windows_[i]);
+ }
+ end = std::chrono::high_resolution_clock::now();
+ duration_us =
+ std::chrono::duration_cast(end - start)
+ .count();
+ std::cout << "pad time: " << duration_us << " us" << std::endl;
+ }
+
+ void distribute_a_window_to_buffer_zero_indexed(
+ int32_t* slices, point_t* points, uint32_t tile_length,
+ uint32_t max_bucket_size, std::vector& buckets,
+ std::vector& bucket_sizes, uint32_t& overflow) {
+ for (uint32_t i = 0; i < tile_length; ++i) {
+ uint32_t prefetch_distance = 32;
+ if (i + prefetch_distance < tile_length) {
+ __builtin_prefetch(&slices[i + prefetch_distance], 0, 1);
+ __builtin_prefetch(&points[i + prefetch_distance], 0, 1);
+ }
+ if (slices[i] <= 0) {
+ continue;
+ }
+ int32_t bucket_index = slices[i];
+ if (bucket_index >= static_cast(buckets.size())) {
+ std::cerr << "Error: bucket index out of range." << std::endl;
+ continue;
+ }
+ if (bucket_sizes[bucket_index] < max_bucket_size) {
+ buckets[bucket_index][bucket_sizes[bucket_index]] = points[i];
+ bucket_sizes[bucket_index]++;
+ } else {
+ overflow++;
+ }
+ }
+ }
+
+ void distribute_to_buffer_zero_indexed_parallel() {
+ std::vector bucket_overflow(window_num_, 0);
+#ifdef PROFILE
+ auto start = std::chrono::high_resolution_clock::now();
+ auto end = std::chrono::high_resolution_clock::now();
+ auto duration_us =
+ std::chrono::duration_cast(end - start)
+ .count();
+ start = std::chrono::high_resolution_clock::now();
+#endif
+#pragma omp parallel for
+ for (uint32_t i = 0; i < window_num_ - 1; ++i) {
+ int32_t* slices = slices_list_pointers_[i];
+ point_t* points = points_list_;
+ std::vector& buckets = bucket_pointers_[i];
+ std::vector& bucket_sizes = bucket_sizes_[i];
+ uint32_t max_bucket_size = this->fixed_regular_padding_size_;
+ distribute_a_window_to_buffer_zero_indexed(
+ slices, points, msm_length_, max_bucket_size, buckets, bucket_sizes,
+ bucket_overflow[i]);
+ }
+#ifdef PROFILE
+ auto end2 = std::chrono::high_resolution_clock::now();
+#endif
+
+ distribute_a_window_to_buffer_zero_indexed(
+ slices_list_pointers_[window_num_ - 1], points_list_, msm_length_,
+ this->fixed_special_padding_size_, bucket_pointers_[window_num_ - 1],
+ bucket_sizes_[window_num_ - 1], bucket_overflow[window_num_ - 1]);
+
+#ifdef PROFILE
+ end = std::chrono::high_resolution_clock::now();
+ duration_us =
+ std::chrono::duration_cast(end - start)
+ .count();
+ auto duration2_us =
+ std::chrono::duration_cast(end2 - start)
+ .count();
+ std::cout << "dist_r time: " << duration2_us << " us" << std::endl;
+ std::cout << "dist time: " << duration_us << " us" << std::endl;
+#endif
+
+ for (uint32_t i = 0; i < window_num_; ++i) {
+ if (bucket_overflow[i] > 0) {
+ this->truncated = 1;
+ break;
+ }
+ }
+
+#ifdef PROFILE
+ start = std::chrono::high_resolution_clock::now();
+#endif
+#pragma omp parallel for
+ for (uint32_t i = 0; i < window_num_; ++i) {
+ uint32_t target_size = (i == window_num_ - 1)
+ ? this->fixed_special_padding_size_
+ : this->fixed_regular_padding_size_;
+ pad_a_window_to_buffer(zero_, target_size, bucket_pointers_[i],
+ bucket_sizes_[i]);
+ }
+#ifdef PROFILE
+ end = std::chrono::high_resolution_clock::now();
+ duration_us =
+ std::chrono::duration_cast(end - start)
+ .count();
+ std::cout << "pad time: " << duration_us << " us" << std::endl;
+#endif
+ }
+
+ void distribute_to_buffer_parallel() {
+ std::vector bucket_overflow(window_num_, 0);
+#ifdef PROFILE
+ auto start = std::chrono::high_resolution_clock::now();
+ auto end = std::chrono::high_resolution_clock::now();
+ auto duration_us =
+ std::chrono::duration_cast(end - start)
+ .count();
+ start = std::chrono::high_resolution_clock::now();
+#endif
+#pragma omp parallel for
+ for (uint32_t i = 0; i < window_num_ - 1; ++i) {
+ int32_t* slices = slices_list_pointers_[i];
+ point_t* points = points_list_;
+ std::vector& buckets = bucket_pointers_[i];
+ std::vector& bucket_sizes = bucket_sizes_[i];
+ uint32_t max_bucket_size = this->fixed_regular_padding_size_;
+ distribute_a_window_to_buffer(slices, points, msm_length_,
+ max_bucket_size, buckets, bucket_sizes,
+ bucket_overflow[i]);
+ }
+#ifdef PROFILE
+ auto end2 = std::chrono::high_resolution_clock::now();
+#endif
+
+ distribute_a_window_to_buffer(
+ slices_list_pointers_[window_num_ - 1], points_list_, msm_length_,
+ this->fixed_special_padding_size_, bucket_pointers_[window_num_ - 1],
+ bucket_sizes_[window_num_ - 1], bucket_overflow[window_num_ - 1]);
+
+#ifdef PROFILE
+ end = std::chrono::high_resolution_clock::now();
+ duration_us =
+ std::chrono::duration_cast(end - start)
+ .count();
+ auto duration2_us =
+ std::chrono::duration_cast(end2 - start)
+ .count();
+ std::cout << "dist_r time: " << duration2_us << " us" << std::endl;
+ std::cout << "dist time: " << duration_us << " us" << std::endl;
+#endif
+
+ for (uint32_t i = 0; i < window_num_; ++i) {
+ if (bucket_overflow[i] > 0) {
+ this->truncated = 1;
+ break;
+ }
+ }
+
+#ifdef PROFILE
+ start = std::chrono::high_resolution_clock::now();
+#endif
+#pragma omp parallel for
+ for (uint32_t i = 0; i < window_num_; ++i) {
+ uint32_t target_size = (i == window_num_ - 1)
+ ? this->fixed_special_padding_size_
+ : this->fixed_regular_padding_size_;
+ pad_a_window_to_buffer(zero_, target_size, bucket_pointers_[i],
+ bucket_sizes_[i]);
+ }
+#ifdef PROFILE
+ end = std::chrono::high_resolution_clock::now();
+ duration_us =
+ std::chrono::duration_cast(end - start)
+ .count();
+ std::cout << "pad time: " << duration_us << " us" << std::endl;
+#endif
+ }
+
+ void distribute_to_buffer_signed_parallel() {
+ std::vector bucket_overflow(window_num_, 0);
+#ifdef PROFILE
+ auto start = std::chrono::high_resolution_clock::now();
+ auto end = std::chrono::high_resolution_clock::now();
+ auto duration_us =
+ std::chrono::duration_cast(end - start)
+ .count();
+ start = std::chrono::high_resolution_clock::now();
+#endif
+#pragma omp parallel for
+ for (uint32_t i = 0; i < window_num_ - 1; ++i) {
+ int32_t* slices = slices_list_pointers_[i];
+ point_t* points = points_list_;
+ point_t* neg_points = neg_points_list_;
+ std::vector& buckets = bucket_pointers_[i];
+ std::vector& bucket_sizes = bucket_sizes_[i];
+ uint32_t max_bucket_size = this->fixed_regular_padding_size_;
+ distribute_a_window_to_buffer_with_sign(
+ slices, points, neg_points, msm_length_, max_bucket_size, buckets,
+ bucket_sizes, bucket_overflow[i]);
+ }
+#ifdef PROFILE
+ auto end2 = std::chrono::high_resolution_clock::now();
+#endif
+
+ distribute_a_window_to_buffer_with_sign(
+ slices_list_pointers_[window_num_ - 1], points_list_, neg_points_list_,
+ msm_length_, this->fixed_special_padding_size_,
+ bucket_pointers_[window_num_ - 1], bucket_sizes_[window_num_ - 1],
+ bucket_overflow[window_num_ - 1]);
+
+#ifdef PROFILE
+ end = std::chrono::high_resolution_clock::now();
+ duration_us =
+ std::chrono::duration_cast(end - start)
+ .count();
+ auto duration2_us =
+ std::chrono::duration_cast(end2 - start)
+ .count();
+ std::cout << "dist_r time: " << duration2_us << " us" << std::endl;
+ std::cout << "dist time: " << duration_us << " us" << std::endl;
+#endif
+
+ for (uint32_t i = 0; i < window_num_; ++i) {
+ if (bucket_overflow[i] > 0) {
+ this->truncated = 1;
+ break;
+ }
+ }
+
+#ifdef PROFILE
+ start = std::chrono::high_resolution_clock::now();
+#endif
+#pragma omp parallel for
+ for (uint32_t i = 0; i < window_num_; ++i) {
+ uint32_t target_size = (i == window_num_ - 1)
+ ? this->fixed_special_padding_size_
+ : this->fixed_regular_padding_size_;
+ pad_a_window_to_buffer(zero_, target_size, bucket_pointers_[i],
+ bucket_sizes_[i]);
+ }
+#ifdef PROFILE
+ end = std::chrono::high_resolution_clock::now();
+ duration_us =
+ std::chrono::duration_cast(end - start)
+ .count();
+ std::cout << "pad time: " << duration_us << " us" << std::endl;
+#endif
+ }
+
+ void distribute_to_buffer_parallel_v2() {
+ std::vector bucket_overflow(window_num_, 0);
+#ifdef PROFILE
+ auto start = std::chrono::high_resolution_clock::now();
+ auto end = std::chrono::high_resolution_clock::now();
+ auto duration_us =
+ std::chrono::duration_cast(end - start)
+ .count();
+ start = std::chrono::high_resolution_clock::now();
+#endif
+ for (uint32_t t = 0; t < tile_num_; ++t) {
+// Distribute each tile to the corresponding window
+#pragma omp parallel for
+ for (uint32_t i = 0; i < window_num_ - 1; ++i) {
+ int32_t* slices = slices_list_pointers_[i] + t * tile_length_;
+ point_t* points = points_list_ + t * tile_length_;
+ std::vector& buckets = bucket_pointers_[i];
+ std::vector& bucket_sizes = bucket_sizes_[i];
+ uint32_t max_bucket_size = this->fixed_regular_padding_size_;
+ distribute_a_window_to_buffer(slices, points, tile_length_,
+ max_bucket_size, buckets, bucket_sizes,
+ bucket_overflow[i]);
+ }
+ }
+#ifdef PROFILE
+ auto end2 = std::chrono::high_resolution_clock::now();
+#endif
+
+ distribute_a_window_to_buffer(
+ slices_list_pointers_[window_num_ - 1], points_list_, msm_length_,
+ this->real_special_padding_size_, bucket_pointers_[window_num_ - 1],
+ bucket_sizes_[window_num_ - 1], bucket_overflow[window_num_ - 1]);
+
+#ifdef PROFILE
+ end = std::chrono::high_resolution_clock::now();
+ duration_us =
+ std::chrono::duration_cast(end - start)
+ .count();
+ auto duration2_us =
+ std::chrono::duration_cast(end2 - start)
+ .count();
+ std::cout << "dist_r time: " << duration2_us << " us" << std::endl;
+ std::cout << "dist time: " << duration_us << " us" << std::endl;
+#endif
+
+ for (uint32_t i = 0; i < window_num_; ++i) {
+ if (bucket_overflow[i] > 0) {
+ this->truncated = 1;
+ break;
+ }
+ }
+
+#ifdef PROFILE
+ start = std::chrono::high_resolution_clock::now();
+#endif
+#pragma omp parallel for
+ for (uint32_t i = 0; i < window_num_; ++i) {
+ uint32_t target_size = (i == window_num_ - 1)
+ ? this->fixed_special_padding_size_
+ : this->fixed_regular_padding_size_;
+ pad_a_window_to_buffer(zero_, target_size, bucket_pointers_[i],
+ bucket_sizes_[i]);
+ }
+#ifdef PROFILE
+ end = std::chrono::high_resolution_clock::now();
+ duration_us =
+ std::chrono::duration_cast(end - start)
+ .count();
+ std::cout << "pad time: " << duration_us << " us" << std::endl;
+#endif
+ }
+
+ void get_merged_regular_buckets(uint32_t* regular_buckets) {
+ uint32_t offset = 0;
+ for (uint32_t i = 0; i < window_num_ - 1; ++i) {
+ uint32_t bucket_size = windows_[i].size();
+ for (uint32_t j = 0; j < bucket_size; ++j) {
+ std::memcpy(regular_buckets + offset, windows_[i][j].data(),
+ windows_[i][j].size() * sizeof(point_t));
+ offset += windows_[i][j].size() * CHUNK_NUM * COORDINATE_NUM;
+ }
+ }
+ }
+
+ void get_merged_special_buckets(uint32_t* special_buckets) {
+ uint32_t offset = 0;
+ uint32_t window_idx = window_num_ - 1;
+ uint32_t bucket_size = windows_[window_num_ - 1].size();
+ for (uint32_t j = 0; j < bucket_size; ++j) {
+ std::memcpy(special_buckets + offset, windows_[window_idx][j].data(),
+ windows_[window_idx][j].size() * sizeof(point_t));
+ offset += windows_[window_idx][j].size() * CHUNK_NUM * COORDINATE_NUM;
+ }
+ }
+
+ void get_merged_metadata(uint32_t* metadata) {
+ metadata[0] = real_regular_padding_size_;
+ metadata[1] = real_special_padding_size_;
+ }
+
+ private:
+ uint32_t window_num_;
+ uint32_t regular_bucket_num_;
+ uint32_t special_bucket_num_;
+ uint32_t msm_length_;
+
+ uint32_t tile_length_;
+ uint32_t tile_num_;
+ float reserve_ratio_;
+
+ std::vector slices_list_pointers_;
+ point_t* points_list_;
+ point_t* neg_points_list_;
+ point_t* zero_;
+ std::vector>> windows_;
+
+ uint32_t fixed_regular_padding_size_;
+ uint32_t real_regular_padding_size_;
+ uint32_t fixed_special_padding_size_;
+ uint32_t real_special_padding_size_;
+
+ uint32_t* regular_buckets_;
+ uint32_t* special_buckets_;
+ // windows
+ std::vector> bucket_pointers_;
+ std::vector> bucket_sizes_;
+
+ /*When truncated is true, the reuslt will be incorrect.
+ It is only for perfmance profiling goal.*/
+ uint32_t truncated;
+
+ /* For profiling*/
+};
+
+distributor_params_t init_distributor_param(uint32_t slice_length,
+ uint32_t msm_length,
+ double buf_extend_ratio,
+ bool signed_bucket) {
+ // General parameter initialization
+ uint32_t window_num = (SCALAR_BITS + slice_length - 1) / slice_length;
+ uint32_t regular_bucket_num = (1 << slice_length) - 1; // 2^slice_length - 1
+ uint32_t shift_bits = ((window_num - 1) * slice_length) - ORDER_LOW_BITS;
+ assert(shift_bits >= 0);
+ uint32_t special_bucket_num = ORDER_HIGH >> shift_bits;
+ if (signed_bucket) {
+ regular_bucket_num = 1 << (slice_length - 1); // 2^(slice_length - 1)
+ special_bucket_num++;
+ assert(special_bucket_num <= regular_bucket_num);
+ }
+
+ // Special bucket optoimization parameter initialization
+ uint32_t log_special_duplication_ratio = static_cast(std::ceil(
+ std::log2(static_cast(regular_bucket_num) / special_bucket_num)));
+ uint32_t special_duplication_ratio = 1U << log_special_duplication_ratio;
+ uint32_t bucket_num_duplication =
+ special_bucket_num * special_duplication_ratio;
+
+ // Output buffer parameter initialization
+ double expected_regular_bucket_size =
+ static_cast(msm_length) / (regular_bucket_num + 1);
+ double expected_special_bucket_size =
+ static_cast(msm_length) / bucket_num_duplication;
+ uint32_t regular_bucket_size =
+ static_cast(expected_regular_bucket_size * buf_extend_ratio);
+ uint32_t special_bucket_size =
+ static_cast(expected_special_bucket_size * buf_extend_ratio);
+ uint32_t regular_buffer_size_in_U32 = (window_num - 1) * regular_bucket_num *
+ regular_bucket_size * COORDINATE_NUM *
+ CHUNK_NUM;
+ uint32_t special_buffer_size_in_U32 =
+ special_bucket_num * special_duplication_ratio * special_bucket_size *
+ COORDINATE_NUM * CHUNK_NUM;
+
+ distributor_params_t params;
+ params.window_num = window_num;
+ params.regular_bucket_num = regular_bucket_num;
+ params.special_bucket_num = special_bucket_num;
+ params.msm_length = msm_length;
+ params.fixed_regular_padding_size = regular_bucket_size;
+ params.fixed_special_padding_size =
+ special_bucket_size * special_duplication_ratio;
+ params.zero =
+ new uint32_t[COORDINATE_NUM * CHUNK_NUM](); // Initialize to zero
+ params.slices_list = new int32_t[window_num * msm_length];
+ params.points_list = new uint32_t[msm_length * COORDINATE_NUM * CHUNK_NUM];
+ params.neg_points_list =
+ new uint32_t[msm_length * COORDINATE_NUM * CHUNK_NUM];
+ params.regular_buckets = new uint32_t[regular_buffer_size_in_U32];
+ params.special_buckets = new uint32_t[special_buffer_size_in_U32];
+ std::random_device rd;
+ std::mt19937 gen(rd());
+ std::uniform_int_distribution regular_dist(0, regular_bucket_num);
+ std::uniform_int_distribution special_dist(0, special_bucket_num);
+ // Initialize slices_list with random values
+ for (uint32_t w = 0; w < window_num - 1; ++w) {
+ for (uint32_t i = 0; i < msm_length; ++i) {
+ params.slices_list[w * msm_length + i] = regular_dist(gen);
+ }
+ }
+ for (uint32_t i = 0; i < msm_length; ++i) {
+ params.slices_list[(window_num - 1) * msm_length + i] = special_dist(gen);
+ }
+
+ return params;
+}
+
+void free_distributor_param(distributor_params_t& params) {
+ delete[] params.zero;
+ delete[] params.slices_list;
+ delete[] params.points_list;
+ delete[] params.neg_points_list;
+ delete[] params.regular_buckets;
+ delete[] params.special_buckets;
+}
+
+int main_old() {
+ // Example usage of the Distributor class
+ uint32_t slice_length = 10; // Example slice length
+ uint32_t msm_length = 1 << 16; // Example MSM length
+ double buf_extend_ratio = 1.1; // Example buffer extend ratio
+ bool signed_bucket = false; // Example signed bucket flag
+ uint32_t repeat = 50;
+ double duration_us = 0;
+ // printing the parameters
+ std::cout << "Slice Length: " << slice_length << std::endl;
+ std::cout << "MSM Length: " << msm_length << std::endl;
+ std::cout << "Buffer Extend Ratio: " << buf_extend_ratio << std::endl;
+ std::cout << "Signed Bucket: " << (signed_bucket ? "true" : "false")
+ << std::endl;
+
+ for (uint32_t i = 0; i < repeat; ++i) {
+ std::cout << "Run " << i + 1 << " of " << repeat << std::endl;
+ distributor_params_t params = init_distributor_param(
+ slice_length, msm_length, buf_extend_ratio, signed_bucket);
+
+ auto start_time = std::chrono::high_resolution_clock::now();
+
+ Distributor* distributor =
+ new Distributor(params.window_num, params.regular_bucket_num,
+ params.special_bucket_num, params.msm_length);
+
+ distributor->set_fixed_padding_size(params.fixed_regular_padding_size,
+ params.fixed_special_padding_size);
+ distributor->set_slices_list(params.slices_list);
+ distributor->set_points_list(params.points_list);
+ distributor->set_neg_points_list(params.neg_points_list);
+ distributor->set_zeros(params.zero);
+ distributor->set_output_buffers(params.regular_buckets,
+ params.special_buckets);
+
+ distributor->distribute_to_buffer_parallel();
+ auto end_time = std::chrono::high_resolution_clock::now();
+ auto duration = std::chrono::duration_cast(
+ end_time - start_time)
+ .count();
+ std::cout << "Distribution time: " << duration << " us" << std::endl;
+ duration_us += duration;
+
+ // Free allocated memory
+ free_distributor_param(params);
+ delete distributor;
+ }
+ std::cout << "Average distribution time: " << (duration_us / repeat) << " us"
+ << std::endl;
+
+ return 0;
+}
+
+// Dummy performance test function (replace with your real test)
+double run_performance_test(uint32_t slice_length, uint32_t msm_length,
+ double buf_extend_ratio, bool signed_bucket) {
+ distributor_params_t params = init_distributor_param(
+ slice_length, msm_length, buf_extend_ratio, signed_bucket);
+
+ auto start_time = std::chrono::high_resolution_clock::now();
+
+ Distributor distributor =
+ Distributor(params.window_num, params.regular_bucket_num,
+ params.special_bucket_num, params.msm_length);
+
+ distributor.set_fixed_padding_size(params.fixed_regular_padding_size,
+ params.fixed_special_padding_size);
+ distributor.set_slices_list(params.slices_list);
+ distributor.set_points_list(params.points_list);
+ distributor.set_neg_points_list(params.neg_points_list);
+ distributor.set_zeros(params.zero);
+ distributor.set_output_buffers(params.regular_buckets,
+ params.special_buckets);
+
+ distributor.distribute_to_buffer_parallel();
+ auto end_time = std::chrono::high_resolution_clock::now();
+ auto duration = std::chrono::duration_cast(
+ end_time - start_time)
+ .count();
+ // std::cout << "Distribution time: " << duration << " us" << std::endl;
+ // Free allocated memory
+ free_distributor_param(params);
+ return static_cast(duration);
+}
+
+// Compute statistics
+void compute_stats(const std::vector& data, double& min, double& max,
+ double& avg, double& stddev) {
+ if (data.empty()) {
+ min = max = avg = stddev = 0.0;
+ return;
+ }
+ min = *std::min_element(data.begin(), data.end());
+ max = *std::max_element(data.begin(), data.end());
+ avg = std::accumulate(data.begin(), data.end(), 0.0) / data.size();
+ double sum_sq = 0.0;
+ for (double v : data) sum_sq += (v - avg) * (v - avg);
+ stddev = std::sqrt(sum_sq / data.size());
+}
+
+int main() {
+ std::ofstream csv("C_kernel_perf_in_C.csv");
+ int trial_num = 50;
+
+ csv << "slice_length,log2_msm_length,msm_length,buf_extend_ratio,signed_"
+ "bucket,min,max,avg(us),stddev\n";
+
+ std::vector slice_lengths = {6, 8, 10, 12, 14};
+ std::vector log2_msm_lengths = {10, 12, 14,
+ 16, 18, 20}; // log2(msm_length)
+ std::vector buf_extend_ratios = {1.1};
+ std::vector signed_buckets = {false, true};
+
+ for (uint32_t slice_length : slice_lengths) {
+ for (uint32_t log2_msm_length : log2_msm_lengths) {
+ uint32_t msm_length = 1U << log2_msm_length;
+ for (double buf_extend_ratio : buf_extend_ratios) {
+ for (bool signed_bucket : signed_buckets) {
+ if (slice_length >= log2_msm_length) {
+ // std::cerr << "Error: slice_length must be less than
+ // log2_msm_length." << std::endl;
+ continue;
+ }
+ std::cout << "Testing slice_length: " << slice_length
+ << ", log2_msm_length: " << log2_msm_length
+ << ", buf_extend_ratio: " << buf_extend_ratio
+ << ", signed_bucket: " << (signed_bucket ? "true" : "false")
+ << std::endl;
+ // Run the test multiple times for statistics
+ std::vector results;
+ for (int trial = 0; trial < trial_num; ++trial) {
+ double perf = run_performance_test(slice_length, msm_length,
+ buf_extend_ratio, signed_bucket);
+ results.push_back(perf);
+ }
+ double min, max, avg, stddev;
+ compute_stats(results, min, max, avg, stddev);
+
+ csv << slice_length << "," << log2_msm_length << "," << msm_length
+ << "," << buf_extend_ratio << "," << signed_bucket << "," << min
+ << "," << max << "," << avg << "," << stddev << "\n";
+ }
+ }
+ }
+ }
+
+ csv.close();
+ std::cout << "Performance results written to performance_results.csv\n";
+ return 0;
+}
+
+namespace ffi = xla::ffi;
+ffi::Error DistributeImpl(uint32_t window_num, uint32_t regular_bucket_num,
+ uint32_t special_bucket_num, uint32_t msm_length,
+ uint32_t fixed_regular_padding_size,
+ uint32_t fixed_special_padding_size,
+ ffi::Buffer slices_list,
+ ffi::Buffer points_list,
+ ffi::Buffer zero,
+ ffi::ResultBuffer regular_buckets,
+ ffi::ResultBuffer special_buckets,
+ ffi::ResultBuffer metadata) {
+ // Create a Distributor object
+ Distributor distributor(window_num, regular_bucket_num, special_bucket_num,
+ msm_length);
+ distributor.set_fixed_padding_size(fixed_regular_padding_size,
+ fixed_special_padding_size);
+ distributor.set_slices_list(slices_list.typed_data());
+ distributor.set_points_list(points_list.typed_data());
+ distributor.set_zeros(zero.typed_data());
+
+ distributor.distribute();
+
+ distributor.get_merged_regular_buckets(regular_buckets->typed_data());
+ distributor.get_merged_special_buckets(special_buckets->typed_data());
+ distributor.get_merged_metadata(metadata->typed_data());
+ // Return success
+ return xla::ffi::Error::Success();
+}
+
+ffi::Error DistributeBufImpl(uint32_t window_num, uint32_t regular_bucket_num,
+ uint32_t special_bucket_num, uint32_t msm_length,
+ uint32_t fixed_regular_padding_size,
+ uint32_t fixed_special_padding_size,
+ ffi::Buffer slices_list,
+ ffi::Buffer points_list,
+ ffi::Buffer zero,
+ ffi::ResultBuffer regular_buckets,
+ ffi::ResultBuffer special_buckets,
+ ffi::ResultBuffer metadata) {
+ // Create a Distributor object
+ Distributor distributor(window_num, regular_bucket_num, special_bucket_num,
+ msm_length);
+ distributor.set_fixed_padding_size(fixed_regular_padding_size,
+ fixed_special_padding_size);
+ distributor.set_slices_list(slices_list.typed_data());
+ distributor.set_points_list(points_list.typed_data());
+ distributor.set_zeros(zero.typed_data());
+ distributor.set_output_buffers(regular_buckets->typed_data(),
+ special_buckets->typed_data());
+
+ distributor.distribute_to_buffer_parallel();
+
+ // Return success
+ return ffi::Error::Success();
+}
+
+ffi::Error DistributeBufSignedImpl(
+ uint32_t window_num, uint32_t regular_bucket_num,
+ uint32_t special_bucket_num, uint32_t msm_length,
+ uint32_t fixed_regular_padding_size, uint32_t fixed_special_padding_size,
+ ffi::Buffer slices_list, ffi::Buffer points_list,
+ ffi::Buffer neg_points_list, ffi::Buffer zero,
+ ffi::ResultBuffer regular_buckets,
+ ffi::ResultBuffer special_buckets,
+ ffi::ResultBuffer metadata) {
+ // Create a Distributor object
+ Distributor distributor(window_num, regular_bucket_num, special_bucket_num,
+ msm_length);
+ distributor.set_fixed_padding_size(fixed_regular_padding_size,
+ fixed_special_padding_size);
+ distributor.set_slices_list(slices_list.typed_data());
+ distributor.set_points_list(points_list.typed_data());
+ distributor.set_neg_points_list(neg_points_list.typed_data());
+ distributor.set_zeros(zero.typed_data());
+ distributor.set_output_buffers(regular_buckets->typed_data(),
+ special_buckets->typed_data());
+
+ distributor.distribute_to_buffer_signed_parallel();
+
+ // Return success
+ return ffi::Error::Success();
+}
+
+XLA_FFI_DEFINE_HANDLER_SYMBOL(
+ Distribute, DistributeImpl,
+ ffi::Ffi::Bind()
+ .Attr("window_num")
+ .Attr("regular_bucket_num")
+ .Attr("special_bucket_num")
+ .Attr("msm_length")
+ .Attr("fixed_regular_padding_size")
+ .Attr("fixed_special_padding_size")
+ .Arg>() // slices_list
+ .Arg>() // points_list
+ .Arg>() // zero
+ .Ret>() // regular_buckets
+ .Ret>() // special_buckets
+ .Ret>() // metadata
+);
+
+XLA_FFI_DEFINE_HANDLER_SYMBOL(
+ DistributeBuf, DistributeBufImpl,
+ ffi::Ffi::Bind()
+ .Attr("window_num")
+ .Attr("regular_bucket_num")
+ .Attr("special_bucket_num")
+ .Attr("msm_length")
+ .Attr("fixed_regular_padding_size")
+ .Attr("fixed_special_padding_size")
+ .Arg>() // slices_list
+ .Arg>() // points_list
+ .Arg>() // zero
+ .Ret>() // regular_buckets
+ .Ret>() // special_buckets
+ .Ret>() // metadata
+);
+
+ffi::Error DistributeBufZeroImpl(
+ uint32_t window_num, uint32_t regular_bucket_num,
+ uint32_t special_bucket_num, uint32_t msm_length,
+ uint32_t fixed_regular_padding_size, uint32_t fixed_special_padding_size,
+ ffi::Buffer slices_list, ffi::Buffer points_list,
+ ffi::Buffer zero, ffi::ResultBuffer regular_buckets,
+ ffi::ResultBuffer special_buckets,
+ ffi::ResultBuffer metadata) {
+ Distributor distributor(window_num, regular_bucket_num, special_bucket_num,
+ msm_length);
+ distributor.set_fixed_padding_size(fixed_regular_padding_size,
+ fixed_special_padding_size);
+ distributor.set_slices_list(slices_list.typed_data());
+ distributor.set_points_list(points_list.typed_data());
+ distributor.set_zeros(zero.typed_data());
+ distributor.set_output_buffers(regular_buckets->typed_data(),
+ special_buckets->typed_data());
+
+ distributor.distribute_to_buffer_zero_indexed_parallel();
+
+ return ffi::Error::Success();
+}
+
+XLA_FFI_DEFINE_HANDLER_SYMBOL(
+ DistributeBufZero, DistributeBufZeroImpl,
+ ffi::Ffi::Bind()
+ .Attr("window_num")
+ .Attr("regular_bucket_num")
+ .Attr("special_bucket_num")
+ .Attr("msm_length")
+ .Attr("fixed_regular_padding_size")
+ .Attr("fixed_special_padding_size")
+ .Arg>() // slices_list
+ .Arg>() // points_list
+ .Arg>() // zero
+ .Ret>() // regular_buckets
+ .Ret>() // special_buckets
+ .Ret>() // metadata
+);
+
+XLA_FFI_DEFINE_HANDLER_SYMBOL(
+ DistributeBufSigned, DistributeBufSignedImpl,
+ ffi::Ffi::Bind()
+ .Attr("window_num")
+ .Attr("regular_bucket_num")
+ .Attr("special_bucket_num")
+ .Attr("msm_length")
+ .Attr("fixed_regular_padding_size")
+ .Attr("fixed_special_padding_size")
+ .Arg>() // slices_list
+ .Arg>() // points_list
+ .Arg>() // neg_points_list
+ .Arg>() // zero
+ .Ret>() // regular_buckets
+ .Ret>() // special_buckets
+ .Ret>() // metadata
+);
diff --git a/jaxite_ec/configurations.toml b/jaxite_ec/configurations.toml
new file mode 100644
index 0000000..bec8438
--- /dev/null
+++ b/jaxite_ec/configurations.toml
@@ -0,0 +1,30 @@
+########################################################
+# General
+########################################################
+serialized_jax_kernel_dir = "./deployments/"
+hash_length = 4 # length of hash strings used in file names (alphanumeric, base-62)
+
+
+
+########################################################
+# Elliptic Curve Parameters
+########################################################
+[ec_parameters_bls12_377_affine]
+prime = 0x01AE3A4617C510EAC63B05C06CA1493B1A22D9F300F5138F1EF3622FBA094800170B5D44300000008508C00000000001
+order = 0x12AB655E9A2CA55660B44D1E5C37B00159AA76FED00000010A11800000000001
+a = 0
+b = 1
+generator = [0x008848DEFE740A67C8FC6225BF87FF5485951E2CAA9D41BB188282C8BD37CB5CD5481512FFCD394EEAB9B16EB21BE9EF, 0x01914A69C5102EFF1F674F5D30AFEEC4BD7FB348CA3E52D96D182AD44FB82305C2FE3D3634A9591AFD82DE55559C8EA6]
+
+[ec_parameters_bls12_377_extended_twisted_edwards]
+prime = 258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177
+order = 8444461749428370424248824938781546531375899335154063827935233455917409239041
+a = -1
+d = 122268283598675559488486339158635529096981886914877139579534153582033676785385790730042363341236035746924960903179
+alpha = -1
+b = 1
+s = 10189023633222963290707194929886294091415157242906428298294512798502806398782149227503530278436336312243746741931
+MA = 228097355113300204138531148905234651262148041026195375645000724271212049151994375092458297304264351187709081232384
+MB = 10189023633222963290707194929886294091415157242906428298294512798502806398782149227503530278436336312243746741931
+t = 23560188534917577818843641916571445935985386319233886518929971599490231428764380923487987729215299304184915158756
+generator = [71222569531709137229370268896323705690285216175189308202338047559628438110820800641278662592954630774340654489393, 6177051365529633638563236407038680211609544222665285371549726196884440490905471891908272386851767077598415378235]
\ No newline at end of file
diff --git a/jaxite_ec/test_case/t1024/zprize_msm_curve_377_bases_dim_1024_seed_0.csv b/jaxite_ec/data/t1024/zprize_msm_curve_377_bases_dim_1024_seed_0.csv
similarity index 99%
rename from jaxite_ec/test_case/t1024/zprize_msm_curve_377_bases_dim_1024_seed_0.csv
rename to jaxite_ec/data/t1024/zprize_msm_curve_377_bases_dim_1024_seed_0.csv
index 1affbca..4d5efb5 100644
--- a/jaxite_ec/test_case/t1024/zprize_msm_curve_377_bases_dim_1024_seed_0.csv
+++ b/jaxite_ec/data/t1024/zprize_msm_curve_377_bases_dim_1024_seed_0.csv
@@ -1021,4 +1021,4 @@
1020, x, 012c98c9f7a2a5d1, ea874762b1e8c4a1, 35fbe0fb5894980c, 44f8e64cd96d180c, 85ba13911acb084f, b3b26ad0609738fe, zz: Fp384 "(0188570868A82EB0E0DDD97F55DAD85FF37B7637DAD6A12A80F9C86AB30F2CE853CCB5833FC3394CB83B187B9913FE7E)", y, 015b19304056067c, 908e5c19fe908420, e0f26a2c8fd72c93, 0840903a30687f8a, bccf44f4ea625176, 501ba7f12adc27a8, zz: Fp384 "(017ED2A2320E71894D670B2226830E2411AC8A717D1768A4CBC6B353571220C96496127BC7362B6B9D91373E8CD2681B)"
1021, x, 00f2048bceae505c, fe87dff235238fb0, 4ded13aa23396358, 84d5a9690203a9b5, d0ac634f189abad8, 4cdbb5659e84a52c, zz: Fp384 "(010AC4C7D2713EEC119C408FFAE8D3384DF4E7C4BCB12E001FDC8A7ADE7FC248FCCB7D919281EA13BBB7872B4D95044C)", y, 0098bdfaec3f29a0, 2e58bfc101ea8d74, a557b0596c9a161c, 735751360474b38c, 921c1d87ec7be426, 3d63974b26fa7705, zz: Fp384 "(0145B8E22FD6E243A109F4D2EED6B6E4AC4E59EB640A2494209777B50D5321A186F0A301EB762D5DB91401FF0C3D46DB)"
1022, x, 00356c16958fb477, 8ae26c11a4828367, 34db5eb3749d3968, 17041b9560c36b01, d5b3a6e348b8f643, c55b73c307c9329b, zz: Fp384 "(010E4E39A08185356BFCB7A9A2A35B9920389D094C35A80CD139B6E9ACDE311D73FBF42845B9322A6C6D9B52B711522D)", y, 0088810ee70cf6b5, 9d2312a04fa2cc6c, 71a6895532c12549, 4dbb59f3b1f28742, 5de8969809d4048b, 4245fa5364a4b35d, zz: Fp384 "(00C00B0BDF9032FBE826E6F890EAA7FDA47C5165934F6CED340B9C3012636D780E3C7DE8DA62FBC4863EC79E1936F8F5)"
-1023, x, 0133064a88647c9a, eef3ffdd08549f02, fcd5b94f6e966386, 96f66ce8af9b507d, 2525c4c66cdde115, 1e52e82abddda9e9, zz: Fp384 "(0103B2D12B9ECAF92133B192BF3DD477160BF1F637EFDC393C31BDB9579F196B1FAD1E4C5AF9FF565AAAACD3A5867659)", y, 0098742883b34a4a, 6290ee8ddccf15b5, 2790a01ac2896430, 5aafe8673f723240, b2149c9d67dced62, 7160bb12c0fc2da8, zz: Fp384 "(015155F210905A6D7700FCC1098F6A9DD895BCADE864008EDA180DB8C32934CAEF5BBD97CC1A49C52E5F593FF31E50A8)"
\ No newline at end of file
+1023, x, 0133064a88647c9a, eef3ffdd08549f02, fcd5b94f6e966386, 96f66ce8af9b507d, 2525c4c66cdde115, 1e52e82abddda9e9, zz: Fp384 "(0103B2D12B9ECAF92133B192BF3DD477160BF1F637EFDC393C31BDB9579F196B1FAD1E4C5AF9FF565AAAACD3A5867659)", y, 0098742883b34a4a, 6290ee8ddccf15b5, 2790a01ac2896430, 5aafe8673f723240, b2149c9d67dced62, 7160bb12c0fc2da8, zz: Fp384 "(015155F210905A6D7700FCC1098F6A9DD895BCADE864008EDA180DB8C32934CAEF5BBD97CC1A49C52E5F593FF31E50A8)"
diff --git a/jaxite_ec/test_case/t1024/zprize_msm_curve_377_res_dim_1024_seed_0.csv b/jaxite_ec/data/t1024/zprize_msm_curve_377_res_dim_1024_seed_0.csv
similarity index 86%
rename from jaxite_ec/test_case/t1024/zprize_msm_curve_377_res_dim_1024_seed_0.csv
rename to jaxite_ec/data/t1024/zprize_msm_curve_377_res_dim_1024_seed_0.csv
index d77e7dd..363b1a8 100644
--- a/jaxite_ec/test_case/t1024/zprize_msm_curve_377_res_dim_1024_seed_0.csv
+++ b/jaxite_ec/data/t1024/zprize_msm_curve_377_res_dim_1024_seed_0.csv
@@ -1 +1 @@
-x, 01543e9a8f6cd81c, 2c5c084d7e5a0ff0, 408dc7edb8382513, 8ddfcfd7249d9b0c, d8e6f994e24cbb4d, 77682f864dc1ce64, zz: Fp384 "(00B5A8DC793FF0EBBD3A0E7170D83EA721FA80ABA6E306D0AF2703EF9762DFB47A80BE839B97E72DAE7C2D0B6757106C)", y, 01729e1ff869e64f, 9ad1b3c12f3f5481, e8f1f4e5a91bed5d, a874614f38e0ded6, 31531869aa1785fb, 4d8e2298435ee696, zz: Fp384 "(00F9A8D085C36B7CD26A6AD1C3DFF2DDA1FE33D4828A527AE6475E1863AE685762B6B9CD59B8FB67C58703D16B340F09)"
\ No newline at end of file
+x, 01543e9a8f6cd81c, 2c5c084d7e5a0ff0, 408dc7edb8382513, 8ddfcfd7249d9b0c, d8e6f994e24cbb4d, 77682f864dc1ce64, zz: Fp384 "(00B5A8DC793FF0EBBD3A0E7170D83EA721FA80ABA6E306D0AF2703EF9762DFB47A80BE839B97E72DAE7C2D0B6757106C)", y, 01729e1ff869e64f, 9ad1b3c12f3f5481, e8f1f4e5a91bed5d, a874614f38e0ded6, 31531869aa1785fb, 4d8e2298435ee696, zz: Fp384 "(00F9A8D085C36B7CD26A6AD1C3DFF2DDA1FE33D4828A527AE6475E1863AE685762B6B9CD59B8FB67C58703D16B340F09)"
diff --git a/jaxite_ec/test_case/t1024/zprize_msm_curve_377_scalars_dim_1024_seed_0.csv b/jaxite_ec/data/t1024/zprize_msm_curve_377_scalars_dim_1024_seed_0.csv
similarity index 99%
rename from jaxite_ec/test_case/t1024/zprize_msm_curve_377_scalars_dim_1024_seed_0.csv
rename to jaxite_ec/data/t1024/zprize_msm_curve_377_scalars_dim_1024_seed_0.csv
index bec8106..fa3fff7 100644
--- a/jaxite_ec/test_case/t1024/zprize_msm_curve_377_scalars_dim_1024_seed_0.csv
+++ b/jaxite_ec/data/t1024/zprize_msm_curve_377_scalars_dim_1024_seed_0.csv
@@ -1021,4 +1021,4 @@
1020, 0ecccbe98754c222, 03a75f6652abc4b2, 943173a9fdfec684, eda2e4c7a7bf0baf, zz: Fp256 "(105BE4B38C035D1848D9A285173F8716BAB21F59E0371871E048033B5980544F)"
1021, 0c08c6b9a73f213b, 895b7d0f13cd478b, 1ee44d69464688cc, 509d47f371f7185d, zz: Fp256 "(02CA79487BE3365732763ACF0E818B6552FC45A786480EC7BBA27DC26EE93858)"
1022, 0855c8b333f7c1c0, 1b6f4de421fb8a6d, 64d1e921237109a7, ef194ae414988390, zz: Fp256 "(09B817A72336D6D78EBD429DAA73F178A0C10AAE103DEC47401A8ED4F089C9FE)"
-1023, 0d30713603e6c901, 56e77dbc8c902e50, 88c1dbdc29e27097, c1516a88192dcbae, zz: Fp256 "(103A3BF5E3AAC3D6C62F18DEC884B7A7CBAF9C7BCEB861C64C5441B2D210BE76)"
\ No newline at end of file
+1023, 0d30713603e6c901, 56e77dbc8c902e50, 88c1dbdc29e27097, c1516a88192dcbae, zz: Fp256 "(103A3BF5E3AAC3D6C62F18DEC884B7A7CBAF9C7BCEB861C64C5441B2D210BE76)"
diff --git a/jaxite_ec/elliptic_curve.py b/jaxite_ec/elliptic_curve.py
deleted file mode 100644
index bf846d2..0000000
--- a/jaxite_ec/elliptic_curve.py
+++ /dev/null
@@ -1,896 +0,0 @@
-"""The jaxite_ec implementation of the Elliptic curve operations on TPU.
-
-Detailed algorithms come from the following papers:
-xyzz: https://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html
-affine: https://www.hyperelliptic.org/EFD/g1p/auto-shortw.html
-projective:
-https://www.hyperelliptic.org/EFD/g1p/auto-shortw-projective.html#addition-madd-1998-cmo
-
-A non-TPU version of the same functions can be found in
-jaxite_ec/algorithm/elliptic_curve.py
-
-To test the functionalities of this library, please refer to
-jaxite_ec/elliptic_curve_test.py
-"""
-
-import functools
-
-import jax
-import jax.numpy as jnp
-from jaxite.jaxite_ec import finite_field
-from jaxite.jaxite_ec import util
-
-
-add_3u16 = finite_field.add_3u16
-add_2u16 = finite_field.add_2u16
-sub_2u16 = finite_field.sub_2u16
-cond_sub_2u16 = finite_field.cond_sub_2u16
-cond_sub_mod_u16 = finite_field.cond_sub_mod_u16
-mod_mul_barrett_2u16 = finite_field.mod_mul_barrett_2u16
-mod_mul_lazy_2u16 = finite_field.mod_mul_lazy_2u16
-mod_mul_rns_2u16 = finite_field.mod_mul_rns_2u16
-add_rns_2u16 = finite_field.add_rns_2u16
-add_rns_3u16 = finite_field.add_rns_3u16
-add_sub_rns_var = finite_field.add_sub_rns_var
-negate_rns_for_var_add = finite_field.negate_rns_for_var_add
-negate_rns_for_var_add_zero_check = (
- finite_field.negate_rns_for_var_add_zero_check
-)
-rns_constant = finite_field.rns_constant
-
-
-# Barrett Reduction Based Functions
-@jax.named_call
-def padd_barret_xyzz(
- x1: jax.Array,
- y1: jax.Array,
- zz1: jax.Array,
- zzz1: jax.Array,
- x2: jax.Array,
- y2: jax.Array,
- zz2: jax.Array,
- zzz2: jax.Array,
-):
- """PADD-BARRETT elliptic curve operation with packed arguments.
-
- As for the algorithm, pls refer to
- jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassXYZZ::add_general
-
- This function implements the PADD-BARRETT elliptic curve operation with packed
- arguments, which is used to compute the elliptic curve points of a given
- group.
-
- Args:
- x1: The first generator element.
- y1: The second generator element.
- zz1: The third generator element.
- zzz1: The third generator element.
- x2: The first generator element.
- y2: The second generator element.
- zz2: The third generator element.
- zzz2: The third generator element.
-
- Returns:
- A tuple containing the third generator element and the elliptic curve points
- of the group.
- """
- u1 = mod_mul_barrett_2u16(x1, zz2)
- u2 = mod_mul_barrett_2u16(x2, zz1)
- s1 = mod_mul_barrett_2u16(y1, zzz2)
- s2 = mod_mul_barrett_2u16(y2, zzz1)
- zz1_zz2 = mod_mul_barrett_2u16(zz1, zz2)
- zzz1_zzz2 = mod_mul_barrett_2u16(zzz1, zzz2)
-
- p = cond_sub_2u16(u2, u1)
- r = cond_sub_2u16(s2, s1)
-
- pp = mod_mul_barrett_2u16(p, p)
- rr = mod_mul_barrett_2u16(r, r)
-
- ppp = mod_mul_barrett_2u16(pp, p)
- q = mod_mul_barrett_2u16(u1, pp)
- zz3 = mod_mul_barrett_2u16(zz1_zz2, pp)
-
- ppp_q_2 = add_3u16(ppp, q, q)
- ppp_q_2 = cond_sub_mod_u16(ppp_q_2)
- ppp_q_2 = cond_sub_mod_u16(ppp_q_2)
-
- x3 = cond_sub_2u16(rr, ppp_q_2)
-
- q_x3 = cond_sub_2u16(q, x3)
- s1_ppp = mod_mul_barrett_2u16(s1, ppp)
- zzz3 = mod_mul_barrett_2u16(zzz1_zzz2, ppp)
-
- y3 = mod_mul_barrett_2u16(r, q_x3)
- y3 = cond_sub_2u16(y3, s1_ppp)
-
- return jnp.array([x3, y3, zz3, zzz3])
-
-
-@jax.named_call
-def pdul_barret_xyzz(
- x1: jax.Array, y1: jax.Array, zz1: jax.Array, zzz1: jax.Array
-):
- """PDUL-BARRET elliptic curve operation with packed arguments.
-
- As for the algorithm, pls refer to
- jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassXYZZ::double_general
-
- This function implements the PDUL-BARRET elliptic curve operation with packed
- arguments, which is used to compute the elliptic curve points of a given
- group.
-
- Args:
- x1: The first generator element.
- y1: The second generator element.
- zz1: The third generator element.
- zzz1: The third generator element.
-
- Returns:
- A tuple containing the third generator element and the elliptic curve points
- of the group.
- """
- u = add_2u16(y1, y1)
- u = cond_sub_mod_u16(u)
-
- x1x1 = mod_mul_barrett_2u16(x1, x1)
- v = mod_mul_barrett_2u16(u, u)
-
- w = mod_mul_barrett_2u16(u, v)
- s = mod_mul_barrett_2u16(x1, v)
-
- s_2 = add_2u16(s, s)
- s_2 = cond_sub_mod_u16(s_2)
-
- m = add_3u16(x1x1, x1x1, x1x1)
- m = cond_sub_mod_u16(m)
- m = cond_sub_mod_u16(m)
-
- mm = mod_mul_barrett_2u16(m, m)
- w_y1 = mod_mul_barrett_2u16(w, y1)
- zz3 = mod_mul_barrett_2u16(v, zz1)
- zzz3 = mod_mul_barrett_2u16(w, zzz1)
-
- x3 = cond_sub_2u16(mm, s_2)
-
- s_x3 = cond_sub_2u16(s, x3)
-
- y3 = mod_mul_barrett_2u16(m, s_x3)
- y3 = cond_sub_2u16(y3, w_y1)
-
- return jnp.array([x3, y3, zz3, zzz3])
-
-
-@jax.named_call
-def pdul_barrett_xyzz_pack(x1_y1_zz1_zzz1: jax.Array):
- return pdul_barret_xyzz(
- x1_y1_zz1_zzz1[0], x1_y1_zz1_zzz1[1], x1_y1_zz1_zzz1[2], x1_y1_zz1_zzz1[3]
- )
-
-
-@jax.named_call
-def padd_barrett_xyzz_pack(
- x1_y1_zz1_zzz1: jax.Array, x2_y2_zz2_zzz2: jax.Array
-):
- return padd_barret_xyzz(
- x1_y1_zz1_zzz1[0],
- x1_y1_zz1_zzz1[1],
- x1_y1_zz1_zzz1[2],
- x1_y1_zz1_zzz1[3],
- x2_y2_zz2_zzz2[0],
- x2_y2_zz2_zzz2[1],
- x2_y2_zz2_zzz2[2],
- x2_y2_zz2_zzz2[3],
- )
-
-
-@jax.named_call
-def pdul_barrett_xyzz_pack_batch_first(
- x1_y1_zz1_zzz1: jax.Array, transpose=(0, 1, 2)
-):
- return pdul_barret_xyzz(
- x1_y1_zz1_zzz1[:, 0],
- x1_y1_zz1_zzz1[:, 1],
- x1_y1_zz1_zzz1[:, 2],
- x1_y1_zz1_zzz1[:, 3],
- ).transpose(transpose[0], transpose[1], transpose[2])
-
-
-@jax.named_call
-def padd_barrett_xyzz_pack_batch_first(
- x1_y1_zz1_zzz1: jax.Array, x2_y2_zz2_zzz2: jax.Array, transpose=(0, 1, 2)
-):
- return padd_barret_xyzz(
- x1_y1_zz1_zzz1[:, 0],
- x1_y1_zz1_zzz1[:, 1],
- x1_y1_zz1_zzz1[:, 2],
- x1_y1_zz1_zzz1[:, 3],
- x2_y2_zz2_zzz2[:, 0],
- x2_y2_zz2_zzz2[:, 1],
- x2_y2_zz2_zzz2[:, 2],
- x2_y2_zz2_zzz2[:, 3],
- ).transpose(transpose[0], transpose[1], transpose[2])
-
-
-# Lazy Reduction Based Functions
-@jax.named_call
-def padd_lazy_xyzz(
- x1: jax.Array,
- y1: jax.Array,
- zz1: jax.Array,
- zzz1: jax.Array,
- x2: jax.Array,
- y2: jax.Array,
- zz2: jax.Array,
- zzz2: jax.Array,
-):
- """PADD-LAZY elliptic curve operation with packed arguments.
-
- As for the algorithm, pls refer to
- jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassXYZZ::add_general
-
- This function implements the PADD-LAZY elliptic curve operation with packed
- arguments, which is used to compute the elliptic curve points of a given
- group.
-
- Args:
- x1: The first generator element.
- y1: The second generator element.
- zz1: The third generator element.
- zzz1: The third generator element.
- x2: The first generator element.
- y2: The second generator element.
- zz2: The third generator element.
- zzz2: The third generator element.
-
- Returns:
- A tuple containing the third generator element and the elliptic curve points
- of the group.
- """
- cond_sub_2u16_ext = functools.partial(
- cond_sub_2u16,
- modulus_377_int_chunk=util.MODULUS_377_S16_INT_CHUNK,
- chunk_num_u16=util.U16_EXT_CHUNK_NUM,
- )
- cond_sub_mod_u16_ext = functools.partial(
- cond_sub_mod_u16,
- modulus_377_int_chunk=util.MODULUS_377_S16_INT_CHUNK,
- chunk_num_u16=util.U16_EXT_CHUNK_NUM,
- )
-
- u1 = mod_mul_lazy_2u16(x1, zz2)
- u2 = mod_mul_lazy_2u16(x2, zz1)
- s1 = mod_mul_lazy_2u16(y1, zzz2)
- s2 = mod_mul_lazy_2u16(y2, zzz1)
- zz1_zz2 = mod_mul_lazy_2u16(zz1, zz2)
- zzz1_zzz2 = mod_mul_lazy_2u16(zzz1, zzz2)
-
- p = cond_sub_2u16_ext(u2, u1)
- r = cond_sub_2u16_ext(s2, s1)
-
- pp = mod_mul_lazy_2u16(p, p)
- rr = mod_mul_lazy_2u16(r, r)
-
- ppp = mod_mul_lazy_2u16(pp, p)
- q = mod_mul_lazy_2u16(u1, pp)
- zz3 = mod_mul_lazy_2u16(zz1_zz2, pp)
-
- # Can be replaced by mod_add_lazy.
- ppp_q_2 = add_3u16(ppp, q, q)
- ppp_q_2 = cond_sub_mod_u16_ext(ppp_q_2)
- ppp_q_2 = cond_sub_mod_u16_ext(ppp_q_2)
-
- x3 = cond_sub_2u16_ext(rr, ppp_q_2)
-
- q_x3 = cond_sub_2u16_ext(q, x3)
- s1_ppp = mod_mul_lazy_2u16(s1, ppp)
- zzz3 = mod_mul_lazy_2u16(zzz1_zzz2, ppp)
-
- y3 = mod_mul_lazy_2u16(r, q_x3)
- y3 = cond_sub_2u16_ext(y3, s1_ppp)
-
- return jnp.array([x3, y3, zz3, zzz3])
-
-
-@jax.named_call
-def pdul_lazy_xyzz(
- x1: jax.Array,
- y1: jax.Array,
- zz1: jax.Array,
- zzz1: jax.Array,
-):
- """PDUL-BARRET elliptic curve operation with packed arguments.
-
- As for the algorithm, pls refer to
- jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassXYZZ::double_general
-
- This function implements the PDUL-BARRET elliptic curve operation with packed
- arguments, which is used to compute the elliptic curve points of a given
- group.
-
- Args:
- x1: The first generator element.
- y1: The second generator element.
- zz1: The third generator element.
- zzz1: The third generator element.
-
- Returns:
- A tuple containing the third generator element and the elliptic curve points
- of the group.
- """
- cond_sub_2u16_ext = functools.partial(
- cond_sub_2u16,
- modulus_377_int_chunk=util.MODULUS_377_S16_INT_CHUNK,
- chunk_num_u16=util.U16_EXT_CHUNK_NUM,
- )
- cond_sub_mod_u16_ext = functools.partial(
- cond_sub_mod_u16,
- modulus_377_int_chunk=util.MODULUS_377_S16_INT_CHUNK,
- chunk_num_u16=util.U16_EXT_CHUNK_NUM,
- )
- u = add_2u16(y1, y1)
- u = cond_sub_mod_u16_ext(u)
-
- x1x1 = mod_mul_lazy_2u16(x1, x1)
- v = mod_mul_lazy_2u16(u, u)
-
- w = mod_mul_lazy_2u16(u, v)
- s = mod_mul_lazy_2u16(x1, v)
-
- s_2 = add_2u16(s, s)
- s_2 = cond_sub_mod_u16_ext(s_2)
-
- m = add_3u16(x1x1, x1x1, x1x1)
- m = cond_sub_mod_u16_ext(m)
- m = cond_sub_mod_u16_ext(m)
-
- mm = mod_mul_lazy_2u16(m, m)
- w_y1 = mod_mul_lazy_2u16(w, y1)
- zz3 = mod_mul_lazy_2u16(v, zz1)
- zzz3 = mod_mul_lazy_2u16(w, zzz1)
-
- x3 = cond_sub_2u16_ext(mm, s_2)
-
- s_x3 = cond_sub_2u16_ext(s, x3)
-
- y3 = mod_mul_lazy_2u16(m, s_x3)
- y3 = cond_sub_2u16_ext(y3, w_y1)
-
- return jnp.array([x3, y3, zz3, zzz3])
-
-
-@jax.named_call
-def padd_lazy_xyzz_pack(x1_y1_zz1_zzz1: jax.Array, x2_y2_zz2_zzz2: jax.Array):
- return padd_lazy_xyzz(
- x1_y1_zz1_zzz1[0],
- x1_y1_zz1_zzz1[1],
- x1_y1_zz1_zzz1[2],
- x1_y1_zz1_zzz1[3],
- x2_y2_zz2_zzz2[0],
- x2_y2_zz2_zzz2[1],
- x2_y2_zz2_zzz2[2],
- x2_y2_zz2_zzz2[3],
- )
-
-
-@jax.named_call
-def pdul_lazy_xyzz_pack(x1_y1_zz1_zzz1: jax.Array):
- return pdul_lazy_xyzz(
- x1_y1_zz1_zzz1[0],
- x1_y1_zz1_zzz1[1],
- x1_y1_zz1_zzz1[2],
- x1_y1_zz1_zzz1[3],
- )
-
-
-# Lazy Reduction Based Function
-@jax.named_call
-@functools.partial(jax.jit, static_argnames="twisted_d_chunk")
-def padd_lazy_twisted(
- x1: jax.Array,
- y1: jax.Array,
- z1: jax.Array,
- t1: jax.Array,
- x2: jax.Array,
- y2: jax.Array,
- z2: jax.Array,
- t2: jax.Array,
- twisted_d_chunk=util.TWIST_D_INT_CHUNK,
-):
- """PADD-LAZY elliptic curve operation with packed arguments.
-
- As for the algorithm, pls refer to
- jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassTwisted::add_general
-
- This function implements the PADD-LAZY elliptic curve operation with packed
- arguments, which is used to compute the elliptic curve points of a given
- group.
-
- Args:
- x1: The first generator element.
- y1: The second generator element.
- z1: The third generator element.
- t1: The fourth generator element.
- x2: The first generator element.
- y2: The second generator element.
- z2: The third generator element.
- t2: The fourth generator element.
- twisted_d_chunk: The twisted d parameter.
-
- Returns:
- A tuple containing the third generator element and the elliptic curve points
- of the group.
- """
- cond_sub_2u16_ext = functools.partial(
- cond_sub_2u16,
- modulus_377_int_chunk=util.MODULUS_377_S16_INT_CHUNK,
- chunk_num_u16=util.U16_EXT_CHUNK_NUM,
- )
- cond_sub_mod_u16_ext = functools.partial(
- cond_sub_mod_u16,
- modulus_377_int_chunk=util.MODULUS_377_S16_INT_CHUNK,
- chunk_num_u16=util.U16_EXT_CHUNK_NUM,
- )
-
- twisted_d = jnp.asarray(twisted_d_chunk, dtype=jnp.uint16)
- twisted_d = jax.lax.broadcast(twisted_d, [x1.shape[0]])
-
- a = mod_mul_lazy_2u16(x1, x2)
- b = mod_mul_lazy_2u16(y1, y2)
- d = mod_mul_lazy_2u16(z1, z2)
- c = mod_mul_lazy_2u16(t1, t2)
- c = mod_mul_lazy_2u16(c, twisted_d)
-
- h = add_2u16(a, b)
- h = cond_sub_mod_u16_ext(h)
- e1 = add_2u16(x1, y1)
- e1 = cond_sub_mod_u16_ext(e1)
- e2 = add_2u16(x2, y2)
- e2 = cond_sub_mod_u16_ext(e2)
- e = mod_mul_lazy_2u16(e1, e2)
-
- e = cond_sub_2u16_ext(e, h)
-
- f = cond_sub_2u16_ext(d, c)
- g = add_2u16(d, c)
- g = cond_sub_mod_u16_ext(g)
-
- x3 = mod_mul_lazy_2u16(e, f)
- y3 = mod_mul_lazy_2u16(g, h)
- z3 = mod_mul_lazy_2u16(f, g)
- t3 = mod_mul_lazy_2u16(e, h)
-
- return jnp.array([x3, y3, z3, t3])
-
-
-def pdul_lazy_twisted(
- x1: jax.Array,
- y1: jax.Array,
- z1: jax.Array,
- t1: jax.Array,
-):
- """PDUL-LAZY elliptic curve operation with packed arguments.
-
- As for the algorithm, pls refer to
- jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassTwisted::double_general
-
- This function implements the PDUL-LAZY elliptic curve operation with packed
- arguments, which is used to compute the elliptic curve points of a given
- group.
-
- Args:
- x1: The first generator element.
- y1: The second generator element.
- z1: The third generator element.
- t1: The fourth generator element.
-
- Returns:
- A tuple containing the third generator element and the elliptic curve points
- of the group.
- """
- cond_sub_2u16_ext = functools.partial(
- cond_sub_2u16,
- modulus_377_int_chunk=util.MODULUS_377_S16_INT_CHUNK,
- chunk_num_u16=util.U16_EXT_CHUNK_NUM,
- )
- cond_sub_mod_u16_ext = functools.partial(
- cond_sub_mod_u16,
- modulus_377_int_chunk=util.MODULUS_377_S16_INT_CHUNK,
- chunk_num_u16=util.U16_EXT_CHUNK_NUM,
- )
- modulus_377_int_array = jnp.asarray(
- util.MODULUS_377_S16_INT_CHUNK, jnp.uint16
- )
-
- a = mod_mul_lazy_2u16(x1, x1)
- b = mod_mul_lazy_2u16(y1, y1)
-
- ct = mod_mul_lazy_2u16(z1, z1)
- ct2 = add_2u16(ct, ct) #
- ct2 = cond_sub_mod_u16_ext(ct2) #
-
- h = add_2u16(a, b)
- h = cond_sub_2u16_ext(modulus_377_int_array, h) #
-
- et = add_2u16(x1, y1) #
- et = cond_sub_mod_u16_ext(et) #
- e = mod_mul_lazy_2u16(et, et) #
- e = add_2u16(e, h) #
- e = cond_sub_mod_u16_ext(e) #
-
- g = cond_sub_2u16_ext(b, a) #
- f = cond_sub_2u16_ext(g, ct2) #
- x3 = mod_mul_lazy_2u16(e, f) #
- y3 = mod_mul_lazy_2u16(g, h)
- z3 = mod_mul_lazy_2u16(f, g)
- t3 = mod_mul_lazy_2u16(e, h)
- return jnp.array([x3, y3, z3, t3])
-
-
-def pneg_lazy_twisted(
- x1: jax.Array,
- y1: jax.Array,
- z1: jax.Array,
- t1: jax.Array,
-):
- """PDUL-LAZY elliptic curve operation with packed arguments.
-
- As for the algorithm, pls refer to
- jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassTwisted::double_general
-
- This function implements the PDUL-LAZY elliptic curve operation with packed
- arguments, which is used to compute the elliptic curve points of a given
- group.
-
- Args:
- x1: The first generator element.
- y1: The second generator element.
- z1: The third generator element.
- t1: The fourth generator element.
-
- Returns:
- A tuple containing the third generator element and the elliptic curve points
- of the group.
- """
-
- modulus_377_int_array = jnp.asarray(
- util.MODULUS_377_S16_INT_CHUNK, jnp.uint16
- )
- sub_2u16_ext = functools.partial(
- sub_2u16,
- chunk_num_u16=util.U16_EXT_CHUNK_NUM,
- )
-
- x3 = sub_2u16_ext(modulus_377_int_array, x1)
- y3 = y1
- z3 = z1
- t3 = sub_2u16_ext(modulus_377_int_array, t1)
-
- return jnp.array([x3, y3, z3, t3])
-
-
-def padd_lazy_twisted_pack(x1_y1_z1_t1: jax.Array, x2_y2_z2_t2: jax.Array):
- return padd_lazy_twisted(
- x1_y1_z1_t1[0],
- x1_y1_z1_t1[1],
- x1_y1_z1_t1[2],
- x1_y1_z1_t1[3],
- x2_y2_z2_t2[0],
- x2_y2_z2_t2[1],
- x2_y2_z2_t2[2],
- x2_y2_z2_t2[3],
- )
-
-
-def pdul_lazy_twisted_pack(x1_y1_z1_t1: jax.Array):
- return pdul_lazy_twisted(
- x1_y1_z1_t1[0],
- x1_y1_z1_t1[1],
- x1_y1_z1_t1[2],
- x1_y1_z1_t1[3],
- )
-
-
-def pneg_lazy_twisted_pack(x1_y1_z1_t1: jax.Array):
- return pneg_lazy_twisted(
- x1_y1_z1_t1[0],
- x1_y1_z1_t1[1],
- x1_y1_z1_t1[2],
- x1_y1_z1_t1[3],
- )
-
-
-# RNS Based Functions
-@jax.named_call
-@functools.partial(jax.jit, static_argnames="rns_mat")
-def padd_rns_xyzz_pack(
- x1_y1_zz1_zzz1: jax.Array,
- x2_y2_zz2_zzz2: jax.Array,
- rns_mat=util.RNS_MAT,
-):
- """PADD-RNS elliptic curve operation with packed arguments.
-
- As for the algorithm, pls refer to
- jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassXYZZ::add_general
-
- This function implements the PADD-RNS elliptic curve operation with packed
- arguments, which is used to compute the elliptic curve points of a given
- group.
-
- Args:
- x1_y1_zz1_zzz1: The first point.
- x2_y2_zz2_zzz2: The second point.
- rns_mat: The RNS matrix.
-
- Returns:
- A tuple containing the third generator element and the elliptic curve points
- of the group.
- """
-
- # u1 = x1 * zz2
- # u2 = x2 * zz1
- # s1 = y1 * zzz2
- # s2 = y2 * zzz1
- # zz1_zz2 = zz1 * zz2
- # zzz1_zzz2 = zzz1 * zzz2
- num_moduli = x1_y1_zz1_zzz1.shape[-1]
- inputsl = jnp.vstack((x1_y1_zz1_zzz1, x1_y1_zz1_zzz1[2:])).reshape(
- -1, num_moduli
- )
- inputsr = jnp.vstack(
- (x2_y2_zz2_zzz2[2:], x2_y2_zz2_zzz2[2:], x2_y2_zz2_zzz2[:2])
- ).reshape(-1, num_moduli)
- outputs = mod_mul_rns_2u16(inputsl, inputsr, rns_mat)
- u1, s1, zz1_zz2, zzz1_zzz2, u2, s2 = jnp.vsplit(outputs, 6)
-
- # p = u2 - u1
- # r = s2 - s1
- p = add_sub_rns_var(u2, negate_rns_for_var_add(u1))
- r = add_sub_rns_var(s2, negate_rns_for_var_add(s1))
-
- # pp = p * p
- # rr = r * r
- pp = mod_mul_rns_2u16(p, p, rns_mat)
- rr = mod_mul_rns_2u16(r, r, rns_mat)
-
- # ppp = p * pp
- # q = u1 * pp
- # zz3 = zz1_zz2 * pp
- inputsl = jnp.vstack((p, u1, zz1_zz2))
- inputsr = jnp.vstack((pp, pp, pp))
- outputs = mod_mul_rns_2u16(inputsl, inputsr, rns_mat)
- ppp, q, zz3 = jnp.vsplit(outputs, 3)
-
- # x3 = r * r - ppp - (q + q)
- # q_x3 = q - x3
- x3 = add_sub_rns_var(
- rr,
- negate_rns_for_var_add(ppp),
- negate_rns_for_var_add(q),
- negate_rns_for_var_add(q),
- )
- q_x3 = add_sub_rns_var(q, negate_rns_for_var_add(x3))
-
- # s1_ppp = s1 * ppp
- # r_q_x3 = r * q_x3
- # zzz3 = zzz1_zzz2 * ppp
- # y3 = r_q_x3 - s1_ppp
- inputsl = jnp.vstack((s1, zzz1_zzz2, r))
- inputsr = jnp.vstack((ppp, ppp, q_x3))
- outputs = mod_mul_rns_2u16(inputsl, inputsr, rns_mat)
- s1_ppp, zzz3, r_q_x3 = jnp.vsplit(outputs, 3)
-
- y3 = add_sub_rns_var(r_q_x3, negate_rns_for_var_add(s1_ppp))
-
- return jnp.array([x3, y3, zz3, zzz3])
-
-
-@jax.named_call
-@functools.partial(jax.jit, static_argnames="rns_mat")
-def pdul_rns_xyzz(
- x1: jax.Array,
- y1: jax.Array,
- zz1: jax.Array,
- zzz1: jax.Array,
- rns_mat=util.RNS_MAT,
-):
- """PDUL-RNS elliptic curve operation with packed arguments.
-
- As for the algorithm, pls refer to
- jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassXYZZ::double_general
-
- This function implements the PDUL-RNS elliptic curve operation with packed
- arguments, which is used to compute the elliptic curve points of a given
- group.
-
- Args:
- x1: The first generator element.
- y1: The second generator element.
- zz1: The third generator element.
- zzz1: The third generator element.
- rns_mat: The RNS matrix.
-
- Returns:
- A tuple containing the third generator element and the elliptic curve points
- of the group.
- """
- # u = y1 + y1
- u = add_rns_2u16(y1, y1)
-
- # v = u * u
- v = mod_mul_rns_2u16(u, u, rns_mat)
-
- # x1x1 = x1 * x1
- # w = u * v
- # s = x1 * v
- inputsl = jnp.vstack((x1, u, x1))
- inputsr = jnp.vstack((x1, v, v))
- output = mod_mul_rns_2u16(inputsl, inputsr, rns_mat)
- x1x1, w, s = jnp.vsplit(output, 3)
-
- # m = (x1x1 + x1x1 + x1x1) + a * (zz1 * zz1), Note: a = 0
- m = add_rns_3u16(x1x1, x1x1, x1x1)
-
- # mm = m * m
- # w_y1 = w * y1
- # zz3 = v * zz1
- # zzz3 = w * zzz1
- inputsl = jnp.vstack((m, w, v, w))
- inputsr = jnp.vstack((m, y1, zz1, zzz1))
- output = mod_mul_rns_2u16(inputsl, inputsr, rns_mat)
- mm, w_y1, zz3, zzz3 = jnp.vsplit(output, 4)
-
- # x3 = mm - (s + s)
- x3 = add_sub_rns_var(mm, negate_rns_for_var_add(s), negate_rns_for_var_add(s))
-
- # s_x3 = s - x3
- s_x3 = add_sub_rns_var(s, negate_rns_for_var_add(mm), s, s)
-
- # m_s_x3 = m * s_x3 - w_y1
- m_s_x3 = mod_mul_rns_2u16(m, s_x3, rns_mat)
-
- # y3 = m_s_x3 - w_y1
- y3 = add_sub_rns_var(m_s_x3, negate_rns_for_var_add(w_y1))
-
- return jnp.array([x3, y3, zz3, zzz3])
-
-
-@jax.named_call
-@functools.partial(jax.jit, static_argnames="rns_mat")
-def pdul_rns_xyzz_pack(x1_y1_zz1_zzz1: jax.Array, rns_mat=util.RNS_MAT):
- return pdul_rns_xyzz(
- x1_y1_zz1_zzz1[0],
- x1_y1_zz1_zzz1[1],
- x1_y1_zz1_zzz1[2],
- x1_y1_zz1_zzz1[3],
- rns_mat,
- )
-
-
-# RNS Based Functions
-@jax.named_call
-@functools.partial(jax.jit, static_argnames=("rns_mat", "twist_d"))
-def padd_rns_twisted_pack(
- x1_y1_zz1_zzz1: jax.Array,
- x2_y2_zz2_zzz2: jax.Array,
- rns_mat=util.RNS_MAT,
- twist_d=util.TWIST_D_RNS,
-):
- """PADD-RNS elliptic curve operation with packed arguments.
-
- As for the algorithm, pls refer to
- jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassXYZZ::add_general
-
- This function implements the PADD-RNS elliptic curve operation with packed
- arguments, which is used to compute the elliptic curve points of a given
- group.
-
- Args:
- x1_y1_zz1_zzz1: The first point.
- x2_y2_zz2_zzz2: The second point.
- rns_mat: The RNS matrix.
- twist_d: curve parameter.
-
- Returns:
- A tuple containing the third generator element and the elliptic curve points
- of the group.
- """
- twist_d = jnp.array(twist_d, dtype=jnp.uint16)
-
- inputsl = jnp.vstack(x1_y1_zz1_zzz1)
- inputsr = jnp.vstack(x2_y2_zz2_zzz2)
- outputs = mod_mul_rns_2u16(inputsl, inputsr, rns_mat)
- a, b, d, c = jnp.vsplit(outputs, 4)
-
- e1 = add_rns_2u16(x1_y1_zz1_zzz1[0], x1_y1_zz1_zzz1[1])
- e2 = add_rns_2u16(x2_y2_zz2_zzz2[0], x2_y2_zz2_zzz2[1])
- twist_d_here = jnp.broadcast_to(
- twist_d.reshape(-1, twist_d.shape[0]), c.shape
- )
- inputsl = jnp.vstack((e1, c))
- inputsr = jnp.vstack((e2, twist_d_here))
- outputs = mod_mul_rns_2u16(inputsl, inputsr, rns_mat)
- e3, c = jnp.vsplit(outputs, 2)
-
- # Issue happens here
- e = add_sub_rns_var(
- e3,
- negate_rns_for_var_add_zero_check(a),
- negate_rns_for_var_add_zero_check(b),
- )
- f = add_sub_rns_var(d, negate_rns_for_var_add_zero_check(c))
- g = add_rns_2u16(d, c)
- h = add_rns_2u16(a, b)
-
- inputsl = jnp.vstack((e, g, f, e))
- inputsr = jnp.vstack((f, h, g, h))
- outputs = mod_mul_rns_2u16(inputsl, inputsr, rns_mat)
- x3, y3, z3, t3 = jnp.vsplit(outputs, 4)
-
- return jnp.array([x3, y3, z3, t3])
-
-
-@jax.named_call
-@functools.partial(jax.jit, static_argnames="rns_mat")
-def pdul_rns_twisted(
- x1: jax.Array,
- y1: jax.Array,
- z1: jax.Array,
- t1: jax.Array,
- rns_mat=util.RNS_MAT,
-):
- """PDUL-RNS elliptic curve operation with packed arguments.
-
- As for the algorithm, pls refer to
- jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassXYZZ::double_general
-
- This function implements the PDUL-RNS elliptic curve operation with packed
- arguments, which is used to compute the elliptic curve points of a given
- group.
-
- Args:
- x1: The first generator element.
- y1: The second generator element.
- z1: The third generator element.
- t1: The third generator element.
- rns_mat: The RNS matrix.
-
- Returns:
- A tuple containing the third generator element and the elliptic curve points
- of the group.
- """
- et = add_rns_2u16(x1, y1)
- inputsl = jnp.vstack((x1, y1, z1, et))
- outputs = mod_mul_rns_2u16(inputsl, inputsl, rns_mat)
- a, b, ct, et2 = jnp.vsplit(outputs, 4)
-
- e = add_sub_rns_var(et2, negate_rns_for_var_add(a), negate_rns_for_var_add(b))
- g = add_sub_rns_var(b, negate_rns_for_var_add(a))
- f = add_sub_rns_var(g, negate_rns_for_var_add(ct), negate_rns_for_var_add(ct))
- h = add_sub_rns_var(negate_rns_for_var_add(a), negate_rns_for_var_add(b))
-
- inputsl = jnp.vstack((e, g, f, e))
- inputsr = jnp.vstack((f, h, g, h))
- outputs = mod_mul_rns_2u16(inputsl, inputsr, rns_mat)
- x3, y3, z3, t3 = jnp.vsplit(outputs, 4)
-
- return jnp.array([x3, y3, z3, t3])
-
-
-@jax.named_call
-@functools.partial(jax.jit, static_argnames="rns_mat")
-def pdul_rns_twisted_pack(x1_y1_zz1_zzz1: jax.Array, rns_mat=util.RNS_MAT):
- return pdul_rns_twisted(
- x1_y1_zz1_zzz1[0],
- x1_y1_zz1_zzz1[1],
- x1_y1_zz1_zzz1[2],
- x1_y1_zz1_zzz1[3],
- rns_mat,
- )
-
-
-@jax.named_call
-def rns_twist_zero():
- return jnp.array(
- [rns_constant(0), rns_constant(1), rns_constant(1), rns_constant(0)]
- )
diff --git a/jaxite_ec/elliptic_curve_context.py b/jaxite_ec/elliptic_curve_context.py
new file mode 100644
index 0000000..bda2862
--- /dev/null
+++ b/jaxite_ec/elliptic_curve_context.py
@@ -0,0 +1,818 @@
+import abc
+from concurrent.futures import ProcessPoolExecutor
+import functools
+import multiprocessing
+import os
+import warnings
+
+import jax
+import jax.numpy as jnp
+from jaxite.jaxite_ec import finite_field_context
+from jaxite.jaxite_ec import utils
+import numpy as np
+
+FiniteFieldContextBase = finite_field_context.FiniteFieldContextBase
+abstractmethod = abc.abstractmethod
+ABC = abc.ABC
+
+
+# Use 'forkserver' to avoid JAX multithreading + fork deadlock
+_MP_CONTEXT = multiprocessing.get_context("forkserver")
+
+JaxParameters = utils.JaxParameters
+JaxKernelContextBase = utils.JaxKernelContextBase
+hash_args = utils.hash_args
+pad_jax_array = utils.pad_jax_array
+store_jax_executable = utils.store_jax_kernel
+load_jax_executable = utils.load_jax_kernel
+jax_jit_lower_compile = utils.jax_jit_lower_compile
+jax.config.update("jax_enable_x64", True)
+
+
+class EllipticCurveContextBase(ABC):
+ """Abstract base class defining the interface for finite field operations.
+
+ Subclasses must implement all abstract methods to provide concrete
+ finite field arithmetic operations.
+ """
+
+ @abstractmethod
+ def __init__(self, parameters: dict):
+ """Initialize the finite field context.
+
+ Args:
+ parameters: Configuration dictionary containing field parameters.
+ """
+ self.parameters = parameters
+ self.prime = parameters.get("prime", None)
+ assert self.prime is not None, "prime must be provided"
+ self.zero_point = None
+ ff_ctx_class = parameters.get("finite_field_context_class", None)
+ assert (
+ ff_ctx_class is not None
+ ), "finite_field_context_class must be provided"
+ self.ff_ctx: FiniteFieldContextBase = ff_ctx_class(
+ parameters.get("finite_field_parameters", {})
+ )
+
+ @abstractmethod
+ def to_computational_format(self, a) -> jnp.ndarray:
+ """Convert input to the internal computational representation.
+
+ Args:
+ a: Input value in standard format.
+
+ Returns:
+ Value converted to computational format (e.g., Montgomery form).
+ """
+ pass
+
+ @abstractmethod
+ def to_original_format(self, a):
+ """Convert from computational format back to standard representation.
+
+ Args:
+ a: Value in computational format.
+
+ Returns:
+ Value in standard integer representation.
+ """
+ pass
+
+ @abstractmethod
+ def point_add(self, a: jnp.ndarray, b: jnp.ndarray):
+ """Perform point addition: (a + b)
+
+ Args:
+ a: First operand in computational format.
+ b: Second operand in computational format.
+
+ Returns:
+ Product in computational format.
+ """
+ pass
+
+ @abstractmethod
+ def point_double(self, a: jnp.ndarray):
+ """Perform point doubling: (2 * a)
+
+ Args:
+ a: First operand in computational format.
+
+ Returns:
+ Product in computational format.
+ """
+ pass
+
+ def get_finite_field_context(self) -> FiniteFieldContextBase:
+ return self.ff_ctx
+
+ def _modular_multiply(self, a: int, b: int) -> int:
+ return (a * b) % self.prime
+
+ def _modular_reduce(self, a: int) -> int:
+ return a % self.prime
+
+ def _modular_divide(self, a: int, b: int) -> int:
+ assert b != 0, "ec divide: b is zero"
+ b_inv = pow(b, self.prime - 2, self.prime)
+ return (a * b_inv) % self.prime
+
+
+class CPUWeierstrassAffineContext(EllipticCurveContextBase):
+ """CPU implementation of Weierstrass affine curve operations.
+
+ This class provides CPU-based implementations for point addition and doubling
+ on a Weierstrass curve in affine coordinates.
+ This class is only for private functional testing, not for production use.
+ """
+
+ def __init__(self, parameters: dict):
+ super().__init__(parameters)
+ # warnings.warn("CPUWeierstrassAffineContext is only for private functional testing, not for production use",UserWarning, stacklevel=2)
+
+ # Curve configuration
+ self.a = parameters["a"]
+ self.b = parameters["b"]
+
+ def point_add(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
+ raise NotImplementedError(
+ "CPUWeierstrassAffineContext: point_add is not implemented"
+ )
+
+ def point_double(self, point: jnp.ndarray) -> jnp.ndarray:
+ raise NotImplementedError(
+ "CPUWeierstrassAffineContext: point_double is not implemented"
+ )
+
+ def to_computational_format(self, a: list) -> jnp.ndarray:
+ raise NotImplementedError(
+ "CPUWeierstrassAffineContext: to_computational_format is not"
+ " implemented"
+ )
+
+ def to_original_format(self, a: jnp.ndarray) -> list:
+ raise NotImplementedError(
+ "CPUWeierstrassAffineContext: to_original_format is not implemented"
+ )
+
+ def _point_add(self, point_a: list, point_b: list) -> list:
+ def single_point_add(point_a: list, point_b: list) -> list[int]:
+ x1, y1 = point_a
+ x2, y2 = point_b
+ slope = self._modular_divide(
+ self._modular_reduce(y2 - y1), self._modular_reduce(x2 - x1)
+ )
+ x3 = self._modular_reduce(self._modular_multiply(slope, slope) - x1 - x2)
+ y3 = self._modular_reduce(
+ self._modular_multiply(slope, self._modular_reduce(x1 - x3)) - y1
+ )
+ return [x3, y3]
+
+ list_depth = utils.nested_list_depth(point_a)
+ if list_depth == 1:
+ return single_point_add(point_a, point_b)
+ elif list_depth == 2:
+ return [
+ single_point_add(point_a_i, point_b_i)
+ for point_a_i, point_b_i in zip(point_a, point_b)
+ ]
+ else:
+ raise ValueError(
+ f"Invalid list depth {list_depth} of input for point addition"
+ )
+
+ def _point_double(self, point: list) -> list:
+ def single_point_double(point: list) -> list[int]:
+ x, y = point
+ slope = self._modular_divide(
+ self._modular_reduce(3 * x * x + self.a), self._modular_reduce(2 * y)
+ )
+ x3 = self._modular_reduce(self._modular_multiply(slope, slope) - 2 * x)
+ y3 = self._modular_reduce(self._modular_multiply(slope, x - x3) - y)
+ return [x3, y3]
+
+ list_depth = utils.nested_list_depth(point)
+ if list_depth == 1:
+ return single_point_double(point)
+ elif list_depth == 2:
+ return [single_point_double(point_i) for point_i in point]
+ else:
+ raise ValueError("Invalid list depth of input for point doubling")
+
+
+class ExtendedTwistedEdwardsContextBase(EllipticCurveContextBase):
+
+ def __init__(self, parameters: dict):
+ super().__init__(parameters)
+
+ # Curve configuration
+ self.a = parameters["a"]
+ self.twist_d = parameters["twist_d"]
+ self.alpha = parameters["alpha"]
+ self.s = parameters["s"]
+ self.A = parameters["MA"]
+ self.B = parameters["MB"]
+ self.t = parameters["t"]
+ self.k = self.twist_d + self.twist_d
+ self.zero_point = [0, 1, 1, 0]
+
+ def _twist(self, coordinates: list[int]) -> list[int]:
+ assert (
+ len(coordinates) == 2
+ ), "Twisted Edwards coordinates must be of length 2"
+ x, y = coordinates
+ # Convert to montgomery (Notel it is ec montgomery not field montgomery)
+ xm = self._modular_reduce(self.s * (x - self.alpha))
+ ym = self._modular_reduce(self.s * y)
+ # Convert to edwards
+ if ym == 0:
+ raise ValueError("ec twist: ym is zero")
+ xt = self._modular_divide(xm, ym)
+
+ yt_denom = xm + 1
+ if yt_denom == 0:
+ raise ValueError("ec twist: yt_denom is zero")
+ yt = self._modular_divide(xm - 1, yt_denom)
+
+ xt = self._modular_multiply(xt, self.t)
+ return [xt, yt]
+
+ def _untwist(self, coordinates: list[int]) -> list[int]:
+ assert (
+ len(coordinates) == 2
+ ), "Twisted Edwards coordinates must be of length 2"
+ xt, yt = coordinates
+ xt = self._modular_divide(xt, self.t)
+ # Convert to montgomery
+ xm = self._modular_divide((1 + yt), (1 - yt))
+ ym = self._modular_divide((1 + yt), self._modular_multiply((1 - yt), xt))
+ # Convert to weierstrass
+ x = self._modular_reduce(
+ self._modular_divide(xm, self.B)
+ + self._modular_divide(self.A, self._modular_multiply(3, self.B))
+ )
+ y = self._modular_divide(ym, self.B)
+ return [x, y]
+
+ def _convert_to_edwards_affine(self, coordinates: list[int]) -> list[int]:
+ assert (
+ len(coordinates) == 4
+ ), "Twisted Edwards coordinates must be of length 2"
+ x, y, z, t = coordinates
+ z_inv = self._modular_divide(1, z)
+ x = self._modular_multiply(x, z_inv)
+ y = self._modular_multiply(y, z_inv)
+ return [x, y]
+
+ def _convert_to_extended_twisted_edwards(
+ self, coordinates: list[int]
+ ) -> list[int]:
+ assert (
+ len(coordinates) == 2
+ ), "Twisted Edwards coordinates must be of length 2"
+ xt, yt = self._twist(coordinates)
+ return [xt, yt, 1, self._modular_multiply(xt, yt)]
+
+ def _convert_to_weierstrass_affine(self, coordinates: list[int]) -> list[int]:
+ assert (
+ len(coordinates) == 4
+ ), "Extended Twisted Edwards coordinates must be of length 4"
+ affine_coords = self._convert_to_edwards_affine(coordinates)
+ untwisted_coords = self._untwist(affine_coords)
+ return untwisted_coords
+
+
+class ExtendedTwistedEdwardsContext(
+ ExtendedTwistedEdwardsContextBase, JaxKernelContextBase
+):
+
+ def __init__(self, parameters: dict):
+ super().__init__(parameters)
+ JaxKernelContextBase.__init__(self)
+ self.jax_parameters = JaxParameters()
+ self._init_jax_parameters()
+
+ def to_computational_format(self, a: list) -> jnp.ndarray:
+ list_depth = utils.nested_list_depth(a)
+ # NOTE: the dimension is (batch, coordinates)
+ if list_depth == 1:
+ twisted_coords = self._convert_to_extended_twisted_edwards(a)
+ elif list_depth == 2:
+ twisted_coords = [
+ self._convert_to_extended_twisted_edwards(a_i) for a_i in a
+ ]
+ else:
+ raise ValueError(
+ "Invalid list depth of input for converting to extended twisted"
+ " edwards coordinates"
+ )
+ result = self.ff_ctx.to_computational_format(twisted_coords)
+ if list_depth == 1:
+ result = jnp.broadcast_to(result, (result.shape[0], 1, result.shape[1]))
+ elif list_depth == 2:
+ result = result.transpose(1, 0, 2)
+ # NOTE: the computational format dim is (coordinates, batch, precision)
+ if self.use_sharding:
+ named_sharding, padded_shape = self.create_named_sharding(
+ shape=result.shape, axes=[1]
+ )
+ result = pad_jax_array(result, padded_shape)
+ return result.to_device(named_sharding)
+ else:
+ return result.to_device(jax.devices()[0])
+
+ def to_original_format(self, a: jnp.ndarray) -> list:
+ dim = a.ndim
+ # NOTE: the computational format dim is (coordinates, batch, precision)
+ if dim == 3:
+ a = a.transpose(
+ 1, 0, 2
+ ) # (coordinates, batch, precision) -> (batch, coordinates, precision)
+ a = self.ff_ctx.to_original_format(a)
+ if dim == 2:
+ affine_coords = self._convert_to_weierstrass_affine(a)
+ elif dim == 3:
+ affine_coords = [self._convert_to_weierstrass_affine(a_i) for a_i in a]
+ else:
+ raise ValueError(
+ "Invalid dimension of input for converting to weierstrass affine"
+ " coordinates"
+ )
+ return affine_coords
+
+ def _init_jax_parameters(self):
+ self.jax_parameters.set_parameter(
+ twist_d=self.ff_ctx.to_computational_format(self.twist_d),
+ )
+
+ def _point_add(
+ self, point_a: jnp.ndarray, point_b: jnp.ndarray
+ ) -> jnp.ndarray:
+ twist_d = self.jax_parameters.twist_d
+
+ inputsl = point_a
+ inputsr = point_b
+ outputs = self.ff_ctx._modular_multiply(inputsl, inputsr)
+ a, b, d, c = jnp.vsplit(outputs, 4)
+ # print(a.shape, b.shape, d.shape, c.shape)
+
+ pax, pay, _, _ = jnp.vsplit(point_a, 4)
+ pbx, pby, _, _ = jnp.vsplit(point_b, 4)
+
+ e1 = self.ff_ctx._modular_add(pax, pay)
+ e2 = self.ff_ctx._modular_add(pbx, pby)
+ twist_d_here = jnp.broadcast_to(
+ twist_d.reshape(-1, twist_d.shape[0]), c.shape
+ )
+ if self.use_sharding:
+ twist_d_here = jax.sharding.reshard(twist_d_here, jax.typeof(e2).sharding)
+ inputsl = jnp.concatenate((e1, c), axis=0)
+ inputsr = jnp.concatenate((e2, twist_d_here), axis=0)
+ outputs = self.ff_ctx._modular_multiply(inputsl, inputsr)
+ e3, c = jnp.vsplit(outputs, 2)
+
+ e = self.ff_ctx._modular_subtract(self.ff_ctx._modular_subtract(e3, a), b)
+ f = self.ff_ctx._modular_subtract(d, c)
+ g = self.ff_ctx._modular_add(d, c)
+ h = self.ff_ctx._modular_add(a, b)
+
+ inputsl = jnp.concatenate((e, g, f, e), axis=0)
+ inputsr = jnp.concatenate((f, h, g, h), axis=0)
+ # print(inputsl.shape, inputsr.shape)
+ outputs = self.ff_ctx._modular_multiply(inputsl, inputsr)
+ return outputs.reshape(4, -1, outputs.shape[-1])
+
+ def _point_double(self, point: jnp.ndarray) -> jnp.ndarray:
+ x, y, z, t = point
+
+ et = self.ff_ctx._modular_multiply(x, y)
+ inputsl = jnp.vstack((x, y, z, et))
+ outputs = self.ff_ctx._modular_multiply(inputsl, inputsl)
+ a, b, ct, et2 = jnp.vsplit(outputs, 4)
+
+ h = self.ff_ctx._modular_negate(self.ff_ctx._modular_add(a, b))
+ e = self.ff_ctx._modular_add(et2, h)
+ g = self.ff_ctx._modular_subtract(b, a)
+ f = self.ff_ctx._modular_subtract(g, self.ff_ctx._modular_add(ct, ct))
+
+ inputsl = jnp.vstack((e, g, f, e))
+ inputsr = jnp.vstack((f, h, g, h))
+ outputs = self.ff_ctx._modular_multiply(inputsl, inputsr)
+ return outputs.reshape(4, -1, outputs.shape[-1])
+
+ def point_add(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
+
+ if self.use_compiled_kernels:
+ kernel_hash = hash_args(a.shape, a.dtype.__str__())
+ return self.compiled_kernels[kernel_hash]["point_add"](a, b)
+ else:
+ return self._point_add(a, b)
+
+ # TODO: fix the functional bug in point_double
+ def point_double(self, point: jnp.ndarray) -> jnp.ndarray:
+ raise ValueError("point_double has logic bug")
+ shape_dtype_struct = jax.ShapeDtypeStruct(point.shape, point.dtype)
+ if self.use_compiled_kernels:
+ return self.compiled_kernels[shape_dtype_struct.__hash__()][
+ "point_double"
+ ](point)
+ else:
+ return self._point_double(point)
+
+ def _get_shape_dtype_structs(
+ self, parameters: dict
+ ) -> list[jax.ShapeDtypeStruct]:
+ batch_size = parameters["batch_size"]
+ num_moduli = self.jax_parameters.twist_d.shape[0]
+ point_shape = (4, batch_size, num_moduli)
+ if self.use_sharding:
+ named_sharding, padded_shape = self.create_named_sharding(
+ shape=point_shape, axes=[1]
+ )
+ return [
+ jax.ShapeDtypeStruct(
+ padded_shape, jnp.uint32, sharding=named_sharding
+ )
+ ]
+ return [jax.ShapeDtypeStruct(point_shape, jnp.uint32)]
+
+ def context_hash(self) -> str:
+ return hash_args(
+ self.__class__.__name__,
+ self.ff_ctx.context_hash(),
+ self.a,
+ self.twist_d,
+ self.alpha,
+ self.s,
+ self.A,
+ self.B,
+ self.t,
+ self.use_sharding,
+ )
+
+ def serialize(self, parameters: dict):
+ shape_dtype_structs = self._get_shape_dtype_structs(parameters)
+ kernel_hash = hash_args(self.context_hash(), parameters)
+ class_name = self.__class__.__name__
+
+ store_jax_executable(
+ self._point_add,
+ shape_dtype_structs[0],
+ shape_dtype_structs[0],
+ name=f"{class_name}_point_add_{kernel_hash}",
+ )
+ store_jax_executable(
+ self._point_double,
+ shape_dtype_structs[0],
+ name=f"{class_name}_point_double_{kernel_hash}",
+ )
+
+ def compile(self, parameters: dict):
+ shape_dtype_structs = self._get_shape_dtype_structs(parameters)
+ kernel_hash = hash_args(self.context_hash(), parameters)
+ class_name = self.__class__.__name__
+
+ point_add_kernel = load_jax_executable(
+ f"{class_name}_point_add_{kernel_hash}"
+ )
+ point_double_kernel = load_jax_executable(
+ f"{class_name}_point_double_{kernel_hash}"
+ )
+
+ if None in [point_add_kernel, point_double_kernel]:
+ warnings.warn(
+ f"Not found stored serialized compiled kernels, compiling...",
+ UserWarning,
+ stacklevel=2,
+ )
+
+ kernel_hash = hash_args(
+ shape_dtype_structs[0].shape, shape_dtype_structs[0].dtype.__str__()
+ )
+ self.compiled_kernels[kernel_hash] = {
+ "point_add": (
+ point_add_kernel
+ if point_add_kernel is not None
+ else jax_jit_lower_compile(
+ self._point_add, shape_dtype_structs[0], shape_dtype_structs[0]
+ )
+ ),
+ "point_double": (
+ point_double_kernel
+ if point_double_kernel is not None
+ else jax_jit_lower_compile(
+ self._point_double, shape_dtype_structs[0]
+ )
+ ),
+ }
+ self.use_compiled_kernels = True
+
+
+def _twist_extend_and_rns_worker(
+ point, s, alpha, prime, prime_m2, t, rns_moduli, radix_bits
+):
+ """Module-level worker: twist + extend + RNS conversion using gmpy2 (must be picklable)."""
+ import gmpy2
+
+ x, y = gmpy2.mpz(point[0]), gmpy2.mpz(point[1])
+ xm = (s * (x - alpha)) % prime
+ ym = (s * y) % prime
+ if ym == 0:
+ raise ValueError("ec twist: ym is zero")
+ xt = (xm * gmpy2.powmod(ym, prime_m2, prime)) % prime
+ yt_denom = xm + 1
+ if yt_denom == 0:
+ raise ValueError("ec twist: yt_denom is zero")
+ yt = ((xm - 1) * gmpy2.powmod(yt_denom, prime_m2, prime)) % prime
+ xt = (xt * t) % prime
+ coords = [int(xt), int(yt), 1, int((xt * yt) % prime)]
+ # RNS conversion inline: (a % m) << radix_bits) % m for each coordinate
+ return [[(((c % m) << radix_bits) % m) for m in rns_moduli] for c in coords]
+
+
+class ExtendedTwistedEdwardsNDContext(
+ ExtendedTwistedEdwardsContextBase, JaxKernelContextBase
+):
+ """Extended Twisted Edwards context supporting arbitrary batch dimensions.
+
+ Computational format layout: (coordinates=4, *batch_dims, precision)
+ where batch_dims can be any number of dimensions, e.g.:
+ - (4, batch, precision) -- 1D batch
+ - (4, batch1, batch2, precision) -- 2D batch
+ - (4, batch1, batch2, batch3, precision) -- 3D batch
+
+ Input points are nested lists of [x, y] affine Weierstrass coordinates.
+ Nesting depth determines batch dimensions:
+ - [x, y] → (4, 1, precision)
+ - [[x,y], ...] → (4, batch1, precision)
+ - [[[x,y], ...], ...] → (4, batch1, batch2, precision)
+ """
+
+ def __init__(self, parameters: dict):
+ super().__init__(parameters)
+ JaxKernelContextBase.__init__(self)
+ self.jax_parameters = JaxParameters()
+ self._init_jax_parameters()
+
+ def to_computational_format(self, a: list) -> jnp.ndarray:
+ list_depth = utils.nested_list_depth(a)
+ if list_depth < 1:
+ raise ValueError(f"Invalid list depth {list_depth} for point conversion")
+
+ # Use parallel processing with gmpy2 for large flat batches of points (depth==2)
+ # Fuses twist + extend + RNS conversion into one parallel step
+ _PARALLEL_THRESHOLD = 2048
+ if list_depth == 2 and len(a) >= _PARALLEL_THRESHOLD:
+ import gmpy2
+
+ ff_ctx = self.ff_ctx
+ worker = functools.partial(
+ _twist_extend_and_rns_worker,
+ s=gmpy2.mpz(self.s),
+ alpha=gmpy2.mpz(self.alpha),
+ prime=gmpy2.mpz(self.prime),
+ prime_m2=gmpy2.mpz(self.prime - 2),
+ t=gmpy2.mpz(self.t),
+ rns_moduli=ff_ctx.rns_moduli,
+ radix_bits=ff_ctx.radix_bits,
+ )
+ num_workers = min(64, os.cpu_count() or 1, max(1, len(a) // 256))
+ with ProcessPoolExecutor(
+ max_workers=num_workers, mp_context=_MP_CONTEXT
+ ) as pool:
+ rns_coords = list(
+ pool.map(worker, a, chunksize=max(1, len(a) // num_workers))
+ )
+ # rns_coords: list of (4, moduli_num) per point → (N, 4, moduli_num) array
+ result = jnp.array(
+ np.array(rns_coords, dtype=np.uint32), dtype=jnp.uint32
+ )
+ # (N, 4, moduli_num) → (4, N, moduli_num)
+ result = result.transpose(1, 0, 2)
+ else:
+
+ def recursive_twist(lst, depth):
+ if depth == 1:
+ return self._convert_to_extended_twisted_edwards(lst)
+ return [recursive_twist(item, depth - 1) for item in lst]
+
+ twisted_coords = recursive_twist(a, list_depth)
+
+ result = self.ff_ctx.to_computational_format(twisted_coords)
+
+ if list_depth == 1:
+ # (4, precision) → (4, 1, precision)
+ result = jnp.expand_dims(result, axis=1)
+ else:
+ # (*batch_dims, 4, precision) → (4, *batch_dims, precision)
+ ndim = result.ndim
+ perm = (ndim - 2,) + tuple(range(ndim - 2)) + (ndim - 1,)
+ result = result.transpose(perm)
+
+ if self.use_sharding:
+ shard_axes = list(range(1, min(3, result.ndim - 1)))
+ if not shard_axes:
+ shard_axes = [1]
+ named_sharding, padded_shape = self.create_named_sharding(
+ shape=result.shape, axes=shard_axes
+ )
+ result = pad_jax_array(result, padded_shape)
+ return result.to_device(named_sharding)
+ else:
+ return result.to_device(jax.devices()[0])
+
+ def to_original_format(self, a: jnp.ndarray) -> list:
+ ndim = a.ndim
+ if ndim < 2:
+ raise ValueError(f"Expected at least 2D array, got {ndim}D")
+
+ if ndim == 2:
+ a_orig = self.ff_ctx.to_original_format(a)
+ return self._convert_to_weierstrass_affine(a_orig)
+
+ # (4, *batch_dims, precision) → (*batch_dims, 4, precision)
+ perm = tuple(range(1, ndim - 1)) + (0, ndim - 1)
+ a = a.transpose(perm)
+ a_orig = self.ff_ctx.to_original_format(a)
+
+ batch_depth = ndim - 2
+
+ def recursive_untwist(lst, depth):
+ if depth == 0:
+ return self._convert_to_weierstrass_affine(lst)
+ return [recursive_untwist(item, depth - 1) for item in lst]
+
+ return recursive_untwist(a_orig, batch_depth)
+
+ def _init_jax_parameters(self):
+ self.jax_parameters.set_parameter(
+ twist_d=self.ff_ctx.to_computational_format(self.twist_d),
+ )
+
+ def _point_add(
+ self, point_a: jnp.ndarray, point_b: jnp.ndarray
+ ) -> jnp.ndarray:
+ twist_d = self.jax_parameters.twist_d
+
+ inputsl = point_a
+ inputsr = point_b
+ outputs = self.ff_ctx._modular_multiply(inputsl, inputsr)
+ a, b, d, c = jnp.vsplit(outputs, 4)
+
+ pax, pay, _, _ = jnp.vsplit(point_a, 4)
+ pbx, pby, _, _ = jnp.vsplit(point_b, 4)
+
+ e1 = self.ff_ctx._modular_add(pax, pay)
+ e2 = self.ff_ctx._modular_add(pbx, pby)
+ twist_d_here = jnp.broadcast_to(twist_d, c.shape)
+ try:
+ _e2_sh = jax.typeof(e2).sharding
+ if _e2_sh is not None:
+ twist_d_here = jax.sharding.reshard(twist_d_here, _e2_sh)
+ except Exception:
+ pass
+ inputsl = jnp.concatenate((e1, c), axis=0)
+ inputsr = jnp.concatenate((e2, twist_d_here), axis=0)
+ outputs = self.ff_ctx._modular_multiply(inputsl, inputsr)
+ e3, c = jnp.vsplit(outputs, 2)
+
+ e = self.ff_ctx._modular_subtract(self.ff_ctx._modular_subtract(e3, a), b)
+ f = self.ff_ctx._modular_subtract(d, c)
+ g = self.ff_ctx._modular_add(d, c)
+ h = self.ff_ctx._modular_add(a, b)
+
+ inputsl = jnp.concatenate((e, g, f, e), axis=0)
+ inputsr = jnp.concatenate((f, h, g, h), axis=0)
+ outputs = self.ff_ctx._modular_multiply(inputsl, inputsr)
+ return outputs
+
+ def _point_double(self, point: jnp.ndarray) -> jnp.ndarray:
+ original_shape = point.shape
+ x, y, z, t = point
+
+ et = self.ff_ctx._modular_multiply(x, y)
+ inputsl = jnp.vstack((x, y, z, et))
+ outputs = self.ff_ctx._modular_multiply(inputsl, inputsl)
+ a, b, ct, et2 = jnp.vsplit(outputs, 4)
+
+ h = self.ff_ctx._modular_negate(self.ff_ctx._modular_add(a, b))
+ e = self.ff_ctx._modular_add(et2, h)
+ g = self.ff_ctx._modular_subtract(b, a)
+ f = self.ff_ctx._modular_subtract(g, self.ff_ctx._modular_add(ct, ct))
+
+ inputsl = jnp.vstack((e, g, f, e))
+ inputsr = jnp.vstack((f, h, g, h))
+ outputs = self.ff_ctx._modular_multiply(inputsl, inputsr)
+ return outputs.reshape(original_shape)
+
+ def point_add(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
+ if self.use_compiled_kernels:
+ kernel_hash = hash_args(a.shape, a.dtype.__str__())
+ return self.compiled_kernels[kernel_hash]["point_add"](a, b)
+ else:
+ return self._point_add(a, b)
+
+ # TODO: fix the functional bug in point_double
+ def point_double(self, point: jnp.ndarray) -> jnp.ndarray:
+ raise ValueError("point_double has logic bug")
+ shape_dtype_struct = jax.ShapeDtypeStruct(point.shape, point.dtype)
+ if self.use_compiled_kernels:
+ return self.compiled_kernels[shape_dtype_struct.__hash__()][
+ "point_double"
+ ](point)
+ else:
+ return self._point_double(point)
+
+ def _get_shape_dtype_structs(
+ self, parameters: dict
+ ) -> list[jax.ShapeDtypeStruct]:
+ batch_shape = parameters.get("batch_shape", None)
+ if batch_shape is None:
+ batch_shape = (parameters["batch_size"],)
+ num_moduli = self.jax_parameters.twist_d.shape[0]
+ point_shape = (4,) + tuple(batch_shape) + (num_moduli,)
+ if self.use_sharding:
+ shard_axes = list(range(1, min(3, len(point_shape) - 1)))
+ if not shard_axes:
+ shard_axes = [1]
+ named_sharding, padded_shape = self.create_named_sharding(
+ shape=point_shape, axes=shard_axes
+ )
+ return [
+ jax.ShapeDtypeStruct(
+ padded_shape, jnp.uint32, sharding=named_sharding
+ )
+ ]
+ return [jax.ShapeDtypeStruct(point_shape, jnp.uint32)]
+
+ def context_hash(self) -> str:
+ return hash_args(
+ self.__class__.__name__,
+ self.ff_ctx.context_hash(),
+ self.a,
+ self.twist_d,
+ self.alpha,
+ self.s,
+ self.A,
+ self.B,
+ self.t,
+ self.use_sharding,
+ )
+
+ def serialize(self, parameters: dict):
+ shape_dtype_structs = self._get_shape_dtype_structs(parameters)
+ kernel_hash = hash_args(self.context_hash(), parameters)
+ class_name = self.__class__.__name__
+
+ store_jax_executable(
+ self._point_add,
+ shape_dtype_structs[0],
+ shape_dtype_structs[0],
+ name=f"{class_name}_point_add_{kernel_hash}",
+ )
+ store_jax_executable(
+ self._point_double,
+ shape_dtype_structs[0],
+ name=f"{class_name}_point_double_{kernel_hash}",
+ )
+
+ def compile(self, parameters: dict):
+ shape_dtype_structs = self._get_shape_dtype_structs(parameters)
+ kernel_hash = hash_args(self.context_hash(), parameters)
+ class_name = self.__class__.__name__
+
+ point_add_kernel = load_jax_executable(
+ f"{class_name}_point_add_{kernel_hash}"
+ )
+ point_double_kernel = load_jax_executable(
+ f"{class_name}_point_double_{kernel_hash}"
+ )
+
+ if None in [point_add_kernel, point_double_kernel]:
+ warnings.warn(
+ f"Not found stored serialized compiled kernels, compiling...",
+ UserWarning,
+ stacklevel=2,
+ )
+
+ kernel_hash = hash_args(
+ shape_dtype_structs[0].shape, shape_dtype_structs[0].dtype.__str__()
+ )
+ self.compiled_kernels[kernel_hash] = {
+ "point_add": (
+ point_add_kernel
+ if point_add_kernel is not None
+ else jax_jit_lower_compile(
+ self._point_add, shape_dtype_structs[0], shape_dtype_structs[0]
+ )
+ ),
+ "point_double": (
+ point_double_kernel
+ if point_double_kernel is not None
+ else jax_jit_lower_compile(
+ self._point_double, shape_dtype_structs[0]
+ )
+ ),
+ }
+ self.use_compiled_kernels = True
diff --git a/jaxite_ec/elliptic_curve_perf_test.py b/jaxite_ec/elliptic_curve_perf_test.py
new file mode 100644
index 0000000..5007608
--- /dev/null
+++ b/jaxite_ec/elliptic_curve_perf_test.py
@@ -0,0 +1,73 @@
+import os
+
+import jax
+from jaxite.jaxite_ec import elliptic_curve_context as ec_context
+from jaxite.jaxite_ec import finite_field_context as ff_context
+from jaxite.jaxite_ec import utils
+import toml
+
+from absl.testing import absltest
+from absl.testing import parameterized
+
+jax.config.update("jax_enable_x64", True)
+
+CONFIG_PATH = os.path.join(os.path.dirname(__file__), "configurations.toml")
+
+BATCH_SIZE_LIST = [128, 256, 512, 1024, 2048, 4096]
+
+NUM_MODULI = 32
+
+TEST_PARAMS_POINT_ADD = [
+ ("point_add", BATCH_SIZE_LIST),
+]
+
+TEST_PARAMS_POINT_DOUBLE = [
+ ("point_double", BATCH_SIZE_LIST),
+]
+
+
+def _build_ec_context():
+ ec_config = toml.load(CONFIG_PATH)
+ rns_moduli = utils.find_moduli_specified_number(NUM_MODULI, 28)
+ finite_field_parameters = {
+ "prime": ec_config["ec_parameters_bls12_377_affine"]["prime"],
+ "rns_moduli": rns_moduli,
+ "precision_bits": 28,
+ "radix_bits": 32,
+ }
+ ete_cfg = ec_config["ec_parameters_bls12_377_extended_twisted_edwards"]
+ ec_parameters = {
+ "finite_field_context_class": ff_context.DRNSlazyContext,
+ "finite_field_parameters": finite_field_parameters,
+ "prime": ete_cfg["prime"],
+ "order": ete_cfg["order"],
+ "a": ete_cfg["a"],
+ "twist_d": ete_cfg["d"],
+ "alpha": ete_cfg["alpha"],
+ "b": ete_cfg["b"],
+ "s": ete_cfg["s"],
+ "MA": ete_cfg["MA"],
+ "MB": ete_cfg["MB"],
+ "t": ete_cfg["t"],
+ "generator": ete_cfg["generator"],
+ }
+ return ec_context.ExtendedTwistedEdwardsContext(ec_parameters)
+
+
+def _point_add_kernel(point_a, point_b, parameters):
+ return parameters["ctx"]._point_add(point_a, point_b)
+
+
+def _point_double_kernel(point, parameters):
+ return parameters["ctx"]._point_double(point)
+
+
+class ECPointAddPerformanceTest(parameterized.TestCase):
+
+ @parameterized.named_parameters(*TEST_PARAMS_POINT_ADD)
+ def test_point_add_performance(self, batch_size_list):
+ pass
+
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/jaxite_ec/elliptic_curve_test.py b/jaxite_ec/elliptic_curve_test.py
index 4b09116..10a3410 100644
--- a/jaxite_ec/elliptic_curve_test.py
+++ b/jaxite_ec/elliptic_curve_test.py
@@ -1,939 +1,98 @@
-import functools
+import os
-import jax
-import jax.numpy as jnp
-from jaxite.jaxite_ec import util
-from jaxite.jaxite_ec.algorithm import config_file
-import jaxite.jaxite_ec.algorithm.elliptic_curve as ec
-import jaxite.jaxite_ec.elliptic_curve as jec
+from jaxite.jaxite_ec import elliptic_curve_context as ec_context
+from jaxite.jaxite_ec import finite_field_context as ff_context
+from jaxite.jaxite_ec import utils
+import numpy as np
+import toml
from absl.testing import absltest
+from absl.testing import parameterized
+
+# NOTE: Ensure all tests point are on the curve
+BLS12_377_TEST_CASES = [
+ (
+ "0",
+ [
+ [
+ 0x01AC3A384FC584EFD3E7F2C5A2927E7D454875C874A051027B9E7363D08942533EDE85DAE295D8CAB2751085206BCA76,
+ 0x011DB83AEC88460820F4868A73B12309EE2E910526E62DB4ACCB303ABF50F86C3985A072ED07A4B81FFB82D8DD247283,
+ ],
+ [
+ 0x0164DDDBF27670CE389E2992C0E7DAB7741F1B925EDBDC254D2BC0830BAF8E0B186F80F0DD4DE0F0EA6176E55934D45B,
+ 0x01908E9D77A0F8AD89AC41441F74248704E756BC59C38920617F51BFCDB738EE5B123876D489D09C9EB904A321A336EC,
+ ],
+ ],
+ [
+ [
+ 0x01546AF2ABB4E189E9BBC412FDBF2A8E5EC6E4A3B0AF132E21EE9CEC3EF5E226490FB98D662670FA3CFB3948B7E2A48C,
+ 0x002961A558A885DF227FDB09F8BDF57AF179CB9437FF8828F13E9DF01AE55502F409AAF5058B88F2F7CCC7BC0676A5D4,
+ ],
+ [
+ 0x00B0630E7F192D20443A93860275447074CE77DF559907FA1900F378D4674649BF25F85C893E2A1916B1DA57594F2E17,
+ 0x01ACC84F362CF60A265C011F0FE4360A15F51BECF7E2C3923FE07C66D5D113104B56E8486C64204A2A9ECD75BA0C41A7,
+ ],
+ ],
+ ),
+]
+
+
+class BLS12_377_Test(parameterized.TestCase):
+
+ def __init__(self, *args, **kwargs):
+ super(BLS12_377_Test, self).__init__(*args, **kwargs)
+
+ @parameterized.named_parameters(*BLS12_377_TEST_CASES)
+ def test_ExtendedTwistedEdwards_point_add(self, point_batch_1, point_batch_2):
+ ec_config = toml.load(
+ os.path.join(os.path.dirname(__file__), "configurations.toml")
+ )
+ rns_moduli = utils.find_moduli_specified_number(32, 28)
+ finite_field_parameters = {
+ "prime": ec_config["ec_parameters_bls12_377_affine"]["prime"],
+ "rns_moduli": rns_moduli,
+ "precision_bits": 28,
+ "radix_bits": 32,
+ }
+ ete_cfg = ec_config["ec_parameters_bls12_377_extended_twisted_edwards"]
+ ec_parameters = {
+ "finite_field_context_class": ff_context.DRNSlazyContext,
+ "finite_field_parameters": finite_field_parameters,
+ "prime": ete_cfg["prime"],
+ "order": ete_cfg["order"],
+ "a": ete_cfg["a"],
+ "twist_d": ete_cfg["d"],
+ "alpha": ete_cfg["alpha"],
+ "b": ete_cfg["b"],
+ "s": ete_cfg["s"],
+ "MA": ete_cfg["MA"],
+ "MB": ete_cfg["MB"],
+ "t": ete_cfg["t"],
+ "generator": ete_cfg["generator"],
+ }
+ affine_cfg = ec_config["ec_parameters_bls12_377_affine"]
+ ref_ec_parameters = {
+ "finite_field_parameters": finite_field_parameters,
+ "finite_field_context_class": ff_context.DRNSlazyContext,
+ "prime": affine_cfg["prime"],
+ "order": affine_cfg["order"],
+ "a": affine_cfg["a"],
+ "b": affine_cfg["b"],
+ "generator": affine_cfg["generator"],
+ }
+
+ ec_ctx = ec_context.ExtendedTwistedEdwardsContext(ec_parameters)
+ ref_ec_ctx = ec_context.CPUWeierstrassAffineContext(ref_ec_parameters)
+
+ point_batch_1_m = ec_ctx.to_computational_format(point_batch_1)
+ point_batch_2_m = ec_ctx.to_computational_format(point_batch_2)
+ result_m = ec_ctx.point_add(point_batch_1_m, point_batch_2_m)
+ result = ec_ctx.to_original_format(result_m)
+
+ ref_result = ref_ec_ctx._point_add(point_batch_1, point_batch_2)
+
+ np.testing.assert_array_equal(result, ref_result)
-class TestEllipticCurve(absltest.TestCase):
-
- def setUp(self):
- super().setUp()
- self.coordinate_num = 4
- self.batch_size = 1
- self.x1_int_ = 0x01AC3A384FC584EFD3E7F2C5A2927E7D454875C874A051027B9E7363D08942533EDE85DAE295D8CAB2751085206BCA76
- self.y1_int_ = 0x011DB83AEC88460820F4868A73B12309EE2E910526E62DB4ACCB303ABF50F86C3985A072ED07A4B81FFB82D8DD247283
- self.x2_int_ = 0x01546AF2ABB4E189E9BBC412FDBF2A8E5EC6E4A3B0AF132E21EE9CEC3EF5E226490FB98D662670FA3CFB3948B7E2A48C
- self.y2_int_ = 0x002961A558A885DF227FDB09F8BDF57AF179CB9437FF8828F13E9DF01AE55502F409AAF5058B88F2F7CCC7BC0676A5D4
- self.point_a = [self.x1_int_, self.y1_int_]
- self.point_b = [self.x2_int_, self.y2_int_]
- self.zero_twisted = [0, 1, 1, 0]
- self.ec_sys = ec.ECCSWeierstrassXYZZ(config_file.config_BLS12_377)
- self.point_a_sys = self.ec_sys.generate_point(self.point_a)
- self.point_b_sys = self.ec_sys.generate_point(self.point_b)
- assert int(self.point_a_sys.coordinates[0].value) == self.point_a[0]
- assert int(self.point_a_sys.coordinates[1].value) == self.point_a[1]
- assert int(self.point_b_sys.coordinates[0].value) == self.point_b[0]
- assert int(self.point_b_sys.coordinates[1].value) == self.point_b[1]
- self.true_result_padd = self.point_a_sys + self.point_b_sys
- self.true_result_padd_affine = self.true_result_padd.convert_to_affine()
- self.true_result_pdub_a = self.point_a_sys + self.point_a_sys
- self.true_result_pdub_a_affine = self.true_result_pdub_a.convert_to_affine()
- self.true_result_pdub_b = self.point_b_sys + self.point_b_sys
- self.true_result_pdub_b_affine = self.true_result_pdub_b.convert_to_affine()
-
- def test_padd_barrett_xyzz_pack(self):
- point_a_jax = util.int_point_batch_to_jax_point_pack(
- [self.point_a + [1] * (self.coordinate_num - len(self.point_a))]
- )
- point_b_jax = util.int_point_batch_to_jax_point_pack(
- [self.point_b + [1] * (self.coordinate_num - len(self.point_b))]
- )
- jit_padd_barrett_xyzz_pack = jax.jit(jec.padd_barrett_xyzz_pack)
-
- result_jax = jit_padd_barrett_xyzz_pack(point_a_jax, point_b_jax)
- result_jax = util.jax_point_pack_to_int_point_batch(result_jax)
-
- self.assertEqual(result_jax[0][0], self.true_result_padd[0].get_value())
- self.assertEqual(result_jax[0][1], self.true_result_padd[1].get_value())
- self.assertEqual(result_jax[0][2], self.true_result_padd[2].get_value())
- self.assertEqual(result_jax[0][3], self.true_result_padd[3].get_value())
-
- # performance measurement
- tasks = [
- (jit_padd_barrett_xyzz_pack, (point_a_jax, point_b_jax)),
- ]
- profile_name = "jit_padd_barrett_xyzz_pack"
- # copybara: util.profile_jax_functions(tasks, profile_name)
-
- def test_pdul_barrett_xyzz(self):
- point_a_jax = util.int_point_batch_to_jax_point_pack(
- [self.point_a + [1] * (self.coordinate_num - len(self.point_a))]
- )
- jit_pdul_barrett_xyzz_pack = jax.jit(jec.pdul_barrett_xyzz_pack)
- result_jax = jit_pdul_barrett_xyzz_pack(point_a_jax)
- result_jax = util.jax_point_pack_to_int_point_batch(result_jax)
-
- self.assertEqual(result_jax[0][0], self.true_result_pdub_a[0].get_value())
- self.assertEqual(result_jax[0][1], self.true_result_pdub_a[1].get_value())
- self.assertEqual(result_jax[0][2], self.true_result_pdub_a[2].get_value())
- self.assertEqual(result_jax[0][3], self.true_result_pdub_a[3].get_value())
-
- # performance measurement
- tasks = [
- (jit_pdul_barrett_xyzz_pack, (point_a_jax,)),
- ]
- profile_name = "jit_pdul_barrett_xyzz_pack"
- # copybara: util.profile_jax_functions(tasks, profile_name)
-
- def test_jit_pdul_barrett_xyzz_pack_two_no_batch(self):
- point_a_jax = util.int_point_batch_to_jax_point_pack(
- [self.point_a + [1] * (self.coordinate_num - len(self.point_a))]
- )
- point_b_jax = util.int_point_batch_to_jax_point_pack(
- [self.point_b + [1] * (self.coordinate_num - len(self.point_b))]
- )
- jit_pdul_barrett_xyzz_pack = jax.jit(jec.pdul_barrett_xyzz_pack)
- result_a_jax = jit_pdul_barrett_xyzz_pack(point_a_jax)
- result_a_int = util.jax_point_pack_to_int_point_batch(result_a_jax)
-
- result_b_jax = jit_pdul_barrett_xyzz_pack(point_b_jax)
- result_b_int = util.jax_point_pack_to_int_point_batch(result_b_jax)
-
- self.assertEqual(result_a_int[0][0], self.true_result_pdub_a[0].get_value())
- self.assertEqual(result_a_int[0][1], self.true_result_pdub_a[1].get_value())
- self.assertEqual(result_a_int[0][2], self.true_result_pdub_a[2].get_value())
- self.assertEqual(result_a_int[0][3], self.true_result_pdub_a[3].get_value())
- self.assertEqual(result_b_int[0][0], self.true_result_pdub_b[0].get_value())
- self.assertEqual(result_b_int[0][1], self.true_result_pdub_b[1].get_value())
- self.assertEqual(result_b_int[0][2], self.true_result_pdub_b[2].get_value())
- self.assertEqual(result_b_int[0][3], self.true_result_pdub_b[3].get_value())
-
- # performance measurement
- tasks = [
- (jit_pdul_barrett_xyzz_pack, (point_a_jax,)),
- (jit_pdul_barrett_xyzz_pack, (point_b_jax,)),
- ]
- profile_name = "jit_pdul_barrett_xyzz_pack"
- # copybara: util.profile_jax_functions(tasks, profile_name)
-
- def test_jit_pdul_barrett_xyzz_pack_two_batch(self):
- point_a_jax = util.int_point_batch_to_jax_point_pack(
- [self.point_a + [1] * (self.coordinate_num - len(self.point_a))]
- )
- point_b_jax = util.int_point_batch_to_jax_point_pack(
- [self.point_b + [1] * (self.coordinate_num - len(self.point_b))]
- )
- batch_point = jnp.concatenate([point_a_jax, point_b_jax], axis=1)
- jit_pdul_barrett_xyzz_pack = jax.jit(jec.pdul_barrett_xyzz_pack)
- result_jax = jit_pdul_barrett_xyzz_pack(batch_point)
- result_int = util.jax_point_pack_to_int_point_batch(result_jax)
-
- self.assertEqual(result_int[0][0], self.true_result_pdub_a[0].get_value())
- self.assertEqual(result_int[0][1], self.true_result_pdub_a[1].get_value())
- self.assertEqual(result_int[0][2], self.true_result_pdub_a[2].get_value())
- self.assertEqual(result_int[0][3], self.true_result_pdub_a[3].get_value())
- self.assertEqual(result_int[1][0], self.true_result_pdub_b[0].get_value())
- self.assertEqual(result_int[1][1], self.true_result_pdub_b[1].get_value())
- self.assertEqual(result_int[1][2], self.true_result_pdub_b[2].get_value())
- self.assertEqual(result_int[1][3], self.true_result_pdub_b[3].get_value())
-
- # performance measurement
- tasks = [
- (jit_pdul_barrett_xyzz_pack, (batch_point,)),
- ]
- profile_name = "jit_pdul_barrett_xyzz_pack"
- # copybara: util.profile_jax_functions(tasks, profile_name)
-
- def test_padd_lazy_xyzz_pack(self):
- point_a_jax = util.int_point_batch_to_jax_point_pack(
- [self.point_a + [1] * (self.coordinate_num - len(self.point_a))],
- chunk_num=util.U16_EXT_CHUNK_NUM,
- )
- point_b_jax = util.int_point_batch_to_jax_point_pack(
- [self.point_b + [1] * (self.coordinate_num - len(self.point_b))],
- chunk_num=util.U16_EXT_CHUNK_NUM,
- )
- # lazy_mat = util.construct_lazy_matrix(util.MODULUS_377_INT)
- jit_padd_lazy_xyzz_pack = jax.jit(jec.padd_lazy_xyzz_pack)
- result_jax = jit_padd_lazy_xyzz_pack(point_a_jax, point_b_jax)
- result_jax = util.jax_point_pack_to_int_point_batch(result_jax)
-
- self.assertEqual(
- result_jax[0][0] % util.MODULUS_377_INT,
- self.true_result_padd[0].get_value(),
- )
- self.assertEqual(
- result_jax[0][1] % util.MODULUS_377_INT,
- self.true_result_padd[1].get_value(),
- )
- self.assertEqual(
- result_jax[0][2] % util.MODULUS_377_INT,
- self.true_result_padd[2].get_value(),
- )
- self.assertEqual(
- result_jax[0][3] % util.MODULUS_377_INT,
- self.true_result_padd[3].get_value(),
- )
-
- # performance measurement
- tasks = [
- (jit_padd_lazy_xyzz_pack, (point_a_jax, point_b_jax)),
- ]
- profile_name = "jit_padd_lazy_xyzz_pack"
- # copybara: util.profile_jax_functions(tasks, profile_name)
-
- def test_pdul_lazy_xyzz_pack(self):
- point_a_jax = util.int_point_batch_to_jax_point_pack(
- [self.point_a + [1] * (self.coordinate_num - len(self.point_a))],
- chunk_num=util.U16_EXT_CHUNK_NUM,
- )
-
- # lazy_mat = util.construct_lazy_matrix(util.MODULUS_377_INT)
- jit_pdul_lazy_xyzz_pack = jax.jit(jec.pdul_lazy_xyzz_pack)
- result_jax = jit_pdul_lazy_xyzz_pack(point_a_jax)
- result_jax = util.jax_point_pack_to_int_point_batch(result_jax)
-
- self.assertEqual(
- result_jax[0][0] % util.MODULUS_377_INT,
- self.true_result_pdub_a[0].get_value(),
- )
- self.assertEqual(
- result_jax[0][1] % util.MODULUS_377_INT,
- self.true_result_pdub_a[1].get_value(),
- )
- self.assertEqual(
- result_jax[0][2] % util.MODULUS_377_INT,
- self.true_result_pdub_a[2].get_value(),
- )
- self.assertEqual(
- result_jax[0][3] % util.MODULUS_377_INT,
- self.true_result_pdub_a[3].get_value(),
- )
-
- # performance measurement
- tasks = [
- (jit_pdul_lazy_xyzz_pack, (point_a_jax,)),
- ]
- profile_name = "jit_pdul_lazy_xyzz_pack"
- # copybara: util.profile_jax_functions(tasks, profile_name)
-
- def test_padd_lazy_twisted_pack(self):
- twisted_ec_sys = ec.ECCSTwistedEdwardsExtended(
- config_file.config_BLS12_377_t
- )
- twist_a = twisted_ec_sys.twist_int_coordinates(self.point_a)
- twist_b = twisted_ec_sys.twist_int_coordinates(self.point_b)
-
- point_a_jax = util.int_point_batch_to_jax_point_pack(
- [twist_a], chunk_num=util.U16_EXT_CHUNK_NUM
- )
- point_b_jax = util.int_point_batch_to_jax_point_pack(
- [twist_b], chunk_num=util.U16_EXT_CHUNK_NUM
- )
-
- jit_padd_lazy_twisted_pack = jax.jit(jec.padd_lazy_twisted_pack)
- result_jax = jit_padd_lazy_twisted_pack(point_a_jax, point_b_jax)
- result_int = util.jax_point_pack_to_int_point_batch(result_jax)
-
- result_affine_point = twisted_ec_sys.generate_point(
- result_int[0], twist=False
- ).convert_to_affine()
-
- self.assertEqual(
- result_affine_point[0].get_value(),
- self.true_result_padd_affine[0].get_value(),
- )
- self.assertEqual(
- result_affine_point[1].get_value(),
- self.true_result_padd_affine[1].get_value(),
- )
-
- def test_padd_lazy_twisted_pack_batch(self):
- for batch_size in [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]:
- twisted_ec_sys = ec.ECCSTwistedEdwardsExtended(
- config_file.config_BLS12_377_t
- )
- twist_a = twisted_ec_sys.twist_int_coordinates(self.point_a)
- twist_b = twisted_ec_sys.twist_int_coordinates(self.point_b)
-
- point_a_jax = util.int_point_batch_to_jax_point_pack(
- [twist_a], chunk_num=util.U16_EXT_CHUNK_NUM
- )
- point_b_jax = util.int_point_batch_to_jax_point_pack(
- [twist_b], chunk_num=util.U16_EXT_CHUNK_NUM
- )
- point_a_jax = jnp.broadcast_to(
- point_a_jax, (point_a_jax.shape[0], batch_size, point_a_jax.shape[-1])
- )
- point_b_jax = jnp.broadcast_to(
- point_b_jax, (point_b_jax.shape[0], batch_size, point_b_jax.shape[-1])
- )
-
- jit_padd_lazy_twisted_pack_batch = jax.jit(
- jax.named_call(
- functools.partial(jec.padd_lazy_twisted_pack),
- name=f"jit_padd_lazy_twisted_pack_batch_{batch_size}",
- ),
- )
- result_jax = jit_padd_lazy_twisted_pack_batch(point_a_jax, point_b_jax)
- result_int = util.jax_point_pack_to_int_point_batch(result_jax)
- result_affine_point = twisted_ec_sys.generate_point(
- result_int[0], twist=False
- ).convert_to_affine()
-
- self.assertEqual(
- result_affine_point.coordinates[0].value % util.MODULUS_377_INT,
- self.true_result_padd_affine[0].get_value(),
- )
- self.assertEqual(
- result_affine_point.coordinates[1].value % util.MODULUS_377_INT,
- self.true_result_padd_affine[1].get_value(),
- )
- self.assertEqual(
- result_affine_point.coordinates[2].value % util.MODULUS_377_INT,
- self.true_result_padd_affine[2].get_value(),
- )
- self.assertEqual(
- result_affine_point.coordinates[3].value % util.MODULUS_377_INT,
- self.true_result_padd_affine[3].get_value(),
- )
-
- # performance measurement
- tasks = [
- (jit_padd_lazy_twisted_pack_batch, (point_a_jax, point_b_jax)),
- ]
- profile_name = f"jit_padd_lazy_twisted_pack_batch_{batch_size}"
- # copybara: util.profile_jax_functions(tasks, profile_name)
-
- def test_padd_same_lazy_twisted_pack(self):
- twisted_ec_sys = ec.ECCSTwistedEdwardsExtended(
- config_file.config_BLS12_377_t
- )
- twist_a1 = twisted_ec_sys.twist_int_coordinates(self.point_a)
- twist_a2 = twisted_ec_sys.twist_int_coordinates(self.point_a)
-
- point_a1_jax = util.int_point_batch_to_jax_point_pack(
- [twist_a1], chunk_num=util.U16_EXT_CHUNK_NUM
- )
- point_a2_jax = util.int_point_batch_to_jax_point_pack(
- [twist_a2], chunk_num=util.U16_EXT_CHUNK_NUM
- )
-
- jit_padd_lazy_twisted_pack = jax.jit(jec.padd_lazy_twisted_pack)
- result_jax = jit_padd_lazy_twisted_pack(point_a1_jax, point_a2_jax)
- result_int = util.jax_point_pack_to_int_point_batch(result_jax)
-
- result_affine_point = twisted_ec_sys.generate_point(
- result_int[0], twist=False
- ).convert_to_affine()
-
- self.assertEqual(
- result_affine_point[0].get_value(),
- self.true_result_pdub_a_affine[0].get_value(),
- )
- self.assertEqual(
- result_affine_point[1].get_value(),
- self.true_result_pdub_a_affine[1].get_value(),
- )
-
- # performance measurement
- tasks = [
- (jit_padd_lazy_twisted_pack, (point_a1_jax, point_a2_jax)),
- ]
- profile_name = "jit_padd_lazy_twisted_pack"
- # copybara: util.profile_jax_functions(tasks, profile_name)
-
- def test_pdul_lazy_twisted_pack(self):
- twisted_ec_sys = ec.ECCSTwistedEdwardsExtended(
- config_file.config_BLS12_377_t
- )
- twist_a = twisted_ec_sys.twist_int_coordinates(self.point_a)
- point_a_jax = util.int_point_batch_to_jax_point_pack(
- [twist_a], chunk_num=util.U16_EXT_CHUNK_NUM
- )
-
- jit_pdul_lazy_twisted_pack = jax.jit(jec.pdul_lazy_twisted_pack)
- result_jax = jit_pdul_lazy_twisted_pack(point_a_jax)
- result_int = util.jax_point_pack_to_int_point_batch(result_jax)
-
- result_affine_point = twisted_ec_sys.generate_point(
- result_int[0], twist=False
- ).convert_to_affine()
- self.assertEqual(
- result_affine_point[0].get_value(),
- self.true_result_pdub_a_affine[0].get_value(),
- )
- self.assertEqual(
- result_affine_point[1].get_value(),
- self.true_result_pdub_a_affine[1].get_value(),
- )
-
- # performance measurement
- tasks = [
- (jit_pdul_lazy_twisted_pack, (point_a_jax,)),
- ]
- profile_name = "jit_pdul_lazy_twisted_pack"
- # copybara: util.profile_jax_functions(tasks, profile_name)
-
- def test_jit_pneg_lazy_twisted_pack(self):
- twisted_ec_sys = ec.ECCSTwistedEdwardsExtended(
- config_file.config_BLS12_377_t
- )
- twist_a = twisted_ec_sys.twist_int_coordinates(self.point_a)
- twist_b = twisted_ec_sys.twist_int_coordinates(self.point_b)
-
- point_a_jax = util.int_point_batch_to_jax_point_pack(
- [twist_a], chunk_num=util.U16_EXT_CHUNK_NUM
- )
- point_b_jax = util.int_point_batch_to_jax_point_pack(
- [twist_b], chunk_num=util.U16_EXT_CHUNK_NUM
- )
-
- jit_padd_lazy_twisted_pack = jax.jit(jec.padd_lazy_twisted_pack)
- jit_pneg_lazy_twisted_pack = jax.jit(jec.pneg_lazy_twisted_pack)
- a_plus_b = jit_padd_lazy_twisted_pack(point_a_jax, point_b_jax)
- neg_b = jit_pneg_lazy_twisted_pack(point_b_jax)
- result_jax = jit_padd_lazy_twisted_pack(a_plus_b, neg_b)
- result_int = util.jax_point_pack_to_int_point_batch(result_jax)
-
- result_affine_point = twisted_ec_sys.generate_point(
- result_int[0], twist=False
- ).convert_to_affine()
- self.assertEqual(
- result_affine_point[0].get_value(), self.point_a_sys[0].get_value()
- )
- self.assertEqual(
- result_affine_point[1].get_value(), self.point_a_sys[1].get_value()
- )
-
- # performance measurement
- tasks = [
- (jit_pneg_lazy_twisted_pack, (point_b_jax,)),
- ]
- profile_name = "jit_pneg_lazy_twisted_pack"
- # copybara: util.profile_jax_functions(tasks, profile_name)
-
- def test_padd_rns_xyzz(self):
- point_a_jax = util.int_point_batch_to_jax_rns_point_pack(
- [self.point_a + [1] * (self.coordinate_num - len(self.point_a))]
- )
- point_b_jax = util.int_point_batch_to_jax_rns_point_pack(
- [self.point_b + [1] * (self.coordinate_num - len(self.point_b))]
- )
- rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT)
-
- jit_padd_rns_xyzz_pack = jax.jit(
- jax.named_call(
- functools.partial(jec.padd_rns_xyzz_pack, rns_mat=rns_mat),
- name="jit_padd_rns_xyzz_pack",
- ),
- )
- result_jax = jit_padd_rns_xyzz_pack(point_a_jax, point_b_jax)
- result_jax = util.jax_rns_point_pack_to_int_point_batch(result_jax)
-
- self.assertEqual(
- result_jax[0][0] % util.MODULUS_377_INT,
- self.true_result_padd[0].get_value(),
- )
- self.assertEqual(
- result_jax[0][1] % util.MODULUS_377_INT,
- self.true_result_padd[1].get_value(),
- )
- self.assertEqual(
- result_jax[0][2] % util.MODULUS_377_INT,
- self.true_result_padd[2].get_value(),
- )
- self.assertEqual(
- result_jax[0][3] % util.MODULUS_377_INT,
- self.true_result_padd[3].get_value(),
- )
-
- # performance measurement
- tasks = [
- (jit_padd_rns_xyzz_pack, (point_a_jax, point_b_jax)),
- ]
- profile_name = "jit_padd_rns_xyzz_pack"
- # copybara: util.profile_jax_functions(tasks, profile_name)
-
- def test_padd_rns_xyzz_batch(self):
- for batch_size in [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]:
- point_a_jax = util.int_point_batch_to_jax_rns_point_pack(
- [self.point_a + [1] * (self.coordinate_num - len(self.point_a))]
- )
- point_b_jax = util.int_point_batch_to_jax_rns_point_pack(
- [self.point_b + [1] * (self.coordinate_num - len(self.point_b))]
- )
- point_a_jax = jnp.broadcast_to(
- point_a_jax, (point_a_jax.shape[0], batch_size, point_a_jax.shape[-1])
- )
- point_b_jax = jnp.broadcast_to(
- point_b_jax, (point_b_jax.shape[0], batch_size, point_b_jax.shape[-1])
- )
- rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT)
-
- jit_padd_rns_xyzz_pack_batch = jax.jit(
- jax.named_call(
- functools.partial(jec.padd_rns_xyzz_pack, rns_mat=rns_mat),
- name="jit_padd_rns_xyzz_pack_batch",
- ),
- )
- result_jax = jit_padd_rns_xyzz_pack_batch(point_a_jax, point_b_jax)
- result_jax = util.jax_rns_point_pack_to_int_point_batch(result_jax)
-
- self.assertEqual(
- result_jax[0][0] % util.MODULUS_377_INT,
- self.true_result_padd[0].get_value(),
- )
- self.assertEqual(
- result_jax[0][1] % util.MODULUS_377_INT,
- self.true_result_padd[1].get_value(),
- )
- self.assertEqual(
- result_jax[0][2] % util.MODULUS_377_INT,
- self.true_result_padd[2].get_value(),
- )
- self.assertEqual(
- result_jax[0][3] % util.MODULUS_377_INT,
- self.true_result_padd[3].get_value(),
- )
-
- # performance measurement
- tasks = [
- (jit_padd_rns_xyzz_pack_batch, (point_a_jax, point_b_jax)),
- ]
- profile_name = f"jit_padd_rns_xyzz_pack_batch_{batch_size}"
- # copybara: util.profile_jax_functions(tasks, profile_name)
-
- def test_pdul_rns_xyzz_pack(self):
- point_a_jax = util.int_point_batch_to_jax_rns_point_pack(
- [self.point_a + [1] * (self.coordinate_num - len(self.point_a))]
- )
- rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT)
-
- jit_pdul_rns_xyzz_pack = jax.jit(
- jax.named_call(
- functools.partial(jec.pdul_rns_xyzz_pack, rns_mat=rns_mat),
- name="jit_pdul_rns_xyzz_pack",
- ),
- )
- result_jax = jit_pdul_rns_xyzz_pack(point_a_jax)
- result_jax = util.jax_rns_point_pack_to_int_point_batch(result_jax)
-
- self.assertEqual(
- result_jax[0][0] % util.MODULUS_377_INT,
- self.true_result_pdub_a[0].get_value(),
- )
- self.assertEqual(
- result_jax[0][1] % util.MODULUS_377_INT,
- self.true_result_pdub_a[1].get_value(),
- )
- self.assertEqual(
- result_jax[0][2] % util.MODULUS_377_INT,
- self.true_result_pdub_a[2].get_value(),
- )
- self.assertEqual(
- result_jax[0][3] % util.MODULUS_377_INT,
- self.true_result_pdub_a[3].get_value(),
- )
-
- # performance measurement
- tasks = [
- (jit_pdul_rns_xyzz_pack, (point_a_jax,)),
- ]
- profile_name = "jit_pdul_rns_xyzz_pack"
- # copybara: util.profile_jax_functions(tasks, profile_name)
-
- def test_padd_rns_twisted_pack(self):
- twisted_ec_sys = ec.ECCSTwistedEdwardsExtended(
- config_file.config_BLS12_377_t
- )
- a = self.point_a_sys.coordinates[:2]
- b = self.point_b_sys.coordinates[:2]
- project_twist_a = twisted_ec_sys.generate_point(a, twist=True)
- project_twist_b = twisted_ec_sys.generate_point(b, twist=True)
- rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT)
- point_a_jax = util.int_point_batch_to_jax_rns_point_pack(
- [[c.get_value() for c in project_twist_a]]
- )
- point_b_jax = util.int_point_batch_to_jax_rns_point_pack(
- [[c.get_value() for c in project_twist_b]]
- )
- jit_padd_rns_twisted_pack = jax.jit(
- jax.named_call(
- functools.partial(jec.padd_rns_twisted_pack, rns_mat=rns_mat),
- name="jit_padd_rns_twisted_pack",
- ),
- )
- point_c_jax = jit_padd_rns_twisted_pack(point_a_jax, point_b_jax)
- project_twist_sum = util.jax_rns_point_pack_to_int_point_batch(point_c_jax)[
- 0
- ]
- project_twist_sum_point = twisted_ec_sys.generate_point(
- project_twist_sum, twist=False
- ).convert_to_affine()
- s = project_twist_sum_point.coordinates[:2]
- correct_s = self.true_result_padd_affine.coordinates[:2]
- self.assertEqual(
- s[0].get_value() % util.MODULUS_377_INT, correct_s[0].get_value()
- )
- self.assertEqual(
- s[1].get_value() % util.MODULUS_377_INT, correct_s[1].get_value()
- )
-
- # performance measurement
- tasks = [
- (jit_padd_rns_twisted_pack, (point_a_jax, point_b_jax)),
- ]
- profile_name = "jit_padd_rns_twisted_pack"
- # copybara: util.profile_jax_functions(tasks, profile_name)
-
- def test_pdul_rns_twisted_pack(self):
- twisted_ec_sys = ec.ECCSTwistedEdwardsExtended(
- config_file.config_BLS12_377_t
- )
- rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT)
- pg = twisted_ec_sys.generate_point(self.point_a, twist=False)
- g_correct_ff = pg << 1
- gcf_af = g_correct_ff.convert_to_affine()
- point_g_jax = util.int_point_batch_to_jax_rns_point_pack(
- [[c.get_value() for c in pg]]
- )
- jit_pdul_rns_twisted_pack = jax.jit(
- jax.named_call(
- functools.partial(jec.pdul_rns_twisted_pack, rns_mat=rns_mat),
- name="jit_pdul_rns_twisted_pack",
- ),
- )
- point_2g_jax = jit_pdul_rns_twisted_pack(point_g_jax)
- g_test = util.jax_rns_point_pack_to_int_point_batch(point_2g_jax)[0]
- gtf_af = twisted_ec_sys.generate_point(
- g_test, twist=False
- ).convert_to_affine()
- self.assertEqual(
- gtf_af.coordinates[0].get_value() % util.MODULUS_377_INT,
- gcf_af.coordinates[0].get_value(),
- )
- self.assertEqual(
- gtf_af.coordinates[1].get_value() % util.MODULUS_377_INT,
- gcf_af.coordinates[1].get_value(),
- )
-
- # performance measurement
- tasks = [
- (jit_pdul_rns_twisted_pack, (point_g_jax,)),
- ]
- profile_name = "jit_pdul_rns_twisted_pack"
- # copybara: util.profile_jax_functions(tasks, profile_name)
-
- def test_padd_rns_twisted_pack_new_twisted(self):
- twisted_ec_sys = ec.ECCSTwistedEdwardsExtended(
- config_file.config_BLS12_377_t
- )
- twist_a = twisted_ec_sys.twist_int_coordinates(self.point_a)
- twist_b = twisted_ec_sys.twist_int_coordinates(self.point_b)
-
- point_a_jax = util.int_point_to_jax_rns_point_pack(twist_a)
- point_b_jax = util.int_point_to_jax_rns_point_pack(twist_b)
-
- rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT)
- jit_padd_rns_twisted_pack = jax.jit(
- jax.named_call(
- functools.partial(jec.padd_rns_twisted_pack, rns_mat=rns_mat),
- name="jit_padd_rns_twisted_pack",
- ),
- )
- point_c_jax = jit_padd_rns_twisted_pack(point_a_jax, point_b_jax)
- project_twist_sum = util.jax_rns_point_pack_to_int_point_batch(point_c_jax)[
- 0
- ]
- project_twist_sum_point = twisted_ec_sys.generate_point(
- project_twist_sum, twist=False
- ).convert_to_affine()
- s = project_twist_sum_point.coordinates[:2]
- correct_s = self.true_result_padd_affine.coordinates[:2]
- self.assertEqual(
- s[0].get_value() % util.MODULUS_377_INT, correct_s[0].get_value()
- )
- self.assertEqual(
- s[1].get_value() % util.MODULUS_377_INT, correct_s[1].get_value()
- )
-
- # performance measurement
- tasks = [
- (jit_padd_rns_twisted_pack, (point_a_jax, point_b_jax)),
- ]
- profile_name = "jit_padd_rns_twisted_pack_new_twist"
- # copybara: util.profile_jax_functions(tasks, profile_name)
-
- def test_pdul_rns_twisted_pack_new_twisted(self):
- twisted_ec_sys = ec.ECCSTwistedEdwardsExtended(
- config_file.config_BLS12_377_t
- )
- rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT)
- twist_a = twisted_ec_sys.twist_int_coordinates(self.point_a)
- point_a_jax = util.int_point_to_jax_rns_point_pack(twist_a)
- jit_pdul_rns_twisted_pack = jax.jit(
- jax.named_call(
- functools.partial(jec.pdul_rns_twisted_pack, rns_mat=rns_mat),
- name="jit_pdul_rns_twisted_pack",
- ),
- )
- point_2a_jax = jit_pdul_rns_twisted_pack(point_a_jax)
- g_test = util.jax_rns_point_pack_to_int_point_batch(point_2a_jax)[0]
- gtf_af = twisted_ec_sys.generate_point(
- g_test, twist=False
- ).convert_to_affine()
- self.assertEqual(
- gtf_af.coordinates[0].get_value() % util.MODULUS_377_INT,
- self.true_result_pdub_a_affine[0].get_value(),
- )
- self.assertEqual(
- gtf_af.coordinates[1].get_value() % util.MODULUS_377_INT,
- self.true_result_pdub_a_affine[1].get_value(),
- )
-
- # performance measurement
- tasks = [
- (jit_pdul_rns_twisted_pack, (point_a_jax,)),
- ]
- profile_name = "jit_pdul_rns_twisted_pack_new_twist"
- # copybara: util.profile_jax_functions(tasks, profile_name)
-
- def test_padd_rns_twisted_pack_new_twist_two_batch(self):
- twisted_ec_sys = ec.ECCSTwistedEdwardsExtended(
- config_file.config_BLS12_377_t
- )
- twist_a = twisted_ec_sys.twist_int_coordinates(self.point_a)
- point_a_jax = util.int_point_to_jax_rns_point_pack(twist_a).reshape(
- util.COORDINATE_NUM, 1, util.NUM_MODULI
- )
-
- twist_b = twisted_ec_sys.twist_int_coordinates(self.point_b)
- point_b_jax = util.int_point_to_jax_rns_point_pack(twist_b).reshape(
- util.COORDINATE_NUM, 1, util.NUM_MODULI
- )
-
- batch_point = jnp.concatenate([point_a_jax, point_b_jax], axis=1)
-
- rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT)
- jit_padd_rns_twisted_pack_two_batch = jax.jit(
- jax.named_call(
- functools.partial(jec.padd_rns_twisted_pack, rns_mat=rns_mat),
- name="jit_padd_rns_twisted_pack_two_batch",
- ),
- )
- result_batch = jit_padd_rns_twisted_pack_two_batch(batch_point, batch_point)
- project_twist_sum = util.jax_rns_point_pack_to_int_point_batch(result_batch)
- point_2a_jax = twisted_ec_sys.generate_point(
- project_twist_sum[0], twist=False
- ).convert_to_affine()
- point_2b_jax = twisted_ec_sys.generate_point(
- project_twist_sum[1], twist=False
- ).convert_to_affine()
- self.assertEqual(
- point_2a_jax[0].get_value() % util.MODULUS_377_INT,
- self.true_result_pdub_a_affine[0].get_value(),
- )
- self.assertEqual(
- point_2a_jax[1].get_value() % util.MODULUS_377_INT,
- self.true_result_pdub_a_affine[1].get_value(),
- )
- self.assertEqual(
- point_2b_jax[0].get_value() % util.MODULUS_377_INT,
- self.true_result_pdub_b_affine[0].get_value(),
- )
- self.assertEqual(
- point_2b_jax[1].get_value() % util.MODULUS_377_INT,
- self.true_result_pdub_b_affine[1].get_value(),
- )
-
- # performance measurement
- tasks = [
- (jit_padd_rns_twisted_pack_two_batch, (batch_point, batch_point)),
- ]
- profile_name = "jit_padd_rns_twisted_pack_two_batch"
- # copybara: util.profile_jax_functions(tasks, profile_name)
-
- def test_pdul_rns_twisted_pack_new_twist_two_batch(self):
- twisted_ec_sys = ec.ECCSTwistedEdwardsExtended(
- config_file.config_BLS12_377_t
- )
- twist_a = twisted_ec_sys.twist_int_coordinates(self.point_a)
- point_a_jax = util.int_point_to_jax_rns_point_pack(twist_a).reshape(
- util.COORDINATE_NUM, 1, util.NUM_MODULI
- )
-
- twist_b = twisted_ec_sys.twist_int_coordinates(self.point_b)
- point_b_jax = util.int_point_to_jax_rns_point_pack(twist_b).reshape(
- util.COORDINATE_NUM, 1, util.NUM_MODULI
- )
-
- batch_point = jnp.concatenate([point_a_jax, point_b_jax], axis=1)
-
- rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT)
- jit_pdul_rns_twisted_pack_two_batch = jax.jit(
- jax.named_call(
- functools.partial(jec.pdul_rns_twisted_pack, rns_mat=rns_mat),
- name="jit_pdul_rns_twisted_pack_two_batch",
- ),
- )
- result_batch = jit_pdul_rns_twisted_pack_two_batch(batch_point)
- project_twist_sum = util.jax_rns_point_pack_to_int_point_batch(result_batch)
- point_2a_jax = twisted_ec_sys.generate_point(
- project_twist_sum[0], twist=False
- ).convert_to_affine()
- point_2b_jax = twisted_ec_sys.generate_point(
- project_twist_sum[1], twist=False
- ).convert_to_affine()
- self.assertEqual(
- point_2a_jax[0].get_value() % util.MODULUS_377_INT,
- self.true_result_pdub_a_affine[0].get_value(),
- )
- self.assertEqual(
- point_2a_jax[1].get_value() % util.MODULUS_377_INT,
- self.true_result_pdub_a_affine[1].get_value(),
- )
- self.assertEqual(
- point_2b_jax[0].get_value() % util.MODULUS_377_INT,
- self.true_result_pdub_b_affine[0].get_value(),
- )
- self.assertEqual(
- point_2b_jax[1].get_value() % util.MODULUS_377_INT,
- self.true_result_pdub_b_affine[1].get_value(),
- )
-
- # performance measurement
- tasks = [
- (jit_pdul_rns_twisted_pack_two_batch, (batch_point,)),
- ]
- profile_name = "jit_pdul_rns_twisted_pack_two_batch"
- # copybara: util.profile_jax_functions(tasks, profile_name)
-
- def test_padd_zero_twisted_pack_new_twisted(self):
- twisted_ec_sys = ec.ECCSTwistedEdwardsExtended(
- config_file.config_BLS12_377_t
- )
- twist_a = twisted_ec_sys.twist_int_coordinates(self.point_a)
- point_a_jax = util.int_point_batch_to_jax_point_pack(
- [twist_a], chunk_num=util.U16_EXT_CHUNK_NUM
- )
- point_zero_jax = util.int_point_batch_to_jax_point_pack(
- [self.zero_twisted], chunk_num=util.U16_EXT_CHUNK_NUM
- )
-
- jit_padd_lazy_twisted_pack = jax.jit(
- jax.named_call(
- functools.partial(jec.padd_lazy_twisted_pack),
- name="jit_padd_lazy_twisted_pack",
- ),
- )
- point_c_jax = jit_padd_lazy_twisted_pack(point_a_jax, point_zero_jax)
- # point_c_jax = jec.padd_lazy_twisted_pack(point_a_jax, point_zero_jax)
- project_twist_sum = util.jax_point_pack_to_int_point_batch(point_c_jax)[0]
- project_twist_sum_point = twisted_ec_sys.generate_point(
- project_twist_sum, twist=False
- ).convert_to_affine()
- self.assertEqual(
- project_twist_sum_point[0].get_value() % util.MODULUS_377_INT,
- self.point_a[0],
- )
- self.assertEqual(
- project_twist_sum_point[1].get_value() % util.MODULUS_377_INT,
- self.point_a[1],
- )
-
- @absltest.skip("This is the known issue, which does not affect XYZZ.")
- def test_padd_zero_rns_twisted_pack_new_twisted(self):
- twisted_ec_sys = ec.ECCSTwistedEdwardsExtended(
- config_file.config_BLS12_377_t
- )
- twist_a = twisted_ec_sys.twist_int_coordinates(self.point_a)
-
- point_a_jax = util.int_point_batch_to_jax_rns_point_pack([twist_a])
- point_zero_jax = util.int_point_batch_to_jax_rns_point_pack(
- [self.zero_twisted]
- )
-
- rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT)
- jit_padd_rns_twisted_pack = jax.jit(
- jax.named_call(
- functools.partial(jec.padd_rns_twisted_pack, rns_mat=rns_mat),
- name="jit_padd_rns_twisted_pack",
- ),
- )
- point_c_jax = jit_padd_rns_twisted_pack(point_a_jax, point_zero_jax)
- project_twist_sum = util.jax_rns_point_pack_to_int_point_batch(point_c_jax)[
- 0
- ]
- project_twist_sum_point = twisted_ec_sys.generate_point(
- project_twist_sum, twist=False
- ).convert_to_affine()
- self.assertEqual(
- project_twist_sum_point[0].get_value() % util.MODULUS_377_INT,
- self.point_a[0],
- )
- self.assertEqual(
- project_twist_sum_point[1].get_value() % util.MODULUS_377_INT,
- self.point_a[1],
- )
-
- @absltest.skip("This is the known issue, which does not affect XYZZ.")
- def test_padd_rns_a_point_add_zero_correctness(self):
- twisted_ec_sys = ec.ECCSTwistedEdwardsExtended(
- config_file.config_BLS12_377_t
- )
- test_in_point = [
- 184360877740379501345167057231241011318955851506892921023614488028185166370541128591905842464011651119609504970811,
- 47698235458971847835762299820400550031713475079888046003406323907410999702258242394959839249289205517205485978635,
- ]
- twist_a = twisted_ec_sys.twist_int_coordinates(test_in_point)
- twist_b = [0, 1, 1, 0]
- point_a_jax = util.int_point_batch_to_jax_point_pack(
- [twist_a], chunk_num=util.U16_EXT_CHUNK_NUM
- )
- point_b_jax = util.int_point_batch_to_jax_point_pack(
- [twist_b], chunk_num=util.U16_EXT_CHUNK_NUM
- )
- point_a_jax_rns = util.int_point_batch_to_jax_rns_point_pack([twist_a])
- point_b_jax_rns = util.int_point_batch_to_jax_rns_point_pack([twist_b])
-
- point_c_jax_rns = jec.padd_rns_twisted_pack(
- point_a_jax_rns, point_b_jax_rns
- )
- project_twist_sum_rns = util.jax_rns_point_pack_to_int_point_batch(
- point_c_jax_rns
- )[0]
- affine_sum_point_rns = twisted_ec_sys.generate_point(
- project_twist_sum_rns, twist=False
- ).convert_to_affine()
- point_c_jax = jec.padd_lazy_twisted_pack(point_a_jax, point_b_jax)
- project_twist_sum = util.jax_point_pack_to_int_point_batch(point_c_jax)[0]
- affine_sum_point = twisted_ec_sys.generate_point(
- project_twist_sum, twist=False
- ).convert_to_affine()
- # In Twisted Edward Representation Verification
- self.assertEqual(
- project_twist_sum[0] % util.MODULUS_377_INT,
- project_twist_sum_rns[0] % util.MODULUS_377_INT,
- )
- self.assertEqual(
- project_twist_sum[1] % util.MODULUS_377_INT,
- project_twist_sum_rns[1] % util.MODULUS_377_INT,
- )
- self.assertEqual(
- project_twist_sum[0] % util.MODULUS_377_INT,
- project_twist_sum_rns[0] % util.MODULUS_377_INT,
- )
- self.assertEqual(
- project_twist_sum[1] % util.MODULUS_377_INT,
- project_twist_sum_rns[1] % util.MODULUS_377_INT,
- )
-
- # Verification in affine
- self.assertEqual(
- affine_sum_point[0].get_value() % util.MODULUS_377_INT,
- affine_sum_point_rns[0].get_value() % util.MODULUS_377_INT,
- )
- self.assertEqual(
- affine_sum_point[1].get_value() % util.MODULUS_377_INT,
- affine_sum_point_rns[1].get_value() % util.MODULUS_377_INT,
- )
-
if __name__ == "__main__":
absltest.main()
diff --git a/jaxite_ec/finite_field.py b/jaxite_ec/finite_field.py
deleted file mode 100644
index 863cd21..0000000
--- a/jaxite_ec/finite_field.py
+++ /dev/null
@@ -1,976 +0,0 @@
-"""library of finite field operations.
-
-This library is used to implement the finite field operations for the
-high-precision
-elliptic curve.
-
-- Data Representation.
-The high-precision elliptic curve coordinate is represented as a vector of
-uint16 24-bit integer.
-The actual base is increased with the index.
-E.g.
- index [ 0, 1, 2, ...]
-bit precision [0~7, 8~15, 16~23, ...]
-
-It includes Barrett based modular multiplication, *barrett_reduction_u16x2*
-
-
-# Function Name Terminology
-## _ indicates that the function only works for a single
-## precision.
-## indicates that the function works for general bit precision.
-
-# Terminology
-## Chunk Reduction: e.g. u8-chunk -> u16-chunk or u32-chunk
-## Chunk Decomposition <-> Chunk Merge:
-### Chunk Decomposition: break int into multiple low-precision chunks.
-### Chunk Merge: Merge multiple low-precision chunks into an int.
-"""
-
-import functools
-
-import jax
-import jax.numpy as jnp
-from jaxite.jaxite_ec import util
-import numpy as np
-
-
-total_modulus = util.total_modulus
-to_rns = util.to_rns
-
-
-jax.config.update("jax_enable_x64", True)
-
-
-@jax.named_call
-@functools.partial(
- jax.jit, static_argnames=("iter_num", "mask", "chunk_shift_bits")
-)
-def carry_add(
- value_c: jax.Array,
- iter_num=util.U16_CHUNK_NUM,
- mask=util.U16_MASK,
- chunk_shift_bits=util.U16_CHUNK_SHIFT_BITS,
-):
- """The purpose of this API is to enable general-purposed carry add, where the following knobs are known before runtime.
-
- iter_num = math.ceil(total_input_bitwidth / chunk_bitwidth),
- mask = 2**chunk_bitwidth - 1,
- chunk_shift_bits = chunk_bitwidth
-
- Args:
- value_c: The value to carry add.
- iter_num: The number of iterations to perform.
- mask: The mask to apply to the value.
- chunk_shift_bits: The number of bits to shift the value.
-
- Returns:
- value_c: The value after carry adding.
- """
- for _ in range(iter_num):
- low = jnp.bitwise_and(value_c, mask)
- high = jnp.right_shift(value_c, chunk_shift_bits)
- high = jnp.roll(high, 1)
- value_c = jnp.add(low, high)
- return value_c
-
-
-@jax.named_call
-@functools.partial(jax.jit, static_argnames="chunk_shift_bits")
-def check_any_chunk_with_carry(
- value_c: jax.Array,
- chunk_shift_bits=util.U16_CHUNK_SHIFT_BITS,
-) -> jax.Array:
- """This function check whether any chunk of input vector 'value_c' has carry.
-
- Args:
- value_c: The value to carry add.
- chunk_shift_bits: ideal bit precision of any given chunk. Note that: actual
- bit precision of any given chunk might be higher than chunk_shift_bits
- because it needs to hold the overflow.
-
- Returns:
- cond: A boolean value indicating whether any chunk of input vector 'value_c'
- has carry.
- """
- high = jnp.right_shift(value_c, chunk_shift_bits)
- cond = jnp.any(jnp.not_equal(high, 0))
- return cond
-
-
-@jax.named_call
-@functools.partial(jax.jit, static_argnames=("mask", "chunk_shift_bits"))
-def carry_propagation(
- value_c: jax.Array,
- mask=util.U16_MASK,
- chunk_shift_bits=util.U16_CHUNK_SHIFT_BITS,
-):
- """The purpose of this API is to enable carry propagation.
-
- Args:
- value_c: The value to carry propagate.
- mask: 2**chunk_bitwidth - 1,
- chunk_shift_bits: chunk_bitwidth
-
- This function split each chunk into high and low parts, and high part is left
- roll by 1 to carry the overflowed bits to the next chunk.
- Note that: in a given jax.array, bit range of the chunk within the original
- high precision value is increased from left to the right.
-
- Returns:
- value_c: The value after carry adding.
- """
- precision_dim = value_c.shape[-1]
- roll_mat = jnp.array(
- [0, 1]
- + ([0] * (precision_dim) + [1]) * (precision_dim - 2)
- + [1]
- + [0] * (precision_dim - 1),
- dtype=jnp.uint16,
- ).reshape(precision_dim, precision_dim)
- low = jnp.bitwise_and(value_c, mask)
- high = jnp.right_shift(value_c, chunk_shift_bits).astype(jnp.uint16)
- high = jnp.matmul(high, roll_mat, preferred_element_type=jnp.uint32).astype(
- jnp.uint16
- )
- value_c = jnp.add(low, high)
- return value_c
-
-
-def conv_1d_2u16xn(value_a: jax.Array, value_b: jax.Array):
- """This function performs a 1D convolution of two u16 arrays.
-
- Args:
- value_a: The chunk-decomposition representation of the high-precision int.
- value_b: The chunk-decomposition representation of the high-precision int.
-
- Returns:
- conv: The convolution results of two input arrays being casted to uint8.
- """
- value_a = jax.lax.bitcast_convert_type(value_a, jnp.uint8).reshape(-1)
- value_b = jax.lax.bitcast_convert_type(value_b, jnp.uint8).reshape(-1)
- conv = jnp.convolve(
- value_a,
- value_b,
- preferred_element_type=jnp.uint32,
- )
- return conv
-
-
-@jax.named_call
-@functools.partial(jax.jit, static_argnames=("chunk_num_u16", "chunk_num_u32"))
-def rechunkify(mul_result: jax.Array, chunk_num_u16, chunk_num_u32):
- """Given the carry add takes O(C) algorithm complexity, where C is the number of chunks.
-
- This function performs chunk reduction for ther results of the convolution,
- i.e. merge two consecutive chunks into one chunk with double precision.
- E.g. u8[0, 8, 8, 0] -> u16[8, 2048] 0-> u32[526336]
-
- Args:
- mul_result: The chunk-wise multiplication (using convolution) result.
- chunk_num_u16: The number of bits in each chunk.
- chunk_num_u32: The number of bits in the second chunk.
-
- Returns:
- value_c: The result of the chunk reduction.
- """
- shift_0_8_u16x4 = jnp.array(
- [[0, 8] for _ in range(chunk_num_u16 * 4)], dtype=jnp.uint8
- )
- shift_0_16_u32x4 = jnp.array(
- [[0, 16] for _ in range(chunk_num_u32 * 4)], dtype=jnp.uint8
- )
- new_shape = (
- mul_result.shape[:-1] + (-1, 2) if mul_result.ndim == 2 else (-1, 2)
- )
- value_c = mul_result.reshape(new_shape)
- value_c = jnp.left_shift(value_c, shift_0_8_u16x4[:chunk_num_u16])
- value_c = jnp.sum(value_c, axis=-1)
- value_c = value_c.reshape(new_shape).astype(jnp.uint64)
- value_c = jnp.left_shift(value_c, shift_0_16_u32x4[:chunk_num_u32])
- value_c = jnp.sum(value_c, axis=-1)
- return value_c
-
-
-@jax.named_call
-@functools.partial(jax.jit, static_argnames="chunk_num_u16")
-def compare_u16(
- value_a: jax.Array, value_b: jax.Array, chunk_num_u16=util.U16_CHUNK_NUM
-):
- """Compare two u16 values.
-
- Args:
- value_a: The first u16 value.
- value_b: The second u16 value.
- chunk_num_u16: The number of chunks in the u16 value.
-
- Returns:
- cond > 0 -> value_a > value_b
- cond = 0 -> value_a = value_b
- cond < 0 -> value_a < value_b
- """
- sign = jnp.sign(
- jnp.subtract(value_a.astype(jnp.int32), value_b.astype(jnp.int32))
- )
- comp_check_vec_weights = jnp.array(
- [2**i for i in range(chunk_num_u16)], dtype=jnp.int32
- )
- weight = jnp.multiply(sign, comp_check_vec_weights)
- cond = weight.sum(axis=-1)
- return cond
-
-
-def add_2u16(value_a: jax.Array, value_b: jax.Array):
- """Add two u16 values.
-
- Args:
- value_a: The first u16 value.
- value_b: The second u16 value.
-
- Returns:
- value_c: The result of the addition.
- """
- value_c = jax.numpy.add(
- value_a.astype(jnp.uint32), value_b.astype(jnp.uint32)
- )
- value_c = jax.lax.while_loop(
- check_any_chunk_with_carry, carry_propagation, value_c
- )
-
- return value_c.astype(jnp.uint16)
-
-
-def add_3u16(value_a: jax.Array, value_b: jax.Array, value_d: jax.Array):
- value_c = jax.numpy.add(
- value_a.astype(jnp.uint32), value_b.astype(jnp.uint32)
- )
- value_c = jax.numpy.add(
- value_c.astype(jnp.uint32), value_d.astype(jnp.uint32)
- )
- value_c = jax.lax.while_loop(
- check_any_chunk_with_carry, carry_propagation, value_c
- )
-
- return value_c.astype(jnp.uint16)
-
-
-@jax.named_call
-@functools.partial(jax.jit, static_argnames=("mask", "chunk_num_u16"))
-def sub_2u16(
- value_a: jax.Array,
- value_b: jax.Array,
- mask=util.U16_MASK,
- chunk_num_u16=util.U16_CHUNK_NUM,
-):
- """Subtract two u16 values.
-
- Args:
- value_a: The first u16 value.
- value_b: The second u16 value.
- mask: The mask to apply to the value.
- chunk_num_u16: The number of chunks in the u16 value (default: 24).
-
- Returns:
- value_c: The result of the subtraction.
- """
- borrow_high_u16_pad_zero_array = jnp.array(
- [0] + [1] * (chunk_num_u16 - 2) + [0], dtype=jnp.uint32
- )
- borrow_low_u16_array = jnp.array(
- [mask + 1] * (chunk_num_u16 - 1) + [0], dtype=jnp.uint32
- )
- value_a = jnp.add(value_a.astype(jnp.uint32), borrow_low_u16_array)
- value_c = jnp.subtract(value_a, value_b)
- value_c = jnp.subtract(value_c, borrow_high_u16_pad_zero_array)
-
- value_c = jax.lax.while_loop(
- check_any_chunk_with_carry, carry_propagation, value_c
- )
- if value_c.ndim == 1:
- value_c = value_c.at[chunk_num_u16 - 1].set(value_c[chunk_num_u16 - 1] - 1)
- else:
- value_c = value_c.at[:, chunk_num_u16 - 1].set(
- value_c[:, chunk_num_u16 - 1] - 1
- )
-
- value_c = value_c.astype(jnp.uint16)
- return value_c
-
-
-@jax.named_call
-@functools.partial(
- jax.jit, static_argnames=("modulus_377_int_chunk", "chunk_num_u16")
-)
-def cond_sub_mod_u16(
- value_a: jax.Array,
- modulus_377_int_chunk=util.MODULUS_377_INT_CHUNK,
- chunk_num_u16=util.U16_CHUNK_NUM,
-):
- """Perform conditional subtraction: value_a - modulus_377_int.
-
- Args:
- value_a: The minuend.
- modulus_377_int_chunk: The modulus 377.
- chunk_num_u16: The number of chunks in the u16 value (default: 24).
-
- Returns:
- value_c: The result of the conditional subtraction.
- """
- compare_u16_local = functools.partial(
- compare_u16, chunk_num_u16=chunk_num_u16
- )
- sub_2u16_local = functools.partial(sub_2u16, chunk_num_u16=chunk_num_u16)
- modulus_377_int_array = jnp.asarray(modulus_377_int_chunk, jnp.uint16)
-
- cond = compare_u16_local(value_a, modulus_377_int_array)
- value_b = sub_2u16_local(value_a, modulus_377_int_array)
- cond = jnp.greater_equal(cond, 0).reshape((cond.shape[0], 1))
- value_c = jnp.where(cond, value_b, value_a)
- return value_c
-
-
-@jax.named_call
-@functools.partial(
- jax.jit, static_argnames=("modulus_377_int_chunk", "chunk_num_u16")
-)
-def cond_sub_2u16(
- value_a: jax.Array,
- value_b: jax.Array,
- modulus_377_int_chunk=util.MODULUS_377_INT_CHUNK,
- chunk_num_u16=util.U16_CHUNK_NUM,
-):
- """Perform conditional subtraction: value_a - value_b.
-
- Args:
- value_a: The minuend.
- value_b: The subtrahend.
- modulus_377_int_chunk: The modulus 377.
- chunk_num_u16: The number of chunks in the u16 value (default: 24).
-
- Returns:
- value_c: The result of the conditional subtraction.
- """
- modulus_377_int_array = jnp.asarray(modulus_377_int_chunk, jnp.uint16)
- compare_u16_local = functools.partial(
- compare_u16, chunk_num_u16=chunk_num_u16
- )
- sub_2u16_local = functools.partial(sub_2u16, chunk_num_u16=chunk_num_u16)
-
- cond = compare_u16_local(value_a, value_b)
- cond = jnp.greater_equal(cond, 0).reshape((cond.shape[0], 1))
-
- value_ap = jnp.add(
- value_a.astype(jnp.uint32), modulus_377_int_array.astype(jnp.uint32)
- )
-
- value_a = jnp.where(cond, value_a.astype(jnp.uint32), value_ap)
- value_c = sub_2u16_local(value_a, value_b)
- return value_c
-
-
-@jax.named_call
-@functools.partial(
- jax.jit,
- static_argnames=(
- "mask",
- "chunk_num_u16",
- "chunk_shift_bits",
- "output_dtype",
- "vmap_axes",
- ),
-)
-def mul_2u16(
- value_a: jax.Array,
- value_b: jax.Array,
- mask=util.U32_MASK,
- chunk_num_u16=util.U16_CHUNK_NUM,
- chunk_shift_bits=util.U32_CHUNK_SHIFT_BITS,
- output_dtype=jnp.uint16,
- vmap_axes=(0, 0),
-):
- """Multiply two u16 values.
-
- Args:
- value_a: The first u16 value.
- value_b: The second u16 value.
- mask: The mask to apply to the value.
- chunk_num_u16: The number of chunks in the u16 value (default: 24).
- chunk_shift_bits: The number of bits to shift the value.
- output_dtype: The desired output data type.
- vmap_axes: The axes to use for vmap.
-
- Returns:
-
- cond > 0 -> value_a > value_b
- cond = 0 -> value_a = value_b
- cond < 0 -> value_a < value_b
- """
- batch_dim = value_a.shape[0]
- mul_result = jax.vmap(conv_1d_2u16xn, in_axes=vmap_axes)(value_a, value_b)
- mul_result = jnp.pad(mul_result, ((0, 0), (0, 1)))
- value_c = rechunkify(mul_result, 2 * chunk_num_u16, chunk_num_u16)
-
- value_c = jax.lax.while_loop(
- functools.partial(
- check_any_chunk_with_carry, chunk_shift_bits=chunk_shift_bits
- ),
- functools.partial(
- carry_propagation,
- mask=mask,
- chunk_shift_bits=chunk_shift_bits,
- ),
- value_c,
- )
- ratio = 4 if output_dtype == jnp.uint8 else 2
- value_c = jax.lax.bitcast_convert_type(
- value_c.astype(jnp.uint32), output_dtype
- ).reshape(batch_dim, -1)[:, : ratio * chunk_num_u16]
- return value_c
-
-
-@jax.named_call
-@functools.partial(
- jax.jit,
- static_argnames=(
- "barrett_shift_u8",
- "chunk_num_u16",
- "chunk_num_u32",
- "vmap_axes",
- ),
-)
-def mul_shift_2u16x2x1(
- value_a: jax.Array,
- value_b: jax.Array,
- mask=util.U32_MASK,
- barrett_shift_u8=util.BARRETT_SHIFT_U8,
- chunk_num_u16=util.U16_CHUNK_NUM,
- chunk_num_u32=util.U32_CHUNK_NUM,
- chunk_shift_bits=util.U32_CHUNK_SHIFT_BITS,
- vmap_axes=(0, None),
-):
- """Multiply and shift two u16 values.
-
- Args:
- value_a: The first u16 value.
- value_b: The second u16 value.
- mask: The mask to apply to the value.
- barrett_shift_u8: The number of bits to shift the value.
- chunk_num_u16: The number of chunks in the u16 value.
- chunk_num_u32: The number of chunks in the u32 value.
- chunk_shift_bits: The number of bits to shift the value.
- vmap_axes: (0, None) means axis 0 is the mapped access, and The rest is not.
-
- Returns:
-
- cond > 0 -> value_a > value_b
- cond = 0 -> value_a = value_b
- cond < 0 -> value_a < value_b
- """
- batch_dim = value_a.shape[0]
- conv = jax.vmap(conv_1d_2u16xn, in_axes=vmap_axes)(value_a, value_b)
- conv = jnp.pad(conv, ((0, 0), (0, 1)))
- value_c = rechunkify(conv, chunk_num_u16 * 3, chunk_num_u32 * 3)
- value_c = jax.lax.while_loop(
- functools.partial(
- check_any_chunk_with_carry, chunk_shift_bits=chunk_shift_bits
- ),
- functools.partial(
- carry_propagation, mask=mask, chunk_shift_bits=chunk_shift_bits
- ),
- value_c,
- )
-
- value_c = jax.lax.bitcast_convert_type(
- value_c.astype(jnp.uint32), jnp.uint8
- ).reshape(batch_dim, -1)[:, barrett_shift_u8:]
- value_c = jax.lax.bitcast_convert_type(
- jnp.pad(value_c, ((0, 0), (0, 1))).reshape(batch_dim, -1, 2), jnp.uint16
- )[:, :chunk_num_u16]
- return value_c
-
-
-@jax.named_call
-@functools.partial(
- jax.jit,
- static_argnames=(
- "mask",
- "modulus_377_int_chunk",
- "mu_377_int_chunk",
- "chunk_num_u16",
- "vmap_axes",
- ),
-)
-def mod_mul_barrett_2u16(
- value_a: jax.Array,
- value_b: jax.Array,
- mask=util.U16_MASK,
- modulus_377_int_chunk=util.MODULUS_377_INT_CHUNK,
- mu_377_int_chunk=util.MU_377_INT_CHUNK,
- chunk_num_u16=util.U16_CHUNK_NUM,
- vmap_axes=(0, None),
-):
- """Multiply two u16 values with Barrett reduction.
-
- Args:
- value_a: The first u16 value.
- value_b: The second u16 value.
- mask: The mask to apply to the value.
- modulus_377_int_chunk: The modulus 377.
- mu_377_int_chunk: The Barrett reduction coefficient.
- chunk_num_u16: The number of chunks in the u16 value (default: 24).
- vmap_axes: The axes to use for vmap.
-
- Returns:
- value_c: The result of the multiplication.
- """
- modulus_377_int_array = jnp.asarray(modulus_377_int_chunk, jnp.uint16)
- mu_377_int_array = jnp.asarray(mu_377_int_chunk, jnp.uint16)
-
- mul_2u16_const = functools.partial(mul_2u16, vmap_axes=vmap_axes)
- sub_2u16_const = functools.partial(
- sub_2u16, mask=mask, chunk_num_u16=chunk_num_u16 * 2
- )
- value_x = mul_2u16(value_a, value_b)
- value_d = mul_shift_2u16x2x1(value_x, mu_377_int_array)
- value_e = mul_2u16_const(value_d, modulus_377_int_array)
- value_t = sub_2u16_const(value_x, value_e)
- value_c = cond_sub_mod_u16(value_t[:, :chunk_num_u16])
- return value_c
-
-
-@jax.named_call
-@functools.partial(
- jax.jit,
- static_argnames=(
- "mask",
- "modulus_377_int_chunk",
- "mu_377_int_chunk",
- "chunk_num_u16",
- "vmap_axes",
- ),
-)
-def barrett_reduction_u16x2(
- value_x: jax.Array,
- mask=util.U16_MASK,
- modulus_377_int_chunk=util.MODULUS_377_INT_CHUNK,
- mu_377_int_chunk=util.MU_377_INT_CHUNK,
- chunk_num_u16=util.U16_CHUNK_NUM,
- vmap_axes=(0, None),
-):
- """Performs Barrett reduction on a u16x2 value.
-
- Args:
- value_x: The u16x2 value to perform Barrett reduction on.
- mask: The mask to apply to the value.
- modulus_377_int_chunk: The modulus 377.
- mu_377_int_chunk: The Barrett reduction coefficient.
- chunk_num_u16: The number of chunks in the u16 value (default: 24).
- vmap_axes: The axes to use for vmap.
-
- Returns:
- value_c: The result of the Barrett reduction.
- """
- modulus_377_int_array = jnp.asarray(modulus_377_int_chunk, jnp.uint16)
- mu_377_int_array = jnp.asarray(mu_377_int_chunk, jnp.uint16)
-
- mul_2u16_const = functools.partial(mul_2u16, vmap_axes=vmap_axes)
- value_d = mul_shift_2u16x2x1(value_x, mu_377_int_array)
- value_e = mul_2u16_const(value_d, modulus_377_int_array)
- value_t = sub_2u16(value_x, value_e, mask, chunk_num_u16 * 2)
- value_c = cond_sub_mod_u16(value_t[:, :chunk_num_u16])
- return value_c
-
-
-@jax.named_call
-@functools.partial(
- jax.jit,
- static_argnames=(
- "modulus_lazy_mat",
- "mask",
- "chunk_num_u8",
- "chunk_shift_bits",
- ),
-)
-def mod_mul_lazy_2u16(
- value_a,
- value_b,
- modulus_lazy_mat=util.MODULUS_377_LAZY_MAT,
- mask=util.U32_MASK,
- chunk_num_u8=util.U8_CHUNK_NUM,
- chunk_shift_bits=util.U32_CHUNK_SHIFT_BITS,
-):
- """Multiply two u16 values with lazy matrix reduction.
-
- Args:
- value_a: The first u16 value.
- value_b: The second u16 value.
- modulus_lazy_mat: The lazy matrix.
- mask: The mask to apply to the value.
- chunk_num_u8: The number of chunks in the u8 value.
- chunk_shift_bits: The number of bits to shift the value.
-
- Returns:
- value_c: The result of the multiplication.
- """
- batch_dim = value_a.shape[0]
- modulus_lazy_mat = jnp.asarray(modulus_lazy_mat, dtype=jnp.uint16)
- mul_2u8 = functools.partial(
- mul_2u16,
- mask=util.U32_MASK,
- chunk_num_u16=util.U16_EXT_CHUNK_NUM,
- chunk_shift_bits=util.U32_CHUNK_SHIFT_BITS,
- output_dtype=jnp.uint8,
- )
- value_c = mul_2u8(value_a, value_b)
- standard_product_low = value_c[:, :chunk_num_u8]
- standard_product_high = value_c[:, chunk_num_u8:]
-
- reduced = jnp.matmul(
- standard_product_high.astype(jnp.uint16),
- modulus_lazy_mat.astype(jnp.uint16),
- preferred_element_type=jnp.uint32,
- )
- value_c_reduced = jnp.add(
- standard_product_low.astype(jnp.uint32), reduced.astype(jnp.uint32)
- )
- value_c_reduced_u32 = rechunkify(
- value_c_reduced, chunk_num_u8 // 2, chunk_num_u8 // 4
- )
- value_c_reduced_u32 = jnp.pad(value_c_reduced_u32, ((0, 0), (0, 1)))
-
- value_c_carried = jax.lax.while_loop(
- functools.partial(
- check_any_chunk_with_carry, chunk_shift_bits=chunk_shift_bits
- ),
- functools.partial(
- carry_propagation, mask=mask, chunk_shift_bits=chunk_shift_bits
- ),
- value_c_reduced_u32,
- )
-
- value_c_u16 = jax.lax.bitcast_convert_type(
- value_c_carried.astype(jnp.uint32), jnp.uint16
- ).reshape(batch_dim, -1)[:, : util.U16_EXT_CHUNK_NUM]
- return value_c_u16
-
-
-def split_view_32_to_16(a: jnp.ndarray):
- # Interpret each 32-bit element as two 16-bit numbers
- # and reshape to add an extra dimension of size 2.
- v = a.view(jnp.uint16).reshape(a.shape + (2,))
- # Assuming little-endian storage, the lower 16 bits are at index 0
- # and the upper 16 bits are at index 1.
- lower = v[..., 0]
- upper = v[..., 1]
- return upper, lower
-
-
-def split_view_32_to_16_8(a: jnp.ndarray):
- # First, reshape the 32-bit integers as groups of 4 bytes.
- v8 = a.view(jnp.uint8).reshape(a.shape + (4,))
- # Also, reshape as 16-bit integers (2 per 32-bit element)
- v16 = a.view(jnp.uint16).reshape(a.shape + (2,))
- # For each 32-bit integer:
- # v16[..., 0] gives the lower 16 bits.
- # v8[..., 2] gives the third byte (i.e. the lower 8 bits of the upper 16 bits)
- lower = v16[..., 0]
- upper8 = v8[..., 2]
- return upper8, lower
-
-
-# Reduce via RNS modulus
-@jax.named_call
-@functools.partial(
- jax.jit,
- static_argnames="moduli_t",
-)
-def moduli_rns_red_internal_2u16(vals, moduli_t=util.RNS_MODULI_T):
- """Reduce via RNS modulus.
-
- Args:
- vals: The values to reduce.
- moduli_t: The moduli for the target.
-
- Returns:
- The reduced values.
- """
- # See jaxite_ec/advanced_algorithm/rns_red.py for description
- moduli_t = jnp.array(moduli_t, dtype=jnp.uint8)
- u1, l1 = split_view_32_to_16(vals)
- i1 = jnp.add(
- l1.astype(jnp.uint32),
- jnp.multiply(u1.astype(jnp.uint32), moduli_t),
- )
- u2, l2 = split_view_32_to_16_8(i1)
- i2 = jnp.add(
- l2.astype(jnp.uint32),
- jnp.multiply(u2.astype(jnp.uint16), moduli_t).astype(
- jnp.uint32
- ),
- )
- u3, l3 = split_view_32_to_16_8(i2)
- out = jnp.add(l3, jnp.multiply(u3, moduli_t).astype(jnp.uint16))
- return out
-
-
-# Reduce via prime modulus
-@jax.named_call
-@functools.partial(
- jax.jit,
- static_argnames=("rns_mat", "moduli_t", "num_moduli", "precision"),
-)
-def mod_red_rns_2u16(
- c_rns_reduced,
- rns_mat=util.RNS_MAT,
- moduli_t=util.RNS_MODULI_T,
- num_moduli=util.NUM_MODULI,
- precision=util.RNS_PRECISION,
-):
- """Reduce via RNS modulus.
-
- Args:
- c_rns_reduced: The values to reduce.
- rns_mat: The RNS precompute.
- moduli_t: The moduli for the target.
- num_moduli: The number of moduli.
- precision: The precision.
-
- Returns:
- The reduced values.
- """
- rns_stacked_mat = jnp.array(rns_mat[0], jnp.uint8)
- cor_mat = jnp.array(rns_mat[1], jnp.uint16)
-
- c_target = jnp.matmul(
- c_rns_reduced.view(jnp.uint8),
- rns_stacked_mat,
- preferred_element_type=jnp.uint32,
- )
-
- mul_res_glb_red_u32 = c_target.reshape(*c_target.shape[:-1], -1, 2)
- mul_res_glb_red_u32 = mul_res_glb_red_u32[..., 0] + (
- mul_res_glb_red_u32[..., 1] << 8
- )
- rns_reduce_u32, qe_u32 = jnp.split(
- mul_res_glb_red_u32, [num_moduli], axis=1
- )
-
- # obtain the high 32 bits from the quotient estimation results qe_u32
- k = (qe_u32 >> precision).astype(jnp.uint16)
- c_corrected = rns_reduce_u32 + jnp.matmul(
- k, cor_mat, preferred_element_type=jnp.uint32
- )
-
- return moduli_rns_red_internal_2u16(c_corrected, moduli_t)
-
-
-# Multiply, without reducing
-@jax.named_call
-@functools.partial(
- jax.jit,
- static_argnames="moduli_t",
-)
-def mul_unreduced_rns_2u16(
- value_a,
- value_b,
- moduli_t=util.RNS_MODULI_T,
-):
- ab = jnp.multiply(value_a.astype(jnp.uint32), value_b.astype(jnp.uint32))
- return moduli_rns_red_internal_2u16(ab, moduli_t)
-
-
-# Multiply and reduce
-@jax.named_call
-@functools.partial(
- jax.jit,
- static_argnames=("rns_mat", "moduli_t"),
-)
-def mod_mul_rns_2u16(
- value_a,
- value_b,
- rns_mat=util.RNS_MAT,
- moduli_t=util.RNS_MODULI_T,
-):
- """Multiply two u16 values with RNS reduction.
-
- Args:
- value_a: The first u16 value.
- value_b: The second u16 value.
- rns_mat: The RNS precompute.
- moduli_t: The moduli for the target.
-
- Returns:
- The product of the two u16 values.
- """
- ab = mul_unreduced_rns_2u16(value_a, value_b, moduli_t)
- return mod_red_rns_2u16(ab, rns_mat, moduli_t)
-
-
-@jax.named_call
-@functools.partial(
- jax.jit,
- static_argnames="moduli_t",
-)
-def add_rns_2u16(
- value_a: jax.Array,
- value_b: jax.Array,
- moduli_t=util.RNS_MODULI_T,
-):
- """Add two u16 values with RNS reduction.
-
- Args:
- value_a: The first u16 value.
- value_b: The second u16 value.
- moduli_t: The moduli for the target.
-
- Returns:
- The sum of the two u16 values.
- """
- return add_sub_rns_var(value_a, value_b, moduli_t=moduli_t)
-
-
-@jax.named_call
-@functools.partial(
- jax.jit,
- static_argnames="moduli_t",
-)
-def add_rns_3u16(
- value_a: jax.Array,
- value_b: jax.Array,
- value_c: jax.Array,
- moduli_t=util.RNS_MODULI_T,
-):
- """Add three u16 values with RNS reduction.
-
- Args:
- value_a: The first u16 value.
- value_b: The second u16 value.
- value_c: The third u16 value.
- moduli_t: The moduli for the target.
-
- Returns:
- The sum of the three u16 values.
- """
- return add_sub_rns_var(value_a, value_b, value_c, moduli_t=moduli_t)
-
-
-@jax.named_call
-@functools.partial(
- jax.jit,
- static_argnames="moduli_sub",
-)
-def negate_rns_for_var_add(
- value_a: jax.Array,
- moduli_sub=util.MODULI_SUB,
-):
- """Negate a value for use in subtraction.
-
- Do not use the output in any function but add_sub_rns_var -- may break
- correctness.
-
- Args:
- value_a: RNS array to negate
- moduli_sub: Precomputed constants for performing negation, that depend on
- the target modulus
-
- Returns:
- An intermediate representing the negation of values_a in the target modulus
- in RNS form.
-
- Note: original data precision is 16 bit, using uint32 to avoid overflow
- """
- moduli_sub = jnp.array(moduli_sub, dtype=jnp.uint32)
-
- return jnp.add(
- jnp.negative(value_a.astype(jnp.uint16)).astype(jnp.uint32),
- moduli_sub,
- )
-
-
-@jax.named_call
-@functools.partial(
- jax.jit,
- static_argnames="moduli_sub",
-)
-def negate_rns_for_var_add_zero_check(
- value_a: jax.Array,
- moduli_sub=util.MODULI_SUB,
-):
- """Negate a value for use in subtraction.
-
- Do not use the output in any function but add_sub_rns_var -- may break
- correctness.
-
- Args:
- value_a: RNS array to negate
- moduli_sub: Precomputed constants for performing negation, that depend on
- the target modulus
-
- Returns:
- An intermediate representing the negation of values_a in the target modulus
- in RNS form.
-
- Note: original data precision is 16 bit, using uint32 to avoid overflow
- """
-
- moduli_sub = jnp.array(moduli_sub, dtype=jnp.uint32)
- a = value_a.astype(jnp.uint16)
-
- # Compute two's complement negation: for nonzero a, jnp.negative(a) computes
- # (2^16 - a).
- neg = jnp.negative(a).astype(jnp.uint32)
-
- # Build a branchless mask: 0 if a==0, 1 otherwise.
- mask = (a != 0).astype(jnp.uint32)
-
- # For nonzero a: (2^16 - a) + moduli_sub; for zero: m + 0 multiplied by 0
- # gives 0.
- return (neg + moduli_sub) * mask
-
-
-@jax.named_call
-@functools.partial(
- jax.jit,
- static_argnames="moduli_t",
-)
-def add_sub_rns_var(*values, moduli_t=util.RNS_MODULI_T):
- """Evaluate an static set of additions and subtractions.
-
- Subtractions are implemented by calling negate_rns_for_var_add on inputs to
- this function. Inputs should be "fresh" or multiplication outputs, and the
- output should be used as a multiplication input. Any other usage produces
- undefined behavior and may break correctness.
-
- Args:
- *values: A list of RNS values to accumulate
- moduli_t: The moduli for the RNS form.
-
- Returns:
- The RNS form of the evaluation of the expession.
- """
- assert len(values) > 0
- acc = None
- for v in values:
- if acc != None:
- acc = jnp.add(v.astype(jnp.uint32), acc)
- else:
- acc = v.astype(jnp.uint32)
- assert len(values) < 256
- moduli_t = jnp.array(moduli_t, dtype=jnp.uint8)
- # u1 < 254
- u1, l1 = split_view_32_to_16_8(acc)
- # i1 < 2**16 - 1 + 255t < 2**17 - t for 8 bit t
- i1 = jnp.add(
- jnp.multiply(u1.astype(np.uint16), moduli_t).astype(jnp.uint32),
- l1.astype(jnp.uint32),
- )
- # u2 = 0 or 1, but if u2 = 1 then l < 2**16 - t, so 2**16 - t + t < 2**16
- u2, l2 = split_view_32_to_16_8(i1)
- return jnp.add(jnp.multiply(u2, moduli_t).astype(jnp.uint16), l2)
-
-
-@functools.partial(jax.jit, static_argnames=("c", "num_moduli"))
-def rns_constant(c, num_moduli=util.NUM_MODULI):
- assert c >= 0
- assert c < 2**14 # small constants only please
- return jnp.repeat(jnp.array([c], dtype=jnp.uint16), num_moduli)
diff --git a/jaxite_ec/finite_field_context.py b/jaxite_ec/finite_field_context.py
new file mode 100644
index 0000000..0a690ce
--- /dev/null
+++ b/jaxite_ec/finite_field_context.py
@@ -0,0 +1,1204 @@
+from abc import ABC, abstractmethod
+import math
+from typing import Any, List, Union
+import warnings
+import jax
+import jax.numpy as jnp
+from jaxite.jaxite_ec import utils
+import numpy as np
+
+JaxKernelContextBase = utils.JaxKernelContextBase
+JaxParameters = utils.JaxParameters
+hash_args = utils.hash_args
+jax_jit_lower_compile = utils.jax_jit_lower_compile
+load_jax_executable = utils.load_jax_executable
+pad_jax_array = utils.pad_jax_array
+store_jax_executable = utils.store_jax_executable
+
+jax.config.update("jax_enable_x64", True)
+
+
+class FiniteFieldContextBase(ABC):
+ """Abstract base class defining the interface for finite field operations.
+
+ Subclasses must implement all abstract methods to provide concrete
+ finite field arithmetic operations.
+ """
+
+ prime: Any = None
+ parameters: Any = None
+ rns_moduli: Any = None
+ radix_bits: Any = None
+
+ @abstractmethod
+ def __init__(self, parameters: dict):
+ """Initialize the finite field context.
+
+ Args:
+ parameters: Configuration dictionary containing field parameters.
+ """
+ self.prime = parameters.get("prime", None)
+ assert self.prime is not None, "prime must be provided"
+ self.parameters = parameters
+
+ @abstractmethod
+ def to_computational_format(self, a) -> jnp.ndarray:
+ """Convert input to the internal computational representation.
+
+ Args:
+ a: Input value in standard format.
+
+ Returns:
+ Value converted to computational format (e.g., Montgomery form).
+ """
+ pass
+
+ def _modular_multiply(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
+ return a
+
+ def _modular_add(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
+ return a
+
+ def _modular_subtract(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
+ return a
+
+ def _modular_negate(self, a: jnp.ndarray) -> jnp.ndarray:
+ return a
+
+ def _modular_reduce(self, a: jnp.ndarray) -> jnp.ndarray:
+ return a
+
+ @abstractmethod
+ def to_original_format(self, a) -> Union[int, List[Any]]:
+ """Convert from computational format back to standard representation.
+
+ Args:
+ a: Value in computational format.
+
+ Returns:
+ Value in standard integer representation.
+ """
+ pass
+
+ @abstractmethod
+ def context_hash(self) -> str:
+ pass
+
+ @abstractmethod
+ def modular_multiply(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
+ """Perform modular multiplication: (a * b) mod p.
+
+ Args:
+ a: First operand in computational format.
+ b: Second operand in computational format.
+
+ Returns:
+ Product in computational format.
+ """
+ pass
+
+ @abstractmethod
+ def modular_add(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
+ """Perform modular addition: (a + b) mod p.
+
+ Args:
+ a: First operand.
+ b: Second operand.
+
+ Returns:
+ Sum reduced modulo p.
+ """
+ pass
+
+ @abstractmethod
+ def modular_subtract(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
+ """Perform modular subtraction: (a - b) mod p.
+
+ Args:
+ a: First operand.
+ b: Second operand.
+
+ Returns:
+ Difference reduced modulo p.
+ """
+ pass
+
+ @abstractmethod
+ def modular_negate(self, a: jnp.ndarray) -> jnp.ndarray:
+ """Perform modular negation: -a mod p.
+
+ Args:
+ a: Operand.
+
+ Returns:
+ Negation reduced modulo p.
+ """
+ pass
+
+ @abstractmethod
+ def modular_reduce(self, a: jnp.ndarray) -> jnp.ndarray:
+ """Reduce value modulo the field prime.
+
+ Args:
+ a: Value to reduce (may be larger than prime).
+
+ Returns:
+ Value reduced to [0, p).
+ """
+ pass
+
+
+class RNSContextBase(FiniteFieldContextBase):
+
+ def __init__(self, parameters: dict):
+ super().__init__(parameters)
+ self.rns_moduli = parameters.get("rns_moduli", [])
+ assert len(self.rns_moduli) != 0, "rns_moduli must be non-empty"
+ self.total_modulus = math.prod(self.rns_moduli)
+ assert (
+ self.total_modulus > self.prime
+ ), "RNS total modulus must be greater than prime"
+ self.precision_bits = parameters.get(
+ "precision_bits", 0
+ ) # Default precision bits
+ assert self.precision_bits != 0, "precision bits must be non-zero"
+
+ self.crt_factors = self._compute_crt_factors(self.rns_moduli)
+ # RNSContextBase._check_parameters(self)
+
+ def _check_parameters(self):
+ pass # disable the warning for now
+ sum_moduli_bits = math.ceil(math.log2(sum(self.rns_moduli)))
+ if self.precision_bits < sum_moduli_bits:
+ warnings.warn(
+ "precision_bits is less than sum of moduli_bits, precision_bits:"
+ f" {self.precision_bits}, sum of moduli_bits: {sum_moduli_bits}."
+ + "This may cause precision loss.",
+ UserWarning,
+ stacklevel=2,
+ )
+
+ max_modulus_bits = max(math.ceil(math.log2(m)) for m in self.rns_moduli)
+ required_precision = max_modulus_bits + math.log2(len(self.rns_moduli)) + 1
+ if self.precision_bits < required_precision:
+ warnings.warn(
+ "precision_bits is less than required precision, precision_bits:"
+ f" {self.precision_bits}, required precision: {required_precision}."
+ + "This may cause precision loss.",
+ UserWarning,
+ stacklevel=2,
+ )
+
+ def _compute_crt_factors(self, moduli: list[int]):
+ ms = [self.total_modulus // m for m in moduli]
+ ms_inv = [
+ utils.modular_inverse(ms[i], moduli[i]) for i in range(len(moduli))
+ ]
+ return [
+ (ms[i] * ms_inv[i]) % self.total_modulus for i in range(len(moduli))
+ ]
+
+ def _elementwise_add(self, a: list[int], b: list[int]):
+ assert len(a) == len(b), "a and b must have the same length"
+ return [a[i] + b[i] for i in range(len(a))]
+
+ def _elementwise_subtract(self, a: list[int], b: list[int]):
+ assert len(a) == len(b), "a and b must have the same length"
+ return [a[i] - b[i] for i in range(len(a))]
+
+ def _elementwise_multiply(self, a: list[int], b: list[int]):
+ assert len(a) == len(b), "a and b must have the same length"
+ return [a[i] * b[i] for i in range(len(a))]
+
+ def _elementwise_reduce(self, z: list[int], m: list[int]):
+ assert len(z) == len(m), "z and m must have the same length"
+ return [z[i] % m[i] for i in range(len(m))]
+
+ def _elementwise_left_shift(self, a: list[int], b: Union[list[int], int]):
+ if isinstance(b, list):
+ assert len(a) == len(b), "a and b must have the same length"
+ return [a[i] << b[i] for i in range(len(a))]
+ else:
+ return [a[i] << b for i in range(len(a))]
+
+ def _elementwise_right_shift(self, a: list[int], b: Union[list[int], int]):
+ if isinstance(b, list):
+ assert len(a) == len(b), "a and b must have the same length"
+ return [a[i] >> b[i] for i in range(len(a))]
+ else:
+ return [a[i] >> b for i in range(len(a))]
+
+ def _elementwise_and(self, a: list[int], b: Union[list[int], int]):
+ if isinstance(b, list):
+ assert len(a) == len(b), "a and b must have the same length"
+ return [a[i] & b[i] for i in range(len(a))]
+ else:
+ return [a[i] & b for i in range(len(a))]
+
+ def _rns_decompose(self, x: int, moduli: list[int]):
+ return [(x % m) for m in moduli]
+
+ def _get_crns_vector_I_before_reducing(self, moduli_m: list[int], y):
+ vector_i = []
+ modular_M = math.prod(moduli_m)
+ for i, m_i in enumerate(moduli_m):
+ # M_y_over_m_i = (modular_M * y) // m_i
+ M_y_over_m_i = (modular_M) // m_i
+ # logging.info(f"M_y_over_m_i: {M_y_over_m_i}")
+ inv_M_y_over_m_i = utils.modular_inverse(M_y_over_m_i % m_i, m_i)
+ # logging.info(f"inv_M_y_over_m_i: {inv_M_y_over_m_i}")
+ I_i = inv_M_y_over_m_i * M_y_over_m_i * y
+ vector_i.append(I_i)
+ return vector_i
+
+ def get_moduli_num(self) -> int:
+ return len(self.rns_moduli)
+
+ def _crns(
+ self,
+ x: list[int],
+ matrix_E: list[list[int]],
+ vector_f_T: list[int],
+ vector_g: list[int],
+ moduli: list[int],
+ u: int,
+ ):
+ """CRNS Computation based on Algorithm steps 9-12
+
+ Args:
+ x: Input vector x_M in RNS representation
+ matrix_E: Precomputed matrix E from crns_precompute
+ vector_f_T: Precomputed vector f^T from crns_precompute
+ vector_g: Precomputed vector g from crns_precompute
+ moduli: RNS moduli (should be self.rns_moduli_n)
+
+ Returns:
+ x_N: Result vector in RNS representation modulo N
+ """
+ # Step 9: E, f^T, g = CRNSPRECOMPUTATION(M, y, N, z, u) - already done
+ # (matrix_E, vector_f_T, vector_g are passed as parameters)
+
+ # Step 10: v = x_M · f^T (Dot product; can be parallelized with x_M · E)
+ v = sum(x[i] * vector_f_T[i] for i in range(len(x)))
+
+ # Step 11: k = ⌊v/2^u⌋ (Bitshifting computes fixed-point quotient)
+ k = v >> u # Right shift by u bits is equivalent to floor division by 2^u
+
+ # Step 12: return x_N = x_M · E + k·g (vector-matrix multiply, scalar-vector multiply, compute ICRT)
+ x_N = []
+ for j in range(len(moduli)):
+ # Compute x_M · E for column j
+ dot_product = sum(x[i] * matrix_E[i][j] for i in range(len(x)))
+ dot_product = dot_product % moduli[j]
+ x_N.append(dot_product)
+
+ for j in range(len(moduli)):
+ # Add k * g_j and take modulo n_j
+ # result = (x_N[j] + k * vector_g[j]) % moduli[j]
+ result = x_N[j] + k * vector_g[j]
+ x_N[j] = result
+ return x_N
+
+
+class DRNSlazyContextBase(RNSContextBase):
+
+ def __init__(self, parameters: dict):
+ super().__init__(parameters)
+ self.radix_bits = parameters.get("radix_bits", 0)
+ assert self.radix_bits != 0, " radix bits must be non-zero"
+ self.moduli_inv_on_radix = [
+ utils.modular_inverse(m, 1 << self.radix_bits) for m in self.rns_moduli
+ ]
+ self.crns_y, self.crns_z = self._precompute_crns_parameters()
+ self.matrix_E, self.vector_f_T, self.vector_g = self._crns_precompute(
+ self.rns_moduli,
+ self.prime,
+ self.crns_y,
+ self.crns_z,
+ self.precision_bits,
+ )
+
+ def _check_parameters(self):
+ pass # disable the warning for now
+ if (self.prime << self.w) > self.total_modulus:
+ warnings.warn(
+ "Total modulus is not enough to hold prime in DRNSlazy, total"
+ f" modulus: {self.total_modulus}, prime: {self.prime}."
+ + "This may cause finite field overflow.",
+ UserWarning,
+ stacklevel=2,
+ )
+
+ def _precompute_crns_parameters(self):
+ radix_inv = utils.modular_inverse(2**self.radix_bits, self.total_modulus)
+ return radix_inv, 1 << (2 * self.radix_bits)
+
+ def _crns_precompute(
+ self,
+ moduli: list[int],
+ prime: int,
+ y: int,
+ z: int,
+ u: int,
+ ):
+ modular = math.prod(moduli)
+ vector_i = []
+ for i, m_i in enumerate(moduli):
+ M_over_m_i = modular // m_i
+ inv_M_over_m_i = utils.modular_inverse(M_over_m_i % m_i, m_i)
+ I_i = (inv_M_over_m_i * M_over_m_i * y) % modular
+ I_i = I_i
+ vector_i.append(I_i)
+
+ matrix_E = []
+ for i, I_i in enumerate(vector_i):
+ E_row = []
+ for j, n_j in enumerate(moduli):
+ E_ij = (z * (I_i % prime)) % n_j
+ E_row.append(E_ij)
+ matrix_E.append(E_row)
+
+ vector_f_T = []
+ for I_i in vector_i:
+ f_i_T = math.ceil((I_i * (1 << u)) / modular)
+ vector_f_T.append(f_i_T)
+
+ vector_g = []
+ for n_j in moduli:
+ g_j = (-z * (modular % prime)) % n_j
+ vector_g.append(g_j)
+
+ return matrix_E, vector_f_T, vector_g
+
+ def _elementwise_montgomery_reduce(
+ self, z: list[int], m: list[int], m_inv: list[int]
+ ):
+ assert len(z) == len(m), "z and m must have the same length"
+ assert len(z) == len(m_inv), "z and m_inv must have the same length"
+ mask = (1 << self.radix_bits) - 1
+ z_low = self._elementwise_and(z, mask)
+ z_high = self._elementwise_right_shift(z, self.radix_bits)
+ q = self._elementwise_and(self._elementwise_multiply(z_low, m_inv), mask)
+ h = self._elementwise_right_shift(
+ self._elementwise_multiply(q, m), self.radix_bits
+ )
+ t = self._elementwise_subtract(z_high, h)
+ t = self._elementwise_add(t, m)
+ return t
+
+
+class DRNSlazyContext(DRNSlazyContextBase, JaxKernelContextBase):
+
+ def __init__(self, parameters: dict):
+ super().__init__(parameters)
+ JaxKernelContextBase.__init__(self)
+ self.jax_parameters = JaxParameters()
+ self._init_jax_parameters()
+
+ def to_computational_format(self, a: Union[int, list]) -> jnp.ndarray:
+ moduli = self.rns_moduli
+ radix_bits = self.radix_bits
+
+ def individual_convert(a: int) -> list:
+ return [(((a % m) << radix_bits) % m) for m in moduli]
+
+ def recursive_convert(a):
+ if isinstance(a, int):
+ return individual_convert(a)
+ return [recursive_convert(a_i) for a_i in a]
+
+ converted_list = recursive_convert(a)
+ converted_a = jnp.array(
+ np.array(converted_list, dtype=np.uint32), dtype=jnp.uint32
+ )
+ if self.use_sharding:
+ named_sharding, padded_shape = self.create_named_sharding(
+ shape=converted_a.shape, axes=[0]
+ )
+ converted_a = pad_jax_array(converted_a, padded_shape)
+ return converted_a.to_device(named_sharding)
+ else:
+ return converted_a.to_device(jax.devices()[0])
+
+ def to_original_format(self, a: jnp.ndarray) -> Union[int, list]: # type: ignore
+ def individual_convert(a: list[int]) -> int:
+ a = self._elementwise_montgomery_reduce(
+ a, self.rns_moduli, self.moduli_inv_on_radix
+ )
+ r = 0
+ for i in range(len(a)):
+ r = (r + a[i] * self.crt_factors[i]) % self.total_modulus
+ return r % self.prime
+
+ def recursive_convert(a):
+ if a.ndim == 1:
+ return individual_convert(a.tolist())
+ return [recursive_convert(a_i) for a_i in a]
+
+ return recursive_convert(a)
+
+ def _get_shape_dtype_structs(
+ self, parameters: dict
+ ) -> list[jax.ShapeDtypeStruct]:
+ batch_shape = parameters["batch_shape"]
+ oprand_shape = batch_shape + (len(self.rns_moduli),)
+ if self.use_sharding:
+ named_sharding, padded_shape = self.create_named_sharding(
+ shape=oprand_shape, axes=[0]
+ )
+ return [
+ jax.ShapeDtypeStruct(
+ padded_shape, jnp.uint32, sharding=named_sharding
+ )
+ ]
+ return [jax.ShapeDtypeStruct(oprand_shape, jnp.uint32)]
+
+ def context_hash(self) -> str:
+ return hash_args(
+ self.__class__.__name__,
+ self.prime,
+ self.rns_moduli,
+ self.precision_bits,
+ self.radix_bits,
+ self.use_sharding,
+ )
+
+ def serialize(self, parameters: dict):
+ shape_dtype_structs = self._get_shape_dtype_structs(parameters)
+ kernel_hash = hash_args(self.context_hash(), parameters)
+ class_name = self.__class__.__name__
+
+ store_jax_executable(
+ self._modular_multiply,
+ shape_dtype_structs[0],
+ shape_dtype_structs[0],
+ name=f"{class_name}_modular_multiply_{kernel_hash}",
+ )
+ store_jax_executable(
+ self._modular_add,
+ shape_dtype_structs[0],
+ shape_dtype_structs[0],
+ name=f"{class_name}_modular_add_{kernel_hash}",
+ )
+ store_jax_executable(
+ self._modular_subtract,
+ shape_dtype_structs[0],
+ shape_dtype_structs[0],
+ name=f"{class_name}_modular_subtract_{kernel_hash}",
+ )
+ store_jax_executable(
+ self._modular_reduce,
+ shape_dtype_structs[0],
+ name=f"{class_name}_modular_reduce_{kernel_hash}",
+ )
+ store_jax_executable(
+ self._modular_negate,
+ shape_dtype_structs[0],
+ name=f"{class_name}_modular_negate_{kernel_hash}",
+ )
+
+ def compile(self, parameters: dict):
+ shape_dtype_structs = self._get_shape_dtype_structs(parameters)
+ kernel_hash = hash_args(self.context_hash(), parameters)
+ class_name = self.__class__.__name__
+
+ modular_multiply_kernel = load_jax_executable(
+ f"{class_name}_modular_multiply_{kernel_hash}"
+ )
+ modular_add_kernel = load_jax_executable(
+ f"{class_name}_modular_add_{kernel_hash}"
+ )
+ modular_subtract_kernel = load_jax_executable(
+ f"{class_name}_modular_subtract_{kernel_hash}"
+ )
+ modular_reduce_kernel = load_jax_executable(
+ f"{class_name}_modular_reduce_{kernel_hash}"
+ )
+ modular_negate_kernel = load_jax_executable(
+ f"{class_name}_modular_negate_{kernel_hash}"
+ )
+
+ if None in [
+ modular_multiply_kernel,
+ modular_add_kernel,
+ modular_subtract_kernel,
+ modular_reduce_kernel,
+ modular_negate_kernel,
+ ]:
+ # if not self.use_sharding:
+ warnings.warn(
+ f"Not found stored serialized compiled kernels, compiling...",
+ UserWarning,
+ stacklevel=2,
+ )
+
+ kernel_hash = hash_args(
+ shape_dtype_structs[0].shape, shape_dtype_structs[0].dtype.__str__()
+ )
+ # print(f"kernel_hash: {kernel_hash}")
+
+ self.compiled_kernels[kernel_hash] = {
+ "modular_multiply": (
+ modular_multiply_kernel
+ if modular_multiply_kernel is not None
+ else jax_jit_lower_compile(
+ self._modular_multiply,
+ shape_dtype_structs[0],
+ shape_dtype_structs[0],
+ )
+ ),
+ "modular_add": (
+ modular_add_kernel
+ if modular_add_kernel is not None
+ else jax_jit_lower_compile(
+ self._modular_add,
+ shape_dtype_structs[0],
+ shape_dtype_structs[0],
+ )
+ ),
+ "modular_subtract": (
+ modular_subtract_kernel
+ if modular_subtract_kernel is not None
+ else jax_jit_lower_compile(
+ self._modular_subtract,
+ shape_dtype_structs[0],
+ shape_dtype_structs[0],
+ )
+ ),
+ "modular_reduce": (
+ modular_reduce_kernel
+ if modular_reduce_kernel is not None
+ else jax_jit_lower_compile(
+ self._modular_reduce, shape_dtype_structs[0]
+ )
+ ),
+ "modular_negate": (
+ modular_negate_kernel
+ if modular_negate_kernel is not None
+ else jax_jit_lower_compile(
+ self._modular_negate, shape_dtype_structs[0]
+ )
+ ),
+ }
+ self.use_compiled_kernels = True
+
+ def _jax_crns_precompute(
+ self, moduli: list[int], prime: int, y: int, z: int, u: int
+ ):
+ rns_moduli_bytes = 4
+ vector_I = self._get_crns_vector_I_before_reducing(moduli, y)
+ modular = math.prod(moduli)
+
+ vector_I_byteshifted = []
+ for value_i in vector_I:
+ value_i_byteshifted = [
+ (value_i << (8 * byte_idx)) % modular
+ for byte_idx in range(rns_moduli_bytes)
+ ]
+ vector_I_byteshifted = vector_I_byteshifted + value_i_byteshifted
+
+ matrix_E = []
+ for value_i_byteshifted in vector_I_byteshifted:
+ value_i_byteshifted = z * (value_i_byteshifted % prime)
+ matrix_E.append(self._rns_decompose(value_i_byteshifted, moduli))
+
+ matrix_E_np = np.array(matrix_E, dtype=np.uint32).reshape(-1, len(moduli))
+ matrix_E_np_u8 = matrix_E_np.view(np.uint8)
+
+ vector_f = []
+ for value_i_byteshifted in vector_I_byteshifted:
+ value_i_byteshifted_f = math.ceil(
+ (value_i_byteshifted * (1 << u)) / modular
+ )
+ vector_f.append(value_i_byteshifted_f)
+ vector_f_T_np = np.array(vector_f, dtype=np.uint32).reshape(-1, 1)
+ vector_f_T_np_u8 = vector_f_T_np.view(np.uint8)
+
+ vector_g = self._rns_decompose(-z * (modular % prime), moduli)
+ vector_g_np = np.array(vector_g, dtype=np.uint32).reshape(1, -1)
+
+ matrix_E_with_f_T_np = np.hstack((matrix_E_np_u8, vector_f_T_np_u8))
+ matrix_E_with_f_T_np = matrix_E_with_f_T_np.reshape(
+ len(moduli), 4, len(moduli) + 1, 4
+ ) # NOTE: reshape is special part for the optimization
+ return matrix_E_with_f_T_np.tolist(), vector_g_np.tolist()
+
+ def _init_jax_parameters(self):
+ half_word_mask = 0xFFFF
+ half_word_bits = 16
+ word_bits = 32
+ word_mask = (1 << word_bits) - 1
+ rns_moduli_low = [m & half_word_mask for m in self.rns_moduli]
+ rns_moduli_high = [m >> half_word_bits for m in self.rns_moduli]
+ rns_moduli_inv_word = [
+ utils.modular_inverse(m, 2**word_bits) for m in self.rns_moduli
+ ]
+ crns_precision = self.precision_bits
+ crns_matrix_E_with_f_T, crns_vector_g = self._jax_crns_precompute(
+ self.rns_moduli,
+ self.prime,
+ self.crns_y,
+ self.crns_z,
+ self.precision_bits,
+ )
+ num_moduli = len(self.rns_moduli)
+ moduli_sub = self.to_computational_format(
+ 256 * num_moduli * 4 * 2 * self.prime
+ )
+
+ self.jax_parameters.set_parameter(
+ word_mask=word_mask,
+ half_word_mask=half_word_mask,
+ half_word_bits=half_word_bits,
+ word_bits=word_bits,
+ rns_moduli=jnp.array(self.rns_moduli, dtype=jnp.uint64),
+ rns_moduli_low=jnp.array(rns_moduli_low, dtype=jnp.uint16),
+ rns_moduli_high=jnp.array(rns_moduli_high, dtype=jnp.uint16),
+ rns_moduli_inv_word=jnp.array(rns_moduli_inv_word, dtype=jnp.uint32),
+ crns_precision=jnp.array(crns_precision, dtype=jnp.uint32),
+ crns_stacked_mat_E_with_f_T=jnp.array(
+ crns_matrix_E_with_f_T, dtype=jnp.uint8
+ ),
+ crns_vector_g=jnp.array(crns_vector_g, dtype=jnp.uint32),
+ rns_moduli_sub=jnp.array(moduli_sub, dtype=jnp.uint32),
+ rns_moduli_negate=jnp.array(self.rns_moduli, dtype=jnp.uint32),
+ )
+
+ def _jax_montgomery_reduce(self, z: jax.Array) -> jax.Array:
+
+ # Computation
+ z_low = z.astype(jnp.uint32)
+ z_high = (z >> self.jax_parameters.word_bits).astype(jnp.uint32)
+ t = (
+ z_low * self.jax_parameters.rns_moduli_inv_word
+ ) & self.jax_parameters.word_mask
+ t_low = t & self.jax_parameters.half_word_mask
+ t_high = (
+ t >> self.jax_parameters.half_word_bits
+ ) & self.jax_parameters.half_word_mask
+
+ prod_high = (
+ t_high * self.jax_parameters.rns_moduli_high
+ ) # This contributes directly to upper 32 bits
+ prod_mid_high = (
+ t_high * self.jax_parameters.rns_moduli_low
+ ) # Upper 16 bits go to upper 32 bits
+ prod_mid_low = (
+ t_low * self.jax_parameters.rns_moduli_high
+ ) # Upper 16 bits go to upper 32 bits
+ prod_low = (
+ t_low * self.jax_parameters.rns_moduli_low
+ ) # Upper 16 bits contribute to middle part
+ mid_low = (
+ (prod_mid_high & self.jax_parameters.half_word_mask)
+ + (prod_mid_low & self.jax_parameters.half_word_mask)
+ + (prod_low >> self.jax_parameters.half_word_bits)
+ )
+ mid_high = (
+ (prod_mid_high >> self.jax_parameters.half_word_bits)
+ + (prod_mid_low >> self.jax_parameters.half_word_bits)
+ + (mid_low >> self.jax_parameters.half_word_bits)
+ )
+
+ # Final upper 32 bits
+ t_final = prod_high + mid_high
+ b = z_high + self.jax_parameters.rns_moduli - t_final
+ return b.astype(jnp.uint32)
+
+ def _jax_crns(self, z: jax.Array) -> jax.Array:
+ shift_factors = jnp.array([0, 8, 16, 24], dtype=jnp.uint64)
+ precision_mask = jnp.array(
+ (1 << self.jax_parameters.crns_precision) - 1, dtype=jnp.uint32
+ )
+ num_moduli_n = self.jax_parameters.crns_vector_g.shape[1]
+
+ einsum_subscripts = "...kq,kqnp->...np"
+ # computation
+ x_u8 = jax.lax.bitcast_convert_type(z, jnp.uint8)
+
+ # x_reconstruction_with_v = jnp.matmul(x_u8, stacked_mat_E_with_f_T, preferred_element_type=jnp.uint32)
+ x_reconstruction_with_v = jnp.einsum(
+ einsum_subscripts,
+ x_u8,
+ self.jax_parameters.crns_stacked_mat_E_with_f_T,
+ preferred_element_type=jnp.uint32,
+ )
+ x_reconstruction_with_v_u64 = jnp.sum(
+ x_reconstruction_with_v.astype(jnp.uint64) << shift_factors, axis=(-1,)
+ )
+
+ x_n, vector_v = jnp.split(
+ x_reconstruction_with_v_u64, [num_moduli_n], axis=-1
+ )
+ vector_v = (vector_v >> self.jax_parameters.crns_precision) & precision_mask
+
+ x_n = x_n + jnp.multiply(vector_v, self.jax_parameters.crns_vector_g)
+ return x_n
+
+ def _modular_multiply(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
+ z = jnp.multiply(a.astype(jnp.uint64), b.astype(jnp.uint64))
+ z_reduced = self._jax_montgomery_reduce(z)
+ z_rns_reduced = self._jax_crns(
+ z_reduced
+ ) # could be skipped for small prime
+ z_reduced = self._jax_montgomery_reduce(
+ z_rns_reduced
+ ) # could be skipped for small prime (paired with the above)
+ return z_reduced
+
+ def _modular_add(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
+ return a + b
+
+ def _modular_subtract(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
+ return self._modular_negate(b) + a
+
+ def _modular_reduce(self, a: jnp.ndarray) -> jnp.ndarray:
+ z_rns_reduced = self._jax_crns(a)
+ z_reduced = self._jax_montgomery_reduce(z_rns_reduced)
+ return z_reduced
+
+ def _modular_negate(self, a: jax.Array) -> jax.Array:
+ return jnp.add(
+ jnp.subtract(self.jax_parameters.rns_moduli_negate, a),
+ self.jax_parameters.rns_moduli_sub,
+ )
+
+ def modular_multiply(self, a: jax.Array, b: jax.Array) -> jax.Array:
+ kernel_hash = hash_args(a.shape, a.dtype.__str__())
+ if self.use_compiled_kernels:
+ print(f"using compiled kernel for modular_multiply: {kernel_hash}")
+ return self.compiled_kernels[kernel_hash]["modular_multiply"](a, b)
+ else:
+ return self._modular_multiply(a, b)
+
+ def modular_add(self, a: jax.Array, b: jax.Array) -> jax.Array:
+ kernel_hash = hash_args(a.shape, a.dtype.__str__())
+ if self.use_compiled_kernels:
+ return self.compiled_kernels[kernel_hash]["modular_add"](a, b)
+ else:
+ return self._modular_add(a, b)
+
+ def modular_subtract(self, a: jax.Array, b: jax.Array) -> jax.Array:
+ kernel_hash = hash_args(a.shape, a.dtype.__str__())
+ if self.use_compiled_kernels:
+ return self.compiled_kernels[kernel_hash]["modular_subtract"](a, b)
+ else:
+ return self._modular_subtract(a, b)
+
+ def modular_reduce(self, a: jax.Array) -> jax.Array:
+ kernel_hash = hash_args(a.shape, a.dtype.__str__())
+ if self.use_compiled_kernels:
+ return self.compiled_kernels[kernel_hash]["modular_reduce"](a)
+ else:
+ return self._modular_reduce(a)
+
+ def modular_negate(self, a: jax.Array) -> jax.Array:
+ kernel_hash = hash_args(a.shape, a.dtype.__str__())
+ if self.use_compiled_kernels:
+ return self.compiled_kernels[kernel_hash]["modular_negate"](a)
+ else:
+ return self._modular_negate(a)
+
+
+# =============================================================================
+# Lazy matrix reduction context
+# =============================================================================
+
+
+def _lazy_check_carry(value_c: jax.Array) -> jax.Array:
+ """Check whether any 32-bit chunk holds a value exceeding 32 bits."""
+ return jnp.any(jnp.not_equal(jnp.right_shift(value_c, jnp.uint64(32)), 0))
+
+
+def _lazy_carry_propagate(value_c: jax.Array) -> jax.Array:
+ """Propagate carries between adjacent 32-bit chunks."""
+ n = value_c.shape[-1]
+ roll_mat = jnp.array(
+ [0, 1] + ([0] * n + [1]) * (n - 2) + [1] + [0] * (n - 1),
+ dtype=jnp.uint16,
+ ).reshape(n, n)
+ low = jnp.bitwise_and(value_c, jnp.uint64(0xFFFFFFFF))
+ high = jnp.right_shift(value_c, jnp.uint64(32)).astype(jnp.uint16)
+ high = jnp.matmul(high, roll_mat, preferred_element_type=jnp.uint32).astype(
+ jnp.uint16
+ )
+ return jnp.add(low, high.astype(jnp.uint64))
+
+
+class LazyContextBase(FiniteFieldContextBase):
+ """Base class for lazy matrix modular reduction.
+
+ Represents field elements as little-endian arrays of uint32 chunks
+ (base 2^32). The number of chunks is chunk_num_u8 // 4 + 1 to
+ provide one extra chunk of headroom for intermediate overflow.
+ """
+
+ def __init__(self, parameters: dict):
+ super().__init__(parameters)
+ raw_chunk_num_u8 = parameters.get(
+ "chunk_num_u8", math.ceil(int(self.prime).bit_length() / 8)
+ )
+ # The byte pipeline in ``_mul_to_u8`` emits a ``4 * 2 * chunk_num_u32``-byte
+ # buffer, and ``_modular_multiply`` slices ``high = vc[:, n8:2*n8+4]``
+ # against a ``(n8+4, n8)`` lazy matrix. The slice only fits when
+ # ``chunk_num_u8`` is a multiple of 4; round up so the invariant holds
+ # for primes with arbitrary bit-length.
+ self.chunk_num_u8 = ((raw_chunk_num_u8 + 3) // 4) * 4
+ self.chunk_num_u32 = self.chunk_num_u8 // 4 + 1
+ self.rns_moduli = self.chunk_num_u32
+ self.word_mask = (1 << 32) - 1
+ self.modulus_lazy_mat = self._construct_lazy_matrix()
+ self.prime_chunk = tuple(self._int_to_array(self.prime, self.chunk_num_u32))
+
+ @staticmethod
+ def _int_to_array(x: int, size: int, base: int = 32) -> list:
+ mask = (1 << base) - 1
+ elems = []
+ while x > 0:
+ elems.append(int(x & mask))
+ x >>= base
+ return elems[:size] + [0] * max(0, size - len(elems))
+
+ @staticmethod
+ def _array_to_int(arr, base: int = 32) -> int:
+ result = 0
+ for i, v in enumerate(arr):
+ result |= int(v) << (i * base)
+ return result
+
+ def _construct_lazy_matrix(self):
+ """Build the lazy reduction matrix.
+
+ Row i = (256^(chunk_num_u8 + i)) % prime, expressed as chunk_num_u8
+ little-endian uint8 chunks. Shape: (chunk_num_u8 + 4, chunk_num_u8).
+ """
+ n = self.chunk_num_u8
+ return tuple(
+ tuple(self._int_to_array(pow(256, n + i, self.prime), n, base=8))
+ for i in range(n + 4)
+ )
+
+
+class CROSSLazyContext(LazyContextBase, JaxKernelContextBase):
+ """Lazy matrix modular multiplication context for JAX."""
+
+ def __init__(self, parameters: dict):
+ super().__init__(parameters)
+ JaxKernelContextBase.__init__(self)
+ self._lazy_mat_jnp = jnp.array(self.modulus_lazy_mat, dtype=jnp.uint16)
+ self._prime_jnp = jnp.array(self.prime_chunk, dtype=jnp.uint32)
+
+ # ---------- format conversion ----------
+
+ def to_computational_format(self, a) -> jnp.ndarray:
+ def _convert(x: int) -> jnp.ndarray:
+ return jnp.array(
+ self._int_to_array(x % self.prime, self.chunk_num_u32),
+ dtype=jnp.uint32,
+ )
+
+ def _recurse(x):
+ if isinstance(x, int):
+ return _convert(x)
+ return jnp.array([_recurse(xi) for xi in x], dtype=jnp.uint32)
+
+ converted_a = _recurse(a)
+ if self.use_sharding:
+ named_sharding, padded_shape = self.create_named_sharding(
+ shape=converted_a.shape, axes=[0]
+ )
+ converted_a = pad_jax_array(converted_a, padded_shape)
+ return converted_a.to_device(named_sharding)
+ else:
+ return converted_a.to_device(jax.devices()[0])
+
+ def to_original_format(self, a: jnp.ndarray):
+ def _convert(arr) -> int:
+ return self._array_to_int(arr.tolist()) % self.prime
+
+ def _recurse(x):
+ if x.ndim == 1:
+ return _convert(x)
+ return [_recurse(xi) for xi in x]
+
+ return _recurse(a)
+
+ # ---------- internal JAX helpers ----------
+
+ @staticmethod
+ def _conv_1d(a: jax.Array, b: jax.Array) -> jax.Array:
+ a = jax.lax.bitcast_convert_type(a, jnp.uint8).reshape(-1)
+ b = jax.lax.bitcast_convert_type(b, jnp.uint8).reshape(-1)
+ if jax.default_backend() == "gpu":
+ # cuDNN's integer conv only supports s8, not u8, so route through
+ # float32 on GPU. u8*u8 sums fit exactly in float32's 24-bit mantissa.
+ res = jnp.convolve(a.astype(jnp.float32), b.astype(jnp.float32))
+ return res.astype(jnp.uint32)
+ return jnp.convolve(a, b, preferred_element_type=jnp.uint32)
+
+ def _rechunkify(self, x: jax.Array, n_u16: int, n_u32: int) -> jax.Array:
+ """Merge adjacent uint8 coefficients into uint16, then uint32 chunks."""
+ shift_u16 = jnp.array([[0, 8]] * n_u16, dtype=jnp.uint8)
+ shift_u32 = jnp.array([[0, 16]] * n_u32, dtype=jnp.uint8)
+ shape = x.shape[:-1] + (-1, 2) if x.ndim == 2 else (-1, 2)
+ x = jnp.sum(jnp.left_shift(x.reshape(shape), shift_u16), axis=-1)
+ x = jnp.sum(
+ jnp.left_shift(x.reshape(shape).astype(jnp.uint64), shift_u32), axis=-1
+ )
+ return x
+
+ def _mul_to_u8(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
+ """Multiply two chunk arrays; return product as a flat uint8 array."""
+ n = self.chunk_num_u32
+ batch = a.shape[0]
+ res = jax.vmap(self._conv_1d)(a, b)
+ res = jnp.pad(res, ((0, 0), (0, 1)))
+ res = self._rechunkify(res, 4 * n, 2 * n)
+ res = jax.lax.while_loop(_lazy_check_carry, _lazy_carry_propagate, res)
+ return jax.lax.bitcast_convert_type(
+ res.astype(jnp.uint32), jnp.uint8
+ ).reshape(batch, -1)
+
+ def _sub_raw(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
+ """Compute a - b assuming a >= b (batched, shape [batch, chunk_num_u32])."""
+ n = self.chunk_num_u32
+ borrow_low = jnp.array(
+ [self.word_mask + 1] * (n - 1) + [0], dtype=jnp.uint64
+ )
+ borrow_high = jnp.array([0] + [1] * (n - 2) + [0], dtype=jnp.uint64)
+ c = jnp.subtract(
+ jnp.add(a.astype(jnp.uint64), borrow_low), b.astype(jnp.uint64)
+ )
+ c = jnp.subtract(c, borrow_high)
+ c = jax.lax.while_loop(_lazy_check_carry, _lazy_carry_propagate, c)
+ c = c.at[:, n - 1].set(c[:, n - 1] - 1)
+ return c.astype(jnp.uint32)
+
+ def _compare(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
+ """Return per-batch sign: >=0 means a>=b, <0 means a jnp.ndarray:
+ n8 = self.chunk_num_u8
+ batch = a.shape[0]
+ vc = self._mul_to_u8(a, b)
+ low = vc[:, :n8]
+ high = vc[:, n8 : n8 * 2 + 4]
+ reduced = jnp.matmul(
+ high.astype(jnp.uint16),
+ self._lazy_mat_jnp,
+ preferred_element_type=jnp.uint32,
+ )
+ vc2 = jnp.add(low.astype(jnp.uint32), reduced)
+ vc2 = self._rechunkify(vc2, n8 // 2, n8 // 4)
+ vc2 = jnp.pad(vc2, ((0, 0), (0, 1)))
+ vc2 = jax.lax.while_loop(_lazy_check_carry, _lazy_carry_propagate, vc2)
+ return vc2.astype(jnp.uint32)[:, : self.chunk_num_u32]
+
+ def _modular_add(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
+ c = jax.lax.while_loop(
+ _lazy_check_carry,
+ _lazy_carry_propagate,
+ jnp.add(a.astype(jnp.uint64), b.astype(jnp.uint64)),
+ )
+ return c.astype(jnp.uint32)
+
+ def _modular_negate(self, a: jnp.ndarray) -> jnp.ndarray:
+ p = jnp.broadcast_to(self._prime_jnp, a.shape)
+ neg = self._sub_raw(p, a)
+ is_zero = jnp.all(a == 0, axis=-1, keepdims=True)
+ return jnp.where(is_zero, a, neg)
+
+ def _modular_subtract(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
+ return self._modular_add(a, self._modular_negate(b))
+
+ def _modular_reduce(self, a: jnp.ndarray) -> jnp.ndarray:
+ p = jnp.broadcast_to(self._prime_jnp, a.shape)
+ cond = jnp.greater_equal(self._compare(a, p), 0).reshape(-1, 1)
+ return jnp.where(cond, self._sub_raw(a, p), a)
+
+ # ---------- public interface ----------
+
+ def modular_multiply(self, a: jax.Array, b: jax.Array) -> jax.Array:
+ kernel_hash = hash_args(a.shape, a.dtype.__str__())
+ if self.use_compiled_kernels:
+ print(f"using compiled kernel for modular_multiply: {kernel_hash}")
+ return self.compiled_kernels[kernel_hash]["modular_multiply"](a, b)
+ else:
+ return self._modular_multiply(a, b)
+
+ def modular_add(self, a: jax.Array, b: jax.Array) -> jax.Array:
+ kernel_hash = hash_args(a.shape, a.dtype.__str__())
+ if self.use_compiled_kernels:
+ return self.compiled_kernels[kernel_hash]["modular_add"](a, b)
+ else:
+ return self._modular_add(a, b)
+
+ def modular_subtract(self, a: jax.Array, b: jax.Array) -> jax.Array:
+ kernel_hash = hash_args(a.shape, a.dtype.__str__())
+ if self.use_compiled_kernels:
+ return self.compiled_kernels[kernel_hash]["modular_subtract"](a, b)
+ else:
+ return self._modular_subtract(a, b)
+
+ def modular_reduce(self, a: jax.Array) -> jax.Array:
+ kernel_hash = hash_args(a.shape, a.dtype.__str__())
+ if self.use_compiled_kernels:
+ return self.compiled_kernels[kernel_hash]["modular_reduce"](a)
+ else:
+ return self._modular_reduce(a)
+
+ def modular_negate(self, a: jax.Array) -> jax.Array:
+ kernel_hash = hash_args(a.shape, a.dtype.__str__())
+ if self.use_compiled_kernels:
+ return self.compiled_kernels[kernel_hash]["modular_negate"](a)
+ else:
+ return self._modular_negate(a)
+
+ def context_hash(self) -> str:
+ return hash_args(
+ self.__class__.__name__,
+ self.prime,
+ self.chunk_num_u8,
+ self.use_sharding,
+ )
+
+ def _get_shape_dtype_structs(
+ self, parameters: dict
+ ) -> list[jax.ShapeDtypeStruct]:
+ batch_shape = parameters["batch_shape"]
+ operand_shape = batch_shape + (self.chunk_num_u32,)
+ if self.use_sharding:
+ named_sharding, padded_shape = self.create_named_sharding(
+ shape=operand_shape, axes=[0]
+ )
+ return [
+ jax.ShapeDtypeStruct(
+ padded_shape, jnp.uint32, sharding=named_sharding
+ )
+ ]
+ return [jax.ShapeDtypeStruct(operand_shape, jnp.uint32)]
+
+ def serialize(self, parameters): # pytype: disable=signature-mismatch
+ shape_dtype_structs = self._get_shape_dtype_structs(parameters)
+ kernel_hash = hash_args(self.context_hash(), parameters)
+ class_name = self.__class__.__name__
+
+ store_jax_executable(
+ self._modular_multiply,
+ shape_dtype_structs[0],
+ shape_dtype_structs[0],
+ name=f"{class_name}_modular_multiply_{kernel_hash}",
+ )
+ store_jax_executable(
+ self._modular_add,
+ shape_dtype_structs[0],
+ shape_dtype_structs[0],
+ name=f"{class_name}_modular_add_{kernel_hash}",
+ )
+ store_jax_executable(
+ self._modular_subtract,
+ shape_dtype_structs[0],
+ shape_dtype_structs[0],
+ name=f"{class_name}_modular_subtract_{kernel_hash}",
+ )
+ store_jax_executable(
+ self._modular_reduce,
+ shape_dtype_structs[0],
+ name=f"{class_name}_modular_reduce_{kernel_hash}",
+ )
+ store_jax_executable(
+ self._modular_negate,
+ shape_dtype_structs[0],
+ name=f"{class_name}_modular_negate_{kernel_hash}",
+ )
+
+ def compile(self, parameters): # pytype: disable=signature-mismatch
+ shape_dtype_structs = self._get_shape_dtype_structs(parameters)
+ kernel_hash = hash_args(self.context_hash(), parameters)
+ class_name = self.__class__.__name__
+
+ modular_multiply_kernel = load_jax_executable(
+ f"{class_name}_modular_multiply_{kernel_hash}"
+ )
+ modular_add_kernel = load_jax_executable(
+ f"{class_name}_modular_add_{kernel_hash}"
+ )
+ modular_subtract_kernel = load_jax_executable(
+ f"{class_name}_modular_subtract_{kernel_hash}"
+ )
+ modular_reduce_kernel = load_jax_executable(
+ f"{class_name}_modular_reduce_{kernel_hash}"
+ )
+ modular_negate_kernel = load_jax_executable(
+ f"{class_name}_modular_negate_{kernel_hash}"
+ )
+
+ if None in [
+ modular_multiply_kernel,
+ modular_add_kernel,
+ modular_subtract_kernel,
+ modular_reduce_kernel,
+ modular_negate_kernel,
+ ]:
+ # if not self.use_sharding:
+ warnings.warn(
+ f"Not found stored serialized compiled kernels, compiling...",
+ UserWarning,
+ stacklevel=2,
+ )
+
+ kernel_hash = hash_args(
+ shape_dtype_structs[0].shape, shape_dtype_structs[0].dtype.__str__()
+ )
+
+ self.compiled_kernels[kernel_hash] = {
+ "modular_multiply": (
+ modular_multiply_kernel
+ if modular_multiply_kernel is not None
+ else jax_jit_lower_compile(
+ self._modular_multiply,
+ shape_dtype_structs[0],
+ shape_dtype_structs[0],
+ )
+ ),
+ "modular_add": (
+ modular_add_kernel
+ if modular_add_kernel is not None
+ else jax_jit_lower_compile(
+ self._modular_add,
+ shape_dtype_structs[0],
+ shape_dtype_structs[0],
+ )
+ ),
+ "modular_subtract": (
+ modular_subtract_kernel
+ if modular_subtract_kernel is not None
+ else jax_jit_lower_compile(
+ self._modular_subtract,
+ shape_dtype_structs[0],
+ shape_dtype_structs[0],
+ )
+ ),
+ "modular_reduce": (
+ modular_reduce_kernel
+ if modular_reduce_kernel is not None
+ else jax_jit_lower_compile(
+ self._modular_reduce, shape_dtype_structs[0]
+ )
+ ),
+ "modular_negate": (
+ modular_negate_kernel
+ if modular_negate_kernel is not None
+ else jax_jit_lower_compile(
+ self._modular_negate, shape_dtype_structs[0]
+ )
+ ),
+ }
+ self.use_compiled_kernels = True
diff --git a/jaxite_ec/finite_field_perf_test.py b/jaxite_ec/finite_field_perf_test.py
new file mode 100644
index 0000000..0069470
--- /dev/null
+++ b/jaxite_ec/finite_field_perf_test.py
@@ -0,0 +1,108 @@
+import os
+
+import jax
+import jax.numpy as jnp
+from jaxite.jaxite_ec import finite_field_context as ff_context
+from jaxite.jaxite_ec import profiler
+from jaxite.jaxite_ec import utils
+
+from absl.testing import absltest
+from absl.testing import parameterized
+
+jax.config.update("jax_enable_x64", True)
+
+PRIME = 0x01AE3A4617C510EAC63B05C06CA1493B1A22D9F300F5138F1EF3622FBA094800170B5D44300000008508C00000000001
+
+BATCH_SIZE_LIST = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
+
+NUM_MODULI_LIST = [32] # 21 for 256-bit, 32 for 384-bit, 56 for 753-bit
+
+TEST_PARAMS = [(f"moduli_{n}", n, BATCH_SIZE_LIST) for n in NUM_MODULI_LIST]
+
+
+def _modular_multiply_kernel(a, b, parameters):
+ return parameters["ctx"]._modular_multiply(a, b)
+
+
+class FiniteFieldModularMultiplyPerformanceTest(parameterized.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ outputs_dir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR")
+ if outputs_dir:
+ self.output_trace_root = os.path.join(outputs_dir, "log")
+ else:
+ self.output_trace_root = os.path.join(os.path.dirname(__file__), "log")
+ self.profiler_config = {
+ "iterations": 1,
+ "save_to_file": True,
+ }
+
+ @classmethod
+ def tearDownClass(cls):
+ super().tearDownClass()
+ root_dir = os.path.dirname(os.path.abspath(__file__))
+ print(f"Collecting logs from: {root_dir}")
+ profiler.collect_logs(root_dir)
+
+ def _create_kernel_wrapper(self, kernel_name, ctx, batch, num_moduli):
+ input_shape = (batch, num_moduli)
+ return profiler.KernelWrapper(
+ kernel_name=kernel_name,
+ function_to_wrap=_modular_multiply_kernel,
+ input_structs=[
+ (input_shape, jnp.uint32),
+ (input_shape, jnp.uint32),
+ ],
+ parameters={"ctx": ctx},
+ )
+
+ def _profile_modular_multiply(self, num_moduli, batch_size_list):
+ rns_moduli = utils.find_moduli_specified_number(num_moduli, 28)
+
+ ctx = ff_context.DRNSlazyContext({
+ "prime": PRIME,
+ "rns_moduli": rns_moduli,
+ "precision_bits": 28,
+ "radix_bits": 32,
+ })
+
+ profiler_instance = profiler.Profiler(
+ output_trace_path=self.output_trace_root,
+ profile_naming=f"ff_modular_multiply_moduli_{num_moduli}",
+ configuration=self.profiler_config,
+ )
+
+ for batch in batch_size_list:
+ kernel_name = f"ff_mod_mul_m{num_moduli}_b{batch}"
+ kernel_wrapper = self._create_kernel_wrapper(
+ kernel_name=kernel_name,
+ ctx=ctx,
+ batch=batch,
+ num_moduli=num_moduli,
+ )
+
+ profiler_instance.add_profile(
+ name=kernel_name,
+ kernel_wrapper=kernel_wrapper,
+ kernel_setting_cols={
+ "num_moduli": num_moduli,
+ "batch": batch,
+ },
+ )
+
+ profiler_instance.profile_all_profilers()
+ profiler_instance.post_process_all_profilers()
+
+ @parameterized.named_parameters(*TEST_PARAMS)
+ def test_DRNSlazy_modular_multiply_performance(
+ self, num_moduli, batch_size_list
+ ):
+ self._profile_modular_multiply(
+ num_moduli=num_moduli,
+ batch_size_list=batch_size_list,
+ )
+
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/jaxite_ec/finite_field_test.py b/jaxite_ec/finite_field_test.py
index 208eda0..7d8f9c3 100644
--- a/jaxite_ec/finite_field_test.py
+++ b/jaxite_ec/finite_field_test.py
@@ -1,254 +1,68 @@
-import random
-
import jax
-import jax.numpy as jnp
-from jaxite.jaxite_ec import util
-import jaxite.jaxite_ec.algorithm.finite_field as pyff
-import jaxite.jaxite_ec.finite_field as ff
+from jaxite.jaxite_ec import finite_field_context as ff_context
+from jaxite.jaxite_ec import utils
import numpy as np
from absl.testing import absltest
+from absl.testing import parameterized
jax.config.update("jax_enable_x64", True)
-randint = random.randint
-
-
-def list_operation(a, b, func):
- return [func(ai, bi) for ai, bi in zip(a, b)]
-
-
-def list_operation_three(a, b, c, func):
- return [func(ai, bi, ci) for ai, bi, ci in zip(a, b, c)]
-
-
-class FiniteFieldTest(absltest.TestCase):
-
- def setUp(self):
- super().setUp()
- self.value_a = [
- 0xBE4FBE5D03CE926E40E058BBDC3269C78CFAFED39796CD13EC8E9B0072DB2538DFFBCA05804574D9E2FF7EEB1DE219,
- 0x008848DEFE740A67C8FC6225BF87FF5485951E2CAA9D41BB188282C8BD37CB5CD5481512FFCD394EEAB9B16EB21BE9EF,
- ]
- self.value_b = [
- 0x82A0ED372BFAB8198D0667A1DC5E299C1F6C8FEB0ACD4D05A228325117BE63EAE5BABE6807F41C6C8016BDAC251CFE,
- 0x01914A69C5102EFF1F674F5D30AFEEC4BD7FB348CA3E52D96D182AD44FB82305C2FE3D3634A9591AFD82DE55559C8EA6,
- ]
- self.value_c = [
- 0x125E69CE765D167C0B19F8D6D6708D39C7782F33B6D320802E2FFA92BBB12DBB3897EAF9CC4CF67E487478F3C3FAD16,
- 0x01AC3A384FC584EFD3E7F2C5A2927E7D454875C874A051027B9E7363D08942533EDE85DAE295D8CAB2751085206BCA76,
- ]
- self.value_a_jax = util.int_list_to_array(
- self.value_a, base=util.BASE, array_size=util.U16_CHUNK_NUM
- )
- self.value_b_jax = util.int_list_to_array(
- self.value_b, base=util.BASE, array_size=util.U16_CHUNK_NUM
- )
- self.value_c_jax = util.int_list_to_array(
- self.value_c, base=util.BASE, array_size=util.U16_CHUNK_NUM
- )
-
- @absltest.skip("This test is only needed if u need to generate RNS Matrix.")
- def test_generate_rns_precompute_matrix(self):
- rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT)
- rns_stack_mat = rns_mat[0]
- cor_mat = rns_mat[1]
- print("Printing out RNS matrix")
- print("(")
- for i in range(len(rns_stack_mat)):
- print("(", end="")
- for j in range(len(rns_stack_mat[0])):
- if j == len(rns_stack_mat[0]) - 1:
- print(rns_stack_mat[i][j], end="), ")
- else:
- print(rns_stack_mat[i][j], end=", ")
- print("")
- print(")")
-
- print("Printing out Correction matrix")
- print("(")
- for i in range(len(cor_mat)):
- print("( ", end="")
- for j in range(len(cor_mat[0])):
- if j == len(cor_mat[0]) - 1:
- print(cor_mat[i][j], end="), ")
- else:
- print(cor_mat[i][j], end=", ")
- print("")
- print(")")
- print()
-
- def test_add_two(self):
- result_jax = ff.add_2u16(self.value_a_jax, self.value_b_jax)
- result = util.array_to_int_list(result_jax, util.BASE)
-
- result_ref = list_operation(self.value_a, self.value_b, lambda a, b: a + b)
-
- self.assertEqual(result, result_ref)
-
- def test_add_three(self):
- result_jax = ff.add_3u16(
- self.value_a_jax, self.value_b_jax, self.value_c_jax
- )
- result = util.array_to_int_list(result_jax, util.BASE)
- result_ref = list_operation(
- list_operation(self.value_a, self.value_b, lambda a, b: a + b),
- self.value_c,
- lambda a, b: a + b,
- )
-
- self.assertEqual(result, result_ref)
-
- def test_cond_sub_1(self):
- result_jax = ff.cond_sub_2u16(self.value_a_jax, self.value_b_jax)
- result = util.array_to_int_list(result_jax, util.BASE)
-
- def cond_sub(a, b):
- if a < b:
- return a + util.MODULUS_377_INT - b
- else:
- return a - b
-
- result_ref = list_operation(self.value_a, self.value_b, cond_sub)
- self.assertEqual(result, result_ref)
-
- def test_cond_sub_2(self):
- result_jax = ff.cond_sub_2u16(self.value_b_jax, self.value_a_jax)
- result = util.array_to_int_list(result_jax, util.BASE)
-
- def cond_sub(a, b):
- if a < b:
- return a + util.MODULUS_377_INT - b
- else:
- return a - b
-
- result_ref = list_operation(self.value_b, self.value_a, cond_sub)
-
- self.assertEqual(result, result_ref)
-
- def test_cond_sub_mod_1(self):
- value_list = [util.MODULUS_377_INT + 123, util.MODULUS_377_INT - 5432]
- value_jax = util.int_list_to_array(
- value_list, base=util.BASE, array_size=util.U16_CHUNK_NUM
- )
- result_jax = ff.cond_sub_mod_u16(value_jax)
- result = util.array_to_int_list(result_jax, util.BASE)
-
- def cond_sub_mod(a):
- if a < util.MODULUS_377_INT:
- return a
- else:
- return a - util.MODULUS_377_INT
-
- result_ref = [cond_sub_mod(a) for a in value_list]
-
- self.assertEqual(result, result_ref)
-
- def test_mul_1(self):
- result_jax = ff.mul_2u16(self.value_a_jax, self.value_b_jax)
- result = util.array_to_int_list(result_jax, util.BASE)
- result_ref = list_operation(self.value_a, self.value_b, lambda a, b: a * b)
- self.assertEqual(result, result_ref)
-
- def test_mod_mul_barrett_1(self):
- result_jax = ff.mod_mul_barrett_2u16(self.value_a_jax, self.value_b_jax)
- result = util.array_to_int_list(result_jax, util.BASE)
-
- def mod_mul_barrett(a, b):
- value_a_barrett = pyff.FiniteFieldElementBarrett(a, util.MODULUS_377_INT)
- value_b_barrett = pyff.FiniteFieldElementBarrett(b, util.MODULUS_377_INT)
- return (value_a_barrett * value_b_barrett).get_value()
-
- result_ref = list_operation(self.value_a, self.value_b, mod_mul_barrett)
-
- self.assertEqual(result, result_ref)
-
- def test_jax_mod_mul_lazy_reduction(self):
- """This test case check the jax version (TPU deployment) of the lazy reduction based modular multiplication algorithm."""
- batch_size = 16
- a_list = [randint(0, util.MODULUS_377_INT) for _ in range(batch_size)]
- b_list = [randint(0, util.MODULUS_377_INT) for _ in range(batch_size)]
-
- a_batch = util.int_list_to_array(
- a_list, base=util.BASE, array_size=util.U16_EXT_CHUNK_NUM
- )
- b_batch = util.int_list_to_array(
- b_list, base=util.BASE, array_size=util.U16_EXT_CHUNK_NUM
- )
- c_batch = ff.mod_mul_lazy_2u16(a_batch, b_batch)
- c_list = util.array_to_int_list(c_batch, util.BASE)
- for i in range(len(a_list)):
- np.testing.assert_equal(
- c_list[i] % util.MODULUS_377_INT,
- (a_list[i] * b_list[i]) % util.MODULUS_377_INT,
- )
-
- def test_jax_mod_mul_rns_reduction(self):
- """This test case check the jax version (TPU deployment) of the rns reduction based modular multiplication algorithm."""
- batch_size = 16
- a_list = [randint(0, util.MODULUS_377_INT) for _ in range(batch_size)]
- b_list = [randint(0, util.MODULUS_377_INT) for _ in range(batch_size)]
-
- modulus_rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT)
- a_batch = util.int_list_to_array_rns(a_list)
- b_batch = util.int_list_to_array_rns(b_list)
- c_batch = ff.mod_mul_rns_2u16(a_batch, b_batch, modulus_rns_mat)
- c_list = util.array_rns_to_int_list(c_batch)
- for i in range(len(a_list)):
- np.testing.assert_equal(
- c_list[i] % util.MODULUS_377_INT,
- (a_list[i] * b_list[i]) % util.MODULUS_377_INT,
- )
-
- def test_jax_add_rns(self):
- max_val = [2**16 - 1 for _ in range(util.NUM_MODULI)]
- max_normal_val = [m - 1 for m in util.MODULI]
- zero = [0 for _ in range(util.NUM_MODULI)]
- values = [zero, max_val, max_normal_val]
- for a in values:
- for b in values:
- jax_a = jnp.array(a, dtype=jnp.uint16).reshape((1, util.NUM_MODULI))
- jax_b = jnp.array(b, dtype=jnp.uint16).reshape((1, util.NUM_MODULI))
- jax_sum = ff.add_rns_2u16(jax_a, jax_b, tuple(util.RNS_MODULI_T))
- jax_3sum = ff.add_rns_3u16(
- jax_a, jax_b, jax_a, tuple(util.RNS_MODULI_T)
- )
- for i in range(util.NUM_MODULI):
- np.testing.assert_equal(
- int(jax_sum[0, i]) % util.MODULI[i],
- (a[i] + b[i]) % util.MODULI[i],
- )
- np.testing.assert_equal(
- int(jax_3sum[0, i]) % util.MODULI[i],
- (a[i] + b[i] + a[i]) % util.MODULI[i],
- )
-
- def test_jax_sub_rns(self):
- batch_size = 16
- bound = 256 * util.NUM_MODULI * util.MODULUS_377_INT
- a_list = [randint(0, bound) for _ in range(batch_size)]
- b_list = [randint(0, bound) for _ in range(batch_size)]
- b_list[0] = bound - 1
- a_batch = util.int_list_to_array_rns(a_list)
- b_batch = util.int_list_to_array_rns(b_list)
- diff = ff.add_sub_rns_var(a_batch, ff.negate_rns_for_var_add(b_batch))
- diff_int = util.array_rns_to_int_list(diff)
- for i in range(batch_size):
- np.testing.assert_equal(
- diff_int[i] % util.MODULUS_377_INT,
- (a_list[i] - b_list[i]) % util.MODULUS_377_INT,
- )
-
- def test_jax_add_rns_specific_case(self):
- e = 149025596882241982990837486539530757729373308235078472950338530041138824139683871423246406325832428746678481926350
- h = 64504914146370186321601383206234327860947337385636434408981741341000348311615332297641759363669577422255115417749
- q = util.MODULUS_377_INT
- jax_a = util.int_list_to_array_rns([e])
- jax_b = util.int_list_to_array_rns([h])
- jax_sum = ff.mod_mul_rns_2u16(jax_a, jax_b)
- val_sum = util.array_rns_to_int(jax_sum[0])
- np.testing.assert_equal(val_sum % util.MODULUS_377_INT, (e * h % q))
+FF = [
+ (
+ "0",
+ [
+ 0xBE4FBE5D03CE926E40E058BBDC3269C78CFAFED39796CD13EC8E9B0072DB2538DFFBCA05804574D9E2FF7EEB1DE219,
+ 0x008848DEFE740A67C8FC6225BF87FF5485951E2CAA9D41BB188282C8BD37CB5CD5481512FFCD394EEAB9B16EB21BE9EF,
+ ],
+ [
+ 0x82A0ED372BFAB8198D0667A1DC5E299C1F6C8FEB0ACD4D05A228325117BE63EAE5BABE6807F41C6C8016BDAC251CFE,
+ 0x01914A69C5102EFF1F674F5D30AFEEC4BD7FB348CA3E52D96D182AD44FB82305C2FE3D3634A9591AFD82DE55559C8EA6,
+ ],
+ ),
+]
+
+
+class FiniteFieldTest(parameterized.TestCase):
+
+ def __init__(self, *args, **kwargs):
+ super(FiniteFieldTest, self).__init__(*args, **kwargs)
+
+ @parameterized.named_parameters(*FF)
+ def test_DRNSlazy_modular_multiply(self, value_a, value_b):
+ prime = 0x01AE3A4617C510EAC63B05C06CA1493B1A22D9F300F5138F1EF3622FBA094800170B5D44300000008508C00000000001
+ rns_moduli = utils.find_moduli_specified_number(32, 28)
+ ref_value_c = [(a * b) % prime for a, b in zip(value_a, value_b)]
+
+ ctx = ff_context.DRNSlazyContext({
+ "prime": prime,
+ "rns_moduli": rns_moduli,
+ "precision_bits": 28,
+ "radix_bits": 32,
+ })
+
+ value_a_m = ctx.to_computational_format(value_a)
+ value_b_m = ctx.to_computational_format(value_b)
+ value_c_m = ctx.modular_multiply(value_a_m, value_b_m)
+ value_c = ctx.to_original_format(value_c_m)
+
+ np.testing.assert_array_equal(value_c, ref_value_c)
+
+ @parameterized.named_parameters(*FF)
+ def test_lazy_modular_multiply(self, value_a, value_b):
+ prime = 0x01AE3A4617C510EAC63B05C06CA1493B1A22D9F300F5138F1EF3622FBA094800170B5D44300000008508C00000000001
+
+ ref_value_c = [(a * b) % prime for a, b in zip(value_a, value_b)]
+
+ ctx = ff_context.CROSSLazyContext({"prime": prime, "chunk_num_u8": 48})
+
+ value_a_m = ctx.to_computational_format(value_a)
+ value_b_m = ctx.to_computational_format(value_b)
+ value_c_m = ctx.modular_multiply(value_a_m, value_b_m)
+ value_c = ctx.to_original_format(value_c_m)
+
+ np.testing.assert_array_equal(value_c, ref_value_c)
if __name__ == "__main__":
diff --git a/jaxite_ec/msm_test.py b/jaxite_ec/msm_test.py
deleted file mode 100644
index cd78971..0000000
--- a/jaxite_ec/msm_test.py
+++ /dev/null
@@ -1,730 +0,0 @@
-import csv
-import os
-import sys
-
-import jax
-import jax.numpy as jnp
-from jaxite.jaxite_ec import pippenger
-from jaxite.jaxite_ec import pippenger_rns
-from jaxite.jaxite_ec import util
-from jaxite.jaxite_ec.algorithm import config_file
-import jaxite.jaxite_ec.algorithm.elliptic_curve as ec
-
-# copybara: from google3.pyglib import resources
-from absl.testing import absltest
-from absl.testing import parameterized
-
-
-script_path = os.path.abspath(sys.argv[0])
-script_dir = os.path.dirname(script_path)
-
-jax.config.update("jax_traceback_filtering", "off")
-
-# Only needed when tryingt to understand the HLO dump.
-os.environ["XLA_FLAGS"] = (
- "--xla_dump_to=sponge --xla_backend_optimization_level=4"
-)
-
-TEST_PARAMS = [
- (
- "test_4_degree",
- os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])),
- f"{script_dir}/jaxite_ec/test_case/t4/zprize_msm_curve_377_scalars_dim_4_seed_0.csv"
- ),
- os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])),
- f"{script_dir}/jaxite_ec/test_case/t4/zprize_msm_curve_377_bases_dim_4_seed_0.csv"
- ),
- os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])),
- f"{script_dir}/jaxite_ec/test_case/t4/zprize_msm_curve_377_res_dim_4_seed_0.csv"
- ),
- ),
- (
- "test_1024_degree",
- os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])),
- f"{script_dir}/jaxite_ec/test_case/t1024/zprize_msm_curve_377_scalars_dim_1024_seed_0.csv"
- ),
- os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])),
- f"{script_dir}/jaxite_ec/test_case/t1024/zprize_msm_curve_377_bases_dim_1024_seed_0.csv"
- ),
- os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])),
- f"{script_dir}/jaxite_ec/test_case/t1024/zprize_msm_curve_377_res_dim_1024_seed_0.csv"
- ),
- ),
-]
-
-
-def twist_coordinates_list(ec_config, coordinates_list):
- twisted_ec_sys = ec.ECCSTwistedEdwardsExtended(ec_config)
- twisted_coordinates_list = []
- untwisted_coordinates_indeices = []
- for i, coordinates in enumerate(coordinates_list):
- twisted_coordinates = twisted_ec_sys.twist_int_coordinates(coordinates)
- twisted_coordinates_list.append(twisted_coordinates)
- if twisted_coordinates == [0, 1, 1, 0]:
- untwisted_coordinates_indeices.append(i)
- return twisted_coordinates_list, untwisted_coordinates_indeices
-
-
-class MSMTest(parameterized.TestCase):
- def read_external_file(self, scalar_path, base_path, result_path):
- scalars = []
- with open(
- scalar_path, "r", newline="", encoding="utf-8"
- ) as csvfile: # Handle potential encoding issues
- csv_reader = csv.reader(csvfile)
- for row in csv_reader:
- scalars.append(int(row[-1][13:-2], 16))
-
- points = []
- with open(
- base_path, "r", newline="", encoding="utf-8"
- ) as csvfile: # Handle potential encoding issues
- csv_reader = csv.reader(csvfile)
- for row in csv_reader:
- points.append([int(row[8][13:-2], 16), int(row[-1][13:-2], 16)])
-
- result_ref = []
- with open(
- result_path, "r", newline="", encoding="utf-8"
- ) as csvfile: # Handle potential encoding issues
- csv_reader = csv.reader(csvfile)
- for row in csv_reader:
- result_ref.append(int(row[7][13:-2], 16))
- result_ref.append(int(row[-1][13:-2], 16))
- return scalars, points, result_ref
-
- @parameterized.named_parameters(*TEST_PARAMS)
- def test_pippenger_index_selection(self, scalar_path, base_path, result_path):
- """Normal version Pippenger."""
- scalars, points, result_ref = self.read_external_file(
- scalar_path, base_path, result_path
- )
- slice_length = 4
- msm_algo = pippenger.MSMPippenger(slice_length)
- msm_algo.initialize(scalars, points)
-
- window_num = msm_algo.window_num
- bucket_num_per_window = msm_algo.bucket_num_per_window
- msm_length = msm_algo.msm_length
- coordinate_num = msm_algo.coordinate_num
- chunk_num = util.U16_EXT_CHUNK_NUM
-
- bucket_accumulation_index_scan_jit = (
- jax.jit(
- pippenger.bucket_accumulation_index_scan_algorithm,
- static_argnames="msm_length",
- )
- .lower(
- jax.ShapeDtypeStruct(
- (coordinate_num, window_num, bucket_num_per_window, chunk_num),
- dtype=jnp.uint16,
- ),
- jax.ShapeDtypeStruct(
- (msm_length, coordinate_num, chunk_num), dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct((msm_length, window_num), dtype=jnp.uint32),
- jax.ShapeDtypeStruct(
- (msm_length, window_num, bucket_num_per_window), dtype=jnp.uint8
- ),
- msm_length,
- )
- .compile()
- )
-
- bucket_reduction_scan_jit = (
- jax.jit(
- pippenger.bucket_reduction_scan_algorithm,
- static_argnames="bucket_num_in_window",
- )
- .lower(
- jax.ShapeDtypeStruct(
- (coordinate_num, window_num, bucket_num_per_window, chunk_num),
- dtype=jnp.uint16,
- ),
- jax.ShapeDtypeStruct(
- (coordinate_num, window_num, chunk_num), dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (coordinate_num, window_num, chunk_num), dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (bucket_num_per_window, window_num), dtype=jnp.uint8
- ),
- jax.ShapeDtypeStruct(
- (bucket_num_per_window, window_num), dtype=jnp.uint8
- ),
- jax.ShapeDtypeStruct(
- (
- bucket_num_per_window,
- window_num,
- ),
- dtype=jnp.uint8,
- ),
- bucket_num_per_window,
- )
- .compile()
- )
-
- window_merge_scan_jit = (
- jax.jit(
- pippenger.window_merge_scan_algorithm,
- static_argnames="slice_length",
- )
- .lower(
- jax.ShapeDtypeStruct(
- (coordinate_num, window_num, chunk_num), dtype=jnp.uint16
- ),
- slice_length,
- )
- .compile()
- )
-
- msm_algo.bucket_accumulation(bucket_accumulation_index_scan_jit)
- msm_algo.bucket_reduction(bucket_reduction_scan_jit)
- result = msm_algo.window_merge(window_merge_scan_jit)
- result = util.jax_point_pack_to_int_point(result)
- ec_sys = ec.ECCSWeierstrassXYZZ(config_file.config_BLS12_377)
- result_affine_point = ec_sys.generate_point(result).convert_to_affine()
- coordinates = (
- result_affine_point[0].get_value(),
- result_affine_point[1].get_value(),
- )
- self.assertEqual(coordinates[0], result_ref[0])
- self.assertEqual(coordinates[1], result_ref[1])
-
- # performance measurement
- tasks = [
- (msm_algo.bucket_accumulation, (bucket_accumulation_index_scan_jit,)),
- (msm_algo.bucket_reduction, (bucket_reduction_scan_jit,)),
- (msm_algo.window_merge, (window_merge_scan_jit,)),
- ]
- profile_name = "normal_pippenger_index_selection"
- # copybara: util.profile_jax_functions(tasks, profile_name)
-
- @parameterized.named_parameters(*TEST_PARAMS)
- def test_pippenger_index_selection_twisted_edwards(
- self, scalar_path, base_path, result_path
- ):
- scalars, points, result_ref = self.read_external_file(
- scalar_path, base_path, result_path
- )
- twisted_points, untwisted_coordinates_indeices = twist_coordinates_list(
- config_file.config_BLS12_377_t, points
- )
- assert not untwisted_coordinates_indeices
- slice_length = 4
- parallel_num = 4
- msm_algo = pippenger.MSMPippengerTwisted(slice_length, parallel_num)
- msm_algo.initialize(scalars, twisted_points)
-
- window_num = msm_algo.window_num
- bucket_num_per_window = msm_algo.bucket_num_per_window
- msm_length = msm_algo.msm_length
- coordinate_num = msm_algo.coordinate_num
- chunk_num = util.U16_EXT_CHUNK_NUM
-
- batch_window_num = window_num * parallel_num
- batch_mem_length = msm_length // parallel_num
-
- bucket_accumulation_index_scan_jit = (
- jax.jit(
- pippenger.bucket_accumulation_index_scan_parallel_algorithm_twisted,
- static_argnames="msm_length",
- )
- .lower(
- jax.ShapeDtypeStruct(
- (
- coordinate_num,
- batch_window_num,
- bucket_num_per_window,
- chunk_num,
- ),
- dtype=jnp.uint16,
- ),
- jax.ShapeDtypeStruct(
- (batch_mem_length, coordinate_num, parallel_num, chunk_num),
- dtype=jnp.uint16,
- ),
- jax.ShapeDtypeStruct(
- (batch_mem_length, batch_window_num), dtype=jnp.uint32
- ),
- batch_mem_length,
- )
- .compile()
- )
-
- bucket_reduction_scan_jit = (
- jax.jit(
- pippenger.bucket_reduction_scan_algorithm_twisted,
- static_argnames="bucket_num_in_window",
- )
- .lower(
- jax.ShapeDtypeStruct(
- (
- coordinate_num,
- batch_window_num,
- bucket_num_per_window,
- chunk_num,
- ),
- dtype=jnp.uint16,
- ),
- jax.ShapeDtypeStruct(
- (coordinate_num, batch_window_num, chunk_num), dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (coordinate_num, batch_window_num, chunk_num), dtype=jnp.uint16
- ),
- bucket_num_per_window,
- )
- .compile()
- )
-
- batch_window_summation_jit = (
- jax.jit(
- pippenger.batch_window_summation_algorithm_twisted,
- static_argnames="point_parallel",
- )
- .lower(
- jax.ShapeDtypeStruct(
- (coordinate_num, window_num, chunk_num), dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (coordinate_num, batch_window_num, chunk_num), dtype=jnp.uint16
- ),
- parallel_num,
- )
- .compile()
- )
-
- window_merge_scan_jit = (
- jax.jit(
- pippenger.window_merge_scan_algorithm_twisted,
- static_argnames="slice_length",
- )
- .lower(
- jax.ShapeDtypeStruct(
- (coordinate_num, window_num, chunk_num), dtype=jnp.uint16
- ),
- slice_length,
- )
- .compile()
- )
-
- # HERE
- msm_algo.bucket_accumulation(bucket_accumulation_index_scan_jit)
- msm_algo.bucket_reduction(bucket_reduction_scan_jit)
- msm_algo.batch_window_summation(batch_window_summation_jit)
- result = msm_algo.window_merge(window_merge_scan_jit)
- result = util.jax_point_pack_to_int_point(result)
- # TO HERE
- ec_sys = ec.ECCSTwistedEdwardsExtended(config_file.config_BLS12_377_t)
- result_affine_point = ec_sys.generate_point(
- result, twist=False
- ).convert_to_affine()
- coordinates = (
- result_affine_point[0].get_value(),
- result_affine_point[1].get_value(),
- )
- self.assertEqual(coordinates[0], result_ref[0])
- self.assertEqual(coordinates[1], result_ref[1])
-
- # performance measurement
- tasks = [
- (msm_algo.bucket_accumulation, (bucket_accumulation_index_scan_jit,)),
- (msm_algo.bucket_reduction, (bucket_reduction_scan_jit,)),
- (msm_algo.batch_window_summation, (batch_window_summation_jit,)),
- (msm_algo.window_merge, (window_merge_scan_jit,)),
- ]
- profile_name = "pippenger_index_selection_twisted_edwards"
- # copybara: util.profile_jax_functions(tasks, profile_name)
-
- @parameterized.named_parameters(*TEST_PARAMS)
- def test_pippenger_signed_index_selection_twisted_edwards(
- self, scalar_path, base_path, result_path
- ):
- scalars, points, result_ref = self.read_external_file(
- scalar_path, base_path, result_path
- )
- twisted_points, untwisted_coordinates_indeices = twist_coordinates_list(
- config_file.config_BLS12_377_t, points
- )
- assert not untwisted_coordinates_indeices
- slice_length = 4
- parallel_num = 4
- msm_algo = pippenger.MSMPippengerTwistedSigned(slice_length, parallel_num)
- msm_algo.initialize(scalars, twisted_points)
-
- window_num = msm_algo.window_num
- bucket_num_per_window = msm_algo.bucket_num_per_window
- msm_length = msm_algo.msm_length
- coordinate_num = msm_algo.coordinate_num
- chunk_num = util.U16_EXT_CHUNK_NUM
-
- batch_window_num = window_num * parallel_num
- batch_mem_length = msm_length // parallel_num
-
- bucket_accumulation_index_scan_jit = (
- jax.jit(
- pippenger.bucket_accumulation_signed_index_scan_parallel_algorithm_twisted,
- static_argnames="msm_length",
- )
- .lower(
- jax.ShapeDtypeStruct(
- (
- coordinate_num,
- batch_window_num,
- bucket_num_per_window,
- chunk_num,
- ),
- dtype=jnp.uint16,
- ),
- jax.ShapeDtypeStruct(
- (batch_mem_length, coordinate_num, parallel_num, chunk_num),
- dtype=jnp.uint16,
- ),
- jax.ShapeDtypeStruct(
- (batch_mem_length, batch_window_num), dtype=jnp.uint32
- ),
- jax.ShapeDtypeStruct(
- (batch_mem_length, batch_window_num), dtype=jnp.uint8
- ),
- batch_mem_length,
- )
- .compile()
- )
-
- bucket_reduction_scan_jit = (
- jax.jit(
- pippenger.bucket_reduction_scan_algorithm_twisted,
- static_argnames="bucket_num_in_window",
- )
- .lower(
- jax.ShapeDtypeStruct(
- (
- coordinate_num,
- batch_window_num,
- bucket_num_per_window,
- chunk_num,
- ),
- dtype=jnp.uint16,
- ),
- jax.ShapeDtypeStruct(
- (coordinate_num, batch_window_num, chunk_num), dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (coordinate_num, batch_window_num, chunk_num), dtype=jnp.uint16
- ),
- bucket_num_per_window,
- )
- .compile()
- )
-
- batch_window_summation_jit = (
- jax.jit(
- pippenger.batch_window_summation_algorithm_twisted,
- static_argnames="point_parallel",
- )
- .lower(
- jax.ShapeDtypeStruct(
- (coordinate_num, window_num, chunk_num), dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (coordinate_num, batch_window_num, chunk_num), dtype=jnp.uint16
- ),
- parallel_num,
- )
- .compile()
- )
-
- window_merge_scan_jit = (
- jax.jit(
- pippenger.window_merge_scan_algorithm_twisted,
- static_argnames="slice_length",
- )
- .lower(
- jax.ShapeDtypeStruct(
- (coordinate_num, window_num, chunk_num), dtype=jnp.uint16
- ),
- slice_length,
- )
- .compile()
- )
-
- # HERE
- msm_algo.bucket_accumulation(bucket_accumulation_index_scan_jit)
- msm_algo.bucket_reduction(bucket_reduction_scan_jit)
- msm_algo.batch_window_summation(batch_window_summation_jit)
- result = msm_algo.window_merge(window_merge_scan_jit)
- result = util.jax_point_pack_to_int_point(result)
- # TO HERE
- ec_sys = ec.ECCSTwistedEdwardsExtended(config_file.config_BLS12_377_t)
- result_affine_point = ec_sys.generate_point(
- result, twist=False
- ).convert_to_affine()
- coordinates = (
- result_affine_point[0].get_value(),
- result_affine_point[1].get_value(),
- )
- self.assertEqual(coordinates[0], result_ref[0])
- self.assertEqual(coordinates[1], result_ref[1])
-
- # performance measurement
- tasks = [
- (msm_algo.bucket_accumulation, (bucket_accumulation_index_scan_jit,)),
- (msm_algo.bucket_reduction, (bucket_reduction_scan_jit,)),
- (msm_algo.batch_window_summation, (batch_window_summation_jit,)),
- (msm_algo.window_merge, (window_merge_scan_jit,)),
- ]
- profile_name = "pippenger_signed_index_selection_twisted_edwards"
- # copybara: util.profile_jax_functions(tasks, profile_name)
-
- # @absltest.skip("test pass")
- @parameterized.named_parameters(*TEST_PARAMS)
- def test_pippenger_index_rns_selection(
- self, scalar_path, base_path, result_path
- ):
- """RNS version Pippenger - XYZZ."""
- scalars, points, result_ref = self.read_external_file(
- scalar_path, base_path, result_path
- )
- slice_length = 4
- msm_algo = pippenger_rns.MSMPippenger(slice_length)
- msm_algo.initialize(scalars, points)
-
- window_num = msm_algo.window_num
- bucket_num_per_window = msm_algo.bucket_num_per_window
- msm_length = msm_algo.msm_length
- coordinate_num = msm_algo.coordinate_num
- chunk_num = util.NUM_MODULI
-
- bucket_accumulation_index_scan_jit = (
- jax.jit(
- pippenger_rns.bucket_accumulation_index_scan_algorithm,
- static_argnames="msm_length",
- )
- .lower(
- jax.ShapeDtypeStruct(
- (coordinate_num, window_num, bucket_num_per_window, chunk_num),
- dtype=jnp.uint16,
- ),
- jax.ShapeDtypeStruct(
- (msm_length, coordinate_num, chunk_num), dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct((msm_length, window_num), dtype=jnp.uint32),
- jax.ShapeDtypeStruct(
- (msm_length, window_num, bucket_num_per_window), dtype=jnp.uint8
- ),
- msm_length,
- )
- .compile()
- )
-
- bucket_reduction_scan_jit = (
- jax.jit(
- pippenger_rns.bucket_reduction_scan_algorithm,
- static_argnames="bucket_num_in_window",
- )
- .lower(
- jax.ShapeDtypeStruct(
- (coordinate_num, window_num, bucket_num_per_window, chunk_num),
- dtype=jnp.uint16,
- ),
- jax.ShapeDtypeStruct(
- (coordinate_num, window_num, chunk_num), dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (coordinate_num, window_num, chunk_num), dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (bucket_num_per_window, window_num), dtype=jnp.uint8
- ),
- jax.ShapeDtypeStruct(
- (bucket_num_per_window, window_num), dtype=jnp.uint8
- ),
- jax.ShapeDtypeStruct(
- (bucket_num_per_window, window_num), dtype=jnp.uint8
- ),
- bucket_num_per_window,
- )
- .compile()
- )
-
- window_merge_scan_jit = (
- jax.jit(
- pippenger_rns.window_merge_scan_algorithm,
- static_argnames="slice_length",
- )
- .lower(
- jax.ShapeDtypeStruct(
- (coordinate_num, window_num, chunk_num), dtype=jnp.uint16
- ),
- slice_length,
- )
- .compile()
- )
-
- # HERE
- msm_algo.bucket_accumulation(bucket_accumulation_index_scan_jit)
- msm_algo.bucket_reduction(bucket_reduction_scan_jit)
- result = msm_algo.window_merge(window_merge_scan_jit)
- result = util.jax_rns_point_pack_to_int_point(result)
- # TO HERE
- ec_sys = ec.ECCSWeierstrassXYZZ(config_file.config_BLS12_377)
- result_affine_point = ec_sys.generate_point(result).convert_to_affine()
- coordinates = (
- result_affine_point[0].get_value(),
- result_affine_point[1].get_value(),
- )
- self.assertEqual(coordinates[0], result_ref[0])
- self.assertEqual(coordinates[1], result_ref[1])
-
- # performance measurement
- tasks = [
- (msm_algo.bucket_accumulation, (bucket_accumulation_index_scan_jit,)),
- (msm_algo.bucket_reduction, (bucket_reduction_scan_jit,)),
- (msm_algo.window_merge, (window_merge_scan_jit,)),
- ]
- profile_name = "pippenger_index_rns_selection"
- # copybara: util.profile_jax_functions(tasks, profile_name)
-
- # @absltest.skip("has some bug in result")
- @parameterized.named_parameters(*TEST_PARAMS)
- def test_pippenger_index_selection_rns_twisted_edwards(
- self, scalar_path, base_path, result_path
- ):
- scalars, points, result_ref = self.read_external_file(
- scalar_path, base_path, result_path
- )
- twisted_points, untwisted_coordinates_indeices = twist_coordinates_list(
- config_file.config_BLS12_377_t, points
- )
- assert not untwisted_coordinates_indeices
- slice_length = 4
- parallel_num = 4
- msm_algo = pippenger_rns.MSMPippengerTwisted(slice_length, parallel_num)
- msm_algo.initialize(scalars, twisted_points)
-
- window_num = msm_algo.window_num
- bucket_num_per_window = msm_algo.bucket_num_per_window
- msm_length = msm_algo.msm_length
- coordinate_num = msm_algo.coordinate_num
- chunk_num = util.NUM_MODULI
-
- batch_window_num = window_num * parallel_num
- batch_mem_length = msm_length // parallel_num
-
- bucket_accumulation_index_scan_jit = (
- jax.jit(
- pippenger_rns.bucket_accumulation_index_scan_parallel_algorithm_twisted,
- static_argnames="msm_length",
- )
- .lower(
- jax.ShapeDtypeStruct(
- (
- coordinate_num,
- batch_window_num,
- bucket_num_per_window,
- chunk_num,
- ),
- dtype=jnp.uint16,
- ),
- jax.ShapeDtypeStruct(
- (batch_mem_length, coordinate_num, parallel_num, chunk_num),
- dtype=jnp.uint16,
- ),
- jax.ShapeDtypeStruct(
- (batch_mem_length, batch_window_num), dtype=jnp.uint32
- ),
- batch_mem_length,
- )
- .compile()
- )
-
- bucket_reduction_scan_jit = (
- jax.jit(
- pippenger_rns.bucket_reduction_scan_algorithm_twisted,
- static_argnames="bucket_num_in_window",
- )
- .lower(
- jax.ShapeDtypeStruct(
- (
- coordinate_num,
- batch_window_num,
- bucket_num_per_window,
- chunk_num,
- ),
- dtype=jnp.uint16,
- ),
- jax.ShapeDtypeStruct(
- (coordinate_num, batch_window_num, chunk_num), dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (coordinate_num, batch_window_num, chunk_num), dtype=jnp.uint16
- ),
- bucket_num_per_window,
- )
- .compile()
- )
-
- batch_window_summation_jit = (
- jax.jit(
- pippenger_rns.batch_window_summation_algorithm_twisted,
- static_argnames="point_parallel",
- )
- .lower(
- jax.ShapeDtypeStruct(
- (coordinate_num, window_num, chunk_num), dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (coordinate_num, batch_window_num, chunk_num), dtype=jnp.uint16
- ),
- parallel_num,
- )
- .compile()
- )
-
- window_merge_scan_jit = (
- jax.jit(
- pippenger_rns.window_merge_scan_algorithm_twisted,
- static_argnames="slice_length",
- )
- .lower(
- jax.ShapeDtypeStruct(
- (coordinate_num, window_num, chunk_num), dtype=jnp.uint16
- ),
- slice_length,
- )
- .compile()
- )
-
- # HERE
- msm_algo.bucket_accumulation(bucket_accumulation_index_scan_jit)
- msm_algo.bucket_reduction(bucket_reduction_scan_jit)
- msm_algo.batch_window_summation(batch_window_summation_jit)
- result = msm_algo.window_merge(window_merge_scan_jit)
- result = util.jax_rns_point_pack_to_int_point(result)
- # TO HERE
- ec_sys = ec.ECCSTwistedEdwardsExtended(config_file.config_BLS12_377_t)
- result_affine_point = ec_sys.generate_point(
- result, twist=False
- ).convert_to_affine()
- coordinates = (
- result_affine_point[0].get_value() % util.MODULUS_377_INT,
- result_affine_point[1].get_value() % util.MODULUS_377_INT,
- )
- self.assertEqual(coordinates[0], result_ref[0])
- self.assertEqual(coordinates[1], result_ref[1])
-
- # performance measurement
- tasks = [
- (msm_algo.bucket_accumulation, (bucket_accumulation_index_scan_jit,)),
- (msm_algo.bucket_reduction, (bucket_reduction_scan_jit,)),
- (msm_algo.batch_window_summation, (batch_window_summation_jit,)),
- (msm_algo.window_merge, (window_merge_scan_jit,)),
- ]
- profile_name = "pippenger_index_selection_rns_twisted_edwards"
- # copybara: util.profile_jax_functions(tasks, profile_name)
-
-
-if __name__ == "__main__":
- absltest.main()
diff --git a/jaxite_ec/multiscalar_multiplication_context.py b/jaxite_ec/multiscalar_multiplication_context.py
new file mode 100644
index 0000000..a78252e
--- /dev/null
+++ b/jaxite_ec/multiscalar_multiplication_context.py
@@ -0,0 +1,2772 @@
+from abc import ABC, abstractmethod
+import ctypes
+import math
+import random
+from typing import Any, Optional
+import warnings
+import jax
+import jax.ffi as ffi
+import jax.numpy as jnp
+from jaxite.jaxite_ec.c_kernels.build import ensure_distribution_kernel
+from jaxite.jaxite_ec.elliptic_curve_context import EllipticCurveContextBase
+import jaxite.jaxite_ec.utils as utils
+from jaxite.jaxite_ec.utils import JaxKernelContextBase, hash_args, jax_jit_lower_compile, load_jax_executable, store_jax_executable
+import numpy as np
+
+jax.config.update("jax_enable_x64", True)
+
+
+class MultiscalarMultiplicationContextBase(ABC):
+
+ def __init__(self, parameters: dict):
+ self.parameters = parameters
+ self.ec_ctx_class = parameters.get("elliptic_curve_context_class", None)
+ assert (
+ self.ec_ctx_class is not None
+ ), "elliptic_curve_context_class must be provided"
+ self.ec_ctx: EllipticCurveContextBase = self.ec_ctx_class(
+ parameters.get("elliptic_curve_parameters", {})
+ )
+
+ def _padd(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray:
+ return self.ec_ctx.point_add(a, b)
+
+ @abstractmethod
+ def to_computational_format(self, a: Any) -> Any:
+ pass
+
+ @abstractmethod
+ def to_original_format(self, a: Any) -> Any:
+ pass
+
+ @abstractmethod
+ def multiscalar_multiply(self, tiled_slices: jnp.ndarray):
+ pass
+
+
+class CPUDistributionMSMContextBase(MultiscalarMultiplicationContextBase):
+
+ def __init__(self, parameters: dict):
+ super().__init__(parameters)
+ self.scalar_bits = 0
+ self.slice_bits = 0
+ self.tile_length = 0
+ self._init_config_parameters()
+ self._init_jax_data()
+ self._init_cpu_kernels()
+ self._init_point_parameters()
+
+ def _init_config_parameters(self):
+ pass
+
+ def _init_jax_data(self):
+ pass
+
+ def _init_cpu_kernels(self):
+ pass
+
+ def _init_point_parameters(self):
+ raw_points = utils.read_external_msm_file(
+ self.parameters.get("points_path"), "points"
+ )
+ with jax.default_device(jax.devices("cpu")[0]):
+ points = self.ec_ctx.to_computational_format(raw_points).to_device(
+ jax.devices("cpu")[0]
+ )
+ self.points = points.transpose(1, 0, 2)
+
+ def _preprocess_scalars(self, scalars: list):
+ tiled_scalar_list = utils.split_list(scalars, self.tile_length)
+ tiled_slices_list = []
+ for tiled_scalars in tiled_scalar_list:
+ sliced_scalars = utils.slice_scalars(
+ tiled_scalars, self.scalar_bits, self.slice_bits
+ )
+ tiled_slices_list.append(sliced_scalars)
+
+ with jax.default_device(jax.devices("cpu")[0]):
+ tiled_slices_list = jnp.array(tiled_slices_list, dtype=jnp.int32)
+ return tiled_slices_list
+
+
+class OldCPUDistributionMSMContext(CPUDistributionMSMContextBase):
+
+ def __init__(self, parameters: dict):
+ CPUDistributionMSMContextBase.__init__(self, parameters)
+
+ def _init_config_parameters(self):
+ self.coordinate_dim = self.parameters.get("coordinate_dim", 4)
+ self.msm_length = self.parameters.get("msm_length")
+ self.tile_length = self.parameters.get("tile_length")
+ assert (
+ self.msm_length % self.tile_length == 0
+ ), "msm_length must be divisible by tile_length"
+ self.tile_num = self.msm_length // self.tile_length
+ self.slice_bits = self.parameters.get("slice_bits")
+ self.scalar_bits = self.parameters.get("scalar_bits")
+ self.order = self.parameters.get("order")
+ self.window_num = int(math.ceil(self.scalar_bits / self.slice_bits)) #
+ self.batch_window_num = self.window_num
+ self.bucket_num_per_window = (
+ 2**self.slice_bits - 1
+ ) # Note: here remove the bucket_0
+ self.bucket_num_last_window = self.order >> (
+ (self.window_num - 1) * self.slice_bits
+ )
+ self.moduli_num = self.ec_ctx.get_finite_field_context().get_moduli_num()
+
+ # Special bucket optimization
+ self.log_special_duplication_ratio = math.ceil(
+ math.log2(self.bucket_num_per_window / self.bucket_num_last_window)
+ )
+ self.special_duplication_ratio = 2**self.log_special_duplication_ratio
+ self.bucket_num_duplication = (
+ self.bucket_num_last_window * self.special_duplication_ratio
+ )
+
+ def _init_jax_data(self):
+ self.zero_point = (
+ self.ec_ctx.get_finite_field_context().to_computational_format(
+ self.ec_ctx.zero_point
+ )
+ )
+ self.all_buckets = jnp.broadcast_to(
+ self.zero_point.reshape(1, self.coordinate_dim, 1, self.moduli_num),
+ (
+ self.batch_window_num,
+ self.coordinate_dim,
+ self.bucket_num_per_window,
+ self.moduli_num,
+ ),
+ )
+ self.temp_sum = jnp.array(
+ [self.zero_point for _ in range(self.batch_window_num)]
+ ).transpose(1, 0, 2)
+ self.window_sum = jnp.array(
+ [self.zero_point for _ in range(self.batch_window_num)]
+ ).transpose(1, 0, 2)
+
+ def _init_cpu_kernels(self):
+ expected_regular_bucket_size = self.tile_length / (
+ self.bucket_num_per_window + 1
+ )
+ expected_special_bucket_size = math.ceil(
+ self.tile_length / self.bucket_num_duplication
+ )
+ # print("expected_regular_bucket_size", expected_regular_bucket_size)
+ # print("expected_special_bucket_size", expected_special_bucket_size)
+ self.expend_ratio = self.parameters["c_kernel_ret_space_ratio"]
+ self.c_kernel_regular_bucket_size = int(
+ expected_regular_bucket_size * self.expend_ratio
+ )
+ self.c_kernel_special_bucket_size = int(
+ expected_special_bucket_size * self.expend_ratio
+ )
+
+ regular_shape = (
+ self.window_num - 1,
+ self.bucket_num_per_window,
+ self.c_kernel_regular_bucket_size,
+ self.coordinate_dim,
+ self.moduli_num,
+ )
+
+ special_shape = (
+ self.bucket_num_last_window,
+ self.special_duplication_ratio,
+ self.c_kernel_special_bucket_size,
+ self.coordinate_dim,
+ self.moduli_num,
+ )
+ self.ba_input_regular_shape = regular_shape
+ self.ba_input_special_shape = special_shape
+
+ lib = ctypes.cdll.LoadLibrary(ensure_distribution_kernel())
+ jax.ffi.register_ffi_target(
+ "distribute_buf", jax.ffi.pycapsule(lib.DistributeBuf), platform="cpu"
+ )
+ self.distribution_buf_c_kernel_call = ffi.ffi_call(
+ "distribute_buf",
+ (
+ jax.ShapeDtypeStruct(regular_shape, jnp.uint32),
+ jax.ShapeDtypeStruct(special_shape, jnp.uint32),
+ jax.ShapeDtypeStruct((2,), jnp.uint32),
+ ),
+ )
+
+ def _bucket_accumulation_per_window(
+ self, buckets: jnp.ndarray, window_points: jnp.ndarray
+ ) -> jnp.ndarray:
+ """Accumulate points within buckets for a single window.
+
+ Args:
+ buckets: Initial bucket values (coordinate_dim, bucket_dim,
+ precision_dim).
+ window_points: Points to accumulate (bucket_size_dim, coordinate_dim,
+ bucket_dim, precision_dim).
+ parameters: Computation parameters.
+
+ Returns:
+ Accumulated bucket values.
+ """
+ bucket_size_dim = window_points.shape[0]
+
+ def scan_body(buckets, points):
+ buckets = self._padd(buckets, points)
+ return buckets, None
+
+ buckets, _ = jax.lax.scan(
+ scan_body, buckets, window_points, length=bucket_size_dim
+ )
+ return buckets
+
+ def _bucket_accumulation_regular_windows(
+ self,
+ regular_buckets: jnp.ndarray,
+ all_points: jnp.ndarray,
+ ) -> jnp.ndarray:
+ """Accumulate points for all regular windows.
+
+ Args:
+ regular_buckets: Initial bucket values (window_dim, coordinate_dim,
+ bucket_dim, precision_dim).
+ all_points: Points for all windows (window_dim, bucket_size_dim,
+ coordinate_dim, bucket_dim, precision_dim).
+
+ Returns:
+ Accumulated bucket values for all regular windows.
+ """
+ window_dim = regular_buckets.shape[0]
+
+ def scan_body(empty, window_bucket_point_pack):
+ window_buckets, points = window_bucket_point_pack
+ buckets = self._bucket_accumulation_per_window(window_buckets, points)
+ return None, buckets
+
+ _, buckets = jax.lax.scan(
+ scan_body, None, (regular_buckets, all_points), length=window_dim
+ )
+ return buckets
+
+ def _bucket_accumulation_last_window(
+ self,
+ buckets_in: jnp.ndarray,
+ window_points: jnp.ndarray,
+ ) -> jnp.ndarray:
+ """Accumulate points for the last (special) window with duplication handling.
+
+ Args:
+ buckets_in: Initial bucket values (coordinate_dim, bucket_dim,
+ precision_dim).
+ window_points: Points with duplication (bucket_size_dim, coordinate_dim,
+ bucket_dup_dim, bucket_dim, precision_dim).
+ parameters: Computation parameters.
+
+ Returns:
+ Accumulated bucket values.
+ """
+ coordinate_dim, buckets_dim, precision_dim = buckets_in.shape
+ bucket_size_dim, _, bucket_dup_dim, _, _ = window_points.shape
+
+ # Reshape for processing
+ window_points = window_points.reshape(
+ bucket_size_dim, coordinate_dim, -1, precision_dim
+ )
+ base_dup_buckets = window_points[0]
+
+ def scan_body(buckets, points):
+ buckets = self._padd(buckets, points)
+ return buckets, None
+
+ dup_buckets, _ = jax.lax.scan(
+ scan_body,
+ base_dup_buckets,
+ window_points[1:],
+ length=bucket_size_dim - 1,
+ )
+
+ # Reduce duplicated buckets using tree reduction
+ log_bucket_dup_dim = int(math.log2(bucket_dup_dim))
+ for _ in range(log_bucket_dup_dim):
+ buckets_split = jnp.split(dup_buckets, 2, axis=1)
+ dup_buckets = self._padd(buckets_split[0], buckets_split[1])
+
+ # Add to input buckets
+ buckets_in = self._padd(buckets_in, dup_buckets)
+ return buckets_in
+
+ def _bucket_accumulation_all_windows(
+ self,
+ all_buckets: jnp.ndarray,
+ regular_points: jnp.ndarray,
+ last_window_points: jnp.ndarray,
+ ) -> jnp.ndarray:
+ """Accumulate points for all windows with distributed optimization.
+
+ This is the main bucket accumulation kernel that handles both regular
+ windows and the special last window with duplication optimization.
+
+ Args:
+ all_buckets: Initial bucket values (window_dim, coordinate_dim,
+ bucket_dim, precision_dim).
+ regular_points: Points for regular windows (window_dim-1, bucket_num,
+ bucket_size, coord_dim, prec_dim).
+ last_window_points: Points for last window (bucket_num, dup_ratio,
+ bucket_size, coord_dim, prec_dim).
+ parameters: Computation parameters.
+
+ Returns:
+ Accumulated bucket values for all windows.
+ """
+ # Transpose for computation
+ regular_points = regular_points.transpose(0, 2, 3, 1, 4)
+ last_window_points = last_window_points.transpose(2, 3, 1, 0, 4)
+
+ window_dim, coordinate_dim, buckets_dim, precision_dim = all_buckets.shape
+ _, _, last_window_bucket_dup, last_window_bucket_dim, _ = (
+ last_window_points.shape
+ )
+
+ # Process regular windows
+ regular_buckets = all_buckets[: window_dim - 1]
+ regular_buckets = self._bucket_accumulation_regular_windows(
+ regular_buckets, regular_points
+ )
+
+ # Process last window
+ last_point_buckets = all_buckets[
+ window_dim - 1, :, :last_window_bucket_dim, :
+ ]
+ last_blank_buckets = all_buckets[
+ window_dim - 1, :, last_window_bucket_dim:, :
+ ]
+ last_point_buckets = self._bucket_accumulation_last_window(
+ last_point_buckets, last_window_points
+ )
+
+ # Combine last window buckets
+ last_buckets = jax.lax.broadcast(
+ jnp.concatenate((last_point_buckets, last_blank_buckets), axis=1), (1,)
+ )
+
+ # Combine all buckets
+ all_buckets = jnp.concatenate((regular_buckets, last_buckets), axis=0)
+ return all_buckets
+
+ def _bucket_reduction(
+ self,
+ all_buckets: jnp.ndarray,
+ temp_sum: jnp.ndarray,
+ window_sum: jnp.ndarray,
+ ) -> jnp.ndarray:
+ """Reduce buckets to window sums using scan algorithm.
+
+ Implements the bucket reduction phase of Pippenger's algorithm using
+ a scan-based approach for efficiency.
+
+ Args:
+ all_buckets: Bucket values (window_dim, coordinate_dim, bucket_dim,
+ precision_dim).
+ temp_sum: Temporary sum array (coordinate_dim, window_dim,
+ precision_dim).
+ window_sum: Initial window sum (coordinate_dim, window_dim,
+ precision_dim).
+ bucket_num_in_window: Number of buckets per window.
+ parameters: Computation parameters.
+
+ Returns:
+ Window sums (coordinate_dim, window_dim, precision_dim).
+ """
+ # Transpose for scan
+ bucket_num_in_window = all_buckets.shape[2]
+ all_buckets = all_buckets.transpose(2, 1, 0, 3)
+
+ def scan_body(temp_and_window_sum_pack, buckets):
+ temp_sum, window_sum = temp_and_window_sum_pack
+ temp_sum = self._padd(temp_sum, buckets)
+ window_sum = self._padd(window_sum, temp_sum)
+ return (temp_sum, window_sum), None
+
+ (_, window_sum), _ = jax.lax.scan(
+ scan_body,
+ (temp_sum, window_sum),
+ all_buckets[:bucket_num_in_window],
+ length=bucket_num_in_window,
+ reverse=True,
+ )
+ return window_sum
+
+ def _window_merge(self, window_sum: jnp.ndarray) -> jnp.ndarray:
+ """Merge window results into final MSM result using scan algorithm.
+
+ Implements the window merging phase of Pippenger's algorithm.
+
+ Args:
+ window_sum: Window sums (coordinate_dim, window_dim, precision_dim).
+ slice_length: Bit width of each window.
+ parameters: Computation parameters.
+
+ Returns:
+ Final MSM result (coordinate_dim, precision_dim).
+ """
+ coordinate_dim, window_dim, precision_dim = window_sum.shape
+ window_sum = window_sum.transpose(1, 0, 2).reshape(
+ (window_dim, coordinate_dim, 1, precision_dim)
+ )
+ result = window_sum[window_dim - 1]
+
+ def fori_loop_body(i, result):
+ result = self._padd(result, result)
+ return result
+
+ def scan_body(result, window_sum):
+ result = jax.lax.fori_loop(0, self.slice_bits, fori_loop_body, result)
+ result = self._padd(result, window_sum)
+ return result, None
+
+ result, _ = jax.lax.scan(
+ scan_body,
+ result,
+ window_sum[: window_dim - 1],
+ reverse=True,
+ length=window_dim - 1,
+ )
+ result = result.reshape((coordinate_dim, precision_dim))
+ return result
+
+ def distribute_buckets(
+ self, tiled_slices: jnp.ndarray, tiled_points: jnp.ndarray
+ ) -> tuple[jnp.ndarray, jnp.ndarray]:
+ with jax.default_device(jax.devices("cpu")[0]):
+ tiled_slices = jax.device_put(tiled_slices, jax.devices("cpu")[0])
+ tiled_points = jax.device_put(tiled_points, jax.devices("cpu")[0])
+ zero_point = jax.device_put(self.zero_point, jax.devices("cpu")[0])
+ regular_buckets, last_window_buckets, metadata = (
+ self.distribution_buf_c_kernel_call(
+ tiled_slices,
+ tiled_points,
+ zero_point,
+ window_num=np.uint32(self.window_num),
+ regular_bucket_num=np.uint32(self.bucket_num_per_window),
+ special_bucket_num=np.uint32(self.bucket_num_last_window),
+ msm_length=np.uint32(self.tile_length),
+ fixed_regular_padding_size=np.uint32(
+ self.c_kernel_regular_bucket_size
+ ),
+ fixed_special_padding_size=np.uint32(
+ self.c_kernel_special_bucket_size
+ * self.special_duplication_ratio
+ ),
+ )
+ )
+ return regular_buckets, last_window_buckets
+
+ def multiscalar_multiply(self, tiled_slices: jnp.ndarray):
+
+ for tile_index in range(self.tile_num):
+ idx_start = tile_index * self.tile_length
+ idx_end = idx_start + self.tile_length
+ tiled_slices_tile = tiled_slices[tile_index]
+ tiled_points_tile = self.points[idx_start:idx_end]
+ regular_buckets, last_window_buckets = self.distribute_buckets(
+ tiled_slices_tile, tiled_points_tile
+ )
+ regular_buckets = jax.device_put(regular_buckets, jax.devices()[0])
+ last_window_buckets = jax.device_put(
+ last_window_buckets, jax.devices()[0]
+ )
+ self.all_buckets = self._bucket_accumulation_all_windows(
+ self.all_buckets, regular_buckets, last_window_buckets
+ )
+
+ window_sum = self._bucket_reduction(
+ self.all_buckets, self.temp_sum, self.window_sum
+ )
+ result = self._window_merge(window_sum)
+ return result
+
+ def to_original_format(self, a: jnp.ndarray) -> list:
+ return self.ec_ctx.to_original_format(a)
+
+ def to_computational_format(self, scalars: list) -> jnp.ndarray:
+ tiled_slices = self._preprocess_scalars(scalars)
+ return tiled_slices
+
+
+class CPUDistributionMSMContext(
+ OldCPUDistributionMSMContext, JaxKernelContextBase
+):
+
+ def __init__(self, parameters: dict):
+ super().__init__(parameters)
+ JaxKernelContextBase.__init__(self)
+
+ def _init_config_parameters(self):
+ self.coordinate_dim = self.parameters.get("coordinate_dim", 4)
+ self.msm_length = self.parameters.get("msm_length")
+ self.tile_length = self.parameters.get("tile_length")
+ assert (
+ self.msm_length % self.tile_length == 0
+ ), "msm_length must be divisible by tile_length"
+ self.tile_num = self.msm_length // self.tile_length
+ self.slice_bits = self.parameters.get("slice_bits")
+ self.scalar_bits = self.parameters.get("scalar_bits")
+ self.order = self.parameters.get("order")
+ self.window_num = int(math.ceil(self.scalar_bits / self.slice_bits)) #
+ self.batch_window_num = self.window_num
+ self.bucket_num_per_window = 2**self.slice_bits # Note: Include Bucket 0
+ orig_bucket_num_last_window = (
+ self.order >> ((self.window_num - 1) * self.slice_bits)
+ ) + 1 # Note: Include Bucket 0
+ # Now pad to nearest value so bucket_num_last_window % 8 == 0
+ added_padding = (8 - orig_bucket_num_last_window % 8) % 8
+ if added_padding > 0:
+ if added_padding / orig_bucket_num_last_window > 0.1:
+ print(
+ f"[bucket_num_last_window] Added {added_padding} to make it"
+ " divisible by 8, but it is too large for"
+ f" {orig_bucket_num_last_window}. Setting to 0."
+ )
+ added_padding = 0
+ else:
+ print(
+ f"[bucket_num_last_window] Was {orig_bucket_num_last_window}, "
+ f"added {added_padding} to make it divisible by 8"
+ )
+ self.bucket_num_last_window = orig_bucket_num_last_window + added_padding
+ self.moduli_num = self.ec_ctx.get_finite_field_context().get_moduli_num()
+
+ # Special bucket optimization
+ self.log_special_duplication_ratio = math.ceil(
+ math.log2(self.bucket_num_per_window / self.bucket_num_last_window)
+ )
+ self.special_duplication_ratio = 2**self.log_special_duplication_ratio
+ self.bucket_num_duplication = (
+ self.bucket_num_last_window * self.special_duplication_ratio
+ )
+
+ def _init_jax_data(self):
+ self.zero_point = (
+ self.ec_ctx.get_finite_field_context().to_computational_format(
+ self.ec_ctx.zero_point
+ )
+ )
+ self.all_buckets = jnp.broadcast_to(
+ self.zero_point.reshape(1, self.coordinate_dim, 1, self.moduli_num),
+ (
+ self.batch_window_num,
+ self.coordinate_dim,
+ self.bucket_num_per_window,
+ self.moduli_num,
+ ),
+ )
+ self.temp_sum = jnp.array(
+ [self.zero_point for _ in range(self.batch_window_num)]
+ ).transpose(1, 0, 2)
+ self.window_sum = jnp.array(
+ [self.zero_point for _ in range(self.batch_window_num)]
+ ).transpose(1, 0, 2)
+ self.window_sum = jnp.broadcast_to(
+ self.window_sum.reshape(
+ self.coordinate_dim, 1, self.batch_window_num, self.moduli_num
+ ),
+ (
+ self.coordinate_dim,
+ self.bucket_num_per_window,
+ self.batch_window_num,
+ self.moduli_num,
+ ),
+ )
+
+ # Pre-store shardings for bucket accumulation inputs
+ self.ba_all_buckets_sharding = None
+ self.ba_regular_points_sharding = None
+ self.ba_special_points_sharding = None
+ # Pre-store shardings for bucket reduction inputs
+ self.br_all_buckets_sharding = None
+ self.br_window_sum_sharding = None
+ # Pre-store shardings for window merge input
+ self.wm_window_sum_sharding = None
+
+ def _init_shardings(self):
+ """Pre-compute and store all shardings for compiled/sharded execution."""
+ all_buckets_shape = (
+ self.batch_window_num,
+ self.coordinate_dim,
+ self.bucket_num_per_window,
+ self.moduli_num,
+ )
+ regular_points_shape = self.ba_input_regular_shape
+ special_points_shape = self.ba_input_special_shape
+ window_sum_shape = (
+ self.coordinate_dim,
+ self.bucket_num_per_window,
+ self.batch_window_num,
+ self.moduli_num,
+ )
+ wm_window_sum_shape = (
+ self.coordinate_dim,
+ self.batch_window_num,
+ self.moduli_num,
+ )
+
+ # Bucket accumulation shardings
+ self.ba_all_buckets_sharding, _ = self.create_named_sharding(
+ shape=all_buckets_shape, axes=[2]
+ )
+ self.ba_regular_points_sharding, _ = self.create_named_sharding(
+ shape=regular_points_shape, axes=[1]
+ )
+ self.ba_special_points_sharding, _ = self.create_named_sharding(
+ shape=special_points_shape, axes=[0, 1]
+ )
+ # Bucket reduction shardings
+ self.br_all_buckets_sharding, _ = self.create_named_sharding(
+ shape=all_buckets_shape, axes=[2]
+ )
+ self.br_window_sum_sharding, _ = self.create_named_sharding(
+ shape=window_sum_shape, axes=[1]
+ )
+ # Window merge shardings
+ self.wm_window_sum_sharding, _ = self.create_named_sharding(
+ shape=wm_window_sum_shape, axes=[]
+ )
+
+ def set_use_sharding(self, use_sharding: bool):
+ super().set_use_sharding(use_sharding)
+ if use_sharding:
+ self._init_shardings()
+
+ def _init_cpu_kernels(self):
+ expected_regular_bucket_size = self.tile_length / (
+ self.bucket_num_per_window
+ )
+ expected_special_bucket_size = math.ceil(
+ self.tile_length / self.bucket_num_duplication
+ )
+ self.expend_ratio = self.parameters["c_kernel_ret_space_ratio"]
+ self.c_kernel_regular_bucket_size = int(
+ expected_regular_bucket_size * self.expend_ratio
+ )
+ self.c_kernel_special_bucket_size = int(
+ expected_special_bucket_size * self.expend_ratio
+ )
+
+ regular_shape = (
+ self.window_num - 1,
+ self.bucket_num_per_window,
+ self.c_kernel_regular_bucket_size,
+ self.coordinate_dim,
+ self.moduli_num,
+ )
+
+ special_shape = (
+ self.bucket_num_last_window,
+ self.special_duplication_ratio,
+ self.c_kernel_special_bucket_size,
+ self.coordinate_dim,
+ self.moduli_num,
+ )
+ self.ba_input_regular_shape = regular_shape
+ self.ba_input_special_shape = special_shape
+
+ lib = ctypes.cdll.LoadLibrary(ensure_distribution_kernel())
+ jax.ffi.register_ffi_target(
+ "distribute_buf",
+ jax.ffi.pycapsule(lib.DistributeBufZero),
+ platform="cpu",
+ )
+ self.distribution_buf_c_kernel_call = ffi.ffi_call(
+ "distribute_buf",
+ (
+ jax.ShapeDtypeStruct(regular_shape, jnp.uint32),
+ jax.ShapeDtypeStruct(special_shape, jnp.uint32),
+ jax.ShapeDtypeStruct((2,), jnp.uint32),
+ ),
+ )
+
+ def _bucket_accumulation_regular_windows_opt(
+ self,
+ regular_buckets: jnp.ndarray,
+ all_points: jnp.ndarray,
+ ) -> jnp.ndarray:
+ """Accumulate points for all regular windows.
+
+ Args:
+ regular_buckets: Initial bucket values (window_dim, coordinate_dim,
+ bucket_dim, precision_dim).
+ all_points: Points for all windows (window_dim, bucket_size_dim,
+ coordinate_dim, bucket_dim, precision_dim).
+
+ Returns:
+ Accumulated bucket values for all regular windows.
+ """
+
+ bucket_size_dim = all_points.shape[1]
+ regular_buckets = regular_buckets.transpose(1, 0, 2, 3)
+ all_points = all_points.transpose(1, 2, 0, 3, 4)
+
+ def scan_body(buckets, points):
+ buckets = self._padd(buckets, points)
+ return buckets, None
+
+ buckets, _ = jax.lax.scan(
+ scan_body, regular_buckets, all_points, length=bucket_size_dim
+ )
+ buckets = buckets.transpose(1, 0, 2, 3)
+ return buckets
+
+ def _bucket_accumulation_all_windows_opt(
+ self,
+ all_buckets: jnp.ndarray,
+ regular_points: jnp.ndarray,
+ last_window_points: jnp.ndarray,
+ ) -> jnp.ndarray:
+ """Accumulate points for all windows with distributed optimization.
+
+ This is the main bucket accumulation kernel that handles both regular
+ windows and the special last window with duplication optimization.
+
+ Args:
+ all_buckets: Initial bucket values (window_dim, coordinate_dim,
+ bucket_dim, precision_dim).
+ regular_points: Points for regular windows (window_dim-1, bucket_num,
+ bucket_size, coord_dim, prec_dim).
+ last_window_points: Points for last window (bucket_num, dup_ratio,
+ bucket_size, coord_dim, prec_dim).
+ parameters: Computation parameters.
+
+ Returns:
+ Accumulated bucket values for all windows.
+ """
+ # Transpose for computation
+ regular_points = regular_points.transpose(0, 2, 3, 1, 4)
+ last_window_points = last_window_points.transpose(2, 3, 1, 0, 4)
+
+ window_dim, coordinate_dim, buckets_dim, precision_dim = all_buckets.shape
+ _, _, last_window_bucket_dup, last_window_bucket_dim, _ = (
+ last_window_points.shape
+ )
+
+ # Process regular windows
+ regular_buckets = all_buckets[: window_dim - 1]
+ regular_buckets = self._bucket_accumulation_regular_windows_opt(
+ regular_buckets, regular_points
+ )
+
+ # Process last window
+ last_point_buckets = all_buckets[
+ window_dim - 1, :, :last_window_bucket_dim, :
+ ]
+ last_blank_buckets = all_buckets[
+ window_dim - 1, :, last_window_bucket_dim:, :
+ ]
+ last_point_buckets = self._bucket_accumulation_last_window(
+ last_point_buckets, last_window_points
+ )
+
+ # Combine last window buckets
+ last_buckets = jax.lax.broadcast(
+ jnp.concatenate((last_point_buckets, last_blank_buckets), axis=1), (1,)
+ )
+
+ # Combine all buckets
+ all_buckets = jnp.concatenate((regular_buckets, last_buckets), axis=0)
+ return all_buckets
+
+ def _bucket_reduction(
+ self,
+ all_buckets: jnp.ndarray,
+ temp_sum: jnp.ndarray,
+ window_sum: jnp.ndarray,
+ ) -> jnp.ndarray:
+ """Reduce buckets to window sums using tree-based parallel algorithm.
+
+ Computes S = sum_{i=0}^{n-1} i * B[i] per window via adjacent-pair
+ tree reduction in O(log n) parallel steps.
+
+ Tracks H = 2^k * bucket_sums (doubling each level) so that H[1::2]
+ provides the correction weight directly. No separate bucket array needed.
+
+ Args:
+ all_buckets: Bucket values (window_dim, coordinate_dim, bucket_dim,
+ precision_dim).
+ temp_sum: Unused (kept for interface compatibility).
+ window_sum: Initial window sum (coordinate_dim, bucket_dim, window_dim,
+ precision_dim).
+
+ Returns:
+ Window sums (coordinate_dim, window_dim, precision_dim).
+ """
+ bucket_num_in_window = all_buckets.shape[2]
+ # (window, coord, bucket, prec) → (coord, bucket, window, prec)
+ all_buckets = all_buckets.transpose(1, 2, 0, 3)
+ iter_num = int(math.log2(bucket_num_in_window))
+
+ cd, m, wd, pd = all_buckets.shape
+ for _ in range(iter_num):
+ m = all_buckets.shape[1]
+ half = m // 2
+
+ all_buckets = all_buckets.reshape(cd, half, 2, wd, pd)
+ window_sum = window_sum.reshape(cd, half, 2, wd, pd)
+
+ all_buckets_left, all_buckets_right = (
+ all_buckets[:, :, 0],
+ all_buckets[:, :, 1],
+ )
+ window_sum_left, window_sum_right = (
+ window_sum[:, :, 0],
+ window_sum[:, :, 1],
+ )
+
+ window_sum = self._padd(
+ self._padd(window_sum_left, window_sum_right), all_buckets_right
+ )
+ bucket_sum = self._padd(all_buckets_left, all_buckets_right)
+ all_buckets = self._padd(bucket_sum, bucket_sum)
+
+ return window_sum[:, 0]
+
+ def _get_ba_shape_dtype_structs(self):
+ """Get ShapeDtypeStructs for bucket_accumulation_all_windows inputs."""
+ all_buckets_shape = (
+ self.batch_window_num,
+ self.coordinate_dim,
+ self.bucket_num_per_window,
+ self.moduli_num,
+ )
+ regular_points_shape = self.ba_input_regular_shape
+ special_points_shape = self.ba_input_special_shape
+
+ if self.use_sharding:
+ return [
+ jax.ShapeDtypeStruct(
+ all_buckets_shape,
+ jnp.uint32,
+ sharding=self.ba_all_buckets_sharding,
+ ),
+ jax.ShapeDtypeStruct(
+ regular_points_shape,
+ jnp.uint32,
+ sharding=self.ba_regular_points_sharding,
+ ),
+ jax.ShapeDtypeStruct(
+ special_points_shape,
+ jnp.uint32,
+ sharding=self.ba_special_points_sharding,
+ ),
+ ]
+ return [
+ jax.ShapeDtypeStruct(all_buckets_shape, jnp.uint32),
+ jax.ShapeDtypeStruct(regular_points_shape, jnp.uint32),
+ jax.ShapeDtypeStruct(special_points_shape, jnp.uint32),
+ ]
+
+ def _get_br_shape_dtype_structs(self):
+ """Get ShapeDtypeStructs for bucket_reduction inputs."""
+ all_buckets_shape = (
+ self.batch_window_num,
+ self.coordinate_dim,
+ self.bucket_num_per_window,
+ self.moduli_num,
+ )
+ temp_sum_shape = (
+ self.coordinate_dim,
+ self.batch_window_num,
+ self.moduli_num,
+ )
+ window_sum_shape = (
+ self.coordinate_dim,
+ self.bucket_num_per_window,
+ self.batch_window_num,
+ self.moduli_num,
+ )
+
+ if self.use_sharding:
+ return [
+ jax.ShapeDtypeStruct(
+ all_buckets_shape,
+ jnp.uint32,
+ sharding=self.br_all_buckets_sharding,
+ ),
+ jax.ShapeDtypeStruct(temp_sum_shape, jnp.uint32),
+ jax.ShapeDtypeStruct(
+ window_sum_shape, jnp.uint32, sharding=self.br_window_sum_sharding
+ ),
+ ]
+ return [
+ jax.ShapeDtypeStruct(all_buckets_shape, jnp.uint32),
+ jax.ShapeDtypeStruct(temp_sum_shape, jnp.uint32),
+ jax.ShapeDtypeStruct(window_sum_shape, jnp.uint32),
+ ]
+
+ def _get_wm_shape_dtype_structs(self):
+ """Get ShapeDtypeStructs for window_merge input."""
+ window_sum_shape = (
+ self.coordinate_dim,
+ self.batch_window_num,
+ self.moduli_num,
+ )
+
+ if self.use_sharding:
+ return [
+ jax.ShapeDtypeStruct(
+ window_sum_shape, jnp.uint32, sharding=self.wm_window_sum_sharding
+ )
+ ]
+ return [jax.ShapeDtypeStruct(window_sum_shape, jnp.uint32)]
+
+ def context_hash(self) -> str:
+ return hash_args(
+ self.__class__.__name__,
+ self.ec_ctx.context_hash()
+ if hasattr(self.ec_ctx, "context_hash")
+ else str(self.ec_ctx.__class__.__name__),
+ self.slice_bits,
+ self.scalar_bits,
+ self.msm_length,
+ self.tile_length,
+ self.bucket_num_per_window,
+ self.bucket_num_last_window,
+ self.use_sharding,
+ )
+
+ def serialize(self, parameters: Optional[dict] = None):
+ ba_structs = self._get_ba_shape_dtype_structs()
+ br_structs = self._get_br_shape_dtype_structs()
+ wm_structs = self._get_wm_shape_dtype_structs()
+ kernel_hash = hash_args(
+ self.context_hash(), parameters if parameters else {}
+ )
+ class_name = self.__class__.__name__
+
+ store_jax_executable(
+ self._bucket_accumulation_all_windows,
+ ba_structs[0],
+ ba_structs[1],
+ ba_structs[2],
+ name=f"{class_name}_bucket_accumulation_all_windows_{kernel_hash}",
+ )
+ store_jax_executable(
+ self._bucket_reduction,
+ br_structs[0],
+ br_structs[1],
+ br_structs[2],
+ name=f"{class_name}_bucket_reduction_{kernel_hash}",
+ )
+ store_jax_executable(
+ self._window_merge,
+ wm_structs[0],
+ name=f"{class_name}_window_merge_{kernel_hash}",
+ )
+
+ def compile(self, parameters: Optional[dict] = None):
+ ba_structs = self._get_ba_shape_dtype_structs()
+ br_structs = self._get_br_shape_dtype_structs()
+ wm_structs = self._get_wm_shape_dtype_structs()
+ kernel_hash = hash_args(
+ self.context_hash(), parameters if parameters else {}
+ )
+ class_name = self.__class__.__name__
+
+ ba_kernel = load_jax_executable(
+ f"{class_name}_bucket_accumulation_all_windows_{kernel_hash}"
+ )
+ br_kernel = load_jax_executable(
+ f"{class_name}_bucket_reduction_{kernel_hash}"
+ )
+ wm_kernel = load_jax_executable(f"{class_name}_window_merge_{kernel_hash}")
+
+ if None in [ba_kernel, br_kernel, wm_kernel]:
+ warnings.warn(
+ "Not found stored serialized compiled kernels for MSM, compiling...",
+ UserWarning,
+ stacklevel=2,
+ )
+
+ ba_hash = hash_args(
+ ba_structs[0].shape,
+ ba_structs[0].dtype.__str__(),
+ ba_structs[0].shape,
+ ba_structs[0].dtype.__str__(),
+ )
+ br_hash = hash_args(
+ br_structs[0].shape,
+ br_structs[0].dtype.__str__(),
+ br_structs[0].shape,
+ br_structs[0].dtype.__str__(),
+ )
+ wm_hash = hash_args(
+ wm_structs[0].shape,
+ wm_structs[0].dtype.__str__(),
+ wm_structs[0].shape,
+ wm_structs[0].dtype.__str__(),
+ )
+ self.compiled_kernels.setdefault(ba_hash, {})[
+ "bucket_accumulation_all_windows"
+ ] = (
+ ba_kernel
+ if ba_kernel is not None
+ else jax_jit_lower_compile(
+ self._bucket_accumulation_all_windows,
+ ba_structs[0],
+ ba_structs[1],
+ ba_structs[2],
+ )
+ )
+ self.compiled_kernels.setdefault(br_hash, {})["bucket_reduction"] = (
+ br_kernel
+ if br_kernel is not None
+ else jax_jit_lower_compile(
+ self._bucket_reduction, br_structs[0], br_structs[1], br_structs[2]
+ )
+ )
+ self.compiled_kernels.setdefault(wm_hash, {})["window_merge"] = (
+ wm_kernel
+ if wm_kernel is not None
+ else jax_jit_lower_compile(self._window_merge, wm_structs[0])
+ )
+ self.use_compiled_kernels = True
+
+ def bucket_accumulation_all_windows(
+ self, all_buckets, regular_points, last_window_points
+ ):
+ if self.use_compiled_kernels:
+ kernel_hash = hash_args(
+ all_buckets.shape,
+ all_buckets.dtype.__str__(),
+ all_buckets.shape,
+ all_buckets.dtype.__str__(),
+ )
+ return self.compiled_kernels[kernel_hash][
+ "bucket_accumulation_all_windows"
+ ](all_buckets, regular_points, last_window_points)
+ else:
+ return self._bucket_accumulation_all_windows_opt(
+ all_buckets, regular_points, last_window_points
+ )
+
+ def bucket_reduction(self, all_buckets, temp_sum, window_sum):
+ if self.use_compiled_kernels:
+ kernel_hash = hash_args(
+ all_buckets.shape,
+ all_buckets.dtype.__str__(),
+ all_buckets.shape,
+ all_buckets.dtype.__str__(),
+ )
+ return self.compiled_kernels[kernel_hash]["bucket_reduction"](
+ all_buckets, temp_sum, window_sum
+ )
+ else:
+ return self._bucket_reduction(all_buckets, temp_sum, window_sum)
+
+ def window_merge(self, window_sum):
+ if self.use_compiled_kernels:
+ kernel_hash = hash_args(
+ window_sum.shape,
+ window_sum.dtype.__str__(),
+ window_sum.shape,
+ window_sum.dtype.__str__(),
+ )
+ return self.compiled_kernels[kernel_hash]["window_merge"](window_sum)
+ else:
+ return self._window_merge(window_sum)
+
+ def multiscalar_multiply(self, tiled_slices: jnp.ndarray):
+
+ for tile_index in range(self.tile_num):
+ idx_start = tile_index * self.tile_length
+ idx_end = idx_start + self.tile_length
+ tiled_slices_tile = tiled_slices[tile_index]
+ # tiled_points_tile = self.points[:, idx_start:idx_end].transpose(1, 0, 2)
+ tiled_points_tile = self.points[idx_start:idx_end]
+ regular_buckets, last_window_buckets = self.distribute_buckets(
+ tiled_slices_tile, tiled_points_tile
+ )
+ if self.use_sharding:
+ regular_buckets = regular_buckets.to_device(
+ self.ba_regular_points_sharding
+ )
+ last_window_buckets = last_window_buckets.to_device(
+ self.ba_special_points_sharding
+ )
+ else:
+ regular_buckets = jax.device_put(regular_buckets, jax.devices()[0])
+ last_window_buckets = jax.device_put(
+ last_window_buckets, jax.devices()[0]
+ )
+ self.all_buckets = self.bucket_accumulation_all_windows(
+ self.all_buckets, regular_buckets, last_window_buckets
+ )
+
+ window_sum = self.bucket_reduction(
+ self.all_buckets, self.temp_sum, self.window_sum
+ )
+ result = self.window_merge(window_sum)
+ return result
+
+
+class TPUDistributionMSMContext(CPUDistributionMSMContext):
+ """MSM context where bucket distribution runs on TPU using a sort-based
+
+ bucketize algorithm instead of a CPU C kernel.
+
+ Distribution becomes a major TPU kernel alongside bucket_accumulation,
+ bucket_reduction, and window_merge — all four are compiled by `compile()`
+ and dispatched the same way.
+ """
+
+ def __init__(self, parameters: dict):
+ MultiscalarMultiplicationContextBase.__init__(self, parameters)
+ self._init_config_parameters()
+ self._init_jax_data()
+ self._init_kernels()
+ self._init_point_parameters()
+ JaxKernelContextBase.__init__(self)
+
+ def _init_kernels(self):
+ expected_regular_bucket_size = self.tile_length / self.bucket_num_per_window
+ expected_special_bucket_size = math.ceil(
+ self.tile_length / self.bucket_num_duplication
+ )
+ self.expend_ratio = self.parameters["c_kernel_ret_space_ratio"]
+ self.regular_bucket_size = int(
+ expected_regular_bucket_size * self.expend_ratio
+ )
+ self.special_bucket_size = int(
+ expected_special_bucket_size * self.expend_ratio
+ )
+
+ self.ba_input_regular_shape = (
+ self.window_num - 1,
+ self.bucket_num_per_window,
+ self.regular_bucket_size,
+ self.coordinate_dim,
+ self.moduli_num,
+ )
+ self.ba_input_special_shape = (
+ self.bucket_num_last_window,
+ self.special_duplication_ratio,
+ self.special_bucket_size,
+ self.coordinate_dim,
+ self.moduli_num,
+ )
+ self.dist_input_slices_shape = (self.window_num, self.tile_length)
+ self.dist_input_points_shape = (
+ self.tile_length,
+ self.coordinate_dim,
+ self.moduli_num,
+ )
+
+ def _init_point_parameters(self):
+ raw_points = utils.read_external_msm_file(
+ self.parameters.get("points_path"), "points"
+ )
+ points = self.ec_ctx.to_computational_format(raw_points)
+ points = points.transpose(1, 0, 2) # (N, coordinate_dim, moduli_num)
+ self.points = jax.device_put(points, jax.devices()[0])
+
+ def _bucketize_regular_windows(
+ self, items: jnp.ndarray, bucket_ids: jnp.ndarray
+ ) -> jnp.ndarray:
+ """Sort-based bucketization for regular windows.
+
+ Args:
+ items: Points to distribute (tile_length, coordinate_dim, moduli_num).
+ bucket_ids: Slice values for the regular windows (window_num - 1,
+ tile_length).
+
+ Returns:
+ Bucketed points (window_num - 1, bucket_num_per_window,
+ regular_bucket_size,
+ coordinate_dim, moduli_num).
+ """
+ n = self.tile_length
+ b = self.window_num - 1
+
+ bucket_ids_i32 = bucket_ids.astype(jnp.int32)
+ sorted_indices = jnp.argsort(bucket_ids_i32, axis=1, stable=True)
+ sorted_buckets = jnp.take_along_axis(bucket_ids_i32, sorted_indices, axis=1)
+
+ is_boundary = jnp.concatenate(
+ [
+ jnp.ones_like(sorted_buckets[:, :1], dtype=jnp.bool_),
+ sorted_buckets[:, 1:] != sorted_buckets[:, :-1],
+ ],
+ axis=1,
+ )
+ positions = jnp.broadcast_to(jnp.arange(n, dtype=jnp.int32), (b, n))
+ boundary_positions = jnp.where(is_boundary, positions, jnp.int32(0))
+ bucket_starts = jax.lax.associative_scan(
+ jnp.maximum, boundary_positions, axis=1
+ )
+ within_rank = positions - bucket_starts
+
+ batch_idx = jnp.broadcast_to(
+ jnp.arange(b, dtype=jnp.int32)[:, None], (b, n)
+ )
+ zeros_bn = jnp.zeros((b, n), dtype=jnp.int32)
+ inv_rank = zeros_bn.at[batch_idx, sorted_indices].set(within_rank)
+ items_b = jnp.broadcast_to(items[None, :, :, :], (b, n, 4, 32))
+
+ # items_sorted = items[sorted_indices]
+
+ output = jnp.broadcast_to(
+ self.zero_point,
+ (
+ b,
+ self.bucket_num_per_window,
+ self.regular_bucket_size,
+ self.coordinate_dim,
+ self.moduli_num,
+ ),
+ ).copy()
+ # output = output.at[batch_idx, sorted_buckets, within_rank].set(items_sorted)
+ output = output.at[batch_idx, bucket_ids_i32, inv_rank].set(items_b)
+ return output
+
+ def _bucketize_last_window(
+ self, items: jnp.ndarray, last_slices: jnp.ndarray
+ ) -> jnp.ndarray:
+ """Sort-based bucketization for the last (special) window with duplication.
+
+ Each logical bucket s in [0, bucket_num_last_window) is replicated
+ special_duplication_ratio times. Item i is mapped to duplicate (i mod sdup),
+ spreading items across the duplicates so the per-slot capacity stays at
+ special_bucket_size. Reduction back to logical buckets is handled later by
+ _bucket_accumulation_last_window.
+
+ Args:
+ items: Points to distribute (tile_length, coordinate_dim, moduli_num).
+ last_slices: Slice values for the last window (tile_length,).
+
+ Returns:
+ Bucketed points (bucket_num_last_window, special_duplication_ratio,
+ special_bucket_size, coordinate_dim, moduli_num).
+ """
+ n = self.tile_length
+ sdup = self.special_duplication_ratio
+
+ dup_ids = jnp.arange(n, dtype=jnp.int32) % sdup
+ bucket_ids_i32 = last_slices.astype(jnp.int32) * sdup + dup_ids
+
+ sorted_indices = jnp.argsort(bucket_ids_i32, stable=True)
+ sorted_buckets = bucket_ids_i32[sorted_indices]
+
+ is_boundary = jnp.concatenate([
+ jnp.ones((1,), dtype=jnp.bool_),
+ sorted_buckets[1:] != sorted_buckets[:-1],
+ ])
+ positions = jnp.arange(n, dtype=jnp.int32)
+ boundary_positions = jnp.where(is_boundary, positions, jnp.int32(0))
+ bucket_starts = jax.lax.associative_scan(jnp.maximum, boundary_positions)
+ within_rank = positions - bucket_starts
+
+ zeros_n = jnp.zeros((n,), dtype=jnp.int32)
+ inv_rank = zeros_n.at[sorted_indices].set(within_rank)
+
+ output = jnp.broadcast_to(
+ self.zero_point,
+ (
+ self.bucket_num_duplication,
+ self.special_bucket_size,
+ self.coordinate_dim,
+ self.moduli_num,
+ ),
+ ).copy()
+ output = output.at[bucket_ids_i32, inv_rank].set(items)
+ return output.reshape(
+ self.bucket_num_last_window,
+ self.special_duplication_ratio,
+ self.special_bucket_size,
+ self.coordinate_dim,
+ self.moduli_num,
+ )
+
+ def _distribute_buckets(
+ self,
+ regular_slices: jnp.ndarray,
+ last_window_slice: jnp.ndarray,
+ tiled_points: jnp.ndarray,
+ ):
+ """Major TPU kernel: distribute one tile of points into buckets.
+
+ Args:
+ regular_slices: Slice values for regular windows (window_num - 1,
+ tile_length).
+ last_window_slice: Slice values for the last window (tile_length,).
+ tiled_points: Points for one tile (tile_length, coordinate_dim,
+ moduli_num).
+
+ Returns:
+ regular_buckets: (window_num - 1, bucket_num_per_window,
+ regular_bucket_size,
+ coordinate_dim, moduli_num).
+ last_window_buckets: (bucket_num_last_window, special_duplication_ratio,
+ special_bucket_size, coordinate_dim, moduli_num).
+ """
+ regular_buckets = self._bucketize_regular_windows(
+ tiled_points, regular_slices
+ )
+ last_window_buckets = self._bucketize_last_window(
+ tiled_points, last_window_slice
+ )
+ return regular_buckets, last_window_buckets
+
+ def _get_dist_shape_dtype_structs(self):
+ """Get ShapeDtypeStructs for _distribute_buckets inputs."""
+ regular_slices_shape = (self.window_num - 1, self.tile_length)
+ last_window_slice_shape = (self.tile_length,)
+ return [
+ jax.ShapeDtypeStruct(regular_slices_shape, jnp.int32),
+ jax.ShapeDtypeStruct(last_window_slice_shape, jnp.int32),
+ jax.ShapeDtypeStruct(self.dist_input_points_shape, jnp.uint32),
+ ]
+
+ def compile(self, parameters: Optional[dict] = None):
+ dist_structs = self._get_dist_shape_dtype_structs()
+ ba_structs = self._get_ba_shape_dtype_structs()
+ br_structs = self._get_br_shape_dtype_structs()
+ wm_structs = self._get_wm_shape_dtype_structs()
+
+ dist_hash = hash_args(
+ dist_structs[0].shape,
+ dist_structs[0].dtype.__str__(),
+ dist_structs[1].shape,
+ dist_structs[1].dtype.__str__(),
+ dist_structs[2].shape,
+ dist_structs[2].dtype.__str__(),
+ )
+ ba_hash = hash_args(
+ ba_structs[0].shape,
+ ba_structs[0].dtype.__str__(),
+ ba_structs[0].shape,
+ ba_structs[0].dtype.__str__(),
+ )
+ br_hash = hash_args(
+ br_structs[0].shape,
+ br_structs[0].dtype.__str__(),
+ br_structs[0].shape,
+ br_structs[0].dtype.__str__(),
+ )
+ wm_hash = hash_args(
+ wm_structs[0].shape,
+ wm_structs[0].dtype.__str__(),
+ wm_structs[0].shape,
+ wm_structs[0].dtype.__str__(),
+ )
+
+ self.compiled_kernels.setdefault(dist_hash, {})["distribute_buckets"] = (
+ jax_jit_lower_compile(
+ self._distribute_buckets,
+ dist_structs[0],
+ dist_structs[1],
+ dist_structs[2],
+ )
+ )
+ self.compiled_kernels.setdefault(ba_hash, {})[
+ "bucket_accumulation_all_windows"
+ ] = jax_jit_lower_compile(
+ self._bucket_accumulation_all_windows_opt,
+ ba_structs[0],
+ ba_structs[1],
+ ba_structs[2],
+ )
+ self.compiled_kernels.setdefault(br_hash, {})["bucket_reduction"] = (
+ jax_jit_lower_compile(
+ self._bucket_reduction, br_structs[0], br_structs[1], br_structs[2]
+ )
+ )
+ self.compiled_kernels.setdefault(wm_hash, {})["window_merge"] = (
+ jax_jit_lower_compile(self._window_merge, wm_structs[0])
+ )
+ self.use_compiled_kernels = True
+
+ def distribute_buckets(self, regular_slices, last_window_slice, tiled_points):
+ if self.use_compiled_kernels:
+ kernel_hash = hash_args(
+ regular_slices.shape,
+ regular_slices.dtype.__str__(),
+ last_window_slice.shape,
+ last_window_slice.dtype.__str__(),
+ tiled_points.shape,
+ tiled_points.dtype.__str__(),
+ )
+ return self.compiled_kernels[kernel_hash]["distribute_buckets"](
+ regular_slices, last_window_slice, tiled_points
+ )
+ else:
+ return self._distribute_buckets(
+ regular_slices, last_window_slice, tiled_points
+ )
+
+ def multiscalar_multiply(self, tiled_slices: jnp.ndarray):
+ tpu = jax.devices()[0]
+ tiled_slices = jax.device_put(tiled_slices, tpu)
+
+ for tile_index in range(self.tile_num):
+ idx_start = tile_index * self.tile_length
+ idx_end = idx_start + self.tile_length
+ tiled_slices_tile = tiled_slices[tile_index]
+ regular_slices_tile = tiled_slices_tile[: self.window_num - 1]
+ last_window_slice_tile = tiled_slices_tile[self.window_num - 1]
+ tiled_points_tile = self.points[idx_start:idx_end]
+ regular_buckets, last_window_buckets = self.distribute_buckets(
+ regular_slices_tile, last_window_slice_tile, tiled_points_tile
+ )
+ self.all_buckets = self.bucket_accumulation_all_windows(
+ self.all_buckets, regular_buckets, last_window_buckets
+ )
+
+ window_sum = self.bucket_reduction(
+ self.all_buckets, self.temp_sum, self.window_sum
+ )
+ result = self.window_merge(window_sum)
+ return result
+
+
+class FusionMSMContext(
+ MultiscalarMultiplicationContextBase, JaxKernelContextBase
+):
+ """MSM context that fuses bucket distribution and accumulation on TPU."""
+
+ def __init__(self, parameters: dict):
+ super().__init__(parameters)
+ MultiscalarMultiplicationContextBase.__init__(self, parameters)
+ self._init_config_parameters()
+ self._init_jax_data()
+ self._init_point_parameters()
+ JaxKernelContextBase.__init__(self)
+ self.use_fused = False
+
+ def _init_config_parameters(self):
+ self.coordinate_dim = self.parameters.get("coordinate_dim", 4)
+ self.msm_length = self.parameters.get("msm_length")
+ self.tile_length = self.parameters.get("tile_length")
+ assert self.msm_length is not None
+ assert self.tile_length is not None
+ assert (
+ self.msm_length % self.tile_length == 0
+ and self.msm_length >= self.tile_length
+ ), (
+ "msm_length must be divisible by tile_length and greater than or equal"
+ " to tile_length"
+ )
+
+ self.tile_num = self.msm_length // self.tile_length
+ self.slice_bits = self.parameters.get("slice_bits")
+ if self.slice_bits != 15:
+ warnings.warn(
+ f"Slice bits {self.slice_bits} may cause performance issue, using 15"
+ " instead."
+ )
+ if 2**self.slice_bits > self.tile_length:
+ warnings.warn(
+ f"2**{self.slice_bits} is greater than tile_length, which may cause"
+ " performance issue."
+ )
+ self.scalar_bits = self.parameters.get("scalar_bits")
+ self.order = self.parameters.get("order")
+ self.window_num = int(math.ceil(self.scalar_bits / self.slice_bits)) #
+ self.batch_window_num = self.window_num
+ self.bucket_num_per_window = 2**self.slice_bits # Note: Include Bucket 0
+ orig_bucket_num_last_window = (
+ self.order >> ((self.window_num - 1) * self.slice_bits)
+ ) + 1 # Note: Include Bucket 0
+ # Now pad to nearest value so bucket_num_last_window % 8 == 0
+ added_padding = (8 - orig_bucket_num_last_window % 8) % 8
+ if added_padding > 0:
+ if added_padding / orig_bucket_num_last_window > 0.1:
+ print(
+ f"[bucket_num_last_window] Added {added_padding} to make it"
+ " divisible by 8, but it is too large for"
+ f" {orig_bucket_num_last_window}. Setting to 0."
+ )
+ added_padding = 0
+ else:
+ print(
+ f"[bucket_num_last_window] Was {orig_bucket_num_last_window}, "
+ f"added {added_padding} to make it divisible by 8"
+ )
+ self.bucket_num_last_window = orig_bucket_num_last_window + added_padding
+ self.moduli_num = self.ec_ctx.get_finite_field_context().get_moduli_num()
+
+ # Special bucket optimization
+ self.log_special_duplication_ratio = math.ceil(
+ math.log2(self.bucket_num_per_window / self.bucket_num_last_window)
+ )
+ self.special_duplication_ratio = 2**self.log_special_duplication_ratio
+ self.bucket_num_duplication = (
+ self.bucket_num_last_window * self.special_duplication_ratio
+ )
+ expected_regular_bucket_size = self.tile_length / self.bucket_num_per_window
+ expected_special_bucket_size = math.ceil(
+ self.tile_length / self.bucket_num_duplication
+ )
+ self.expend_ratio = self.parameters["c_kernel_ret_space_ratio"]
+ self.regular_bucket_size = int(
+ expected_regular_bucket_size * self.expend_ratio
+ )
+ self.special_bucket_size = int(
+ expected_special_bucket_size * self.expend_ratio
+ )
+
+ def _init_point_parameters(self):
+ if self.parameters.get("points_path") is not None:
+ raw_points = utils.read_external_msm_file(
+ self.parameters.get("points_path"), "points"
+ )
+ points = self.ec_ctx.to_computational_format(raw_points)
+ points = points.transpose(1, 0, 2) # (N, coordinate_dim, moduli_num)
+ else:
+ points = jax.random.randint(
+ jax.random.PRNGKey(0),
+ (self.msm_length, self.coordinate_dim, self.moduli_num),
+ 0,
+ 2**16,
+ dtype=jnp.uint32,
+ )
+ # Reshape once here into per-tile layout so both execution paths index
+ # tiles with a single leading-axis lookup (self.points[tile_index]) and
+ # the fused path can pass self.points directly as tiled_points.
+ self.points = points.reshape(
+ self.tile_num,
+ self.tile_length,
+ self.coordinate_dim,
+ self.moduli_num,
+ )
+
+ def _init_jax_data(self):
+ self.zero_point = (
+ self.ec_ctx.get_finite_field_context().to_computational_format(
+ self.ec_ctx.zero_point
+ )
+ )
+ all_buckets = jnp.broadcast_to(
+ self.zero_point.reshape(self.coordinate_dim, 1, 1, self.moduli_num),
+ (
+ self.coordinate_dim,
+ self.batch_window_num,
+ self.bucket_num_per_window,
+ self.moduli_num,
+ ),
+ )
+ self.regular_window_buckets = all_buckets[:, : self.window_num - 1]
+ self.last_window_buckets = all_buckets[:, self.window_num - 1]
+ self.window_sum = jnp.broadcast_to(
+ self.zero_point.reshape(self.coordinate_dim, 1, 1, self.moduli_num),
+ (
+ self.coordinate_dim,
+ self.bucket_num_per_window,
+ self.batch_window_num,
+ self.moduli_num,
+ ),
+ )
+ self.fused_reg_slices_padded = None
+ self.fused_reg_slices_sharding = None
+ self.fused_last_slices_padded = None
+ self.fused_last_slices_sharding = None
+ self.fused_tiled_points_padded = None
+ self.fused_tiled_points_sharding = None
+ self.bba_reg_slices_padded = None
+ self.bba_reg_slices_sharding = None
+ self.bba_last_slice_padded = None
+ self.bba_last_slice_sharding = None
+ self.bba_points_padded = None
+ self.bba_points_sharding = None
+ self.bba_reg_buckets_padded = None
+ self.bba_reg_buckets_sharding = None
+ self.bba_last_buckets_padded = None
+ self.bba_last_buckets_sharding = None
+ self.bws_reg_buckets_padded = None
+ self.bws_reg_buckets_sharding = None
+ self.bws_last_buckets_padded = None
+ self.bws_last_buckets_sharding = None
+ self.bws_window_sum_padded = None
+ self.bws_window_sum_sharding = None
+
+ def _preprocess_scalars(self, scalars: list):
+ tiled_scalar_list = utils.split_list(scalars, self.tile_length)
+ tiled_slices_list = []
+ for tiled_scalars in tiled_scalar_list:
+ sliced_scalars = utils.slice_scalars(
+ tiled_scalars, self.scalar_bits, self.slice_bits
+ )
+ tiled_slices_list.append(sliced_scalars)
+ tiled_slices_list = jnp.array(tiled_slices_list, dtype=jnp.int32)
+ return tiled_slices_list
+
+ def to_original_format(self, a: jnp.ndarray) -> list:
+ return self.ec_ctx.to_original_format(a)
+
+ def to_computational_format(
+ self, scalars: Optional[list] = None
+ ) -> jnp.ndarray:
+ if scalars is None:
+ warnings.warn("No scalars provided, generating random scalars.")
+ scalars = [
+ random.randint(0, self.order - 1) for _ in range(self.msm_length)
+ ]
+ tiled_slices = self._preprocess_scalars(scalars)
+
+ # When sharding is enabled, pad tiled_slices along the tile_length axis
+ # once here so every per-tile slice (regular + last) already matches the
+ # padded shape expected by the compiled BBA kernel. The hot loop can then
+ # just `to_device(...)` without any shape-time work.
+ if self.use_sharding:
+ padded_tile_length = self.bba_reg_slices_padded[1]
+ if tiled_slices.shape[-1] != padded_tile_length:
+ target_shape = (
+ tiled_slices.shape[0],
+ tiled_slices.shape[1],
+ padded_tile_length,
+ )
+ warnings.warn(
+ f"Tiled slices shape {tiled_slices.shape} does not match padded"
+ f" tile length {padded_tile_length}, padding to {target_shape}."
+ )
+ tiled_slices = utils.pad_jax_array(tiled_slices, target_shape)
+ return tiled_slices
+
+ def _bucketize_regular_windows(
+ self, items: jnp.ndarray, bucket_ids: jnp.ndarray
+ ) -> jnp.ndarray:
+ """Sort-based bucketization for regular windows.
+
+ Args:
+ items: Points to distribute (tile_length, coordinate_dim, moduli_num).
+ bucket_ids: Slice values for the regular windows (window_num - 1,
+ tile_length).
+
+ Returns:
+ Bucketed points (regular_bucket_size, coordinate_dim,
+ window_num - 1, bucket_num_per_window, moduli_num).
+ """
+ n = self.tile_length
+ b = self.window_num - 1
+
+ bucket_ids_i32 = bucket_ids.astype(jnp.int32)
+ sorted_indices = jnp.argsort(bucket_ids_i32, axis=1, stable=True)
+ sorted_buckets = jnp.take_along_axis(bucket_ids_i32, sorted_indices, axis=1)
+
+ is_boundary = jnp.concatenate(
+ [
+ jnp.ones_like(sorted_buckets[:, :1], dtype=jnp.bool_),
+ sorted_buckets[:, 1:] != sorted_buckets[:, :-1],
+ ],
+ axis=1,
+ )
+ positions = jnp.broadcast_to(jnp.arange(n, dtype=jnp.int32), (b, n))
+ boundary_positions = jnp.where(is_boundary, positions, jnp.int32(0))
+ bucket_starts = jax.lax.associative_scan(
+ jnp.maximum, boundary_positions, axis=1
+ )
+ within_rank = positions - bucket_starts
+
+ batch_idx = jnp.broadcast_to(
+ jnp.arange(b, dtype=jnp.int32)[:, None], (b, n)
+ )
+ zeros_bn = jnp.zeros((b, n), dtype=jnp.int32)
+ sharding = jax.typeof(sorted_indices).sharding
+ zeros_bn = jax.device_put(zeros_bn, sharding)
+ batch_idx = jax.device_put(batch_idx, sharding)
+ inv_rank = zeros_bn.at[batch_idx, sorted_indices].set(
+ within_rank, out_sharding=sharding
+ )
+ items_b = jnp.broadcast_to(
+ items[None, :, :, :], (b, n, self.coordinate_dim, self.moduli_num)
+ )
+
+ output = jnp.broadcast_to(
+ self.zero_point.reshape(1, self.coordinate_dim, 1, 1, self.moduli_num),
+ (
+ self.regular_bucket_size,
+ self.coordinate_dim,
+ b,
+ self.bucket_num_per_window,
+ self.moduli_num,
+ ),
+ ).copy()
+ sharding_out = jax.typeof(output).sharding
+ output = output.at[inv_rank, :, batch_idx, bucket_ids_i32].set(
+ items_b, out_sharding=sharding_out
+ )
+ return output
+
+ def _bucketize_last_window(
+ self, items: jnp.ndarray, last_slices: jnp.ndarray
+ ) -> jnp.ndarray:
+ """Sort-based bucketization for the last (special) window with duplication.
+
+ Each logical bucket s in [0, bucket_num_last_window) is replicated
+ special_duplication_ratio times. Item i is mapped to duplicate (i mod sdup),
+ spreading items across the duplicates so the per-slot capacity stays at
+ special_bucket_size. Reduction back to logical buckets is handled later by
+ _bucket_accumulation_last_window.
+
+ Args:
+ items: Points to distribute (tile_length, coordinate_dim, moduli_num).
+ last_slices: Slice values for the last window (tile_length,).
+
+ Returns:
+ Bucketed points (special_bucket_size, coordinate_dim,
+ special_duplication_ratio, bucket_num_last_window,
+ moduli_num).
+ """
+ n = self.tile_length
+ sdup = self.special_duplication_ratio
+
+ dup_ids = jnp.arange(n, dtype=jnp.int32) % sdup
+ bucket_ids_i32 = dup_ids * self.bucket_num_last_window + last_slices.astype(
+ jnp.int32
+ )
+
+ sorted_indices = jnp.argsort(bucket_ids_i32, stable=True)
+ sorted_buckets = bucket_ids_i32[sorted_indices]
+
+ is_boundary = jnp.concatenate([
+ jnp.ones((1,), dtype=jnp.bool_),
+ sorted_buckets[1:] != sorted_buckets[:-1],
+ ])
+ positions = jnp.arange(n, dtype=jnp.int32)
+ boundary_positions = jnp.where(is_boundary, positions, jnp.int32(0))
+ bucket_starts = jax.lax.associative_scan(jnp.maximum, boundary_positions)
+ within_rank = positions - bucket_starts
+
+ zeros_n = jnp.zeros((n,), dtype=jnp.int32)
+ inv_rank = zeros_n.at[sorted_indices].set(within_rank)
+
+ output = jnp.broadcast_to(
+ self.zero_point.reshape(1, self.coordinate_dim, 1, self.moduli_num),
+ (
+ self.special_bucket_size,
+ self.coordinate_dim,
+ self.bucket_num_duplication,
+ self.moduli_num,
+ ),
+ ).copy()
+ output = output.at[inv_rank, :, bucket_ids_i32].set(items)
+ return output.reshape(
+ self.special_bucket_size,
+ self.coordinate_dim,
+ self.special_duplication_ratio,
+ self.bucket_num_last_window,
+ self.moduli_num,
+ )
+
+ def _distribute_buckets(
+ self,
+ regular_slices: jnp.ndarray,
+ last_window_slice: jnp.ndarray,
+ tiled_points: jnp.ndarray,
+ ):
+ """Major TPU kernel: distribute one tile of points into buckets.
+
+ Args:
+ regular_slices: Slice values for regular windows (window_num - 1,
+ tile_length).
+ last_window_slice: Slice values for the last window (tile_length,).
+ tiled_points: Points for one tile (tile_length, coordinate_dim,
+ moduli_num).
+
+ Returns:
+ regular_buckets: (regular_bucket_size, coordinate_dim,
+ window_num - 1, bucket_num_per_window, moduli_num).
+ last_window_buckets: (special_bucket_size, coordinate_dim,
+ special_duplication_ratio, bucket_num_last_window,
+ moduli_num).
+ """
+ regular_buckets = self._bucketize_regular_windows(
+ tiled_points, regular_slices
+ )
+ last_window_buckets = self._bucketize_last_window(
+ tiled_points, last_window_slice
+ )
+ return regular_buckets, last_window_buckets
+
+ def _bucket_accumulation_regular_windows_2d_parallel(
+ self,
+ regular_buckets: jnp.ndarray,
+ all_points: jnp.ndarray,
+ ) -> jnp.ndarray:
+ """Accumulate points for all regular windows.
+
+ Args:
+ regular_buckets: Initial bucket values (coordinate_dim, window_dim,
+ bucket_dim, precision_dim).
+ all_points: Points for all windows (bucket_size_dim, coordinate_dim,
+ window_dim, bucket_dim, precision_dim).
+
+ Returns:
+ Accumulated bucket values for all regular windows.
+ """
+ bucket_size_dim = all_points.shape[0]
+
+ def scan_body(buckets, points):
+ buckets = self._padd(buckets, points)
+ return buckets, None
+
+ buckets, _ = jax.lax.scan(
+ scan_body, regular_buckets, all_points, length=bucket_size_dim
+ )
+ return buckets
+
+ def _bucket_accumulation_last_window(
+ self,
+ buckets_in: jnp.ndarray,
+ window_points: jnp.ndarray,
+ ) -> jnp.ndarray:
+ """Accumulate points for the last (special) window with duplication handling.
+
+ Args:
+ buckets_in: Initial bucket values (coordinate_dim, bucket_dim,
+ precision_dim).
+ window_points: Points with duplication (bucket_size_dim, coordinate_dim,
+ bucket_dup_dim, bucket_dim, precision_dim).
+ parameters: Computation parameters.
+
+ Returns:
+ Accumulated bucket values.
+ """
+ coordinate_dim, buckets_dim, precision_dim = buckets_in.shape
+ bucket_size_dim, _, bucket_dup_dim, _, _ = window_points.shape
+
+ # Reshape for processing
+ window_points = window_points.reshape(
+ bucket_size_dim, coordinate_dim, -1, precision_dim
+ )
+ base_dup_buckets = window_points[0]
+
+ def scan_body(buckets, points):
+ buckets = self._padd(buckets, points)
+ return buckets, None
+
+ dup_buckets, _ = jax.lax.scan(
+ scan_body,
+ base_dup_buckets,
+ window_points[1:],
+ length=bucket_size_dim - 1,
+ )
+
+ # Reduce duplicated buckets using tree reduction
+ log_bucket_dup_dim = int(math.log2(bucket_dup_dim))
+ for _ in range(log_bucket_dup_dim):
+ buckets_split = jnp.split(dup_buckets, 2, axis=1)
+ dup_buckets = self._padd(buckets_split[0], buckets_split[1])
+
+ # Add to input buckets
+ buckets_in = self._padd(buckets_in, dup_buckets)
+ return buckets_in
+
+ def _bucket_accumulation_all_windows(
+ self,
+ regular_window_buckets: jnp.ndarray,
+ last_window_buckets: jnp.ndarray,
+ regular_window_points: jnp.ndarray,
+ last_window_points: jnp.ndarray,
+ ) -> tuple[jnp.ndarray, jnp.ndarray]:
+ """Accumulate points for all windows with optimized regular window processing.
+
+ Uses the opt variant for regular windows, which processes all windows
+ simultaneously in a single scan instead of scanning window-by-window.
+
+ Args:
+ regular_buckets: Initial bucket values for regular windows (coord_dim,
+ window_dim-1, bucket_dim, prec_dim).
+ last_window_buckets: Initial bucket values for the last window
+ (coord_dim, bucket_dim, prec_dim).
+ regular_window_points: Points for regular windows (bucket_size,
+ coord_dim, window_dim-1, bucket_num, prec_dim).
+ last_window_points: Points for last window (bucket_size, coord_dim,
+ dup_ratio, bucket_num, prec_dim).
+
+ Returns:
+ (regular_buckets, last_window_buckets): Accumulated bucket values,
+ same shapes as the inputs.
+ """
+ _, _, last_window_bucket_dup, last_window_bucket_dim, _ = (
+ last_window_points.shape
+ )
+
+ regular_window_buckets = (
+ self._bucket_accumulation_regular_windows_2d_parallel(
+ regular_window_buckets, regular_window_points
+ )
+ )
+
+ last_point_buckets = last_window_buckets[:, :last_window_bucket_dim, :]
+ last_blank_buckets = last_window_buckets[:, last_window_bucket_dim:, :]
+ last_point_buckets = self._bucket_accumulation_last_window(
+ last_point_buckets, last_window_points
+ )
+
+ last_window_buckets = jnp.concatenate(
+ (last_point_buckets, last_blank_buckets), axis=1
+ )
+
+ # Combine all buckets
+ return regular_window_buckets, last_window_buckets
+
+ def _bucketize_and_bucket_accumulation(
+ self,
+ regular_window_slices: jnp.ndarray,
+ last_window_slices: jnp.ndarray,
+ tiled_points: jnp.ndarray,
+ regular_window_buckets: jnp.ndarray,
+ last_window_buckets: jnp.ndarray,
+ ) -> tuple[jnp.ndarray, jnp.ndarray]:
+ regular_window_points = self._bucketize_regular_windows(
+ tiled_points, regular_window_slices
+ )
+ last_window_points = self._bucketize_last_window(
+ tiled_points, last_window_slices
+ )
+ regular_window_buckets, last_window_buckets = (
+ self._bucket_accumulation_all_windows(
+ regular_window_buckets,
+ last_window_buckets,
+ regular_window_points,
+ last_window_points,
+ )
+ )
+
+ return regular_window_buckets, last_window_buckets
+
+ def _bucket_reduction(
+ self,
+ regular_window_buckets: jnp.ndarray,
+ last_window_buckets: jnp.ndarray,
+ window_sum: jnp.ndarray,
+ ) -> jnp.ndarray:
+ """Reduce buckets to window sums using tree-based parallel algorithm.
+
+ Computes S = sum_{i=0}^{n-1} i * B[i] per window via adjacent-pair
+ tree reduction in O(log n) parallel steps.
+
+ Tracks H = 2^k * bucket_sums (doubling each level) so that H[1::2]
+ provides the correction weight directly. No separate bucket array needed.
+
+ Args:
+ all_buckets: Bucket values (window_dim, coordinate_dim, bucket_dim,
+ precision_dim).
+ window_sum: Initial window sum (coordinate_dim, bucket_dim, window_dim,
+ precision_dim).
+
+ Returns:
+ Window sums (coordinate_dim, window_dim, precision_dim).
+ """
+ all_buckets = jnp.concatenate(
+ (regular_window_buckets, last_window_buckets[:, jnp.newaxis]), axis=1
+ )
+ bucket_num_in_window = all_buckets.shape[2]
+ # (coord, window, bucket, prec) → (coord, bucket, window, prec)
+ all_buckets = all_buckets.transpose(0, 2, 1, 3)
+ iter_num = int(math.log2(bucket_num_in_window))
+
+ cd, m, wd, pd = all_buckets.shape
+ for _ in range(iter_num):
+ m = all_buckets.shape[1]
+ half = m // 2
+
+ all_buckets = all_buckets.reshape(cd, half, 2, wd, pd)
+ window_sum = window_sum.reshape(cd, half, 2, wd, pd)
+
+ all_buckets_left, all_buckets_right = (
+ all_buckets[:, :, 0],
+ all_buckets[:, :, 1],
+ )
+ window_sum_left, window_sum_right = (
+ window_sum[:, :, 0],
+ window_sum[:, :, 1],
+ )
+
+ window_sum = self._padd(
+ self._padd(window_sum_left, window_sum_right), all_buckets_right
+ )
+ bucket_sum = self._padd(all_buckets_left, all_buckets_right)
+ all_buckets = self._padd(bucket_sum, bucket_sum)
+
+ return window_sum[:, 0]
+
+ def _window_merge(self, window_sum: jnp.ndarray) -> jnp.ndarray:
+ """Merge window results into final MSM result using scan algorithm.
+
+ Implements the window merging phase of Pippenger's algorithm.
+
+ Args:
+ window_sum: Window sums (coordinate_dim, window_dim, precision_dim).
+ slice_length: Bit width of each window.
+ parameters: Computation parameters.
+
+ Returns:
+ Final MSM result (coordinate_dim, precision_dim).
+ """
+ coordinate_dim, window_dim, precision_dim = window_sum.shape
+ window_sum = window_sum.transpose(1, 0, 2).reshape(
+ (window_dim, coordinate_dim, 1, precision_dim)
+ )
+ result = window_sum[window_dim - 1]
+
+ def fori_loop_body(i, result):
+ result = self._padd(result, result)
+ return result
+
+ def scan_body(result, window_sum):
+ result = jax.lax.fori_loop(0, self.slice_bits, fori_loop_body, result)
+ result = self._padd(result, window_sum)
+ return result, None
+
+ result, _ = jax.lax.scan(
+ scan_body,
+ result,
+ window_sum[: window_dim - 1],
+ reverse=True,
+ length=window_dim - 1,
+ )
+ result = result.reshape((coordinate_dim, precision_dim))
+ return result
+
+ def _bba_and_bws(
+ self,
+ regular_window_slices: jnp.ndarray,
+ last_window_slices: jnp.ndarray,
+ tiled_points: jnp.ndarray,
+ regular_window_buckets: jnp.ndarray,
+ last_window_buckets: jnp.ndarray,
+ window_sum: jnp.ndarray,
+ ) -> jnp.ndarray:
+ """Fused kernel: run the last-tile BBA and then BWS in a single jit.
+
+ Saves a kernel launch + the BBA->BWS round-trip via the host/HBM: XLA
+ sees the accumulated buckets immediately feed into bucket reduction and
+ can overlap the window-axis -> bucket-axis reshard with the reduction
+ tree.
+ """
+ regular_window_buckets, last_window_buckets = (
+ self._bucketize_and_bucket_accumulation(
+ regular_window_slices,
+ last_window_slices,
+ tiled_points,
+ regular_window_buckets,
+ last_window_buckets,
+ )
+ )
+ # Sharding conversion at the BBA -> BWS boundary.
+ # BBA output: regular_window_buckets on axis[1] (window),
+ # last_window_buckets on axis[1] (bucket).
+ # BWS input : regular_window_buckets on axis[2] (bucket),
+ # last_window_buckets on axis[1] (bucket, unchanged).
+ # The last_window_buckets layout already matches BWS; only the
+ # regular buckets need a window-axis -> bucket-axis reshard.
+ if self.use_sharding:
+ P = jax.sharding.PartitionSpec
+ sharding = jax.sharding.NamedSharding(
+ self.sharding_mesh, P(None, None, self.mesh_axes, None)
+ )
+ regular_window_buckets = jax.device_put(regular_window_buckets, sharding)
+ return self._bucket_and_window_sum(
+ regular_window_buckets, last_window_buckets, window_sum
+ )
+
+ def _bucket_and_window_sum(
+ self,
+ regular_window_buckets: jnp.ndarray,
+ last_window_buckets: jnp.ndarray,
+ window_sum: jnp.ndarray,
+ ) -> jnp.ndarray:
+ bucket_summations = self._bucket_reduction(
+ regular_window_buckets, last_window_buckets, window_sum
+ )
+ # bucket_summations: (C, W, M). Window merge has no natural batch axis
+ # to shard along, so force replication here. Without this constraint,
+ # XLA may propagate BR's bucket-axis sharding forward and emit awkward
+ # collectives inside the window-merge scan.
+ if self.use_sharding:
+ P = jax.sharding.PartitionSpec
+ bucket_summations = self.shard_constraint(
+ bucket_summations, P(None, None, None)
+ )
+ window_summations = self._window_merge(bucket_summations)
+ return window_summations
+
+ # ----- Sharding setup -----
+
+ def _init_shardings(self):
+ """Pre-compute named shardings and padded shapes for BBA/BWS kernels.
+
+ Sharding plan (mirrors scratch_profile_msm_sharding.py references):
+
+ BBA (bucketize + bucket accumulation) — analogous to DA fused:
+ regular_slices (W-1, tile_length) -> axis[0] (window)
+ last_window_slice (tile_length,) -> replicated
+ tiled_points (tile_length, C, M) -> replicated
+ regular_window_buckets (C, W-1, B, M) -> axis[1] (window)
+ last_window_buckets (C, B, M) -> axis[1] (bucket)
+
+ BWS (bucket reduction + window merge) — analogous to BR_new_opt + WM:
+ regular_window_buckets (C, W-1, B, M) -> axis[2] (bucket)
+ last_window_buckets (C, B, M) -> axis[1] (bucket)
+ window_sum (C, B, W, M) -> axis[1] (bucket)
+
+ Note: regular_window_buckets is bucket-axis sharded at BWS entry but
+ window-axis sharded during BBA. XLA inserts the reshard at the BWS
+ boundary.
+
+ Any axis whose extent is not divisible by the mesh size triggers a
+ warning (from create_named_sharding) and is padded up.
+ """
+ C = self.coordinate_dim
+ B = self.bucket_num_per_window
+ M = self.moduli_num
+ W = self.window_num
+ T = self.tile_length
+
+ # Shapes
+ reg_slices_shape = (W - 1, T)
+ last_slice_shape = (T,)
+ points_shape = (T, C, M)
+ reg_buckets_shape = (C, W - 1, B, M)
+ last_buckets_shape = (C, B, M)
+ window_sum_shape = (C, B, W, M)
+
+ def _check(name, shape, padded):
+ if tuple(shape) != tuple(padded):
+ warnings.warn(
+ f"[FusionMSMContext sharding] '{name}' shape {shape} not "
+ f"divisible by mesh; padded to {padded}.",
+ stacklevel=2,
+ )
+
+ # ----- BBA kernel shardings -----
+ self.bba_reg_slices_sharding, self.bba_reg_slices_padded = (
+ self.create_named_sharding(shape=reg_slices_shape, axes=[0])
+ )
+ _check("bba_regular_slices", reg_slices_shape, self.bba_reg_slices_padded)
+
+ self.bba_last_slice_sharding, self.bba_last_slice_padded = (
+ self.create_named_sharding(shape=last_slice_shape, axes=[])
+ )
+ _check(
+ "bba_last_window_slice", last_slice_shape, self.bba_last_slice_padded
+ )
+
+ self.bba_points_sharding, self.bba_points_padded = (
+ self.create_named_sharding(shape=points_shape, axes=[])
+ )
+ _check("bba_tiled_points", points_shape, self.bba_points_padded)
+
+ self.bba_reg_buckets_sharding, self.bba_reg_buckets_padded = (
+ self.create_named_sharding(shape=reg_buckets_shape, axes=[1])
+ )
+ _check(
+ "bba_regular_window_buckets",
+ reg_buckets_shape,
+ self.bba_reg_buckets_padded,
+ )
+
+ self.bba_last_buckets_sharding, self.bba_last_buckets_padded = (
+ self.create_named_sharding(shape=last_buckets_shape, axes=[1])
+ )
+ _check(
+ "bba_last_window_buckets",
+ last_buckets_shape,
+ self.bba_last_buckets_padded,
+ )
+
+ # ----- BWS kernel shardings -----
+ self.bws_reg_buckets_sharding, self.bws_reg_buckets_padded = (
+ self.create_named_sharding(shape=reg_buckets_shape, axes=[2])
+ )
+ _check(
+ "bws_regular_window_buckets",
+ reg_buckets_shape,
+ self.bws_reg_buckets_padded,
+ )
+
+ self.bws_last_buckets_sharding, self.bws_last_buckets_padded = (
+ self.create_named_sharding(shape=last_buckets_shape, axes=[1])
+ )
+ _check(
+ "bws_last_window_buckets",
+ last_buckets_shape,
+ self.bws_last_buckets_padded,
+ )
+
+ self.bws_window_sum_sharding, self.bws_window_sum_padded = (
+ self.create_named_sharding(shape=window_sum_shape, axes=[1])
+ )
+ _check("bws_window_sum", window_sum_shape, self.bws_window_sum_padded)
+
+ # ----- Whole-MSM fused kernel shardings -----
+ # Slices get a tile axis prepended, so the regular window axis shifts
+ # from axis[0] to axis[1]. Last slices + points are replicated.
+ fused_reg_slices_shape = (self.tile_num, W - 1, T)
+ fused_last_slices_shape = (self.tile_num, T)
+ fused_tiled_points_shape = (self.tile_num, T, C, M)
+ self.fused_reg_slices_sharding, self.fused_reg_slices_padded = (
+ self.create_named_sharding(shape=fused_reg_slices_shape, axes=[1])
+ )
+ _check(
+ "fused_regular_tiled_slices",
+ fused_reg_slices_shape,
+ self.fused_reg_slices_padded,
+ )
+ self.fused_last_slices_sharding, self.fused_last_slices_padded = (
+ self.create_named_sharding(shape=fused_last_slices_shape, axes=[])
+ )
+ _check(
+ "fused_last_tiled_slices",
+ fused_last_slices_shape,
+ self.fused_last_slices_padded,
+ )
+ self.fused_tiled_points_sharding, self.fused_tiled_points_padded = (
+ self.create_named_sharding(shape=fused_tiled_points_shape, axes=[])
+ )
+ _check(
+ "fused_tiled_points",
+ fused_tiled_points_shape,
+ self.fused_tiled_points_padded,
+ )
+
+ def set_use_sharding(self, use_sharding: bool):
+ super().set_use_sharding(use_sharding)
+ if use_sharding:
+ self._init_shardings()
+ self._place_sharded_state()
+
+ def _place_sharded_state(self):
+ """Pad + place persistent state (accumulators, points) onto the correct
+
+ shardings so compiled kernels receive inputs whose layout matches what
+ they were compiled for. Must be called after _init_shardings().
+ """
+ # Accumulators — BBA input layouts.
+ self.regular_window_buckets = utils.pad_jax_array(
+ jnp.asarray(self.regular_window_buckets), self.bba_reg_buckets_padded
+ ).to_device(self.bba_reg_buckets_sharding)
+ self.last_window_buckets = utils.pad_jax_array(
+ jnp.asarray(self.last_window_buckets), self.bba_last_buckets_padded
+ ).to_device(self.bba_last_buckets_sharding)
+ # window_sum — BWS input layout.
+ self.window_sum = utils.pad_jax_array(
+ jnp.asarray(self.window_sum), self.bws_window_sum_padded
+ ).to_device(self.bws_window_sum_sharding)
+ # Points — replicated on the mesh. Use the 4D fused sharding so that
+ # the runtime placement matches what the compiled fused kernel was
+ # lowered with (both paths see the same PartitionSpec).
+ self.points = jax.device_put(
+ jnp.asarray(self.points), self.fused_tiled_points_sharding
+ )
+
+ # ----- Compile system (mirrors TPUDistributionMSMContext) -----
+
+ def _get_bba_shape_dtype_structs(self):
+ """ShapeDtypeStructs for _bucketize_and_bucket_accumulation inputs.
+
+ When sharding is enabled, inputs carry the shardings/padded shapes
+ established in _init_shardings().
+ """
+ if self.use_sharding:
+ return [
+ jax.ShapeDtypeStruct(
+ self.bba_reg_slices_padded,
+ jnp.int32,
+ sharding=self.bba_reg_slices_sharding,
+ ),
+ jax.ShapeDtypeStruct(
+ self.bba_last_slice_padded,
+ jnp.int32,
+ sharding=self.bba_last_slice_sharding,
+ ),
+ jax.ShapeDtypeStruct(
+ self.bba_points_padded,
+ jnp.uint32,
+ sharding=self.bba_points_sharding,
+ ),
+ jax.ShapeDtypeStruct(
+ self.bba_reg_buckets_padded,
+ jnp.uint32,
+ sharding=self.bba_reg_buckets_sharding,
+ ),
+ jax.ShapeDtypeStruct(
+ self.bba_last_buckets_padded,
+ jnp.uint32,
+ sharding=self.bba_last_buckets_sharding,
+ ),
+ ]
+ regular_slices_shape = (self.window_num - 1, self.tile_length)
+ last_window_slice_shape = (self.tile_length,)
+ tiled_points_shape = (
+ self.tile_length,
+ self.coordinate_dim,
+ self.moduli_num,
+ )
+ regular_window_buckets_shape = (
+ self.coordinate_dim,
+ self.window_num - 1,
+ self.bucket_num_per_window,
+ self.moduli_num,
+ )
+ last_window_buckets_shape = (
+ self.coordinate_dim,
+ self.bucket_num_per_window,
+ self.moduli_num,
+ )
+ return [
+ jax.ShapeDtypeStruct(regular_slices_shape, jnp.int32),
+ jax.ShapeDtypeStruct(last_window_slice_shape, jnp.int32),
+ jax.ShapeDtypeStruct(tiled_points_shape, jnp.uint32),
+ jax.ShapeDtypeStruct(regular_window_buckets_shape, jnp.uint32),
+ jax.ShapeDtypeStruct(last_window_buckets_shape, jnp.uint32),
+ ]
+
+ def _get_bws_shape_dtype_structs(self):
+ """ShapeDtypeStructs for _bucket_and_window_sum inputs."""
+ if self.use_sharding:
+ return [
+ jax.ShapeDtypeStruct(
+ self.bws_reg_buckets_padded,
+ jnp.uint32,
+ sharding=self.bws_reg_buckets_sharding,
+ ),
+ jax.ShapeDtypeStruct(
+ self.bws_last_buckets_padded,
+ jnp.uint32,
+ sharding=self.bws_last_buckets_sharding,
+ ),
+ jax.ShapeDtypeStruct(
+ self.bws_window_sum_padded,
+ jnp.uint32,
+ sharding=self.bws_window_sum_sharding,
+ ),
+ ]
+ regular_window_buckets_shape = (
+ self.coordinate_dim,
+ self.window_num - 1,
+ self.bucket_num_per_window,
+ self.moduli_num,
+ )
+ last_window_buckets_shape = (
+ self.coordinate_dim,
+ self.bucket_num_per_window,
+ self.moduli_num,
+ )
+ window_sum_shape = (
+ self.coordinate_dim,
+ self.bucket_num_per_window,
+ self.batch_window_num,
+ self.moduli_num,
+ )
+ return [
+ jax.ShapeDtypeStruct(regular_window_buckets_shape, jnp.uint32),
+ jax.ShapeDtypeStruct(last_window_buckets_shape, jnp.uint32),
+ jax.ShapeDtypeStruct(window_sum_shape, jnp.uint32),
+ ]
+
+ def _get_bba_bws_shape_dtype_structs(self):
+ """ShapeDtypeStructs for the fused _bba_and_bws inputs.
+
+ Slices/points + BBA-layout accumulators + BWS-layout window_sum. The
+ reshard from BBA's window-axis bucket layout to BWS's bucket-axis
+ layout happens inside the kernel via with_sharding_constraint.
+ """
+ bba = self._get_bba_shape_dtype_structs()
+ bws = self._get_bws_shape_dtype_structs()
+ # 3 slice/point inputs + 2 BBA accumulators + BWS window_sum.
+ return bba + [bws[2]]
+
+ def compile(self, parameters: Optional[dict] = None):
+ """Compile the kernels needed for the chosen execution path.
+
+ Args:
+ parameters: Compile-time options. Recognised keys: - ``use_fused`` (bool,
+ default False): when True, compile the whole-MSM fused kernel
+ (``_multiscalar_multiply_fused``) and only the subkernels it uses
+ internally (``_bba_and_bws`` for the last tile, plus
+ ``_bucketize_and_bucket_accumulation`` for the scan body when ``tile_num
+ > 1``). When False, compile the per-stage path: BBA + BWS, and
+ additionally ``_bba_and_bws`` for the last-tile fusion that
+ ``_multiscalar_multiply`` uses. Also skips compiling the standalone BBA
+ when ``tile_num == 1``.
+ """
+ parameters = parameters or {}
+ use_fused = parameters.get("use_fused", False)
+ self.use_fused = use_fused
+
+ bba_structs = self._get_bba_shape_dtype_structs()
+ bws_structs = self._get_bws_shape_dtype_structs()
+ bba_bws_structs = self._get_bba_bws_shape_dtype_structs()
+
+ bba_hash = hash_args(
+ *(v for s in bba_structs for v in (s.shape, s.dtype.__str__()))
+ )
+ bws_hash = hash_args(
+ *(v for s in bws_structs for v in (s.shape, s.dtype.__str__()))
+ )
+ bba_bws_hash = hash_args(
+ *(v for s in bba_bws_structs for v in (s.shape, s.dtype.__str__()))
+ )
+
+ # Always compile the fused BBA+BWS kernel: it handles the last tile
+ # in both execution paths, and is the single kernel used by
+ # tile_num == 1 in the non-fused path.
+ self.compiled_kernels.setdefault(bba_bws_hash, {})["bba_and_bws"] = (
+ jax_jit_lower_compile(
+ self._bba_and_bws,
+ bba_bws_structs[0],
+ bba_bws_structs[1],
+ bba_bws_structs[2],
+ bba_bws_structs[3],
+ bba_bws_structs[4],
+ bba_bws_structs[5],
+ )
+ )
+
+ if use_fused:
+ fused_structs = self._get_fused_shape_dtype_structs()
+ fused_hash = hash_args(
+ *(v for s in fused_structs for v in (s.shape, s.dtype.__str__()))
+ )
+ self.compiled_kernels.setdefault(fused_hash, {})[
+ "multiscalar_multiply_fused"
+ ] = jax_jit_lower_compile(
+ self._multiscalar_multiply_fused,
+ *fused_structs,
+ )
+ # The whole-MSM fused kernel absorbs BBA (via scan) and BWS
+ # internally, so the standalone per-stage compiles aren't needed.
+ else:
+ # Non-fused path uses standalone BBA for tiles 0..N-2, and
+ # _bba_and_bws for the last tile (already compiled above). When
+ # tile_num == 1 there are no non-last tiles, so skip the
+ # standalone BBA compile entirely.
+ if self.tile_num > 1:
+ self.compiled_kernels.setdefault(bba_hash, {})[
+ "bucketize_and_bucket_accumulation"
+ ] = jax_jit_lower_compile(
+ self._bucketize_and_bucket_accumulation,
+ bba_structs[0],
+ bba_structs[1],
+ bba_structs[2],
+ bba_structs[3],
+ bba_structs[4],
+ )
+ # BWS standalone is still useful if any external caller invokes
+ # bucket_and_window_sum directly. Keep compiling it.
+ self.compiled_kernels.setdefault(bws_hash, {})[
+ "bucket_and_window_sum"
+ ] = jax_jit_lower_compile(
+ self._bucket_and_window_sum,
+ bws_structs[0],
+ bws_structs[1],
+ bws_structs[2],
+ )
+
+ self.use_compiled_kernels = True
+
+ def bba_and_bws(
+ self,
+ regular_slices,
+ last_window_slice,
+ tiled_points,
+ regular_window_buckets,
+ last_window_buckets,
+ window_sum,
+ ):
+ if self.use_compiled_kernels:
+ kernel_hash = hash_args(
+ regular_slices.shape,
+ regular_slices.dtype.__str__(),
+ last_window_slice.shape,
+ last_window_slice.dtype.__str__(),
+ tiled_points.shape,
+ tiled_points.dtype.__str__(),
+ regular_window_buckets.shape,
+ regular_window_buckets.dtype.__str__(),
+ last_window_buckets.shape,
+ last_window_buckets.dtype.__str__(),
+ window_sum.shape,
+ window_sum.dtype.__str__(),
+ )
+ return self.compiled_kernels[kernel_hash]["bba_and_bws"](
+ regular_slices,
+ last_window_slice,
+ tiled_points,
+ regular_window_buckets,
+ last_window_buckets,
+ window_sum,
+ )
+ return self._bba_and_bws(
+ regular_slices,
+ last_window_slice,
+ tiled_points,
+ regular_window_buckets,
+ last_window_buckets,
+ window_sum,
+ )
+
+ def bucketize_and_bucket_accumulation(
+ self,
+ regular_slices,
+ last_window_slice,
+ tiled_points,
+ regular_window_buckets,
+ last_window_buckets,
+ ):
+ if self.use_compiled_kernels:
+ kernel_hash = hash_args(
+ regular_slices.shape,
+ regular_slices.dtype.__str__(),
+ last_window_slice.shape,
+ last_window_slice.dtype.__str__(),
+ tiled_points.shape,
+ tiled_points.dtype.__str__(),
+ regular_window_buckets.shape,
+ regular_window_buckets.dtype.__str__(),
+ last_window_buckets.shape,
+ last_window_buckets.dtype.__str__(),
+ )
+ return self.compiled_kernels[kernel_hash][
+ "bucketize_and_bucket_accumulation"
+ ](
+ regular_slices,
+ last_window_slice,
+ tiled_points,
+ regular_window_buckets,
+ last_window_buckets,
+ )
+ return self._bucketize_and_bucket_accumulation(
+ regular_slices,
+ last_window_slice,
+ tiled_points,
+ regular_window_buckets,
+ last_window_buckets,
+ )
+
+ def bucket_and_window_sum(
+ self, regular_window_buckets, last_window_buckets, window_sum
+ ):
+ if self.use_compiled_kernels:
+ kernel_hash = hash_args(
+ regular_window_buckets.shape,
+ regular_window_buckets.dtype.__str__(),
+ last_window_buckets.shape,
+ last_window_buckets.dtype.__str__(),
+ window_sum.shape,
+ window_sum.dtype.__str__(),
+ )
+ return self.compiled_kernels[kernel_hash]["bucket_and_window_sum"](
+ regular_window_buckets,
+ last_window_buckets,
+ window_sum,
+ )
+ return self._bucket_and_window_sum(
+ regular_window_buckets, last_window_buckets, window_sum
+ )
+
+ def context_hash(self) -> str:
+ return hash_args(
+ self.__class__.__name__,
+ self.ec_ctx.context_hash()
+ if hasattr(self.ec_ctx, "context_hash")
+ else str(self.ec_ctx.__class__.__name__),
+ self.slice_bits,
+ self.scalar_bits,
+ self.msm_length,
+ self.tile_length,
+ self.bucket_num_per_window,
+ self.bucket_num_last_window,
+ self.use_sharding,
+ )
+
+ def _multiscalar_multiply(self, tiled_slices: jnp.ndarray):
+ # Padding (if any) was applied in to_computational_format. Here we only
+ # place the per-tile slices onto the right sharding — points are already
+ # placed on the replicated BBA points sharding via _place_sharded_state.
+ last_tile = self.tile_num - 1
+ for tile_index in range(self.tile_num):
+ tiled_slices_tile = tiled_slices[tile_index]
+ regular_slices_tile = tiled_slices_tile[: self.window_num - 1]
+ last_window_slice_tile = tiled_slices_tile[self.window_num - 1]
+ tiled_points_tile = self.points[tile_index]
+
+ if self.use_sharding:
+ regular_slices_tile = regular_slices_tile.to_device(
+ self.bba_reg_slices_sharding
+ )
+ last_window_slice_tile = last_window_slice_tile.to_device(
+ self.bba_last_slice_sharding
+ )
+
+ if tile_index == last_tile:
+ # Fused BBA + BWS: reshard (BBA -> BWS) stays inside one compiled
+ # kernel, so XLA can overlap it with the reduction tree.
+ return self.bba_and_bws(
+ regular_slices_tile,
+ last_window_slice_tile,
+ tiled_points_tile,
+ self.regular_window_buckets,
+ self.last_window_buckets,
+ self.window_sum,
+ )
+
+ self.regular_window_buckets, self.last_window_buckets = (
+ self.bucketize_and_bucket_accumulation(
+ regular_slices_tile,
+ last_window_slice_tile,
+ tiled_points_tile,
+ self.regular_window_buckets,
+ self.last_window_buckets,
+ )
+ )
+
+ def _multiscalar_multiply_fused(
+ self,
+ regular_tiled_slices: jnp.ndarray,
+ last_tiled_slices: jnp.ndarray,
+ tiled_points: jnp.ndarray,
+ regular_window_buckets: jnp.ndarray,
+ last_window_buckets: jnp.ndarray,
+ window_sum: jnp.ndarray,
+ ) -> jnp.ndarray:
+ """Whole-MSM fused kernel — one jit across all tiles + BWS.
+
+ Regular and last-window slices are passed as separate tensors because
+ they require different input shardings: regular is sharded along its
+ window axis, while the last-window slice has no batch-like axis to
+ shard over and is replicated. Packing them together would force a
+ single sharding on the combined tensor, losing this distinction.
+
+ Args:
+ regular_tiled_slices: (tile_num, W-1, tile_length) — shard axis[1]
+ last_tiled_slices: (tile_num, tile_length) — replicated
+ tiled_points: (tile_num, tile_length, C, M) — replicated
+ regular_window_buckets: (C, W-1, B, M) — BBA axis[1]
+ last_window_buckets: (C, B, M) — BBA axis[1]
+ window_sum: (C, B, W, M) — BWS axis[1]
+
+ Hybrid structure:
+ - Tiles ``0..N-2`` run through ``jax.lax.scan``.
+ - The last tile goes through ``_bba_and_bws`` (reshard + BR->WM
+ constraint live inside that kernel).
+ - When ``tile_num == 1``, the scan is skipped.
+ """
+
+ def scan_body(carry, inputs):
+ reg_buckets, last_buckets = carry
+ regular_slices, last_window_slice, points_tile = inputs
+ reg_buckets, last_buckets = self._bucketize_and_bucket_accumulation(
+ regular_slices,
+ last_window_slice,
+ points_tile,
+ reg_buckets,
+ last_buckets,
+ )
+ return (reg_buckets, last_buckets), None
+
+ if self.tile_num > 1:
+ (regular_window_buckets, last_window_buckets), _ = jax.lax.scan(
+ scan_body,
+ (regular_window_buckets, last_window_buckets),
+ (
+ regular_tiled_slices[:-1],
+ last_tiled_slices[:-1],
+ tiled_points[:-1],
+ ),
+ length=self.tile_num - 1,
+ )
+
+ return self._bba_and_bws(
+ regular_tiled_slices[-1],
+ last_tiled_slices[-1],
+ tiled_points[-1],
+ regular_window_buckets,
+ last_window_buckets,
+ window_sum,
+ )
+
+ def _get_fused_shape_dtype_structs(self):
+ """ShapeDtypeStructs for the whole-MSM fused kernel.
+
+ Regular / last slices are produced as separate structs so each can
+ carry its own sharding (regular: window-axis sharded; last:
+ replicated).
+ """
+ bba = self._get_bba_shape_dtype_structs()
+ bws = self._get_bws_shape_dtype_structs()
+ per_tile_length = bba[0].shape[1]
+
+ regular_tiled_slices_shape = (
+ self.tile_num,
+ self.window_num - 1,
+ per_tile_length,
+ )
+ last_tiled_slices_shape = (self.tile_num, per_tile_length)
+ tiled_points_shape = (self.tile_num,) + tuple(bba[2].shape)
+
+ if self.use_sharding:
+ # Use the shardings/padded shapes pre-computed in _init_shardings so
+ # the compile-time and run-time placements agree.
+ return [
+ jax.ShapeDtypeStruct(
+ self.fused_reg_slices_padded,
+ jnp.int32,
+ sharding=self.fused_reg_slices_sharding,
+ ),
+ jax.ShapeDtypeStruct(
+ self.fused_last_slices_padded,
+ jnp.int32,
+ sharding=self.fused_last_slices_sharding,
+ ),
+ jax.ShapeDtypeStruct(
+ self.fused_tiled_points_padded,
+ jnp.uint32,
+ sharding=self.fused_tiled_points_sharding,
+ ),
+ bba[3],
+ bba[4],
+ bws[2],
+ ]
+ return [
+ jax.ShapeDtypeStruct(regular_tiled_slices_shape, jnp.int32),
+ jax.ShapeDtypeStruct(last_tiled_slices_shape, jnp.int32),
+ jax.ShapeDtypeStruct(tiled_points_shape, jnp.uint32),
+ bba[3],
+ bba[4],
+ bws[2],
+ ]
+
+ def multiscalar_multiply(self, tiled_slices: jnp.ndarray):
+ if not self.use_fused:
+ return self._multiscalar_multiply(tiled_slices)
+
+ # Fused path: split regular / last slices so they can carry independent
+ # shardings. self.points is already placed in per-tile layout.
+ regular_tiled_slices = tiled_slices[:, : self.window_num - 1]
+ last_tiled_slices = tiled_slices[:, self.window_num - 1]
+
+ if self.use_sharding:
+ regular_tiled_slices = regular_tiled_slices.to_device(
+ self.fused_reg_slices_sharding
+ )
+ last_tiled_slices = last_tiled_slices.to_device(
+ self.fused_last_slices_sharding
+ )
+
+ args = (
+ regular_tiled_slices,
+ last_tiled_slices,
+ self.points,
+ self.regular_window_buckets,
+ self.last_window_buckets,
+ self.window_sum,
+ )
+ if self.use_compiled_kernels:
+ kernel_hash = hash_args(
+ *(v for a in args for v in (a.shape, a.dtype.__str__()))
+ )
+ return self.compiled_kernels[kernel_hash]["multiscalar_multiply_fused"](
+ *args
+ )
+ return self._multiscalar_multiply_fused(*args)
diff --git a/jaxite_ec/multiscalar_multiplication_perf_test.py b/jaxite_ec/multiscalar_multiplication_perf_test.py
new file mode 100644
index 0000000..c7eaf2d
--- /dev/null
+++ b/jaxite_ec/multiscalar_multiplication_perf_test.py
@@ -0,0 +1,211 @@
+import os
+
+import jax
+from jaxite.jaxite_ec import utils
+import jaxite.jaxite_ec.elliptic_curve_context as ec_context
+import jaxite.jaxite_ec.finite_field_context as ff_context
+import jaxite.jaxite_ec.multiscalar_multiplication_context as msm_context
+from jaxite.jaxite_ec.profiler import (
+ PrecompiledKernelWrapper,
+ Profiler,
+ collect_logs,
+)
+from jaxite.jaxite_ec.utils import hash_args
+
+from absl.testing import absltest
+from absl.testing import parameterized
+
+jax.config.update("jax_enable_x64", True)
+
+MODULUS_377_INT = 0x01AE3A4617C510EAC63B05C06CA1493B1A22D9F300F5138F1EF3622FBA094800170B5D44300000008508C00000000001
+NUM_MODULI = 32
+RNS_MODULI = utils.find_moduli_specified_number(NUM_MODULI, 28)
+
+MSM_LENGTH_LIST = [2**10]
+SLICE_BITS = 10
+SCALAR_BITS = 253
+
+TEST_PARAMS_MSM_FUSION = [
+ ("msm_fusion", MSM_LENGTH_LIST),
+]
+
+
+def _build_msm_parameters(msm_length: int):
+ ff_parameters = {
+ "prime": MODULUS_377_INT,
+ "rns_moduli": RNS_MODULI,
+ "precision_bits": 28,
+ "radix_bits": 32,
+ }
+ ec_parameters = {
+ "finite_field_context_class": ff_context.DRNSlazyContext,
+ "finite_field_parameters": ff_parameters,
+ "prime": (
+ 258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177
+ ),
+ "order": (
+ 8444461749428370424248824938781546531375899335154063827935233455917409239041
+ ),
+ "a": -1,
+ "twist_d": (
+ 122268283598675559488486339158635529096981886914877139579534153582033676785385790730042363341236035746924960903179
+ ),
+ "alpha": -1,
+ "b": 1,
+ "s": (
+ 10189023633222963290707194929886294091415157242906428298294512798502806398782149227503530278436336312243746741931
+ ),
+ "MA": (
+ 228097355113300204138531148905234651262148041026195375645000724271212049151994375092458297304264351187709081232384
+ ),
+ "MB": (
+ 10189023633222963290707194929886294091415157242906428298294512798502806398782149227503530278436336312243746741931
+ ),
+ "t": (
+ 23560188534917577818843641916571445935985386319233886518929971599490231428764380923487987729215299304184915158756
+ ),
+ "generator": [
+ 71222569531709137229370268896323705690285216175189308202338047559628438110820800641278662592954630774340654489393,
+ 6177051365529633638563236407038680211609544222665285371549726196884440490905471891908272386851767077598415378235,
+ ],
+ }
+ return {
+ "elliptic_curve_context_class": (
+ ec_context.ExtendedTwistedEdwardsNDContext
+ ),
+ "elliptic_curve_parameters": ec_parameters,
+ "coordinate_dim": 4,
+ "msm_length": msm_length,
+ "tile_length": msm_length,
+ "slice_bits": SLICE_BITS,
+ "scalar_bits": SCALAR_BITS,
+ "order": ec_parameters["order"],
+ "c_kernel_ret_space_ratio": 2,
+ }
+
+
+def _build_fusion_ctx(msm_length: int):
+ """Mirrors scratch_profile_msm_fused.py: sharded fused MSM with
+
+ pre-compiled kernels. The context's multiscalar_multiply internally
+ dispatches to these already-compiled, already-sharded kernels — so it
+ must NOT be re-traced under jax.jit.
+ """
+ params = _build_msm_parameters(msm_length)
+ ctx = msm_context.FusionMSMContext(params)
+ ctx.set_use_compiled_kernels(True)
+ ctx.set_use_sharding(True)
+ ctx.compile(parameters={"use_fused": True})
+ return ctx
+
+
+class FusionMSMPerformanceTest(parameterized.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ outputs_dir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR")
+ if outputs_dir:
+ self.output_trace_root = os.path.join(outputs_dir, "log")
+ else:
+ self.output_trace_root = os.path.join(os.path.dirname(__file__), "log")
+ self.profiler_config = {
+ "iterations": 1,
+ "save_to_file": True,
+ "enable_sharding": True,
+ }
+
+ @classmethod
+ def tearDownClass(cls):
+ super().tearDownClass()
+ outputs_dir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR")
+ root_dir = (
+ outputs_dir
+ if outputs_dir
+ else os.path.dirname(os.path.abspath(__file__))
+ )
+ print(f"Collecting logs from: {root_dir}")
+ collect_logs(root_dir)
+
+ def _create_fusion_msm_wrapper(self, kernel_name, ctx, tiled_slices):
+ """Profile ``_multiscalar_multiply_fused`` — the individual jit-able kernel.
+
+ ``ctx.multiscalar_multiply`` is a Python dispatcher that slices inputs,
+ places them on shardings, and hashes into ``compiled_kernels``; it isn't
+ itself a kernel. We reproduce its setup once here, then hand the
+ already-compiled executable + concrete inputs to the profiler.
+ """
+ # Replicate the slicing ctx.multiscalar_multiply does for the fused path.
+ regular_tiled_slices = tiled_slices[:, : ctx.window_num - 1]
+ last_tiled_slices = tiled_slices[:, ctx.window_num - 1]
+ if ctx.use_sharding:
+ regular_tiled_slices = regular_tiled_slices.to_device(
+ ctx.fused_reg_slices_sharding
+ )
+ last_tiled_slices = last_tiled_slices.to_device(
+ ctx.fused_last_slices_sharding
+ )
+
+ input_arrays = [
+ regular_tiled_slices,
+ last_tiled_slices,
+ ctx.points,
+ ctx.regular_window_buckets,
+ ctx.last_window_buckets,
+ ctx.window_sum,
+ ]
+
+ # Pull the already-compiled kernel out of ctx — same hash scheme as
+ # ctx.multiscalar_multiply uses at dispatch time.
+ kernel_hash = hash_args(
+ *(v for a in input_arrays for v in (a.shape, a.dtype.__str__()))
+ )
+ compiled_fn = ctx.compiled_kernels[kernel_hash][
+ "multiscalar_multiply_fused"
+ ]
+ # compiled_fn = None
+
+ return PrecompiledKernelWrapper(
+ kernel_name=kernel_name,
+ callable_function=compiled_fn,
+ input_arrays=input_arrays,
+ enable_sharding=True,
+ callable_function_name="_multiscalar_multiply_fused", # it is important for collecting the correct trace events
+ )
+
+ @parameterized.named_parameters(*TEST_PARAMS_MSM_FUSION)
+ def test_fusion_msm_performance(self, msm_length_list):
+ profiler_instance = Profiler(
+ output_trace_path=self.output_trace_root,
+ profile_naming="msm_fusion",
+ configuration=self.profiler_config,
+ )
+
+ for msm_length in msm_length_list:
+ ctx = _build_fusion_ctx(msm_length)
+ tiled_slices = ctx.to_computational_format(None)
+
+ kernel_name = f"msm_fusion_n{msm_length}"
+ kernel_wrapper = self._create_fusion_msm_wrapper(
+ kernel_name=kernel_name,
+ ctx=ctx,
+ tiled_slices=tiled_slices,
+ )
+ profiler_instance.add_profile(
+ name=kernel_name,
+ kernel_wrapper=kernel_wrapper,
+ kernel_setting_cols={
+ "msm_length": msm_length,
+ "slice_bits": SLICE_BITS,
+ "scalar_bits": SCALAR_BITS,
+ "num_moduli": NUM_MODULI,
+ "use_sharding": True,
+ "use_fused": True,
+ },
+ )
+
+ profiler_instance.profile_all_profilers()
+ profiler_instance.post_process_all_profilers()
+
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/jaxite_ec/multiscalar_multiplication_test.py b/jaxite_ec/multiscalar_multiplication_test.py
new file mode 100644
index 0000000..ab4b935
--- /dev/null
+++ b/jaxite_ec/multiscalar_multiplication_test.py
@@ -0,0 +1,126 @@
+import os
+
+import jax
+import jaxite.jaxite_ec.elliptic_curve_context as ec_context
+import jaxite.jaxite_ec.finite_field_context as ff_context
+import jaxite.jaxite_ec.multiscalar_multiplication_context as msm_context
+import jaxite.jaxite_ec.utils as utils
+import numpy as np
+
+from absl.testing import absltest
+from absl.testing import parameterized
+
+jax.config.update("jax_enable_x64", True)
+
+MODULUS_377_INT = 0x01AE3A4617C510EAC63B05C06CA1493B1A22D9F300F5138F1EF3622FBA094800170B5D44300000008508C00000000001
+NUM_MODULI = 32
+RNS_MODULI = utils.find_moduli_specified_number(NUM_MODULI, 28)
+
+MSM_DIM = 2**10
+SEED = 0
+MSM_TEST_DIR = os.path.join(os.path.dirname(__file__), "data", f"t{MSM_DIM}")
+POINTS_PATH = os.path.join(
+ MSM_TEST_DIR, f"zprize_msm_curve_377_bases_dim_{MSM_DIM}_seed_{SEED}.csv"
+)
+SCALARS_PATH = os.path.join(
+ MSM_TEST_DIR, f"zprize_msm_curve_377_scalars_dim_{MSM_DIM}_seed_{SEED}.csv"
+)
+REF_RESULT_PATH = os.path.join(
+ MSM_TEST_DIR, f"zprize_msm_curve_377_res_dim_{MSM_DIM}_seed_{SEED}.csv"
+)
+
+
+def _build_msm_parameters():
+ ff_parameters = {
+ "prime": MODULUS_377_INT,
+ "rns_moduli": RNS_MODULI,
+ "precision_bits": 28,
+ "radix_bits": 32,
+ }
+ ec_parameters = {
+ "finite_field_context_class": ff_context.DRNSlazyContext,
+ "finite_field_parameters": ff_parameters,
+ "prime": (
+ 258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177
+ ),
+ "order": (
+ 8444461749428370424248824938781546531375899335154063827935233455917409239041
+ ),
+ "a": -1,
+ "twist_d": (
+ 122268283598675559488486339158635529096981886914877139579534153582033676785385790730042363341236035746924960903179
+ ),
+ "alpha": -1,
+ "b": 1,
+ "s": (
+ 10189023633222963290707194929886294091415157242906428298294512798502806398782149227503530278436336312243746741931
+ ),
+ "MA": (
+ 228097355113300204138531148905234651262148041026195375645000724271212049151994375092458297304264351187709081232384
+ ),
+ "MB": (
+ 10189023633222963290707194929886294091415157242906428298294512798502806398782149227503530278436336312243746741931
+ ),
+ "t": (
+ 23560188534917577818843641916571445935985386319233886518929971599490231428764380923487987729215299304184915158756
+ ),
+ "generator": [
+ 71222569531709137229370268896323705690285216175189308202338047559628438110820800641278662592954630774340654489393,
+ 6177051365529633638563236407038680211609544222665285371549726196884440490905471891908272386851767077598415378235,
+ ],
+ }
+ return {
+ "elliptic_curve_context_class": (
+ ec_context.ExtendedTwistedEdwardsNDContext
+ ),
+ "elliptic_curve_parameters": ec_parameters,
+ "coordinate_dim": 4,
+ "msm_length": MSM_DIM,
+ "tile_length": MSM_DIM,
+ "slice_bits": 6,
+ "scalar_bits": 253,
+ "order": ec_parameters["order"],
+ "points_path": POINTS_PATH,
+ "c_kernel_ret_space_ratio": 2,
+ }
+
+
+MSM_CONTEXT_CASES = [
+ ("cpu_distribution", msm_context.CPUDistributionMSMContext, False),
+ ("tpu_distribution", msm_context.TPUDistributionMSMContext, False),
+ ("fusion", msm_context.FusionMSMContext, True),
+]
+
+
+class MsmBls12377Test(parameterized.TestCase):
+
+ def setUp(self):
+ super().setUp()
+ self.msm_parameters = _build_msm_parameters()
+ self.scalars = utils.read_external_msm_file(SCALARS_PATH, "scalars")
+ self.ref_result = utils.read_external_msm_file(
+ REF_RESULT_PATH, "result_ref"
+ )
+
+ def _run_msm(self, context_class, use_fused):
+ ctx = context_class(self.msm_parameters)
+ ctx.set_use_compiled_kernels(True)
+ compile_kwargs = {"use_fused": True} if use_fused else {}
+ ctx.compile(parameters=compile_kwargs)
+
+ tiled_slices = ctx.to_computational_format(self.scalars)
+ result_m = ctx.multiscalar_multiply(tiled_slices)
+ return ctx.to_original_format(result_m)
+
+ @parameterized.named_parameters(*MSM_CONTEXT_CASES)
+ def test_multiscalar_multiply_matches_reference(
+ self, context_class, use_fused
+ ):
+ result = self._run_msm(context_class, use_fused)
+ np.testing.assert_array_equal(
+ np.asarray(result), np.asarray(self.ref_result)
+ )
+
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/jaxite_ec/number_theory_transform_context.py b/jaxite_ec/number_theory_transform_context.py
new file mode 100644
index 0000000..14f7134
--- /dev/null
+++ b/jaxite_ec/number_theory_transform_context.py
@@ -0,0 +1,834 @@
+from abc import ABC, abstractmethod
+from typing import Any, Dict
+import jax
+import jax.numpy as jnp
+from jaxite.jaxite_ec import utils
+from jaxite.jaxite_ec.finite_field_context import (
+ CROSSLazyContext,
+ DRNSlazyContext,
+)
+import numpy as np
+
+jax.config.update("jax_enable_x64", True)
+
+
+########################
+# Helper Functions
+########################
+def gen_twiddle_matrix(rows, cols, q, omega):
+ """Precompute twiddle matrix T where T[r, c] = omega^(r*c) mod q.
+
+ Stored as ``dtype=object`` so the cells hold arbitrary-precision Python
+ ints; needed because ``q`` can exceed 64 bits (e.g. the 753-bit ZKP
+ prime used in the perf test).
+ """
+ twiddle_matrix = np.empty((rows, cols), dtype=object)
+ for r in range(rows):
+ for c in range(cols):
+ twiddle_matrix[r, c] = pow(int(omega), int(r * c), int(q))
+ return twiddle_matrix
+
+
+def gen_twiddle_matrix_inv(rows, cols, q, omega):
+ """Precompute inverse twiddle matrix T_inv where T_inv[r, c] = omega^{-(r*c)} mod q."""
+ twiddle_matrix_inv = np.empty((rows, cols), dtype=object)
+ for r in range(rows):
+ for c in range(cols):
+ twiddle_matrix_inv[r, c] = pow(int(omega), int(-r * c), int(q))
+ return twiddle_matrix_inv
+
+
+def get_bit_reverse_perm(n):
+ """Generates bit-reversal permutation indices of size n."""
+ if n <= 0:
+ return []
+ bits = n.bit_length() - 1
+ perm = [0] * n
+ for i in range(n):
+ r = 0
+ temp = i
+ for _ in range(bits):
+ r = (r << 1) | (temp & 1)
+ temp >>= 1
+ perm[i] = r
+ return perm
+
+
+########################
+# Abstract Base Class
+########################
+class NumberTheoryTransformBase(ABC):
+ """Abstract base class for all NTT context implementations."""
+
+ @abstractmethod
+ def ntt(self, v):
+ """Forward Number Theory Transform."""
+ pass
+
+ @abstractmethod
+ def intt(self, v):
+ """Inverse Number Theory Transform."""
+ pass
+
+ @abstractmethod
+ def to_computational_format(self, a):
+ """Convert from plain integers to the internal computational representation."""
+ pass
+
+ @abstractmethod
+ def to_original_format(self, a):
+ """Convert from internal representation back to a flat list of integers."""
+ pass
+
+
+########################
+# BAT (Basis Aligned Transformation) helpers
+########################
+def basis_aligned_transformation(
+ matrix_drns: np.ndarray, rns_moduli
+) -> jnp.ndarray:
+ """Convert a 2-D DRNS twiddle matrix to BAT format for 8-bit matmul.
+
+ Each uint32 value is byte-shifted by 0/8/16/24 bits, reduced mod each
+ RNS modulus, then bitcast to uint8.
+
+ Args:
+ matrix_drns: DRNS twiddle, shape (rows, cols, num_moduli) uint32.
+ rns_moduli: sequence of RNS moduli.
+
+ Returns:
+ BAT matrix of shape (rows, 4, cols, 4, num_moduli) uint8.
+ """
+ rows, cols, M = np.array(matrix_drns).shape
+ matrix_u64 = np.array(matrix_drns, dtype=np.uint64)
+ moduli = np.array(rns_moduli, dtype=np.uint64)
+
+ # (4, rows, cols, M) — byte-shifted and reduced per channel
+ shifted = np.empty((4, rows, cols, M), dtype=np.uint32)
+ for s in range(4):
+ shifted[s] = ((matrix_u64 << (8 * s)) % moduli).astype(np.uint32)
+
+ # Bitcast uint32 → uint8: (4, rows, cols, M) → (4, rows, cols, M*4) → (4, rows, cols, M, 4)
+ shifted_u8 = shifted.view(np.uint8).reshape(4, rows, cols, M, 4)
+
+ # Rearrange to (rows, 4_shift, cols, 4_bytes, M)
+ shifted_u8 = shifted_u8.transpose(1, 0, 2, 4, 3)
+ return jnp.array(shifted_u8)
+
+
+def matmul_bat_einsum(
+ v: jax.Array, bat_twiddle: jax.Array, subscripts: str
+) -> jax.Array:
+ """BAT-based 8-bit matrix multiplication.
+
+ Bitcasts the input from uint32 to uint8, performs an 8-bit einsum with
+ the pre-processed BAT twiddle, and reconstructs the uint64 result via
+ byte-shifting.
+
+ Args:
+ v: Input array (uint32). The trailing dimension (num_moduli) is expanded
+ to (num_moduli, 4) by bitcast.
+ bat_twiddle: BAT twiddle (uint8), pre-computed offline.
+ subscripts: Einsum subscript string including the byte dimensions.
+
+ Returns:
+ uint64 result array (trailing moduli dimension, byte dim summed out).
+ """
+ v_u8 = jax.lax.bitcast_convert_type(v, jnp.uint8)
+ shift_factors = jnp.array([0, 8, 16, 24], dtype=jnp.uint32)
+ products = jnp.einsum(
+ subscripts, v_u8, bat_twiddle, preferred_element_type=jnp.uint32
+ )
+ return jnp.sum(products.astype(jnp.uint64) << shift_factors, axis=-1)
+
+
+########################
+# BAT subscript generator (right-convention only)
+########################
+_BAT_RESERVED_LETTERS = frozenset("mqpj")
+
+
+def _bat_subscripts_right(v_ndim: int, contract_axis: int) -> str:
+ """Einsum subscripts for BAT-based right-convention matmul.
+
+ Operand ``v`` has shape ``(*leading, M)`` with ``v_ndim == len(leading) + 1``;
+ after bitcast-to-u8 it grows a trailing size-4 byte axis. The twiddle
+ handle has shape ``(K, 4, J, 4, M)``: axis 0 is the contracted dim, axis
+ 2 is the produced dim. The axis at ``contract_axis`` of ``v`` (size K)
+ is replaced on the output by an axis of size J.
+ """
+ if contract_axis < 0:
+ contract_axis += v_ndim
+ if not (0 <= contract_axis < v_ndim - 1):
+ raise ValueError(
+ "contract_axis must be a leading axis of v "
+ f"(got {contract_axis} for v.ndim={v_ndim})"
+ )
+ letters = []
+ c = ord("a")
+ while len(letters) < v_ndim - 1:
+ ch = chr(c)
+ if ch not in _BAT_RESERVED_LETTERS:
+ letters.append(ch)
+ c += 1
+ k = letters[contract_axis]
+ v_sub = "".join(letters) + "mq"
+ out_letters = letters.copy()
+ out_letters[contract_axis] = "j"
+ t_sub = f"{k}qjpm"
+ out_sub = "".join(out_letters) + "mp"
+ return f"{v_sub},{t_sub}->{out_sub}"
+
+
+########################
+# Extension contexts — add NTT-specific modular_matmul / twiddle helpers
+# on top of the finite-field backends. The NTT layer only ever consumes
+# the abstract API exposed here, not the backend internals.
+########################
+class DRNSLazyExtensionContext(DRNSlazyContext):
+ """``DRNSlazyContext`` + NTT-specific modular matrix multiplication.
+
+ Exposes:
+ * :py:meth:`preprocess_matmul` — 2-D integer matrix → BAT
+ uint8 handle of shape ``(K, 4, J, 4, M)``.
+ * :py:meth:`preprocess_elementwise` — N-D integer tensor →
+ DRNS computational-format uint32 tensor of shape ``(..., M)``.
+ * :py:meth:`modular_matmul` — right-convention modular matmul
+ (BAT einsum + Montgomery → CRNS → Montgomery). Only supports the
+ matmul-shape used by the NTT stack (contract axis 0 of the handle).
+ * :py:meth:`modular_multiply_broadcast` — broadcasting element-wise
+ modular multiply for twiddle steps.
+ """
+
+ def __init__(self, parameters: Dict[str, Any]):
+ super().__init__(parameters)
+ self.for_ntt = True
+
+ def preprocess_matmul(self, mat_2d) -> jnp.ndarray:
+ """Encode a 2-D ``(K, J)`` integer matrix as a BAT uint8 handle."""
+ arr = np.asarray(mat_2d)
+ if arr.ndim != 2:
+ raise ValueError(f"matmul twiddle must be 2-D, got shape {arr.shape}")
+ # Reduce mod each RNS modulus in Python-int space so this stays correct
+ # for primes wider than 64 bits. Output fits in uint32 (moduli < 2^28).
+ if arr.dtype != object:
+ arr = arr.astype(object)
+ moduli_obj = np.array(self.rns_moduli, dtype=object)
+ drns_obj = (
+ (arr[..., np.newaxis] % moduli_obj) << self.radix_bits
+ ) % moduli_obj
+ return basis_aligned_transformation(
+ drns_obj.astype(np.uint32), self.rns_moduli
+ )
+
+ def preprocess_elementwise(self, mat_nd) -> jnp.ndarray:
+ """Encode an arbitrary-rank integer tensor in DRNS computational form."""
+ arr = np.asarray(mat_nd)
+ if arr.dtype != object:
+ arr = arr.astype(object)
+ moduli_obj = np.array(self.rns_moduli, dtype=object)
+ drns_obj = (
+ (arr[..., np.newaxis] % moduli_obj) << self.radix_bits
+ ) % moduli_obj
+ return jnp.array(drns_obj.astype(np.uint32))
+
+ def modular_matmul(
+ self, operand: jax.Array, handle: jax.Array, contract_axis: int
+ ) -> jax.Array:
+ """Modular matmul contracting ``operand[contract_axis]`` against
+
+ ``handle[0]``. ``handle`` must come from
+ :py:meth:`preprocess_matmul`.
+ """
+ operand = jnp.asarray(operand)
+ subs = _bat_subscripts_right(operand.ndim, contract_axis)
+ z = matmul_bat_einsum(operand, handle, subs)
+ z = self._jax_montgomery_reduce(z)
+ z = self._jax_crns(z)
+ z = self._jax_montgomery_reduce(z)
+ return z.astype(jnp.uint32)
+
+ def modular_multiply_broadcast(self, a: jax.Array, b: jax.Array) -> jax.Array:
+ """Element-wise modular multiply with numpy-style broadcasting."""
+ return self._modular_multiply(a, b)
+
+
+class CROSSLazyExtensionContext(CROSSLazyContext):
+ """``CROSSLazyContext`` + NTT-specific modular matrix multiplication.
+
+ Same four-method API as :py:class:`DRNSLazyExtensionContext`. The
+ matmul path here is a ``fori_loop`` that accumulates per-``k``
+ broadcast products via ``_modular_multiply`` + ``_modular_add``
+ (there is no dense-MXU-friendly representation for multi-limb CROSS
+ elements).
+ """
+
+ def __init__(self, parameters: Dict[str, Any]):
+ super().__init__(parameters)
+ self.for_ntt = True
+
+ def _encode_chunks(self, mat) -> jnp.ndarray:
+ """Encode an integer tensor as little-endian uint32 chunks."""
+ arr = np.asarray(mat)
+ flat = [int(x) % self.prime for x in arr.reshape(-1).tolist()]
+ n = self.chunk_num_u32
+ chunks = np.zeros((len(flat), n), dtype=np.uint32)
+ for i, v in enumerate(flat):
+ x = v
+ for j in range(n):
+ chunks[i, j] = x & 0xFFFFFFFF
+ x >>= 32
+ return jnp.asarray(chunks.reshape(arr.shape + (n,)))
+
+ def preprocess_matmul(self, mat_2d) -> jnp.ndarray:
+ arr = np.asarray(mat_2d)
+ if arr.ndim != 2:
+ raise ValueError(f"matmul twiddle must be 2-D, got shape {arr.shape}")
+ return self._encode_chunks(arr)
+
+ def preprocess_elementwise(self, mat_nd) -> jnp.ndarray:
+ return self._encode_chunks(mat_nd)
+
+ def modular_multiply_broadcast(self, a: jax.Array, b: jax.Array) -> jax.Array:
+ """Element-wise modular multiply with broadcasting.
+
+ Uses nested ``vmap`` so sharded callers (running under ``shard_map``)
+ see consistent axis sharding between the two operands.
+ """
+ shape = jnp.broadcast_shapes(a.shape, b.shape)
+ a_b = jnp.broadcast_to(a, shape)
+ b_b = jnp.broadcast_to(b, shape)
+ fn = self._modular_multiply
+ for _ in range(len(shape) - 2):
+ fn = jax.vmap(fn)
+ return fn(a_b, b_b)
+
+ def modular_matmul(
+ self, operand: jax.Array, handle: jax.Array, contract_axis: int
+ ) -> jax.Array:
+ """Modular matmul contracting ``operand[contract_axis]`` against
+
+ ``handle[0]``. ``handle`` must come from
+ :py:meth:`preprocess_matmul` and is ``(K, J, chunks)``.
+ """
+ operand = jnp.asarray(operand)
+ handle = jnp.asarray(handle)
+ if contract_axis < 0:
+ contract_axis += operand.ndim
+ if not (0 <= contract_axis < operand.ndim - 1):
+ raise ValueError(
+ "contract_axis must be a leading axis of operand "
+ f"(got {contract_axis} for operand.ndim={operand.ndim})"
+ )
+ nc = self.chunk_num_u32
+ v_moved = jnp.moveaxis(operand, contract_axis, 0) # (K, *rest, chunks)
+ K = v_moved.shape[0]
+ J = handle.shape[1]
+ rest = v_moved.shape[1:-1]
+ out_shape = (J,) + rest + (nc,)
+ t_reshape = (J,) + (1,) * len(rest) + (nc,)
+
+ def body(k, acc):
+ vk = v_moved[k]
+ tk = handle[k, :, :].reshape(t_reshape)
+ prod = self.modular_multiply_broadcast(vk[None], tk)
+ return self._modular_add(acc, prod)
+
+ init = jnp.zeros(out_shape, dtype=jnp.uint32)
+ out = jax.lax.fori_loop(0, K, body, init)
+ return jnp.moveaxis(out, 0, contract_axis)
+
+
+########################
+# Unified NTT stack — one class per step count, works with either
+# extension context above. The extension context is what lets the NTT
+# layer stay backend-agnostic.
+########################
+class NumpyCPUContext:
+ """Pure-numpy CPU reference context compatible with the unified NTT stack.
+
+ Exposes the same ``preprocess_matmul`` / ``preprocess_elementwise``
+ / ``modular_matmul`` / ``modular_multiply_broadcast`` /
+ ``to_computational_format`` / ``to_original_format`` API as
+ :class:`DRNSLazyExtensionContext` and :class:`CROSSLazyExtensionContext`,
+ so that :class:`NTT3Step` / :class:`NTT5Step` / :class:`NTT7Step`
+ reproduce the same outputs as the CPU reference classes
+ (:class:`CPUCROSSContext`, :class:`CPUCROSS5StepContext`,
+ :class:`CPUCROSS7StepContext`).
+
+ Each field element is stored as a single ``uint64`` with a trailing
+ size-1 "chunks" axis (matching the layout the NTT layer expects).
+ All arithmetic is a plain ``numpy`` multiply / tensordot followed by
+ ``% prime`` — correctness only, no JAX / no TPU.
+ """
+
+ def __init__(self, parameters: Dict[str, Any]):
+ self.prime = int(parameters["prime"])
+ self.for_ntt = True
+
+ # ---------- format conversion ----------
+
+ def to_computational_format(self, a) -> np.ndarray:
+ """Wrap ``a`` as an ``(..., 1)`` uint64 ndarray."""
+ arr = np.asarray(a, dtype=np.uint64)
+ return arr[..., np.newaxis]
+
+ def to_original_format(self, a):
+ """Reduce, flatten, return a flat list of Python ints."""
+ arr = np.asarray(a, dtype=np.uint64)
+ return (arr.reshape(-1) % self.prime).tolist()
+
+ # ---------- twiddle preprocessing ----------
+
+ def preprocess_matmul(self, mat_2d) -> np.ndarray:
+ arr = np.asarray(mat_2d)
+ if arr.ndim != 2:
+ raise ValueError(f"matmul twiddle must be 2-D, got shape {arr.shape}")
+ if arr.dtype != object:
+ arr = arr.astype(object)
+ return (arr % self.prime).astype(np.uint64)[..., np.newaxis]
+
+ def preprocess_elementwise(self, mat_nd) -> np.ndarray:
+ arr = np.asarray(mat_nd)
+ if arr.dtype != object:
+ arr = arr.astype(object)
+ return (arr % self.prime).astype(np.uint64)[..., np.newaxis]
+
+ # ---------- modular arithmetic ----------
+
+ def modular_matmul(
+ self, operand: np.ndarray, handle: np.ndarray, contract_axis: int
+ ) -> np.ndarray:
+ """Right-convention modular matmul along ``contract_axis``.
+
+ ``operand`` has trailing size-1 axis (the "chunks" axis).
+ ``handle`` has shape ``(K, J, 1)`` — axis 0 is the contracted dim.
+ """
+ op = np.asarray(operand, dtype=np.uint64)[..., 0]
+ h = np.asarray(handle, dtype=np.uint64)[..., 0]
+ if contract_axis < 0:
+ contract_axis += op.ndim + 1 # +1 because we squeezed trailing chunks
+ # tensordot contracts op.axes[contract_axis] with h.axis[0]; the new
+ # J axis from h lands at the end of the result, so move it back to
+ # ``contract_axis``.
+ out = np.tensordot(op, h, axes=([contract_axis], [0]))
+ out = np.moveaxis(out, -1, contract_axis)
+ out = out % self.prime
+ return out[..., np.newaxis]
+
+ def modular_multiply_broadcast(
+ self, a: np.ndarray, b: np.ndarray
+ ) -> np.ndarray:
+ """Broadcasting element-wise modular multiply."""
+ a_arr = np.asarray(a, dtype=np.uint64)
+ b_arr = np.asarray(b, dtype=np.uint64)
+ return (a_arr * b_arr) % self.prime
+
+
+class NTTBase(NumberTheoryTransformBase):
+ """Shared scaffolding for the unified NTT classes.
+
+ All NTT classes below follow the same algorithm skeleton (precompute
+ twiddles via :py:func:`gen_twiddle_matrix`, apply bit-reversal
+ permutations, then alternate matmul / element-wise twiddle steps)
+ and dispatch all arithmetic through the extension context passed in
+ as ``finite_field_context``.
+
+ The ``finite_field_context`` parameter accepts either:
+
+ * An already-constructed extension context instance
+ (``DRNSLazyExtensionContext``, ``CROSSLazyExtensionContext``, or
+ ``NumpyCPUContext``).
+ * A backend string (``"drns"``, ``"cross"``, or ``"cpu"``), in which
+ case the extension context is auto-constructed from the remaining
+ parameters:
+
+ ============ =======================================================
+ ``"drns"`` ``num_moduli`` (default 21), ``precision_bits`` (28),
+ ``radix_bits`` (32).
+ ``"cross"`` ``chunk_num_u8`` (default derived from prime bit-length).
+ ``"cpu"`` No extra parameters.
+ ============ =======================================================
+ """
+
+ @staticmethod
+ def _build_ff_ctx(parameters: dict):
+ """Auto-construct an extension context from top-level parameters.
+
+ If ``finite_field_context`` is already an instance, return it as-is.
+ If it's a backend string, build the matching extension context using
+ ``prime`` and the optional sizing parameters from ``parameters``.
+ """
+ ff_ctx = parameters["finite_field_context"]
+ if isinstance(ff_ctx, str):
+ prime = parameters["prime"]
+ backend = ff_ctx.lower()
+ if backend == "drns":
+ num_moduli = parameters.get("num_moduli", 21)
+ precision_bits = parameters.get("precision_bits", 28)
+ radix_bits = parameters.get("radix_bits", 32)
+ rns_moduli = utils.find_moduli_specified_number(
+ num_moduli, precision_bits
+ )
+ return DRNSLazyExtensionContext({
+ "prime": prime,
+ "rns_moduli": rns_moduli,
+ "precision_bits": precision_bits,
+ "radix_bits": radix_bits,
+ })
+ elif backend == "cross":
+ ctx_params = {"prime": prime}
+ if "chunk_num_u8" in parameters:
+ ctx_params["chunk_num_u8"] = parameters["chunk_num_u8"]
+ return CROSSLazyExtensionContext(ctx_params)
+ elif backend == "cpu":
+ return NumpyCPUContext({"prime": prime})
+ else:
+ raise ValueError(
+ f"Unknown backend {ff_ctx!r}; use 'drns', 'cross', or 'cpu'"
+ )
+ return ff_ctx
+
+ def __init__(self, ff_ctx, spatial_shape: tuple):
+ if not getattr(ff_ctx, "for_ntt", False):
+ raise TypeError(
+ "finite_field_context does not declare NTT support; "
+ f"{type(ff_ctx).__name__} must set ``self.for_ntt = True``"
+ )
+ self.ff_ctx = ff_ctx
+ self._spatial_shape = spatial_shape
+
+ def to_computational_format(self, a):
+ return self.ff_ctx.to_computational_format(a)
+
+ def to_original_format(self, a):
+ a = jnp.asarray(a)
+ trailing = a.shape[-1]
+ return self.ff_ctx.to_original_format(a.reshape(-1, trailing))
+
+ def _ensure_ntt_shape(self, v: jnp.ndarray) -> jnp.ndarray:
+ v = jnp.asarray(v)
+ expected_ndim = 1 + len(self._spatial_shape) + 1
+ if v.ndim < expected_ndim:
+ v = v.reshape(-1, *self._spatial_shape, v.shape[-1])
+ return v
+
+
+class NTT3Step(NTTBase):
+ """Unified 3-step NTT. Input/output shape: ``(B, R, C, trailing)``."""
+
+ def __init__(self, parameters: Dict[str, Any]):
+ self.prime = parameters["prime"]
+ self.r = parameters["r"]
+ self.c = parameters["c"]
+ ff_ctx = self._build_ff_ctx(parameters)
+ super().__init__(ff_ctx, spatial_shape=(self.r, self.c))
+
+ self.transform_length = self.r * self.c
+ psi = parameters.get("psi")
+ self.psi = (
+ int(psi)
+ if psi is not None
+ else utils.root_of_unity(2 * self.transform_length, self.prime)
+ )
+ self.omega = (self.psi**2) % self.prime
+
+ # --- twiddle matrices (plain integer form) ---
+ omega_col = pow(self.omega, self.c, self.prime)
+ omega_row = pow(self.omega, self.r, self.prime)
+ ntt_tf1 = gen_twiddle_matrix(self.r, self.r, self.prime, omega_col)
+ ntt_tf2 = gen_twiddle_matrix(self.r, self.c, self.prime, self.omega)
+ ntt_tf3 = gen_twiddle_matrix(self.c, self.c, self.prime, omega_row)
+
+ inv_omega_col = pow(omega_col, -1, self.prime)
+ inv_omega_row = pow(omega_row, -1, self.prime)
+ intt_tf1 = gen_twiddle_matrix(self.c, self.c, self.prime, inv_omega_row)
+ intt_tf2 = gen_twiddle_matrix_inv(self.r, self.c, self.prime, self.omega)
+ col_inv = pow(self.c, -1, self.prime)
+ row_inv = pow(self.r, -1, self.prime)
+ intt_tf2 = (intt_tf2 * col_inv) % self.prime
+ intt_tf3 = gen_twiddle_matrix(self.r, self.r, self.prime, inv_omega_col)
+ intt_tf3 = (intt_tf3 * row_inv) % self.prime
+
+ # --- bit-reversal permutations ---
+ perm_r = get_bit_reverse_perm(self.r)
+ perm_c = get_bit_reverse_perm(self.c)
+ ntt_tf1 = ntt_tf1[perm_r, :]
+ ntt_tf2 = ntt_tf2[perm_r, :]
+ ntt_tf3 = ntt_tf3[:, perm_c]
+ intt_tf1 = intt_tf1[perm_c, :]
+ intt_tf2 = intt_tf2[perm_r, :]
+ intt_tf3 = intt_tf3[:, perm_r]
+
+ # Right-convention handles: the contracted dim is at axis 0 of the
+ # handle. Step 1 of NTT is logically "T1 @ v along R", which becomes
+ # right-matmul of v against T1.T; step 3 is already "v @ T3".
+ self.ntt_t1 = ff_ctx.preprocess_matmul(ntt_tf1.T)
+ self.ntt_t2 = ff_ctx.preprocess_elementwise(ntt_tf2)
+ self.ntt_t3 = ff_ctx.preprocess_matmul(ntt_tf3)
+ self.intt_t1 = ff_ctx.preprocess_matmul(intt_tf1)
+ self.intt_t2 = ff_ctx.preprocess_elementwise(intt_tf2)
+ self.intt_t3 = ff_ctx.preprocess_matmul(intt_tf3.T)
+
+ def ntt(self, v: jnp.ndarray):
+ v = self._ensure_ntt_shape(v) # (B, R, C, trailing)
+ v = self.ff_ctx.modular_matmul(v, self.ntt_t1, contract_axis=1)
+ v = self.ff_ctx.modular_multiply_broadcast(v, self.ntt_t2)
+ v = self.ff_ctx.modular_matmul(v, self.ntt_t3, contract_axis=2)
+ return v
+
+ def intt(self, v: jnp.ndarray):
+ v = self._ensure_ntt_shape(v)
+ v = self.ff_ctx.modular_matmul(v, self.intt_t1, contract_axis=2)
+ v = self.ff_ctx.modular_multiply_broadcast(v, self.intt_t2)
+ v = self.ff_ctx.modular_matmul(v, self.intt_t3, contract_axis=1)
+ return v
+
+
+class NTT5Step(NTTBase):
+ """Unified 5-step NTT. Shape: ``(B, RR, RC, C, trailing)``."""
+
+ def __init__(self, parameters: Dict[str, Any]):
+ self.prime = parameters["prime"]
+ self.rr = parameters["rr"]
+ self.rc = parameters["rc"]
+ self.c = parameters["c"]
+ ff_ctx = self._build_ff_ctx(parameters)
+ super().__init__(ff_ctx, spatial_shape=(self.rr, self.rc, self.c))
+
+ self.transform_length = self.rr * self.rc * self.c
+ R = self.rr * self.rc
+ psi = parameters.get("psi")
+ self.psi = (
+ int(psi)
+ if psi is not None
+ else utils.root_of_unity(2 * self.transform_length, self.prime)
+ )
+ self.omega = (self.psi**2) % self.prime
+
+ omega_R = pow(self.omega, self.c, self.prime)
+ omega_RR = pow(omega_R, self.rc, self.prime)
+ omega_RC = pow(omega_R, self.rr, self.prime)
+ omega_C = pow(self.omega, R, self.prime)
+
+ ntt_T1 = gen_twiddle_matrix(self.rr, self.rr, self.prime, omega_RR)
+ ntt_T2 = gen_twiddle_matrix(self.rr, self.rc, self.prime, omega_R)
+ ntt_T3 = gen_twiddle_matrix(self.rc, self.rc, self.prime, omega_RC)
+ ntt_T4 = gen_twiddle_matrix(R, self.c, self.prime, self.omega).reshape(
+ self.rr, self.rc, self.c
+ )
+ ntt_T5 = gen_twiddle_matrix(self.c, self.c, self.prime, omega_C)
+
+ inv_omega_RR = pow(omega_RR, -1, self.prime)
+ inv_omega_RC = pow(omega_RC, -1, self.prime)
+ inv_omega_C = pow(omega_C, -1, self.prime)
+ rr_inv = pow(self.rr, -1, self.prime)
+ rc_inv = pow(self.rc, -1, self.prime)
+ c_inv = pow(self.c, -1, self.prime)
+
+ intt_T5 = gen_twiddle_matrix(self.c, self.c, self.prime, inv_omega_C)
+ intt_T4 = gen_twiddle_matrix_inv(R, self.c, self.prime, self.omega).reshape(
+ self.rr, self.rc, self.c
+ )
+ intt_T4 = (intt_T4 * c_inv) % self.prime
+ intt_t3 = gen_twiddle_matrix(self.rc, self.rc, self.prime, inv_omega_RC)
+ intt_T2 = gen_twiddle_matrix_inv(self.rr, self.rc, self.prime, omega_R)
+ intt_T2 = (intt_T2 * rc_inv) % self.prime
+ intt_T1 = gen_twiddle_matrix(self.rr, self.rr, self.prime, inv_omega_RR)
+ intt_T1 = (intt_T1 * rr_inv) % self.prime
+
+ perm_rr = get_bit_reverse_perm(self.rr)
+ perm_rc = get_bit_reverse_perm(self.rc)
+ perm_c = get_bit_reverse_perm(self.c)
+
+ ntt_T1 = ntt_T1[perm_rr, :]
+ ntt_T2 = ntt_T2[perm_rr, :]
+ ntt_T3 = ntt_T3[:, perm_rc]
+ perm_R = get_bit_reverse_perm(R)
+ ntt_T4 = ntt_T4.reshape(R, self.c)[perm_R, :].reshape(
+ self.rr, self.rc, self.c
+ )
+ ntt_T5 = ntt_T5[:, perm_c]
+
+ intt_T5 = intt_T5[perm_c, :]
+ intt_T4 = intt_T4.reshape(R, self.c)[perm_R, :].reshape(
+ self.rr, self.rc, self.c
+ )
+ intt_t3 = intt_t3[perm_rc, :]
+ intt_T2 = intt_T2[perm_rr, :]
+ intt_T1 = intt_T1[:, perm_rr]
+
+ # Matmul twiddles: transpose when the original step was left-matmul.
+ self.ntt_T1 = ff_ctx.preprocess_matmul(
+ ntt_T1.T
+ ) # left-mat → right via transpose
+ self.ntt_T2 = ff_ctx.preprocess_elementwise(ntt_T2) # (RR, RC, trailing)
+ self.ntt_T3 = ff_ctx.preprocess_matmul(ntt_T3) # right-mat
+ self.ntt_T4 = ff_ctx.preprocess_elementwise(ntt_T4) # (RR, RC, C, trailing)
+ self.ntt_T5 = ff_ctx.preprocess_matmul(ntt_T5) # right-mat
+
+ self.intt_T5 = ff_ctx.preprocess_matmul(intt_T5)
+ self.intt_T4 = ff_ctx.preprocess_elementwise(intt_T4)
+ self.intt_T3 = ff_ctx.preprocess_matmul(intt_t3)
+ self.intt_T2 = ff_ctx.preprocess_elementwise(intt_T2)
+ self.intt_T1 = ff_ctx.preprocess_matmul(
+ intt_T1.T
+ ) # undo the step-1 left-mat
+
+ def ntt(self, v: jnp.ndarray):
+ v = self._ensure_ntt_shape(v) # (B, RR, RC, C, trailing)
+ v = self.ff_ctx.modular_matmul(v, self.ntt_T1, contract_axis=1)
+ v = self.ff_ctx.modular_multiply_broadcast(v, self.ntt_T2[:, :, None, :])
+ v = self.ff_ctx.modular_matmul(v, self.ntt_T3, contract_axis=2)
+ v = self.ff_ctx.modular_multiply_broadcast(v, self.ntt_T4[None])
+ v = self.ff_ctx.modular_matmul(v, self.ntt_T5, contract_axis=3)
+ return v
+
+ def intt(self, v: jnp.ndarray):
+ v = self._ensure_ntt_shape(v)
+ v = self.ff_ctx.modular_matmul(v, self.intt_T5, contract_axis=3)
+ v = self.ff_ctx.modular_multiply_broadcast(v, self.intt_T4[None])
+ v = self.ff_ctx.modular_matmul(v, self.intt_T3, contract_axis=2)
+ v = self.ff_ctx.modular_multiply_broadcast(v, self.intt_T2[:, :, None, :])
+ v = self.ff_ctx.modular_matmul(v, self.intt_T1, contract_axis=1)
+ return v
+
+
+class NTT7Step(NTTBase):
+ """Unified 7-step NTT. Shape: ``(B, RR, RC, CR, CC, trailing)``."""
+
+ def __init__(self, parameters: Dict[str, Any]):
+ self.prime = parameters["prime"]
+ self.rr = parameters["rr"]
+ self.rc = parameters["rc"]
+ self.cr = parameters["cr"]
+ self.cc = parameters["cc"]
+ ff_ctx = self._build_ff_ctx(parameters)
+ super().__init__(ff_ctx, spatial_shape=(self.rr, self.rc, self.cr, self.cc))
+
+ r_total = self.rr * self.rc
+ c_total = self.cr * self.cc
+ self.transform_length = r_total * c_total
+ psi = parameters.get("psi")
+ self.psi = (
+ int(psi)
+ if psi is not None
+ else utils.root_of_unity(2 * self.transform_length, self.prime)
+ )
+ self.omega = (self.psi**2) % self.prime
+
+ omega_r = pow(self.omega, c_total, self.prime)
+ omega_rr = pow(omega_r, self.rc, self.prime)
+ omega_rc = pow(omega_r, self.rr, self.prime)
+ omega_c = pow(self.omega, r_total, self.prime)
+ omega_cr = pow(omega_c, self.cc, self.prime)
+ omega_cc = pow(omega_c, self.cr, self.prime)
+
+ ntt_T1 = gen_twiddle_matrix(self.rr, self.rr, self.prime, omega_rr)
+ ntt_T2 = gen_twiddle_matrix(self.rr, self.rc, self.prime, omega_r)
+ ntt_T3 = gen_twiddle_matrix(self.rc, self.rc, self.prime, omega_rc)
+ ntt_T4 = gen_twiddle_matrix(
+ r_total, c_total, self.prime, self.omega
+ ).reshape(self.rr, self.rc, self.cr, self.cc)
+ ntt_T5 = gen_twiddle_matrix(self.cr, self.cr, self.prime, omega_cr)
+ ntt_T6 = gen_twiddle_matrix(self.cr, self.cc, self.prime, omega_c)
+ ntt_T7 = gen_twiddle_matrix(self.cc, self.cc, self.prime, omega_cc)
+
+ inv_omega_rr = pow(omega_rr, -1, self.prime)
+ inv_omega_rc = pow(omega_rc, -1, self.prime)
+ inv_omega_cr = pow(omega_cr, -1, self.prime)
+ inv_omega_cc = pow(omega_cc, -1, self.prime)
+ rr_inv = pow(self.rr, -1, self.prime)
+ rc_inv = pow(self.rc, -1, self.prime)
+ cr_inv = pow(self.cr, -1, self.prime)
+ cc_inv = pow(self.cc, -1, self.prime)
+
+ intt_T7 = gen_twiddle_matrix(self.cc, self.cc, self.prime, inv_omega_cc)
+ intt_T6 = gen_twiddle_matrix_inv(self.cr, self.cc, self.prime, omega_c)
+ intt_T6 = (intt_T6 * cc_inv) % self.prime
+ intt_T5 = gen_twiddle_matrix(self.cr, self.cr, self.prime, inv_omega_cr)
+ intt_T5 = (intt_T5 * cr_inv) % self.prime
+ intt_T4 = gen_twiddle_matrix_inv(
+ r_total, c_total, self.prime, self.omega
+ ).reshape(self.rr, self.rc, self.cr, self.cc)
+ intt_T3 = gen_twiddle_matrix(self.rc, self.rc, self.prime, inv_omega_rc)
+ intt_T2 = gen_twiddle_matrix_inv(self.rr, self.rc, self.prime, omega_r)
+ intt_T2 = (intt_T2 * rc_inv) % self.prime
+ intt_T1 = gen_twiddle_matrix(self.rr, self.rr, self.prime, inv_omega_rr)
+ intt_T1 = (intt_T1 * rr_inv) % self.prime
+
+ perm_rr = get_bit_reverse_perm(self.rr)
+ perm_rc = get_bit_reverse_perm(self.rc)
+ perm_cr = get_bit_reverse_perm(self.cr)
+ perm_cc = get_bit_reverse_perm(self.cc)
+
+ ntt_T1 = ntt_T1[perm_rr, :]
+ ntt_T2 = ntt_T2[perm_rr, :]
+ ntt_T3 = ntt_T3[:, perm_rc]
+ perm_r = get_bit_reverse_perm(r_total)
+ ntt_T4 = ntt_T4.reshape(r_total, c_total)[perm_r, :].reshape(
+ self.rr, self.rc, self.cr, self.cc
+ )
+ ntt_T5 = ntt_T5[perm_cr, :]
+ ntt_T6 = ntt_T6[perm_cr, :]
+ ntt_T7 = ntt_T7[:, perm_cc]
+
+ intt_T7 = intt_T7[perm_cc, :]
+ intt_T6 = intt_T6[perm_cr, :]
+ intt_T5 = intt_T5[:, perm_cr]
+ intt_T4 = intt_T4.reshape(r_total, c_total)[perm_r, :].reshape(
+ self.rr, self.rc, self.cr, self.cc
+ )
+ intt_T3 = intt_T3[perm_rc, :]
+ intt_T2 = intt_T2[perm_rr, :]
+ intt_T1 = intt_T1[:, perm_rr]
+
+ # Steps 1 and 5 are logically left-matmul; steps 3 and 7 are right.
+ self.ntt_T1 = ff_ctx.preprocess_matmul(ntt_T1.T)
+ self.ntt_T2 = ff_ctx.preprocess_elementwise(ntt_T2)
+ self.ntt_T3 = ff_ctx.preprocess_matmul(ntt_T3)
+ self.ntt_T4 = ff_ctx.preprocess_elementwise(ntt_T4)
+ self.ntt_T5 = ff_ctx.preprocess_matmul(ntt_T5.T)
+ self.ntt_T6 = ff_ctx.preprocess_elementwise(ntt_T6)
+ self.ntt_T7 = ff_ctx.preprocess_matmul(ntt_T7)
+
+ self.intt_T7 = ff_ctx.preprocess_matmul(intt_T7)
+ self.intt_T6 = ff_ctx.preprocess_elementwise(intt_T6)
+ self.intt_T5 = ff_ctx.preprocess_matmul(intt_T5.T)
+ self.intt_T4 = ff_ctx.preprocess_elementwise(intt_T4)
+ self.intt_T3 = ff_ctx.preprocess_matmul(intt_T3)
+ self.intt_T2 = ff_ctx.preprocess_elementwise(intt_T2)
+ self.intt_T1 = ff_ctx.preprocess_matmul(intt_T1.T)
+
+ def ntt(self, v: jnp.ndarray):
+ v = self._ensure_ntt_shape(v) # (B, RR, RC, CR, CC, trailing)
+ # Steps 1-3: inner NTT on R = RR * RC.
+ v = self.ff_ctx.modular_matmul(v, self.ntt_T1, contract_axis=1)
+ v = self.ff_ctx.modular_multiply_broadcast(
+ v, self.ntt_T2[:, :, None, None, :]
+ )
+ v = self.ff_ctx.modular_matmul(v, self.ntt_T3, contract_axis=2)
+ # Step 4: outer twiddle (broadcast over batch only).
+ v = self.ff_ctx.modular_multiply_broadcast(v, self.ntt_T4[None])
+ # Steps 5-7: inner NTT on C = CR * CC.
+ v = self.ff_ctx.modular_matmul(v, self.ntt_T5, contract_axis=3)
+ v = self.ff_ctx.modular_multiply_broadcast(
+ v, self.ntt_T6[None, None, None, :, :, :]
+ )
+ v = self.ff_ctx.modular_matmul(v, self.ntt_T7, contract_axis=4)
+ return v
+
+ def intt(self, v: jnp.ndarray):
+ v = self._ensure_ntt_shape(v)
+ v = self.ff_ctx.modular_matmul(v, self.intt_T7, contract_axis=4)
+ v = self.ff_ctx.modular_multiply_broadcast(
+ v, self.intt_T6[None, None, None, :, :, :]
+ )
+ v = self.ff_ctx.modular_matmul(v, self.intt_T5, contract_axis=3)
+ v = self.ff_ctx.modular_multiply_broadcast(v, self.intt_T4[None])
+ v = self.ff_ctx.modular_matmul(v, self.intt_T3, contract_axis=2)
+ v = self.ff_ctx.modular_multiply_broadcast(
+ v, self.intt_T2[:, :, None, None, :]
+ )
+ v = self.ff_ctx.modular_matmul(v, self.intt_T1, contract_axis=1)
+ return v
diff --git a/jaxite_ec/number_theory_transform_perf_test.py b/jaxite_ec/number_theory_transform_perf_test.py
new file mode 100644
index 0000000..b990f51
--- /dev/null
+++ b/jaxite_ec/number_theory_transform_perf_test.py
@@ -0,0 +1,447 @@
+"""Performance tests for batched NTT.
+
+Profiles the six NTT designs produced by crossing
+ * step count (:class:`NTT3Step`, :class:`NTT5Step`, :class:`NTT7Step`)
+ * backend (DRNS via :class:`DRNSLazyExtensionContext`,
+ CROSS via :class:`CROSSLazyExtensionContext`)
+
+The batch (leading) dimension is sharded across all available JAX
+devices. Performance is measured via ``jax.profiler`` through the
+``KernelWrapper`` / ``Profiler`` helpers in ``profiler.py``, NOT
+wall-clock time. Sharded-correctness checks live in
+``number_theory_transform_test.py``.
+"""
+
+import os
+
+import jax
+from jax.experimental.shard_map import shard_map
+import jax.numpy as jnp
+import jaxite.jaxite_ec.number_theory_transform_context as ntt_context
+from jaxite.jaxite_ec.profiler import KernelWrapper, Profiler, collect_logs
+import jaxite.jaxite_ec.utils as utils
+
+from absl.testing import absltest
+from absl.testing import parameterized
+
+jax.config.update("jax_enable_x64", True)
+
+# ---------------------------------------------------------------------------
+# Trailing-dimension sizing — change these to sweep field-element width.
+# NUM_MODULI drives the prime preset below (21 → 256-bit prime, 56 → 753-bit).
+# ---------------------------------------------------------------------------
+NUM_MODULI = 21 # DRNS: number of RNS moduli (trailing dim size)
+PRECISION_BITS = 28 # DRNS: bit-width per modulus
+RADIX_BITS = 32 # DRNS: Montgomery radix
+
+# ---------------------------------------------------------------------------
+# Prime & 2N-th-root presets, keyed by NUM_MODULI. Each preset supplies a
+# matching Q_PERF and PSI_PERF_BY_DEGREE (primitive 2*2^degree-th roots of
+# unity mod Q_PERF). The per-degree psi sidesteps utils.root_of_unity(),
+# which trial-divides q-1 and is infeasible for the 753-bit prime.
+# ---------------------------------------------------------------------------
+_PERF_PRIME_PRESETS: dict[int, dict] = {
+ # 256-bit prime — fits NUM_MODULI * PRECISION_BITS = 21 * 28 = 588 ≥ 256.
+ 21: {
+ "q": 0x8000000000000000000000000000000000000000000000000000000070000001,
+ "psi_by_degree": {
+ 14: (
+ 0x210D1D264152132AE3E5610B7E230BCD0058FE66FB35C5713527EA1FA40D1845
+ ),
+ 16: (
+ 0x40D7C3F33672325E7B65C4A20B0BE07DD32F3EBB05C33DD8675D68EB3A8BDB6B
+ ),
+ 18: (
+ 0x568FDDCD95737AC264EAADA546D74B051CA1B7FC5B8427DCE706674011E009E0
+ ),
+ 20: (
+ 0x3AAE36A59E8E4F95E3118AA64270D0E122E0FC9585380815F737A67D613B5516
+ ),
+ 22: (
+ 0x23E461BCC11091F4A355AD034B454991F9CFA113272B8DBDF38E895C68BE3702
+ ),
+ },
+ },
+ # 753-bit prime — fits NUM_MODULI * PRECISION_BITS = 56 * 28 = 1568 ≥ 753.
+ 56: {
+ "q": (
+ 0x100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000023C00001
+ ),
+ "psi_by_degree": {
+ 14: (
+ 0x987B6025AB8620C5CB8376257CBC1FED3F6A9CD8003B6CB78C442378C5CB76DF824DDFAD28F53E6EFED050F1193BD6B114DDBAF944860CE6CEEC6A4ECCC92D690996CD24ED61167E18B61F33C7F45CA1231F30751C16AA586DA157BC12DA
+ ),
+ 16: (
+ 0xC35116503813C2FEA4AA3458B0F4F8B28BBCC8BEE7004D34FB47B57FE855C6D764AF4F7CE54C9334C6C957CDE2D613405AD78F5BF210772600A9E8FBC797E35DDFCB55643E3F9ADE2C9B4AD8D08F00D7224EE0E4E176C98E61232E23114C
+ ),
+ 18: (
+ 0x80DE16A04909C865C8C3954F81D93780D98B4826B7C8AEC95A3149697831B5CE3BE84CEA28866B38A0458484ACDFBFC0FB26215FFF7083F52721E3BBE094FA42CCAF9D9DF1730211F6A211BED8C8128041609A585C1234113AC33FD11FD2
+ ),
+ 20: (
+ 0x4AE6B7005951673890AC5AEA5DFCA80F449C73138FBBA6EF4E26D43DDA812A3BE60C97E60450C3AA294B751F6FCAB9E3736C02191A19E87C08AB00F3A3D2B599CFC3D52A0886F3EDA8941813BCDBCD51930F838B04CFE32A1AEF52510324
+ ),
+ 22: (
+ 0x26DFA808FE842C5A3C50EED23D433805D7CA3BC5FE9EDFA38C0A325159D75B8DBF7ED80054D92678DCCC4CC6B9A1D47976DFDA07D39C463F312EC45147860C46F8E4ACE696AB8B7F789EBE170FE615C18902C15F97ABC7300DBD9F1870A
+ ),
+ },
+ },
+}
+
+if NUM_MODULI not in _PERF_PRIME_PRESETS:
+ raise ValueError(
+ f"No Q_PERF/PSI preset for NUM_MODULI={NUM_MODULI}; "
+ f"available: {sorted(_PERF_PRIME_PRESETS)}"
+ )
+Q_PERF: int = _PERF_PRIME_PRESETS[NUM_MODULI]["q"]
+PSI_PERF_BY_DEGREE: dict[int, int] = _PERF_PRIME_PRESETS[NUM_MODULI][
+ "psi_by_degree"
+]
+
+# DRNS configs (BAT einsum + Montgomery/CRNS — runs fast, can push to deg=22).
+PERF_CONFIGS_DRNS_3STEP = [
+ {"degree": 14, "r1": 7, "c": 7},
+ # {"degree": 16, "r1": 8, "c": 8},
+ # {"degree": 18, "r1": 9, "c": 9},
+ # {"degree": 20, "r1": 10, "c": 10},
+ # {"degree": 22, "r1": 11, "c": 11},
+]
+
+PERF_CONFIGS_DRNS_5STEP = [
+ {"degree": 14, "r1": 5, "r2": 5, "c": 4},
+ # {"degree": 16, "r1": 5, "r2": 5, "c": 6},
+ # {"degree": 18, "r1": 6, "r2": 6, "c": 6},
+ # {"degree": 20, "r1": 6, "r2": 6, "c": 8},
+ # {"degree": 22, "r1": 7, "r2": 7, "c": 8},
+]
+
+PERF_CONFIGS_DRNS_7STEP = [
+ {"degree": 16, "r1": 4, "r2": 4, "c1": 4, "c2": 4},
+ # {"degree": 20, "r1": 5, "r2": 5, "c1": 5, "c2": 5},
+]
+
+# CROSS configs — CROSS's fori_loop-based matmul is ~1000× slower than
+# DRNS BAT per NTT, so keep sizes modest to finish in reasonable time.
+PERF_CONFIGS_CROSS_3STEP = [
+ {"degree": 14, "r1": 7, "c": 7},
+ # {"degree": 16, "r1": 8, "c": 8},
+ # {"degree": 18, "r1": 9, "c": 9},
+ # {"degree": 20, "r1": 10, "c": 10},
+ # {"degree": 22, "r1": 11, "c": 11},
+]
+
+PERF_CONFIGS_CROSS_5STEP = [
+ {"degree": 14, "r1": 5, "r2": 5, "c": 4},
+ # {"degree": 16, "r1": 5, "r2": 5, "c": 6},
+ # {"degree": 18, "r1": 6, "r2": 6, "c": 6},
+ # {"degree": 20, "r1": 6, "r2": 6, "c": 8},
+ # {"degree": 22, "r1": 7, "r2": 7, "c": 8},
+]
+
+PERF_CONFIGS_CROSS_7STEP = [
+ {"degree": 16, "r1": 4, "r2": 4, "c1": 4, "c2": 4},
+ # {"degree": 20, "r1": 5, "r2": 5, "c1": 5, "c2": 5},
+]
+
+
+# ---------------------------------------------------------------------------
+# Extension-context builders
+# ---------------------------------------------------------------------------
+def _create_drns_ff_ctx(prime):
+ rns_moduli = utils.find_moduli_specified_number(NUM_MODULI, PRECISION_BITS)
+ return ntt_context.DRNSLazyExtensionContext({
+ "prime": prime,
+ "rns_moduli": rns_moduli,
+ "precision_bits": PRECISION_BITS,
+ "radix_bits": RADIX_BITS,
+ })
+
+
+def _create_cross_ff_ctx(prime):
+ params = {"prime": prime}
+ return ntt_context.CROSSLazyExtensionContext(params)
+
+
+# ---------------------------------------------------------------------------
+# Sharding helpers
+# ---------------------------------------------------------------------------
+def _create_sharding():
+ """Create default batch sharding for the current device mesh."""
+ available_devices = jax.devices()
+ if not available_devices:
+ raise RuntimeError("No devices available for sharding test.")
+ if len(available_devices) == 8:
+ mesh_shape = (2, 4)
+ elif len(available_devices) == 4:
+ mesh_shape = (2, 2)
+ elif len(available_devices) == 2:
+ mesh_shape = (2, 1)
+ else:
+ mesh_shape = (1, 1)
+
+ mesh = jax.make_mesh(mesh_shape, ("x", "y"))
+ return mesh, jax.sharding.PartitionSpec
+
+
+def _batch_sharding(mesh, partition_spec, ndim):
+ """NamedSharding that partitions only the leading (batch) axis."""
+ axis_names = mesh.axis_names
+ batch_partition = axis_names if len(axis_names) > 1 else axis_names[0]
+ spec = (batch_partition,) + (None,) * (ndim - 1)
+ return jax.sharding.NamedSharding(mesh, partition_spec(*spec))
+
+
+# ---------------------------------------------------------------------------
+# Kernel-wrapper helpers
+# ---------------------------------------------------------------------------
+def _ntt_kernel(input_array, parameters):
+ return parameters["ctx"].ntt(input_array)
+
+
+def _intt_kernel(input_array, parameters):
+ return parameters["ctx"].intt(input_array)
+
+
+def _shard_mapped_kernel(method_name):
+ """Build a kernel entry that shard-maps ``ctx.`` over batch.
+
+ CROSS's NTT path uses ``jax.vmap`` internally over every leading axis
+ of ``_modular_multiply``. Under a plain jit with a batch-sharded
+ input, the broadcast of replicated twiddles against the sharded input
+ creates vmap axis-spec mismatches. Running the kernel under
+ ``shard_map`` gives every device a shard-local (unsharded) view, so
+ all broadcast / reshape / vmap ops inside the kernel see regular-
+ strided arrays while the heavy lifting still parallelizes across all
+ devices at the outer (shard_map) level.
+ """
+
+ def kernel(input_array, parameters):
+ ctx = parameters["ctx"]
+ mesh = parameters["mesh"]
+ batch_spec = parameters["batch_spec"]
+ fn = getattr(ctx, method_name)
+ mapped = shard_map(
+ fn,
+ mesh=mesh,
+ in_specs=batch_spec,
+ out_specs=batch_spec,
+ check_rep=False,
+ )
+ return mapped(input_array)
+
+ return kernel
+
+
+# ---------------------------------------------------------------------------
+# Performance profiling (jax.profiler traces via KernelWrapper + Profiler).
+# One test method per design: 3 DRNS designs + 3 CROSS designs = 6 tests.
+# ---------------------------------------------------------------------------
+class NTTShardedPerformanceTest(parameterized.TestCase):
+ """Profile all six NTT designs (step-count × backend) at sharded batch."""
+
+ def setUp(self):
+ super().setUp()
+ outputs_dir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR")
+ self.output_trace_root = (
+ os.path.join(outputs_dir, "log")
+ if outputs_dir
+ else os.path.join(os.path.dirname(os.path.abspath(__file__)), "log")
+ )
+ self.profiler_config = {
+ "iterations": 1,
+ "save_to_file": True,
+ "enable_sharding": True,
+ }
+
+ @classmethod
+ def tearDownClass(cls):
+ super().tearDownClass()
+ collect_logs(os.path.dirname(os.path.abspath(__file__)))
+
+ # -------------------------------------------------------------------------
+ # Profile driver: uniform entry point for DRNS and CROSS. Selects the
+ # right extension context, trailing-dim size, and per-device kernel
+ # wrapping (plain jit for DRNS, shard_map-wrapped jit for CROSS).
+ # -------------------------------------------------------------------------
+ def _profile_design(
+ self, variant_name, configs, ntt_cls, backend, make_params
+ ):
+ assert backend in ("drns", "cross")
+ if backend == "drns":
+ ff_ctx = _create_drns_ff_ctx(Q_PERF)
+ trailing = len(ff_ctx.rns_moduli)
+ trailing_key = "num_moduli"
+ kernel_factory = lambda direction: (
+ _ntt_kernel if direction == "ntt" else _intt_kernel
+ )
+ else:
+ ff_ctx = _create_cross_ff_ctx(Q_PERF)
+ trailing = ff_ctx.chunk_num_u32
+ trailing_key = "chunk_num_u32"
+ kernel_factory = lambda direction: _shard_mapped_kernel(direction)
+
+ mesh, partition_spec = _create_sharding()
+ num_devices = len(jax.devices())
+
+ profiler = Profiler(
+ output_trace_path=self.output_trace_root,
+ profile_naming=f"sharded_{variant_name}_{ntt_cls.__name__}",
+ configuration=self.profiler_config,
+ )
+
+ for cfg in configs:
+ degree = cfg["degree"]
+ spatial_params, spatial_shape = make_params(cfg)
+ params = {
+ "prime": Q_PERF,
+ "finite_field_context": ff_ctx,
+ **spatial_params,
+ }
+ psi = PSI_PERF_BY_DEGREE.get(degree)
+ if psi is not None:
+ params["psi"] = psi
+ ntt_ctx = ntt_cls(params)
+
+ ndim = 1 + len(spatial_shape) + 1
+ batch_sharding = _batch_sharding(mesh, partition_spec, ndim)
+
+ axis_names = mesh.axis_names
+ batch_partition = axis_names if len(axis_names) > 1 else axis_names[0]
+ batch_spec = partition_spec(batch_partition, *([None] * (ndim - 1)))
+
+ setting_base = {
+ "context": ntt_cls.__name__,
+ "backend": backend,
+ "degree": degree,
+ "spatial_shape": str(spatial_shape),
+ trailing_key: trailing,
+ "num_devices": num_devices,
+ }
+
+ batch = num_devices
+ input_shape = (batch,) + spatial_shape + (trailing,)
+ for direction in ("ntt", "intt"):
+ name = f"{variant_name}_{direction}_deg{degree}_batch{batch}"
+ wrapper = KernelWrapper(
+ kernel_name=name,
+ function_to_wrap=kernel_factory(direction),
+ input_structs=[(input_shape, jnp.uint32)],
+ parameters={"ctx": ntt_ctx, "mesh": mesh, "batch_spec": batch_spec},
+ mesh=mesh,
+ input_shardings=(batch_sharding,),
+ output_sharding=batch_sharding,
+ enable_sharding=True,
+ )
+ profiler.add_profile(
+ name=name,
+ kernel_wrapper=wrapper,
+ kernel_setting_cols={
+ **setting_base,
+ "direction": direction,
+ "batch": batch,
+ },
+ )
+
+ profiler.profile_all_profilers()
+ profiler.post_process_all_profilers()
+
+ # -------------------------------------------------------------------------
+ # 3 DRNS-based designs
+ # -------------------------------------------------------------------------
+ def test_sharded_drns_3step(self):
+ """DRNS 3-step NTT (``NTT3Step`` + ``DRNSLazyExtensionContext``)."""
+
+ def make_params(cfg):
+ r, c = 2 ** cfg["r1"], 2 ** cfg["c"]
+ return {"r": r, "c": c}, (r, c)
+
+ self._profile_design(
+ "drns_3step",
+ PERF_CONFIGS_DRNS_3STEP,
+ ntt_context.NTT3Step,
+ "drns",
+ make_params,
+ )
+
+ def test_sharded_drns_5step(self):
+ """DRNS 5-step NTT (``NTT5Step`` + ``DRNSLazyExtensionContext``)."""
+
+ def make_params(cfg):
+ rr, rc, c = 2 ** cfg["r1"], 2 ** cfg["r2"], 2 ** cfg["c"]
+ return {"rr": rr, "rc": rc, "c": c}, (rr, rc, c)
+
+ self._profile_design(
+ "drns_5step",
+ PERF_CONFIGS_DRNS_5STEP,
+ ntt_context.NTT5Step,
+ "drns",
+ make_params,
+ )
+
+ def test_sharded_drns_7step(self):
+ """DRNS 7-step NTT (``NTT7Step`` + ``DRNSLazyExtensionContext``)."""
+
+ def make_params(cfg):
+ rr, rc = 2 ** cfg["r1"], 2 ** cfg["r2"]
+ cr, cc = 2 ** cfg["c1"], 2 ** cfg["c2"]
+ return {"rr": rr, "rc": rc, "cr": cr, "cc": cc}, (rr, rc, cr, cc)
+
+ self._profile_design(
+ "drns_7step",
+ PERF_CONFIGS_DRNS_7STEP,
+ ntt_context.NTT7Step,
+ "drns",
+ make_params,
+ )
+
+ # -------------------------------------------------------------------------
+ # 3 CROSS-backed designs
+ # -------------------------------------------------------------------------
+ def test_sharded_cross_3step(self):
+ """CROSS 3-step NTT (``NTT3Step`` + ``CROSSLazyExtensionContext``)."""
+
+ def make_params(cfg):
+ r, c = 2 ** cfg["r1"], 2 ** cfg["c"]
+ return {"r": r, "c": c}, (r, c)
+
+ self._profile_design(
+ "cross_3step",
+ PERF_CONFIGS_CROSS_3STEP,
+ ntt_context.NTT3Step,
+ "cross",
+ make_params,
+ )
+
+ def test_sharded_cross_5step(self):
+ """CROSS 5-step NTT (``NTT5Step`` + ``CROSSLazyExtensionContext``)."""
+
+ def make_params(cfg):
+ rr, rc, c = 2 ** cfg["r1"], 2 ** cfg["r2"], 2 ** cfg["c"]
+ return {"rr": rr, "rc": rc, "c": c}, (rr, rc, c)
+
+ self._profile_design(
+ "cross_5step",
+ PERF_CONFIGS_CROSS_5STEP,
+ ntt_context.NTT5Step,
+ "cross",
+ make_params,
+ )
+
+ def test_sharded_cross_7step(self):
+ """CROSS 7-step NTT (``NTT7Step`` + ``CROSSLazyExtensionContext``)."""
+
+ def make_params(cfg):
+ rr, rc = 2 ** cfg["r1"], 2 ** cfg["r2"]
+ cr, cc = 2 ** cfg["c1"], 2 ** cfg["c2"]
+ return {"rr": rr, "rc": rc, "cr": cr, "cc": cc}, (rr, rc, cr, cc)
+
+ self._profile_design(
+ "cross_7step",
+ PERF_CONFIGS_CROSS_7STEP,
+ ntt_context.NTT7Step,
+ "cross",
+ make_params,
+ )
+
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/jaxite_ec/number_theory_transform_test.py b/jaxite_ec/number_theory_transform_test.py
new file mode 100644
index 0000000..c972a48
--- /dev/null
+++ b/jaxite_ec/number_theory_transform_test.py
@@ -0,0 +1,428 @@
+import jax
+import jax.numpy as jnp
+import jaxite.jaxite_ec.number_theory_transform_context as ntt_context
+import jaxite.jaxite_ec.utils as utils
+import numpy as np
+
+from absl.testing import absltest
+from absl.testing import parameterized
+
+MODULI = 134219681
+
+# ---------------------------------------------------------------------------
+# Trailing-dimension sizing — change these to sweep field-element width.
+# ---------------------------------------------------------------------------
+NUM_MODULI = 21 # DRNS: number of RNS moduli (trailing dim size), 21 for 256-bit, 56 for 753-bit
+PRECISION_BITS = 28 # DRNS: bit-width per modulus
+RADIX_BITS = 32 # DRNS: Montgomery radix
+CHUNK_NUM_U8 = 32 # CROSS: override chunk_num_u8 (None = auto from prime), 32 for 256-bit, 95 for 753-bit
+
+TEST_VECTOR = {
+ "coef_in": [
+ 105825732,
+ 68433452,
+ 36629220,
+ 126901109,
+ 89469849,
+ 106633716,
+ 15102657,
+ 108374459,
+ 68789927,
+ 23451922,
+ 93538050,
+ 20585372,
+ 30604976,
+ 37517995,
+ 65644325,
+ 102451383,
+ ],
+ "eval_in": [
+ 26196696,
+ 45475009,
+ 10055359,
+ 23277424,
+ 69041040,
+ 71916973,
+ 73894069,
+ 3311254,
+ 44646798,
+ 49882443,
+ 28097016,
+ 70484730,
+ 10811958,
+ 11946041,
+ 61318182,
+ 19099272,
+ ],
+}
+
+
+def make_ntt_case(name, *layout_dims):
+ return (
+ name,
+ MODULI,
+ None,
+ 1,
+ *layout_dims,
+ TEST_VECTOR["coef_in"],
+ TEST_VECTOR["eval_in"],
+ )
+
+
+NTT_3STEP = [make_ntt_case("0", 4, 4)] # Layout: (r, c)
+NTT_5STEP = [make_ntt_case("0", 2, 2, 4)] # Layout: (rr, rc, c)
+NTT_7STEP = [make_ntt_case("0", 2, 2, 2, 2)] # Layout: (rr, rc, cr, cc)
+
+
+# ---------------------------------------------------------------------------
+# Sharded correctness configs: (case_name, spatial_params, ntt_cls, spatial_shape)
+# ---------------------------------------------------------------------------
+SHARDED_CORRECTNESS_CONFIGS = [
+ ("3step", {"r": 4, "c": 4}, ntt_context.NTT3Step, (4, 4)),
+ ("5step", {"rr": 2, "rc": 2, "c": 4}, ntt_context.NTT5Step, (2, 2, 4)),
+ (
+ "7step",
+ {"rr": 2, "rc": 2, "cr": 2, "cc": 2},
+ ntt_context.NTT7Step,
+ (2, 2, 2, 2),
+ ),
+]
+
+
+# ---------------------------------------------------------------------------
+# Sharding helpers
+# ---------------------------------------------------------------------------
+def _create_sharding():
+ """Create default batch sharding for the current device mesh."""
+ available_devices = jax.devices()
+ if not available_devices:
+ raise RuntimeError("No devices available for sharding test.")
+ if len(available_devices) == 8:
+ mesh_shape = (2, 4)
+ elif len(available_devices) == 4:
+ mesh_shape = (2, 2)
+ elif len(available_devices) == 2:
+ mesh_shape = (2, 1)
+ else:
+ mesh_shape = (1, 1)
+
+ mesh = jax.make_mesh(mesh_shape, ("x", "y"))
+ return mesh, jax.sharding.PartitionSpec
+
+
+def _batch_sharding(mesh, partition_spec, ndim):
+ """NamedSharding that partitions only the leading (batch) axis."""
+ axis_names = mesh.axis_names
+ batch_partition = axis_names if len(axis_names) > 1 else axis_names[0]
+ spec = (batch_partition,) + (None,) * (ndim - 1)
+ return jax.sharding.NamedSharding(mesh, partition_spec(*spec))
+
+
+def _tile_to_batch(single, shard_batch):
+ """Replicate a single ``(1, *spatial, trailing)`` array across batch dim."""
+ return jnp.tile(single, (shard_batch,) + (1,) * (single.ndim - 1))
+
+
+def _create_drns_ff_ctx(prime):
+ rns_moduli = utils.find_moduli_specified_number(NUM_MODULI, PRECISION_BITS)
+ return ntt_context.DRNSLazyExtensionContext({
+ "prime": prime,
+ "rns_moduli": rns_moduli,
+ "precision_bits": PRECISION_BITS,
+ "radix_bits": RADIX_BITS,
+ })
+
+
+class NTTTest(parameterized.TestCase):
+
+ # @absltest.skip("Skip DRNS NTT tests")
+ @parameterized.named_parameters(*NTT_3STEP)
+ def test_DRNS_NTT_3step(self, q, psi, batch, r, c, coef_in, eval_in):
+ """Validate the 3-step NTT with DRNS lazy reduction."""
+ rns_moduli = utils.find_moduli_specified_number(NUM_MODULI, PRECISION_BITS)
+ ff_ctx = ntt_context.DRNSLazyExtensionContext({
+ "prime": q,
+ "rns_moduli": rns_moduli,
+ "precision_bits": PRECISION_BITS,
+ "radix_bits": RADIX_BITS,
+ })
+ ntt_ctx = ntt_context.NTT3Step(
+ {"prime": q, "r": r, "c": c, "finite_field_context": ff_ctx}
+ )
+
+ # to_computational_format gives (N, num_moduli); ntt auto-reshapes
+ coef_drns = ntt_ctx.to_computational_format(coef_in)
+
+ # Forward NTT
+ ntt_result = ntt_ctx.ntt(coef_drns)
+ np.testing.assert_array_equal(
+ eval_in, ntt_ctx.to_original_format(ntt_result)
+ )
+
+ # Inverse NTT
+ intt_result = ntt_ctx.intt(ntt_result)
+ np.testing.assert_array_equal(
+ coef_in, ntt_ctx.to_original_format(intt_result)
+ )
+
+ # @absltest.skip("Skip DRNS NTT tests")
+ @parameterized.named_parameters(*NTT_5STEP)
+ def test_DRNS_NTT_5step(self, q, psi, batch, rr, rc, c, coef_in, eval_in):
+ """Validate the 5-step NTT with DRNS lazy reduction."""
+ rns_moduli = utils.find_moduli_specified_number(NUM_MODULI, PRECISION_BITS)
+ ff_ctx = ntt_context.DRNSLazyExtensionContext({
+ "prime": q,
+ "rns_moduli": rns_moduli,
+ "precision_bits": PRECISION_BITS,
+ "radix_bits": RADIX_BITS,
+ })
+ ntt_ctx = ntt_context.NTT5Step(
+ {"prime": q, "rr": rr, "rc": rc, "c": c, "finite_field_context": ff_ctx}
+ )
+
+ # to_computational_format gives (N, num_moduli); ntt auto-reshapes
+ coef_drns = ntt_ctx.to_computational_format(coef_in)
+
+ # Forward NTT
+ ntt_result = ntt_ctx.ntt(coef_drns)
+ np.testing.assert_array_equal(
+ eval_in, ntt_ctx.to_original_format(ntt_result)
+ )
+
+ # Inverse NTT
+ intt_result = ntt_ctx.intt(ntt_result)
+ np.testing.assert_array_equal(
+ coef_in, ntt_ctx.to_original_format(intt_result)
+ )
+
+ # @absltest.skip("Skip DRNS NTT tests")
+ @parameterized.named_parameters(*NTT_7STEP)
+ def test_DRNS_NTT_7step(
+ self, q, psi, batch, rr, rc, cr, cc, coef_in, eval_in
+ ):
+ """Validate the 7-step NTT with DRNS lazy reduction."""
+ rns_moduli = utils.find_moduli_specified_number(NUM_MODULI, PRECISION_BITS)
+ ff_ctx = ntt_context.DRNSLazyExtensionContext({
+ "prime": q,
+ "rns_moduli": rns_moduli,
+ "precision_bits": PRECISION_BITS,
+ "radix_bits": RADIX_BITS,
+ })
+ ntt_ctx = ntt_context.NTT7Step({
+ "prime": q,
+ "rr": rr,
+ "rc": rc,
+ "cr": cr,
+ "cc": cc,
+ "finite_field_context": ff_ctx,
+ })
+
+ coef_drns = ntt_ctx.to_computational_format(coef_in)
+
+ # Forward NTT
+ ntt_result = ntt_ctx.ntt(coef_drns)
+ np.testing.assert_array_equal(
+ eval_in, ntt_ctx.to_original_format(ntt_result)
+ )
+
+ # Inverse NTT
+ intt_result = ntt_ctx.intt(ntt_result)
+ np.testing.assert_array_equal(
+ coef_in, ntt_ctx.to_original_format(intt_result)
+ )
+
+ # @absltest.skip("Skip CROSS NTT tests")
+ @parameterized.named_parameters(*NTT_3STEP)
+ def test_CROSS_NTT_3step(self, q, psi, batch, r, c, coef_in, eval_in):
+ """Validate the 3-step NTT with CROSS lazy matrix reduction."""
+ cross_params = {"prime": q}
+ if CHUNK_NUM_U8 is not None:
+ cross_params["chunk_num_u8"] = CHUNK_NUM_U8
+ ff_ctx = ntt_context.CROSSLazyExtensionContext(cross_params)
+ ntt_ctx = ntt_context.NTT3Step(
+ {"prime": q, "r": r, "c": c, "finite_field_context": ff_ctx}
+ )
+
+ coef_cross = ntt_ctx.to_computational_format(coef_in)
+
+ ntt_result = ntt_ctx.ntt(coef_cross)
+ np.testing.assert_array_equal(
+ eval_in, ntt_ctx.to_original_format(ntt_result)
+ )
+
+ intt_result = ntt_ctx.intt(ntt_result)
+ np.testing.assert_array_equal(
+ coef_in, ntt_ctx.to_original_format(intt_result)
+ )
+
+ # @absltest.skip("Skip CROSS NTT tests")
+ @parameterized.named_parameters(*NTT_5STEP)
+ def test_CROSS_NTT_5step(self, q, psi, batch, rr, rc, c, coef_in, eval_in):
+ """Validate the 5-step NTT with CROSS lazy matrix reduction."""
+ cross_params = {"prime": q}
+ if CHUNK_NUM_U8 is not None:
+ cross_params["chunk_num_u8"] = CHUNK_NUM_U8
+ ff_ctx = ntt_context.CROSSLazyExtensionContext(cross_params)
+ ntt_ctx = ntt_context.NTT5Step(
+ {"prime": q, "rr": rr, "rc": rc, "c": c, "finite_field_context": ff_ctx}
+ )
+
+ coef_cross = ntt_ctx.to_computational_format(coef_in)
+
+ ntt_result = ntt_ctx.ntt(coef_cross)
+ np.testing.assert_array_equal(
+ eval_in, ntt_ctx.to_original_format(ntt_result)
+ )
+
+ intt_result = ntt_ctx.intt(ntt_result)
+ np.testing.assert_array_equal(
+ coef_in, ntt_ctx.to_original_format(intt_result)
+ )
+
+ # @absltest.skip("Skip CROSS NTT tests")
+ @parameterized.named_parameters(*NTT_7STEP)
+ def test_CROSS_NTT_7step(
+ self, q, psi, batch, rr, rc, cr, cc, coef_in, eval_in
+ ):
+ """Validate the 7-step NTT with CROSS lazy matrix reduction."""
+ cross_params = {"prime": q}
+ if CHUNK_NUM_U8 is not None:
+ cross_params["chunk_num_u8"] = CHUNK_NUM_U8
+ ff_ctx = ntt_context.CROSSLazyExtensionContext(cross_params)
+ ntt_ctx = ntt_context.NTT7Step({
+ "prime": q,
+ "rr": rr,
+ "rc": rc,
+ "cr": cr,
+ "cc": cc,
+ "finite_field_context": ff_ctx,
+ })
+
+ coef_cross = ntt_ctx.to_computational_format(coef_in)
+
+ ntt_result = ntt_ctx.ntt(coef_cross)
+ np.testing.assert_array_equal(
+ eval_in, ntt_ctx.to_original_format(ntt_result)
+ )
+
+ intt_result = ntt_ctx.intt(ntt_result)
+ np.testing.assert_array_equal(
+ coef_in, ntt_ctx.to_original_format(intt_result)
+ )
+
+ # ---------------------------------------------------------------------
+ # Unified NTT + NumpyCPUContext parity tests: NTT{3,5,7}Step with a
+ # NumpyCPUContext backend must produce the same output as the
+ # CPUCROSS{,5Step,7Step}Context legacy reference classes.
+ # ---------------------------------------------------------------------
+ # @absltest.skip("Skip NumpyCPU NTT tests")
+ @parameterized.named_parameters(*NTT_3STEP)
+ def test_NumpyCPU_NTT_3step(self, q, psi, batch, r, c, coef_in, eval_in):
+ ff_ctx = ntt_context.NumpyCPUContext({"prime": q})
+ ntt_ctx = ntt_context.NTT3Step(
+ {"prime": q, "r": r, "c": c, "finite_field_context": ff_ctx}
+ )
+ coef_cpu = ntt_ctx.to_computational_format(coef_in)
+ ntt_result = ntt_ctx.ntt(coef_cpu)
+ np.testing.assert_array_equal(
+ eval_in, ntt_ctx.to_original_format(ntt_result)
+ )
+ intt_result = ntt_ctx.intt(ntt_result)
+ np.testing.assert_array_equal(
+ coef_in, ntt_ctx.to_original_format(intt_result)
+ )
+
+ # @absltest.skip("Skip NumpyCPU NTT tests")
+ @parameterized.named_parameters(*NTT_5STEP)
+ def test_NumpyCPU_NTT_5step(self, q, psi, batch, rr, rc, c, coef_in, eval_in):
+ ff_ctx = ntt_context.NumpyCPUContext({"prime": q})
+ ntt_ctx = ntt_context.NTT5Step(
+ {"prime": q, "rr": rr, "rc": rc, "c": c, "finite_field_context": ff_ctx}
+ )
+ coef_cpu = ntt_ctx.to_computational_format(coef_in)
+ ntt_result = ntt_ctx.ntt(coef_cpu)
+ np.testing.assert_array_equal(
+ eval_in, ntt_ctx.to_original_format(ntt_result)
+ )
+ intt_result = ntt_ctx.intt(ntt_result)
+ np.testing.assert_array_equal(
+ coef_in, ntt_ctx.to_original_format(intt_result)
+ )
+
+ # @absltest.skip("Skip NumpyCPU NTT tests")
+ @parameterized.named_parameters(*NTT_7STEP)
+ def test_NumpyCPU_NTT_7step(
+ self, q, psi, batch, rr, rc, cr, cc, coef_in, eval_in
+ ):
+ ff_ctx = ntt_context.NumpyCPUContext({"prime": q})
+ ntt_ctx = ntt_context.NTT7Step({
+ "prime": q,
+ "rr": rr,
+ "rc": rc,
+ "cr": cr,
+ "cc": cc,
+ "finite_field_context": ff_ctx,
+ })
+ coef_cpu = ntt_ctx.to_computational_format(coef_in)
+ ntt_result = ntt_ctx.ntt(coef_cpu)
+ np.testing.assert_array_equal(
+ eval_in, ntt_ctx.to_original_format(ntt_result)
+ )
+ intt_result = ntt_ctx.intt(ntt_result)
+ np.testing.assert_array_equal(
+ coef_in, ntt_ctx.to_original_format(intt_result)
+ )
+
+
+# ---------------------------------------------------------------------------
+# Sharded correctness (DRNS only — CROSS correctness is covered by the
+# per-backend tests above since CROSS's device-0 closure constants conflict
+# with shard_map under the small-mesh jit harness).
+# ---------------------------------------------------------------------------
+class NTTShardedCorrectnessTest(parameterized.TestCase):
+ """Verifies that batched, device-sharded NTT is element-wise correct."""
+
+ # @absltest.skip("Skip DRNS NTT correctness tests")
+ @parameterized.named_parameters(*SHARDED_CORRECTNESS_CONFIGS)
+ def test_sharded_ntt_correctness(
+ self, spatial_params, ctx_cls, spatial_shape
+ ):
+ coef_in = TEST_VECTOR["coef_in"]
+ eval_in = TEST_VECTOR["eval_in"]
+ ff_ctx = _create_drns_ff_ctx(MODULI)
+ params = {"prime": MODULI, "finite_field_context": ff_ctx, **spatial_params}
+ ntt_ctx = ctx_cls(params)
+
+ coef_drns = ntt_ctx.to_computational_format(coef_in)
+ coef_drns = ntt_ctx._ensure_ntt_shape(coef_drns)
+
+ mesh, partition_spec = _create_sharding()
+ shard_batch = len(jax.devices())
+ batched = _tile_to_batch(coef_drns, shard_batch)
+ sharding = _batch_sharding(mesh, partition_spec, batched.ndim)
+ batched_sharded = jax.device_put(batched, sharding)
+
+ jit_ntt = jax.jit(ntt_ctx.ntt, out_shardings=sharding)
+ jit_intt = jax.jit(ntt_ctx.intt, out_shardings=sharding)
+
+ ntt_result = jit_ntt(batched_sharded)
+ ntt_host = np.asarray(ntt_result)
+ for i in range(shard_batch):
+ np.testing.assert_array_equal(
+ eval_in,
+ ntt_ctx.to_original_format(ntt_host[i]),
+ err_msg=f"NTT mismatch at batch index {i}",
+ )
+
+ intt_result = jit_intt(ntt_result)
+ intt_host = np.asarray(intt_result)
+ for i in range(shard_batch):
+ np.testing.assert_array_equal(
+ coef_in,
+ ntt_ctx.to_original_format(intt_host[i]),
+ err_msg=f"INTT mismatch at batch index {i}",
+ )
+
+
+if __name__ == "__main__":
+ absltest.main()
diff --git a/jaxite_ec/pippenger.py b/jaxite_ec/pippenger.py
deleted file mode 100644
index 63d5372..0000000
--- a/jaxite_ec/pippenger.py
+++ /dev/null
@@ -1,1319 +0,0 @@
-"""Pippenger algorithm for elliptic curves.
-
-This module implements the Pippenger algorithm for elliptic curves. The
-algorithm
-is a generalization of the elliptic curve algorithm that can be used to find
-elliptic curves of arbitrary order.
-
-The Pippenger algorithm works by first finding a set of "scalars" and "points"
-that lie on the elliptic curve. These scalars and points are then used to
-construct
-a "bucket" for each window of the elliptic curve. The buckets are then reduced
-to a single point for each window. Finally, the points from all of the windows
-are merged together to form the final elliptic curve.
-
-The Pippenger algorithm is a powerful tool for finding elliptic curves, and it
-has been used to find elliptic curves of arbitrary order and dimension.
-
-Example Compiled Function:
-
-bucket_accumulation_scan_jax_jit = jax.jit(
- bucket_accumulation_scan_jax,
- static_argnames='msm_length').lower(
- jax.ShapeDtypeStruct(
- (COORDINATE_NUM, WINDOW_NUM*BUCKET_NUM_PER_WINDOW, CHUNK_NUM),
- dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (MSM_LENGTH, COORDINATE_NUM, CHUNK_NUM),
- dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (MSM_LENGTH,WINDOW_NUM*BUCKET_NUM_PER_WINDOW),
- dtype=jnp.uint8
- ),
- jax.ShapeDtypeStruct(
- (MSM_LENGTH,WINDOW_NUM*BUCKET_NUM_PER_WINDOW),
- dtype=jnp.uint8
- ),
- MSM_LENGTH
- ).compile()
-
-bucket_accumulation_index_scan_jax_jit = jax.jit(
- bucket_accumulation_index_scan_jax,
- static_argnames='msm_length').lower(
- jax.ShapeDtypeStruct(
- (COORDINATE_NUM, WINDOW_NUM,BUCKET_NUM_PER_WINDOW, CHUNK_NUM),
- dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (MSM_LENGTH, COORDINATE_NUM, CHUNK_NUM),
- dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (MSM_LENGTH,WINDOW_NUM),
- dtype=jnp.uint32
- ),
- jax.ShapeDtypeStruct(
- (MSM_LENGTH,WINDOW_NUM,BUCKET_NUM_PER_WINDOW),
- dtype=jnp.uint8
- ),
- MSM_LENGTH
- ).compile()
-
-bucket_reduction_scan_jax_jit = jax.jit(
- bucket_reduction_scan_jax,
- static_argnames='bucket_num_in_window').lower(
- jax.ShapeDtypeStruct(
- (COORDINATE_NUM, WINDOW_NUM, BUCKET_NUM_PER_WINDOW, CHUNK_NUM),
- dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (COORDINATE_NUM, WINDOW_NUM, CHUNK_NUM),
- dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (COORDINATE_NUM, WINDOW_NUM, CHUNK_NUM),
- dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (BUCKET_NUM_PER_WINDOW, WINDOW_NUM),
- dtype=jnp.uint8
- ),
- jax.ShapeDtypeStruct(
- (BUCKET_NUM_PER_WINDOW, WINDOW_NUM),
- dtype=jnp.uint8
- ),
- jax.ShapeDtypeStruct(
- (BUCKET_NUM_PER_WINDOW, WINDOW_NUM,),
- dtype=jnp.uint8
- ),
- BUCKET_NUM_PER_WINDOW
- ).compile()
-
-window_merge_scan_jax_jit = jax.jit(
- window_merge_scan_jax,
- static_argnames='slice_length').lower(
- jax.ShapeDtypeStruct((COORDINATE_NUM, WINDOW_NUM, CHUNK_NUM),
- dtype=jnp.uint16),
- SLICE_LENGTH
- ).compile()
-
-selective_padd_with_zero_jit = jax.jit(
- selective_padd_with_zero).lower(
- jax.ShapeDtypeStruct(
- (COORDINATE_NUM, WINDOW_NUM*BUCKET_NUM_PER_WINDOW, COORDINATE_NUM),
- dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (COORDINATE_NUM, MSM_LENGTH, COORDINATE_NUM),
- dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (WINDOW_NUM*BUCKET_NUM_PER_WINDOW,),
- dtype=jnp.uint8
- ),
- jax.ShapeDtypeStruct(
- (WINDOW_NUM*BUCKET_NUM_PER_WINDOW,),
- dtype=jnp.uint8
- ),
- ).compile()
-"""
-
-import copy
-import math
-from typing import List
-
-import jax
-import jax.numpy as jnp
-from jaxite.jaxite_ec import util
-import jaxite.jaxite_ec.elliptic_curve as jec
-import jaxite.jaxite_ec.finite_field as ff
-import numpy as np
-
-
-deepcopy = copy.deepcopy
-
-"""
-Example Parameters:
-SLICE_LENGTH = 4
-WINDOW_NUM = int(math.ceil(255 / SLICE_LENGTH))
-BUCKET_NUM_PER_WINDOW = 2**SLICE_LENGTH
-MSM_LENGTH = 1024
-COORDINATE_NUM = 4
-"""
-
-
-def selective_padd_with_zero(partial_sum, single_point, select, is_zero):
- """Padd the partial sum with the single point, but only if the selection state is 1.
-
- Args:
- partial_sum: The partial sum.
- single_point: The single point.
- select: The selection state.
- is_zero: The zero states.
-
- Returns:
- The new partial sum.
- """
- _, batch_dim, _ = partial_sum.shape
- new_partial_sum = jec.padd_lazy_xyzz_pack(partial_sum, single_point)
-
- cond_select = jnp.equal(select, 1).reshape(1, batch_dim, 1)
- sum_result = jnp.where(cond_select, new_partial_sum, partial_sum)
-
- cond_zero = jnp.equal(is_zero, 1).reshape(1, batch_dim, 1)
- cond_select_and_zero = jnp.logical_and(cond_select, cond_zero)
- result = jnp.where(cond_select_and_zero, single_point, sum_result)
- return result
-
-
-def padd_with_zero(partial_sum, single_point, ps_is_zero, sp_is_zero):
- """Padd the partial sum with the single point.
-
- Check if the partial sum is equal to the single point first.
-
- Args:
- partial_sum: The partial sum.
- single_point: The single point.
- ps_is_zero: The zero states of the partial sum.
- sp_is_zero: The zero states of the single point.
-
- Returns:
- The new partial sum.
- """
- _, batch_dim, _ = partial_sum.shape
- new_partial_sum = jec.padd_lazy_xyzz_pack(partial_sum, single_point)
- cond_sp_zero = jnp.equal(sp_is_zero, 1).reshape(1, batch_dim, 1)
- cond_ps_zero = jnp.equal(ps_is_zero, 1).reshape(1, batch_dim, 1)
- result_1 = jnp.where(cond_sp_zero, partial_sum, single_point)
- result_2 = jnp.where(
- jnp.logical_or(cond_sp_zero, cond_ps_zero), result_1, new_partial_sum
- )
- return result_2
-
-
-def padd_with_zero_alter(partial_sum, single_point, ps_is_zero):
-
- _, batch_dim, _ = partial_sum.shape
- new_partial_sum = jec.padd_lazy_xyzz_pack(partial_sum, single_point)
- cond_ps_zero = jnp.equal(ps_is_zero, 1).reshape(1, batch_dim, 1)
- result_2 = jnp.where(cond_ps_zero, single_point, new_partial_sum)
- return result_2
-
-
-def padd_with_zero_and_pdul_check(
- partial_sum, single_point, ps_is_zero, sp_is_zero
-):
- """Padd the partial sum with the single point.
-
- Check if the partial sum is equal to the single point first. If they are
- equal, then double the partial sum.
-
- Args:
- partial_sum: The partial sum.
- single_point: The single point.
- ps_is_zero: The zero states of the partial sum.
- sp_is_zero: The zero states of the single point.
-
- Returns:
- The new partial sum.
- """
- # coordinate_dim, batch_dim, precision_dim = partial_sum.shape
- _, batch_dim, _ = partial_sum.shape
- new_partial_sum = jec.padd_lazy_xyzz_pack(partial_sum, single_point)
- double_partial_sum = jec.pdul_lazy_xyzz_pack(partial_sum)
- cond_equal = jnp.all(partial_sum == single_point, axis=(0, 2)).reshape(
- 1, batch_dim, 1
- )
- cond_sp_zero = jnp.equal(sp_is_zero, 1).reshape(1, batch_dim, 1)
- cond_ps_zero = jnp.equal(ps_is_zero, 1).reshape(1, batch_dim, 1)
- result_1 = jnp.where(cond_sp_zero, partial_sum, single_point)
- result_2 = jnp.where(
- jnp.logical_or(cond_sp_zero, cond_ps_zero), result_1, new_partial_sum
- )
- reuslt_3 = jnp.where(cond_equal, double_partial_sum, result_2)
- return reuslt_3
-
-
-def bucket_accumulation_algorithm(
- all_buckets: jnp.ndarray,
- all_points: jnp.ndarray,
- selection_list: jnp.ndarray,
- zero_states_list: jnp.ndarray,
- msm_length: int,
-):
- """Non-scan version BA."""
- coordinate_dim, buckets_dim, precision_dim = all_buckets.shape
- for i in range(msm_length):
- point = jax.lax.broadcast_in_dim(
- all_points[i], (coordinate_dim, buckets_dim, precision_dim), (0, 2)
- )
- all_buckets = selective_padd_with_zero(
- all_buckets, point, selection_list[i], zero_states_list[i]
- )
- return all_buckets
-
-
-def bucket_accumulation_scan_algorithm(
- all_buckets: jnp.ndarray,
- all_points: jnp.ndarray,
- selection_list: jnp.ndarray,
- zero_states_list: jnp.ndarray,
- msm_length: int,
-):
- """Scan version BA."""
- coordinate_dim, buckets_dim, precision_dim = all_buckets.shape
-
- def scan_body(buckets, point_with_cond_pack):
- point, selection, zero_states = point_with_cond_pack
- point = jax.lax.broadcast_in_dim(
- point, (coordinate_dim, buckets_dim, precision_dim), (0, 2)
- )
- all_buckets = selective_padd_with_zero(
- buckets, point, selection, zero_states
- )
- return all_buckets, None
-
- all_buckets, _ = jax.lax.scan(
- scan_body,
- all_buckets,
- (all_points, selection_list, zero_states_list),
- length=msm_length,
- )
- return all_buckets
-
-
-def bucket_accumulation_index_algorithm(
- all_buckets: jnp.ndarray,
- all_points: jnp.ndarray,
- selection_index_list: jnp.ndarray,
- zero_states_list: jnp.ndarray,
- msm_length: int,
-):
- """Non-scan version BA with index selection."""
- # buckets_dim is not used in the algorithm, changed it into _.
- coordinate_dim, window_dim, _, precision_dim = all_buckets.shape
- for i in range(msm_length):
- point = jax.lax.broadcast_in_dim(
- all_points[i], (coordinate_dim, window_dim, precision_dim), (0, 2)
- )
- selective_buckets = all_buckets[
- :, jnp.arange(window_dim), selection_index_list[i], :
- ]
- selective_zero_states = zero_states_list[
- i, jnp.arange(window_dim), selection_index_list[i]
- ]
- selective_update = padd_with_zero_alter(
- selective_buckets, point, selective_zero_states
- )
- all_buckets = all_buckets.at[
- :, jnp.arange(window_dim), selection_index_list[i], :
- ].set(selective_update)
- return all_buckets
-
-
-def bucket_accumulation_index_scan_algorithm(
- all_buckets: jnp.ndarray,
- all_points: jnp.ndarray,
- selection_index_list: jnp.ndarray,
- zero_states_list: jnp.ndarray,
- msm_length: int,
-):
- """Scan version BA with index selection."""
- # buckets_dim is not used in the algorithm, changed it into _.
- coordinate_dim, window_dim, _, precision_dim = all_buckets.shape
-
- def scan_body(buckets, point_with_cond_pack):
- point, selection_index, zero_states = point_with_cond_pack
- point = jax.lax.broadcast_in_dim(
- point, (coordinate_dim, window_dim, precision_dim), (0, 2)
- )
- selective_buckets = buckets[:, jnp.arange(window_dim), selection_index, :]
- selective_zero_states = zero_states[jnp.arange(window_dim), selection_index]
- selective_update = padd_with_zero_alter(
- selective_buckets, point, selective_zero_states
- )
- return (
- buckets.at[:, jnp.arange(window_dim), selection_index, :].set(
- selective_update
- ),
- None,
- )
-
- all_buckets, _ = jax.lax.scan(
- scan_body,
- all_buckets,
- (all_points, selection_index_list, zero_states_list),
- length=msm_length,
- )
- return all_buckets
-
-
-def bucket_reduction_algorithm(
- all_buckets: jnp.ndarray,
- temp_sum: jnp.ndarray,
- window_sum: jnp.ndarray,
- bucket_zero_states_list: jnp.ndarray,
- temp_zero_states_list: jnp.ndarray,
- window_zero_states_list: jnp.ndarray,
- bucket_num_in_window: int,
-):
- """Non-scan version BR."""
- all_buckets = jnp.flip(all_buckets.transpose(2, 0, 1, 3), axis=0)
- for i in range(bucket_num_in_window - 1):
- temp_sum = padd_with_zero(
- temp_sum,
- all_buckets[i],
- temp_zero_states_list[i],
- bucket_zero_states_list[i],
- )
- window_sum = padd_with_zero_and_pdul_check(
- window_sum,
- temp_sum,
- window_zero_states_list[i],
- temp_zero_states_list[i + 1],
- )
- return window_sum
-
-
-def bucket_reduction_scan_algorithm(
- all_buckets: jnp.ndarray,
- temp_sum: jnp.ndarray,
- window_sum: jnp.ndarray,
- bucket_zero_states_list: jnp.ndarray,
- temp_zero_states_list: jnp.ndarray,
- window_zero_states_list: jnp.ndarray,
- bucket_num_in_window: int,
-):
- """Scan version BR."""
- all_buckets = jnp.flip(all_buckets.transpose(2, 0, 1, 3), axis=0)
-
- def scan_body(temp_and_window_sum_pack, bucket_with_cond_pack):
- temp_sum, window_sum = temp_and_window_sum_pack
- (
- bucket,
- bucket_zero_states,
- temp_zero_states,
- temp_zero_states1,
- window_zero_states,
- ) = bucket_with_cond_pack
- temp_sum = padd_with_zero(
- temp_sum, bucket, temp_zero_states, bucket_zero_states
- )
- window_sum = padd_with_zero_and_pdul_check(
- window_sum, temp_sum, window_zero_states, temp_zero_states1
- )
- return (temp_sum, window_sum), None
-
- (_, window_sum), _ = jax.lax.scan(
- scan_body,
- (temp_sum, window_sum),
- (
- all_buckets[: bucket_num_in_window - 1],
- bucket_zero_states_list[: bucket_num_in_window - 1],
- temp_zero_states_list[: bucket_num_in_window - 1],
- temp_zero_states_list[1:],
- window_zero_states_list[: bucket_num_in_window - 1],
- ),
- length=bucket_num_in_window - 1,
- )
-
- return window_sum
-
-
-def window_merge_algorithm(window_sum: jnp.ndarray, slice_length: int):
- """Non-scan version WM."""
- coordinate_dim, window_dim, precision_dim = window_sum.shape
- result = window_sum[:, window_dim - 1, :].reshape(
- (coordinate_dim, 1, precision_dim)
- )
- for w in range(window_dim - 2, -1, -1):
- for _ in range(slice_length):
- result = jec.pdul_lazy_xyzz_pack(result)
- result = jec.padd_lazy_xyzz_pack(
- result,
- window_sum[:, w, :].reshape(
- (coordinate_dim, 1, util.U16_EXT_CHUNK_NUM)
- ),
- )
-
- result = result.reshape((coordinate_dim, precision_dim))
- return result
-
-
-def window_merge_scan_algorithm(window_sum: jnp.ndarray, slice_length: int):
- """Scan version WM."""
- coordinate_dim, window_dim, precision_dim = window_sum.shape
- window_sum = window_sum.transpose(1, 0, 2)
- result = window_sum[window_dim - 1, :, :].reshape(
- (coordinate_dim, 1, precision_dim)
- )
-
- def fori_loop_body(_, result):
- result = jec.pdul_lazy_xyzz_pack(result)
- return result
-
- def scan_body(result, window_sum):
- result = jax.lax.fori_loop(0, slice_length, fori_loop_body, result)
- result = jec.padd_lazy_xyzz_pack(
- result, window_sum.reshape((coordinate_dim, 1, util.U16_EXT_CHUNK_NUM))
- )
- return result, None
-
- result, _ = jax.lax.scan(
- scan_body,
- result,
- window_sum[: window_dim - 1, :, :],
- reverse=True,
- length=window_dim - 1,
- )
- result = result.reshape((coordinate_dim, precision_dim))
- return result
-
-
-class MSMPippenger:
- """Pippenger algorithm for elliptic curves.
-
- Attributes:
- coordinate_num: The number of coordinates in the elliptic curve.
- slice_length: The length of each slice in the elliptic curve.
- window_num: The number of windows in the elliptic curve.
- bucket_num_per_window: The number of buckets in each window.
- bucket_num_in_window: The number of buckets in each window.
- slice_mask: The mask for the slices in the elliptic curve.
- blank_point: A JAX array of zeros, used to initialize the buckets.
- all_buckets: A JAX array of all the buckets in the elliptic curve.
- points: A list of JAX arrays, where each array represents an Orignal point
- from the trace.
- scalars: A list of integers, where each integer represents an Orignal scalar
- from the trace.
- all_points: A JAX array of all the points in the elliptic curve. from the
- trace.
- window_sum: A JAX array of the window sum.
- bucket_zero_states: A JAX array of the zero states for the buckets.
- temp_sum_zero_states: A JAX array of the zero states for the temp sum.
- window_sum_zero_states: A JAX array of the zero states for the window sum.
- zero_states_list: A JAX array of the zero states for the buckets.
- selection_list: A JAX array of the selection states for the buckets.
- selection_index_list: A JAX array of the selection index for the buckets.
- msm_length: The length of the MSM trace.
- result: The final elliptic curve.
- lazy_mat: The lazy matrix used for padding and doubling.
- """
-
- def __init__(self, slice_length):
- self.coordinate_num = util.COORDINATE_NUM
-
- self.slice_length = slice_length
- self.window_num = int(math.ceil(254 / self.slice_length)) #
- self.bucket_num_per_window = 2**self.slice_length
- self.slice_mask = self.bucket_num_per_window - 1
- self.blank_point = util.int_list_to_array(
- [0, 0, 0, 0], util.BASE, util.U16_EXT_CHUNK_NUM
- ).reshape(self.coordinate_num, 1, util.U16_EXT_CHUNK_NUM)
-
- self.all_buckets = jnp.broadcast_to(
- self.blank_point.reshape(
- 1, self.coordinate_num, 1, util.U16_EXT_CHUNK_NUM
- ).transpose(1, 0, 2, 3),
- (
- self.coordinate_num,
- self.window_num,
- self.bucket_num_per_window,
- util.U16_EXT_CHUNK_NUM,
- ),
- )
-
- self.window_sum: jnp.ndarray
-
- self.msm_length = 0
- self.bucket_zero_states: jnp.ndarray
- self.temp_sum_zero_states: jnp.ndarray
- self.window_sum_zero_states: jnp.ndarray
- self.zero_states_list: jnp.ndarray
- self.selection_list: jnp.ndarray
- self.selection_index_list: jnp.ndarray
- self.all_points: jnp.ndarray
-
- self.scalars: List[int] = [] # Orignal scalar from the trace
- # [Points, Points, ..., Points]
- self.points: List[jnp.ndarray] = [] # Orignal points from the trace
- self.lazy_mat = util.construct_lazy_matrix(util.MODULUS_377_INT)
-
- self.result = None
-
- def initialize(self, scalars, points):
- """Initialize the Pippenger algorithm.
-
- Args:
- scalars: A list of integers, where each integer represents an Orignal
- scalar from the trace.
- points: A list of JAX arrays, where each array represents an Orignal point
- from the trace.
- """
- # Initial internal selection from the scalar
- self.scalars = scalars
- self.msm_length = len(scalars)
-
- # Convert high-precision points into a vector of low-precision chunks
- self.points = [
- util.int_list_to_array(
- coordinates + [1, 1], util.BASE, util.U16_EXT_CHUNK_NUM
- )
- for coordinates in points
- ] # pytype: disable=container-type-mismatch
-
- self.all_points = jnp.array(self.points)
-
- # For BA
- zero_states_pylist, selection_pylist, selection_index_pylist = (
- self.construct_ba_zero_states_and_selection()
- )
- self.selection_list = jnp.array(selection_pylist, dtype=jnp.uint8).reshape(
- (-1, self.window_num * self.bucket_num_per_window)
- )
- # For index selection version BA
- self.zero_states_list = jnp.array(zero_states_pylist, dtype=jnp.uint8)
- self.selection_index_list = jnp.array(
- selection_index_pylist, dtype=jnp.uint32
- )
-
- # For BR
- (
- bucket_zero_states_py,
- temp_sum_zero_states_py,
- window_sum_zero_states_py,
- ) = self.construct_br_zero_states(
- zero_states_pylist[len(zero_states_pylist) - 1]
- )
- self.bucket_zero_states = jnp.array(bucket_zero_states_py, dtype=jnp.uint8)
- self.temp_sum_zero_states = jnp.array(
- temp_sum_zero_states_py, dtype=jnp.uint8
- )
- self.window_sum_zero_states = jnp.array(
- window_sum_zero_states_py, dtype=jnp.uint8
- )
-
- def bucket_accumulation(self, bucket_accumulation_index_func):
- """BA index selection version."""
- self.all_buckets = bucket_accumulation_index_func(
- self.all_buckets,
- self.all_points,
- self.selection_index_list,
- self.zero_states_list[: self.msm_length],
- )
-
- return self.all_buckets
-
- def bucket_reduction(self, bucket_reduction_func):
- """Reduce the buckets to a single point for each window."""
- temp_sum = jnp.broadcast_to(
- self.blank_point,
- (
- self.coordinate_num,
- self.window_num,
- util.U16_EXT_CHUNK_NUM,
- ),
- )
- window_sum = jnp.broadcast_to(
- self.blank_point,
- (
- self.coordinate_num,
- self.window_num,
- util.U16_EXT_CHUNK_NUM,
- ),
- )
- self.window_sum = bucket_reduction_func(
- self.all_buckets,
- temp_sum,
- window_sum,
- self.bucket_zero_states[: self.bucket_num_per_window],
- self.temp_sum_zero_states[: self.bucket_num_per_window],
- self.window_sum_zero_states[: self.bucket_num_per_window],
- )
- return self.window_sum
-
- def window_merge(self, window_merge_func):
- """Merge the windows to form the final elliptic curve.
-
- Args:
- window_merge_func: The function to merge the windows.
-
- Returns:
- The final elliptic curve.
- """
- self.result = window_merge_func(self.window_sum)
- return self.result
-
- def construct_ba_zero_states_and_selection(self):
- """Construct the zero states and selection for the bucket accumulation (BA) step.
-
- Returns:
- A tuple of two lists: the zero states for the bucket accumulation, and the
- selection for the bucket accumulation.
- """
- zero_states = [
- deepcopy([1] * self.bucket_num_per_window)
- for _ in range(self.window_num)
- ]
- zero_states_list = []
- zero_states_list.append(deepcopy(zero_states))
- selection_list = []
- selection_index_list = [] # Used for index selection
- for scalar in self.scalars:
- # Compute the zero states for each scalar by time dependence
- selection = [
- deepcopy(([0] * self.bucket_num_per_window))
- for _ in range(self.window_num)
- ]
- selection_index = []
- for w in range(self.window_num):
- slice_index = (scalar >> (w * self.slice_length)) & self.slice_mask
- zero_states[w][slice_index] = 0
- selection[w][slice_index] = 1
- selection_index.append(slice_index)
-
- selection_list.append(deepcopy(selection))
- zero_states_list.append(deepcopy(zero_states))
- selection_index_list.append(deepcopy(selection_index))
- return zero_states_list, selection_list, selection_index_list
-
- def construct_br_zero_states(self, bucket_zero_states):
- """Construct the zero states for the bucket reduction (BR) step.
-
- Args:
- bucket_zero_states: The zero states of the buckets.
-
- Returns:
- A tuple of three lists: the zero states for the bucket reduction, the zero
- states for the temporary sum, and the zero states for the window sum.
- """
- temp_sum_zero_states = np.array([1] * self.window_num)
- window_sum_zero_states = np.array([1] * self.window_num)
- temp_sum_zero_states_list = []
- window_sum_zero_states_list = []
- temp_sum_zero_states_list.append(temp_sum_zero_states)
- window_sum_zero_states_list.append(window_sum_zero_states)
- bucket_zero_states_list = np.flip(
- np.array(bucket_zero_states).transpose(1, 0), axis=0
- )
- for b in range(self.bucket_num_per_window):
- next_temp_sum_zero_states = (
- temp_sum_zero_states_list[b] & bucket_zero_states_list[b]
- )
- next_window_sum_zero_states = (
- window_sum_zero_states_list[b] & next_temp_sum_zero_states
- )
- temp_sum_zero_states_list.append(next_temp_sum_zero_states)
- window_sum_zero_states_list.append(next_window_sum_zero_states)
- return (
- bucket_zero_states_list,
- temp_sum_zero_states_list,
- window_sum_zero_states_list,
- )
-
-
-#########################
-# Functions for twisted curve
-#########################
-
-
-def padd(partial_sum, single_point):
- return jec.padd_lazy_twisted_pack(partial_sum, single_point)
-
-
-def padd_with_pdul_check(partial_sum, single_point):
- # coordinate_dim, batch_dim, precision_dim = partial_sum.shape
- _, batch_dim, _ = partial_sum.shape
- new_partial_sum = jec.padd_lazy_twisted_pack(partial_sum, single_point)
- double_partial_sum = jec.pdul_lazy_twisted_pack(partial_sum)
- cond_equal = jnp.all(partial_sum == single_point, axis=(0, 2)).reshape(
- 1, batch_dim, 1
- )
- return jnp.where(cond_equal, double_partial_sum, new_partial_sum)
-
-
-def bucket_accumulation_index_scan_parallel_algorithm_twisted(
- all_buckets: jnp.ndarray,
- all_points: jnp.ndarray,
- selection_index_list: jnp.ndarray,
- msm_length: int,
-):
- """Scan version BA with index selection."""
- # buckets_dim is not used in the algorithm, changed it into _.
- coordinate_dim, batch_window_dim, _, precision_dim = all_buckets.shape
- _, _, parallel_dim, _ = (
- all_points.shape
- ) # (serial_dim, coordinate_dim, parallel_dim, precision_dim)
- single_window_dim = batch_window_dim // parallel_dim
-
- def scan_body(buckets, point_with_cond_pack):
- point, selection_index = point_with_cond_pack
- point = jax.lax.broadcast_in_dim(
- point,
- (coordinate_dim, parallel_dim, single_window_dim, precision_dim),
- (0, 1, 3),
- )
- point = point.reshape((coordinate_dim, batch_window_dim, precision_dim))
- selective_buckets = buckets[
- :, jnp.arange(batch_window_dim), selection_index, :
- ]
- selective_update = padd(selective_buckets, point)
- return (
- buckets.at[:, jnp.arange(batch_window_dim), selection_index, :].set(
- selective_update
- ),
- None,
- )
-
- all_buckets, _ = jax.lax.scan(
- scan_body,
- all_buckets,
- (all_points, selection_index_list),
- length=msm_length,
- )
- return all_buckets
-
-
-def bucket_reduction_scan_algorithm_twisted(
- all_buckets: jnp.ndarray,
- temp_sum: jnp.ndarray,
- window_sum: jnp.ndarray,
- bucket_num_in_window: int,
-):
- """Scan version BR."""
- all_buckets = jnp.flip(all_buckets.transpose(2, 0, 1, 3), axis=0)
-
- def scan_body(temp_and_window_sum_pack, buckets):
- temp_sum, window_sum = temp_and_window_sum_pack
- temp_sum = padd(temp_sum, buckets)
- window_sum = padd_with_pdul_check(window_sum, temp_sum)
- return (temp_sum, window_sum), None
-
- (_, window_sum), _ = jax.lax.scan(
- scan_body,
- (temp_sum, window_sum),
- all_buckets[:bucket_num_in_window],
- length=bucket_num_in_window,
- )
-
- return window_sum
-
-
-def batch_window_summation_algorithm_twisted(
- batch_window_sum: jnp.ndarray,
- all_window_sum: jnp.ndarray,
- point_parallel: int,
-):
- """Batch window summation algorithm for twisted curve."""
- # batch_window_dim is not used in the algorithm, changed it into _.
- coordinate_dim, _, precision_dim = all_window_sum.shape
- all_window_sum = all_window_sum.reshape(
- (coordinate_dim, point_parallel, -1, precision_dim)
- ).transpose(1, 0, 2, 3)
-
- def scan_body(batch_window_sum, single_window_sum):
- batch_window_sum = padd(batch_window_sum, single_window_sum)
- return batch_window_sum, None
-
- batch_window_sum, _ = jax.lax.scan(
- scan_body, batch_window_sum, all_window_sum, length=point_parallel
- )
- return batch_window_sum
-
-
-def window_merge_scan_algorithm_twisted(
- window_sum: jnp.ndarray, slice_length: int
-):
- """Scan version WM."""
- coordinate_dim, window_dim, precision_dim = window_sum.shape
- window_sum = window_sum.transpose(1, 0, 2)
- result = window_sum[window_dim - 1, :, :].reshape(
- (coordinate_dim, 1, precision_dim)
- )
-
- def fori_loop_body(_, result):
- result = jec.pdul_lazy_twisted_pack(result)
- return result
-
- def scan_body(result, window_sum):
- result = jax.lax.fori_loop(0, slice_length, fori_loop_body, result)
- result = jec.padd_lazy_twisted_pack(
- result, window_sum.reshape((coordinate_dim, 1, util.U16_EXT_CHUNK_NUM))
- )
- return result, None
-
- result, _ = jax.lax.scan(
- scan_body,
- result,
- window_sum[: window_dim - 1, :, :],
- reverse=True,
- length=window_dim - 1,
- )
- result = result.reshape((coordinate_dim, precision_dim))
- return result
-
-
-class MSMPippengerTwisted:
- """Pippenger algorithm for elliptic curves with twisted points.
-
- Attributes:
- coordinate_num: The number of coordinates in the elliptic curve.
- slice_length: The length of each slice in the elliptic curve.
- point_parallel: The number of parallel points in the elliptic curve.
- window_num: The number of windows in the elliptic curve.
- batch_window_num: The number of batch windows in the elliptic curve.
- bucket_num_per_window: The number of buckets in each window.
- slice_mask: The mask for the slices in the elliptic curve.
- blank_point: A JAX array of zeros, used to initialize the buckets.
- all_buckets: A JAX array of all the buckets in the elliptic curve.
- points: A list of JAX arrays, where each array represents an Orignal point
- from the trace.
- scalars: A list of integers, where each integer represents an Orignal scalar
- from the trace.
- all_points: A JAX array of all the points in the elliptic curve. from the
- trace.
- window_sum: A JAX array of the window sum.
- br_temp_sum: A JAX array of the temp sum for bucket reduction.
- batch_window_sum: A JAX array of the batch window sum.
- selection_index_list: A JAX array of the selection index for the buckets.
- msm_length: The length of the MSM trace.
- result: The final elliptic curve.
- lazy_mat: The lazy matrix used for padding and doubling.
- """
-
- def __init__(self, slice_length: int, point_parallel: int):
-
- self.coordinate_num = util.COORDINATE_NUM
-
- self.slice_length = slice_length
- self.point_parallel = point_parallel
- self.window_num = int(math.ceil(254 / self.slice_length)) #
- self.batch_window_num = self.window_num * self.point_parallel
- self.bucket_num_per_window = (
- 2**self.slice_length - 1
- ) # Note: here remove the bucket_0
- self.slice_mask = 2**self.slice_length - 1
- self.blank_point = util.int_list_to_array(
- [0, 1, 1, 0], util.BASE, util.U16_EXT_CHUNK_NUM
- ).reshape(self.coordinate_num, 1, util.U16_EXT_CHUNK_NUM)
-
- self.all_buckets = jnp.broadcast_to(
- self.blank_point.reshape(
- 1, self.coordinate_num, 1, util.U16_EXT_CHUNK_NUM
- ).transpose(1, 0, 2, 3),
- (
- self.coordinate_num,
- self.window_num,
- self.bucket_num_per_window,
- util.U16_EXT_CHUNK_NUM,
- ),
- )
-
- self.all_buckets = jnp.tile(
- self.all_buckets, (1, self.point_parallel, 1, 1)
- )
-
- self.window_sum: jnp.ndarray
- self.br_temp_sum: jnp.ndarray
- self.batch_window_sum: jnp.ndarray
-
- self.msm_length = 0
-
- self.selection_index_list: jnp.ndarray
- self.all_points: jnp.ndarray
-
- self.scalars: List[int] = [] # Orignal scalar from the trace
- # [Points, Points, ..., Points]
- self.points: List[jnp.ndarray] = [] # Orignal points from the trace
- self.lazy_mat = util.construct_lazy_matrix(util.MODULUS_377_INT)
-
- self.result = None
-
- def initialize(self, scalars, points):
- """Initialize the Pippenger algorithm.
-
- Args:
- scalars: A list of integers, where each integer represents an Orignal
- scalar from the trace.
- points: A list of JAX arrays, where each array represents an Orignal point
- from the trace.
- """
- # Initial internal selection from the scalar
- self.scalars = scalars
- self.msm_length = len(scalars)
-
- # Convert high-precision points into a vector of low-precision chunks
- self.points = [
- util.int_list_to_array(coordinates, util.BASE, util.U16_EXT_CHUNK_NUM)
- for coordinates in points
- ] # pytype: disable=container-type-mismatch
-
- self.all_points = jnp.array(self.points)
- _, coordinate_dim, precision_dim = self.all_points.shape
-
- # For BA
- selection_index_pylist = self.construct_ba_selection()
- # Note: it contains uint(-1) for the bucket_0.
- # In BA, it may cause some undefined behavior when do bucket selection
- # it is correct now, because when setting buckets after the computation,
- # jax.numpy will ignore the index with uint(-1) out of index.
- self.selection_index_list = jnp.array(selection_index_pylist).astype(
- jnp.uint32
- )
- _, window_dim = self.selection_index_list.shape
-
- # Batch construction
- self.all_points = self.all_points.reshape(
- (-1, self.point_parallel, coordinate_dim, precision_dim)
- ).transpose(0, 2, 1, 3)
- self.selection_index_list = self.selection_index_list.reshape(
- (-1, window_dim * self.point_parallel)
- )
- self.br_temp_sum = jnp.broadcast_to(
- self.blank_point,
- (
- self.coordinate_num,
- self.batch_window_num,
- util.U16_EXT_CHUNK_NUM,
- ),
- )
- self.window_sum = jnp.broadcast_to(
- self.blank_point,
- (
- self.coordinate_num,
- self.batch_window_num,
- util.U16_EXT_CHUNK_NUM,
- ),
- )
- self.batch_window_sum = jnp.broadcast_to(
- self.blank_point,
- (
- self.coordinate_num,
- self.window_num,
- util.U16_EXT_CHUNK_NUM,
- ),
- )
-
- def bucket_accumulation(self, bucket_accumulation_index_func):
- """BA index selection version."""
- self.all_buckets = bucket_accumulation_index_func(
- self.all_buckets, self.all_points, self.selection_index_list
- )
- return self.all_buckets
-
- def bucket_reduction(self, bucket_reduction_func):
- """Reduce the buckets to a single point for each window."""
- temp_sum = jnp.broadcast_to(
- self.blank_point,
- (
- self.coordinate_num,
- self.batch_window_num,
- util.U16_EXT_CHUNK_NUM,
- ),
- )
- window_sum = jnp.broadcast_to(
- self.blank_point,
- (
- self.coordinate_num,
- self.batch_window_num,
- util.U16_EXT_CHUNK_NUM,
- ),
- )
- self.window_sum = bucket_reduction_func(
- self.all_buckets, temp_sum, window_sum
- )
- return self.window_sum
-
- def batch_window_summation(self, batch_window_summation_func):
- """Sum the batch windows to form the final window sum."""
- batch_window_sum = jnp.broadcast_to(
- self.blank_point,
- (
- self.coordinate_num,
- self.window_num,
- util.U16_EXT_CHUNK_NUM,
- ),
- )
- self.window_sum = batch_window_summation_func(
- batch_window_sum, self.window_sum
- )
- return self.window_sum
-
- def window_merge(self, window_merge_func):
- """Merge the windows to form the final elliptic curve."""
- self.result = window_merge_func(self.window_sum)
- return self.result
-
- def construct_ba_selection(self):
- selection_index_list = [] # Used for index selection
- for scalar in self.scalars:
- # Compute the zero states for each scalar by time dependence
- selection_index = []
- for w in range(self.window_num):
- slice_index = (
- (scalar >> (w * self.slice_length)) & self.slice_mask
- ) - 1
- selection_index.append(slice_index)
- selection_index_list.append(deepcopy(selection_index))
- return selection_index_list
-
-
-#########################
-# Functions for Signed bucket + twisted curve
-#########################
-def padd_with_sign(partial_sum, single_point, sign):
- neg_single_point = jec.pneg_lazy_twisted_pack(single_point)
- _, batch_dim, _ = partial_sum.shape
- cond_neg = jnp.equal(sign, 1).reshape(1, batch_dim, 1)
- signed_point = jnp.where(cond_neg, neg_single_point, single_point)
- result = jec.padd_lazy_twisted_pack(partial_sum, signed_point)
- return result
-
-
-def bucket_accumulation_signed_index_scan_parallel_algorithm_twisted(
- all_buckets: jnp.ndarray,
- all_points: jnp.ndarray,
- selection_index_list: jnp.ndarray,
- selection_sign_list: jnp.ndarray,
- msm_length: int,
-):
- """Scan version BA with index selection."""
- coordinate_dim, batch_window_dim, _, precision_dim = all_buckets.shape
- _, _, parallel_dim, _ = (
- all_points.shape
- ) # (serial_dim, coordinate_dim, parallel_dim, precision_dim)
- single_window_dim = batch_window_dim // parallel_dim
-
- def scan_body(buckets, point_with_cond_pack):
- point, selection_index, selection_sign = point_with_cond_pack
- point = jax.lax.broadcast_in_dim(
- point,
- (coordinate_dim, parallel_dim, single_window_dim, precision_dim),
- (0, 1, 3),
- )
- point = point.reshape((coordinate_dim, batch_window_dim, precision_dim))
- selective_buckets = buckets[
- :, jnp.arange(batch_window_dim), selection_index, :
- ]
- selective_update = padd_with_sign(selective_buckets, point, selection_sign)
- return (
- buckets.at[:, jnp.arange(batch_window_dim), selection_index, :].set(
- selective_update
- ),
- None,
- )
-
- all_buckets, _ = jax.lax.scan(
- scan_body,
- all_buckets,
- (all_points, selection_index_list, selection_sign_list),
- length=msm_length,
- )
- return all_buckets
-
-
-class MSMPippengerTwistedSigned:
- """Pippenger algorithm for elliptic curves with twisted and signed points.
-
- Attributes:
- coordinate_num: The number of coordinates in the elliptic curve.
- slice_length: The length of each slice in the elliptic curve.
- point_parallel: The number of parallel points in the elliptic curve.
- window_num: The number of windows in the elliptic curve.
- batch_window_num: The number of batch windows in the elliptic curve.
- bucket_num_per_window: The number of buckets in each window.
- slice_mask: The mask for the slices in the elliptic curve.
- blank_point: A JAX array of zeros, used to initialize the buckets.
- all_buckets: A JAX array of all the buckets in the elliptic curve.
- points: A list of JAX arrays, where each array represents an Orignal point
- from the trace.
- scalars: A list of integers, where each integer represents an Orignal scalar
- from the trace.
- all_points: A JAX array of all the points in the elliptic curve. from the
- trace.
- window_sum: A JAX array of the window sum.
- zero_states_list: A JAX array of the zero states for the buckets.
- selection_list: A JAX array of the selection states for the buckets.
- selection_index_list: A JAX array of the selection index for the buckets.
- selection_sign_list: A JAX array of the selection sign for the buckets.
- all_points: A JAX array of all the points in the elliptic curve. from the
- trace.
- msm_length: The length of the MSM trace.
- result: The final elliptic curve.
- lazy_mat: The lazy matrix used for padding and doubling.
- """
-
- def __init__(self, slice_length: int, point_parallel: int):
- self.coordinate_num = util.COORDINATE_NUM
-
- self.slice_length = slice_length
- self.point_parallel = point_parallel
- self.window_num = int(math.ceil(254 / self.slice_length)) #
- self.batch_window_num = self.window_num * self.point_parallel
- self.bucket_num_per_window = 2 ** (self.slice_length - 1)
- self.slice_mask = 2**self.slice_length - 1
- self.blank_point = util.int_list_to_array(
- [0, 1, 1, 0], util.BASE, util.U16_EXT_CHUNK_NUM
- ).reshape(self.coordinate_num, 1, util.U16_EXT_CHUNK_NUM)
-
- self.all_buckets = jnp.broadcast_to(
- self.blank_point.reshape(
- 1, self.coordinate_num, 1, util.U16_EXT_CHUNK_NUM
- ).transpose(1, 0, 2, 3),
- (
- self.coordinate_num,
- self.window_num,
- self.bucket_num_per_window,
- util.U16_EXT_CHUNK_NUM,
- ),
- )
-
- self.all_buckets = jnp.tile(
- self.all_buckets, (1, self.point_parallel, 1, 1)
- )
-
- self.window_sum: jnp.ndarray
-
- self.msm_length = 0
-
- self.zero_states_list: jnp.ndarray
- self.selection_list: jnp.ndarray
- self.selection_index_list: jnp.ndarray
- self.selection_sign_list: jnp.ndarray
- self.all_points: jnp.ndarray
-
- self.scalars: List[int] = [] # Orignal scalar from the trace
- # [Points, Points, ..., Points]
- self.points: List[jnp.ndarray] = [] # Orignal points from the trace
- self.lazy_mat = util.construct_lazy_matrix(util.MODULUS_377_INT)
-
- self.result = None
-
- def initialize(self, scalars, points):
- """Initialize the Pippenger algorithm.
-
- Args:
- scalars: A list of integers, where each integer represents an Orignal
- scalar from the trace.
- points: A list of JAX arrays, where each array represents an Orignal point
- from the trace.
- """
- # Initial internal selection from the scalar
- self.scalars = scalars
- self.msm_length = len(scalars)
-
- # Convert high-precision points into a vector of low-precision chunks
- self.points = [
- util.int_list_to_array(coordinates, util.BASE, util.U16_EXT_CHUNK_NUM)
- for coordinates in points
- ]
- self.all_points = jnp.array(self.points)
- _, coordinate_dim, precision_dim = self.all_points.shape
-
- # For BA
- selection_index_pylist, selection_sign_pylist = (
- self.construct_ba_selection_with_sign()
- )
- self.selection_index_list = jnp.asarray(selection_index_pylist).astype(
- jnp.uint32
- )
- self.selection_sign_list = jnp.array(selection_sign_pylist, dtype=jnp.uint8)
- _, window_dim = self.selection_index_list.shape
-
- # Batch construction
- self.all_points = self.all_points.reshape(
- (-1, self.point_parallel, coordinate_dim, precision_dim)
- ).transpose(0, 2, 1, 3)
- self.selection_index_list = self.selection_index_list.reshape(
- (-1, window_dim * self.point_parallel)
- )
- self.selection_sign_list = self.selection_sign_list.reshape(
- (-1, window_dim * self.point_parallel)
- )
-
- def bucket_accumulation(self, bucket_accumulation_index_algorithm):
- """BA index selection version."""
- self.all_buckets = bucket_accumulation_index_algorithm(
- self.all_buckets,
- self.all_points,
- self.selection_index_list,
- self.selection_sign_list,
- )
-
- return self.all_buckets
-
- def bucket_reduction(self, bucket_reduction_func):
- """Reduce the buckets to a single point for each window."""
- temp_sum = jnp.broadcast_to(
- self.blank_point,
- (
- self.coordinate_num,
- self.batch_window_num,
- util.U16_EXT_CHUNK_NUM,
- ),
- )
- window_sum = jnp.broadcast_to(
- self.blank_point,
- (
- self.coordinate_num,
- self.batch_window_num,
- util.U16_EXT_CHUNK_NUM,
- ),
- )
- self.window_sum = bucket_reduction_func(
- self.all_buckets, temp_sum, window_sum
- )
- return self.window_sum
-
- def batch_window_summation(self, batch_window_summation_algorithm):
- batch_window_sum = jnp.broadcast_to(
- self.blank_point,
- (
- self.coordinate_num,
- self.window_num,
- util.U16_EXT_CHUNK_NUM,
- ),
- )
- self.window_sum = batch_window_summation_algorithm(
- batch_window_sum, self.window_sum
- )
- return self.window_sum
-
- def window_merge(self, window_merge_func):
- """Merge the windows to form the final elliptic curve."""
- self.result = window_merge_func(self.window_sum)
- return self.result
-
- def construct_ba_selection_with_sign(self):
- """Construct the selection index and sign for the bucket accumulation (BA) step.
-
- Returns:
- A tuple of two lists: the selection index for the bucket accumulation, and
- the selection sign for the bucket accumulation.
- """
- selection_index_list = [] # Used for index selection
- selection_sign_list = []
- slice_max = 2**self.slice_length
- slice_half = 2 ** (self.slice_length - 1)
- for scalar in self.scalars:
- # Compute the zero states for each scalar by time dependence
- selection_index = []
- selection_sign = []
- carry = 0
- for w in range(self.window_num):
- slice_index = (scalar >> (w * self.slice_length)) & self.slice_mask
- slice_index = slice_index + carry
- if slice_index >= slice_half:
- new_slice_index = abs(slice_index - slice_max)
- carry = 1
- else:
- new_slice_index = slice_index
- carry = 0
- selection_index.append(new_slice_index - 1)
- selection_sign.append(carry)
- assert carry == 0
- selection_index_list.append(deepcopy(selection_index))
- selection_sign_list.append(deepcopy(selection_sign))
- return selection_index_list, selection_sign_list
diff --git a/jaxite_ec/pippenger_rns.py b/jaxite_ec/pippenger_rns.py
deleted file mode 100644
index fd27be3..0000000
--- a/jaxite_ec/pippenger_rns.py
+++ /dev/null
@@ -1,1057 +0,0 @@
-"""RNS Pippenger algorithm for elliptic curves.
-
-This module implements the Pippenger algorithm for elliptic curves. The
-algorithm
-is a generalization of the elliptic curve algorithm that can be used to find
-elliptic curves of arbitrary order.
-
-The Pippenger algorithm works by first finding a set of "scalars" and "points"
-that lie on the elliptic curve. These scalars and points are then used to
-construct
-a "bucket" for each window of the elliptic curve. The buckets are then reduced
-to a single point for each window. Finally, the points from all of the windows
-are merged together to form the final elliptic curve.
-
-The Pippenger algorithm is a powerful tool for finding elliptic curves, and it
-has been used to find elliptic curves of arbitrary order and dimension.
-
-Example Compiled Function:
-
-bucket_accumulation_scan_jax_jit = jax.jit(
- bucket_accumulation_scan_jax,
- static_argnames='msm_length').lower(
- jax.ShapeDtypeStruct(
- (COORDINATE_NUM, WINDOW_NUM*BUCKET_NUM_PER_WINDOW, CHUNK_NUM),
- dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (MSM_LENGTH, COORDINATE_NUM, CHUNK_NUM),
- dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (MSM_LENGTH,WINDOW_NUM*BUCKET_NUM_PER_WINDOW),
- dtype=jnp.uint8
- ),
- jax.ShapeDtypeStruct(
- (MSM_LENGTH,WINDOW_NUM*BUCKET_NUM_PER_WINDOW),
- dtype=jnp.uint8
- ),
- MSM_LENGTH
- ).compile()
-
-bucket_accumulation_index_scan_jax_jit = jax.jit(
- bucket_accumulation_index_scan_jax,
- static_argnames='msm_length').lower(
- jax.ShapeDtypeStruct(
- (COORDINATE_NUM, WINDOW_NUM,BUCKET_NUM_PER_WINDOW, CHUNK_NUM),
- dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (MSM_LENGTH, COORDINATE_NUM, CHUNK_NUM),
- dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (MSM_LENGTH,WINDOW_NUM),
- dtype=jnp.uint32
- ),
- jax.ShapeDtypeStruct(
- (MSM_LENGTH,WINDOW_NUM,BUCKET_NUM_PER_WINDOW),
- dtype=jnp.uint8
- ),
- MSM_LENGTH
- ).compile()
-
-bucket_reduction_scan_jax_jit = jax.jit(
- bucket_reduction_scan_jax,
- static_argnames='bucket_num_in_window').lower(
- jax.ShapeDtypeStruct(
- (COORDINATE_NUM, WINDOW_NUM, BUCKET_NUM_PER_WINDOW, CHUNK_NUM),
- dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (COORDINATE_NUM, WINDOW_NUM, CHUNK_NUM),
- dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (COORDINATE_NUM, WINDOW_NUM, CHUNK_NUM),
- dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (BUCKET_NUM_PER_WINDOW, WINDOW_NUM),
- dtype=jnp.uint8
- ),
- jax.ShapeDtypeStruct(
- (BUCKET_NUM_PER_WINDOW, WINDOW_NUM),
- dtype=jnp.uint8
- ),
- jax.ShapeDtypeStruct(
- (BUCKET_NUM_PER_WINDOW, WINDOW_NUM,),
- dtype=jnp.uint8
- ),
- BUCKET_NUM_PER_WINDOW
- ).compile()
-
-window_merge_scan_jax_jit = jax.jit(
- window_merge_scan_jax,
- static_argnames='slice_length').lower(
- jax.ShapeDtypeStruct((COORDINATE_NUM, WINDOW_NUM, CHUNK_NUM),
- dtype=jnp.uint16),
- SLICE_LENGTH
- ).compile()
-
-selective_padd_with_zero_jit = jax.jit(
- selective_padd_with_zero).lower(
- jax.ShapeDtypeStruct(
- (COORDINATE_NUM, WINDOW_NUM*BUCKET_NUM_PER_WINDOW, COORDINATE_NUM),
- dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (COORDINATE_NUM, MSM_LENGTH, COORDINATE_NUM),
- dtype=jnp.uint16
- ),
- jax.ShapeDtypeStruct(
- (WINDOW_NUM*BUCKET_NUM_PER_WINDOW,),
- dtype=jnp.uint8
- ),
- jax.ShapeDtypeStruct(
- (WINDOW_NUM*BUCKET_NUM_PER_WINDOW,),
- dtype=jnp.uint8
- ),
- ).compile()
-"""
-
-import copy
-import math
-from typing import List
-
-import jax
-import jax.numpy as jnp
-from jaxite.jaxite_ec import util
-import jaxite.jaxite_ec.elliptic_curve as jec
-import jaxite.jaxite_ec.finite_field as ff
-import numpy as np
-
-
-deepcopy = copy.deepcopy
-
-"""
-Example Parameters:
-SLICE_LENGTH = 4
-WINDOW_NUM = int(math.ceil(255 / SLICE_LENGTH))
-BUCKET_NUM_PER_WINDOW = 2**SLICE_LENGTH
-MSM_LENGTH = 1024
-COORDINATE_NUM = 4
-"""
-
-
-def selective_padd_with_zero(partial_sum, single_point, select, is_zero):
- """Padd the partial sum with the single point, but only if the selection state is 1.
-
- Args:
- partial_sum: The partial sum.
- single_point: The single point.
- select: The selection state.
- is_zero: The zero states.
-
- Returns:
- The new partial sum.
- """
- _, batch_dim, _ = partial_sum.shape
- new_partial_sum = jec.padd_rns_xyzz_pack(partial_sum, single_point)
-
- cond_select = jnp.equal(select, 1).reshape(1, batch_dim, 1)
- sum_result = jnp.where(cond_select, new_partial_sum, partial_sum)
-
- cond_zero = jnp.equal(is_zero, 1).reshape(1, batch_dim, 1)
- cond_select_and_zero = jnp.logical_and(cond_select, cond_zero)
- result = jnp.where(cond_select_and_zero, single_point, sum_result)
- return result
-
-
-def padd_with_zero(partial_sum, single_point, ps_is_zero, sp_is_zero):
- """Padd the partial sum with the single point.
-
- Check if the partial sum is equal to the single point first.
-
- Args:
- partial_sum: The partial sum.
- single_point: The single point.
- ps_is_zero: The zero states of the partial sum.
- sp_is_zero: The zero states of the single point.
-
- Returns:
- The new partial sum.
- """
- _, batch_dim, _ = partial_sum.shape
- new_partial_sum = jec.padd_rns_xyzz_pack(partial_sum, single_point)
- cond_sp_zero = jnp.equal(sp_is_zero, 1).reshape(1, batch_dim, 1)
- cond_ps_zero = jnp.equal(ps_is_zero, 1).reshape(1, batch_dim, 1)
- result_1 = jnp.where(cond_sp_zero, partial_sum, single_point)
- result_2 = jnp.where(
- jnp.logical_or(cond_sp_zero, cond_ps_zero), result_1, new_partial_sum
- )
- return result_2
-
-
-def padd_with_zero_alter(partial_sum, single_point, ps_is_zero):
-
- _, batch_dim, _ = partial_sum.shape
- new_partial_sum = jec.padd_rns_xyzz_pack(partial_sum, single_point)
- cond_ps_zero = jnp.equal(ps_is_zero, 1).reshape(1, batch_dim, 1)
- result_2 = jnp.where(cond_ps_zero, single_point, new_partial_sum)
- return result_2
-
-
-def padd_with_zero_and_pdul_check(
- partial_sum, single_point, ps_is_zero, sp_is_zero
-):
- """Padd the partial sum with the single point.
-
- Check if the partial sum is equal to the single point first. If they are
- equal, then double the partial sum.
-
- Args:
- partial_sum: The partial sum.
- single_point: The single point.
- ps_is_zero: The zero states of the partial sum.
- sp_is_zero: The zero states of the single point.
-
- Returns:
- The new partial sum.
- """
- # coordinate_dim, batch_dim, precision_dim = partial_sum.shape
- _, batch_dim, _ = partial_sum.shape
- new_partial_sum = jec.padd_rns_xyzz_pack(partial_sum, single_point)
- double_partial_sum = jec.pdul_rns_xyzz_pack(partial_sum)
- cond_equal = jnp.all(partial_sum == single_point, axis=(0, 2)).reshape(
- 1, batch_dim, 1
- )
- cond_sp_zero = jnp.equal(sp_is_zero, 1).reshape(1, batch_dim, 1)
- cond_ps_zero = jnp.equal(ps_is_zero, 1).reshape(1, batch_dim, 1)
- result_1 = jnp.where(cond_sp_zero, partial_sum, single_point)
- result_2 = jnp.where(
- jnp.logical_or(cond_sp_zero, cond_ps_zero), result_1, new_partial_sum
- )
- reuslt_3 = jnp.where(cond_equal, double_partial_sum, result_2)
- return reuslt_3
-
-
-def bucket_accumulation_algorithm(
- all_buckets: jnp.ndarray,
- all_points: jnp.ndarray,
- selection_list: jnp.ndarray,
- zero_states_list: jnp.ndarray,
- msm_length: int,
-):
- """Non-scan version BA."""
- coordinate_dim, buckets_dim, precision_dim = all_buckets.shape
- for i in range(msm_length):
- point = jax.lax.broadcast_in_dim(
- all_points[i], (coordinate_dim, buckets_dim, precision_dim), (0, 2)
- )
- all_buckets = selective_padd_with_zero(
- all_buckets, point, selection_list[i], zero_states_list[i]
- )
- return all_buckets
-
-
-def bucket_accumulation_scan_algorithm(
- all_buckets: jnp.ndarray,
- all_points: jnp.ndarray,
- selection_list: jnp.ndarray,
- zero_states_list: jnp.ndarray,
- msm_length: int,
-):
- """Scan version BA."""
- coordinate_dim, buckets_dim, precision_dim = all_buckets.shape
-
- def scan_body(buckets, point_with_cond_pack):
- point, selection, zero_states = point_with_cond_pack
- point = jax.lax.broadcast_in_dim(
- point, (coordinate_dim, buckets_dim, precision_dim), (0, 2)
- )
- all_buckets = selective_padd_with_zero(
- buckets, point, selection, zero_states
- )
- return all_buckets, None
-
- all_buckets, _ = jax.lax.scan(
- scan_body,
- all_buckets,
- (all_points, selection_list, zero_states_list),
- length=msm_length,
- )
- return all_buckets
-
-
-def bucket_accumulation_index_algorithm(
- all_buckets: jnp.ndarray,
- all_points: jnp.ndarray,
- selection_index_list: jnp.ndarray,
- zero_states_list: jnp.ndarray,
- msm_length: int,
-):
- """Non-scan version BA with index selection."""
- # buckets_dim is not used in the algorithm, changed it into _.
- coordinate_dim, window_dim, _, precision_dim = all_buckets.shape
- for i in range(msm_length):
- point = jax.lax.broadcast_in_dim(
- all_points[i], (coordinate_dim, window_dim, precision_dim), (0, 2)
- )
- selective_buckets = all_buckets[
- :, jnp.arange(window_dim), selection_index_list[i], :
- ]
- selective_zero_states = zero_states_list[
- i, jnp.arange(window_dim), selection_index_list[i]
- ]
- selective_update = padd_with_zero_alter(
- selective_buckets, point, selective_zero_states
- )
- all_buckets = all_buckets.at[
- :, jnp.arange(window_dim), selection_index_list[i], :
- ].set(selective_update)
- return all_buckets
-
-
-def bucket_accumulation_index_scan_algorithm(
- all_buckets: jnp.ndarray,
- all_points: jnp.ndarray,
- selection_index_list: jnp.ndarray,
- zero_states_list: jnp.ndarray,
- msm_length: int,
-):
- """Scan version BA with index selection."""
- # buckets_dim is not used in the algorithm, changed it into _.
- coordinate_dim, window_dim, _, precision_dim = all_buckets.shape
-
- def scan_body(buckets, point_with_cond_pack):
- point, selection_index, zero_states = point_with_cond_pack
- point = jax.lax.broadcast_in_dim(
- point, (coordinate_dim, window_dim, precision_dim), (0, 2)
- )
- selective_buckets = buckets[:, jnp.arange(window_dim), selection_index, :]
- selective_zero_states = zero_states[jnp.arange(window_dim), selection_index]
- selective_update = padd_with_zero_alter(
- selective_buckets, point, selective_zero_states
- )
- return (
- buckets.at[:, jnp.arange(window_dim), selection_index, :].set(
- selective_update
- ),
- None,
- )
-
- all_buckets, _ = jax.lax.scan(
- scan_body,
- all_buckets,
- (all_points, selection_index_list, zero_states_list),
- length=msm_length,
- )
- return all_buckets
-
-
-def bucket_reduction_algorithm(
- all_buckets: jnp.ndarray,
- temp_sum: jnp.ndarray,
- window_sum: jnp.ndarray,
- bucket_zero_states_list: jnp.ndarray,
- temp_zero_states_list: jnp.ndarray,
- window_zero_states_list: jnp.ndarray,
- bucket_num_in_window: int,
-):
- """Non-scan version BR."""
- all_buckets = jnp.flip(all_buckets.transpose(2, 0, 1, 3), axis=0)
- for i in range(bucket_num_in_window - 1):
- temp_sum = padd_with_zero(
- temp_sum,
- all_buckets[i],
- temp_zero_states_list[i],
- bucket_zero_states_list[i],
- )
- window_sum = padd_with_zero_and_pdul_check(
- window_sum,
- temp_sum,
- window_zero_states_list[i],
- temp_zero_states_list[i + 1],
- )
- return window_sum
-
-
-def bucket_reduction_scan_algorithm(
- all_buckets: jnp.ndarray,
- temp_sum: jnp.ndarray,
- window_sum: jnp.ndarray,
- bucket_zero_states_list: jnp.ndarray,
- temp_zero_states_list: jnp.ndarray,
- window_zero_states_list: jnp.ndarray,
- bucket_num_in_window: int,
-):
- """Scan version BR."""
- all_buckets = jnp.flip(all_buckets.transpose(2, 0, 1, 3), axis=0)
-
- def scan_body(temp_and_window_sum_pack, bucket_with_cond_pack):
- temp_sum, window_sum = temp_and_window_sum_pack
- (
- bucket,
- bucket_zero_states,
- temp_zero_states,
- temp_zero_states1,
- window_zero_states,
- ) = bucket_with_cond_pack
- temp_sum = padd_with_zero(
- temp_sum, bucket, temp_zero_states, bucket_zero_states
- )
- window_sum = padd_with_zero_and_pdul_check(
- window_sum, temp_sum, window_zero_states, temp_zero_states1
- )
- return (temp_sum, window_sum), None
-
- (_, window_sum), _ = jax.lax.scan(
- scan_body,
- (temp_sum, window_sum),
- (
- all_buckets[: bucket_num_in_window - 1],
- bucket_zero_states_list[: bucket_num_in_window - 1],
- temp_zero_states_list[: bucket_num_in_window - 1],
- temp_zero_states_list[1:],
- window_zero_states_list[: bucket_num_in_window - 1],
- ),
- length=bucket_num_in_window - 1,
- )
-
- return window_sum
-
-
-def window_merge_algorithm(window_sum: jnp.ndarray, slice_length: int):
- """Non-scan version WM."""
- coordinate_dim, window_dim, precision_dim = window_sum.shape
- result = window_sum[:, window_dim - 1, :].reshape(
- (coordinate_dim, 1, precision_dim)
- )
- for w in range(window_dim - 2, -1, -1):
- for _ in range(slice_length):
- result = jec.pdul_rns_xyzz_pack(result)
- result = jec.padd_rns_xyzz_pack(
- result,
- window_sum[:, w, :].reshape((coordinate_dim, 1, util.NUM_MODULI)),
- )
-
- result = result.reshape((coordinate_dim, precision_dim))
- return result
-
-
-def window_merge_scan_algorithm(window_sum: jnp.ndarray, slice_length: int):
- """Scan version WM."""
- coordinate_dim, window_dim, precision_dim = window_sum.shape
- window_sum = window_sum.transpose(1, 0, 2)
- result = window_sum[window_dim - 1, :, :].reshape(
- (coordinate_dim, 1, precision_dim)
- )
-
- def fori_loop_body(_, result):
- result = jec.pdul_rns_xyzz_pack(result)
- return result
-
- def scan_body(result, window_sum):
- result = jax.lax.fori_loop(0, slice_length, fori_loop_body, result)
- result = jec.padd_rns_xyzz_pack(
- result, window_sum.reshape((coordinate_dim, 1, util.NUM_MODULI))
- )
- return result, None
-
- result, _ = jax.lax.scan(
- scan_body,
- result,
- window_sum[: window_dim - 1, :, :],
- reverse=True,
- length=window_dim - 1,
- )
- result = result.reshape((coordinate_dim, precision_dim))
- return result
-
-
-class MSMPippenger:
- """Pippenger algorithm for elliptic curves.
-
- Attributes:
- coordinate_num: The number of coordinates in the elliptic curve.
- slice_length: The length of each slice in the elliptic curve.
- window_num: The number of windows in the elliptic curve.
- bucket_num_per_window: The number of buckets in each window.
- bucket_num_in_window: The number of buckets in each window.
- slice_mask: The mask for the slices in the elliptic curve.
- blank_point: A JAX array of zeros, used to initialize the buckets.
- all_buckets: A JAX array of all the buckets in the elliptic curve.
- points: A list of JAX arrays, where each array represents an Orignal point
- from the trace.
- scalars: A list of integers, where each integer represents an Orignal scalar
- from the trace.
- all_points: A JAX array of all the points in the elliptic curve. from the
- trace.
- window_sum: A JAX array of the window sum.
- bucket_zero_states: A JAX array of the zero states for the buckets.
- temp_sum_zero_states: A JAX array of the zero states for the temp sum.
- window_sum_zero_states: A JAX array of the zero states for the window sum.
- zero_states_list: A JAX array of the zero states for the buckets.
- selection_list: A JAX array of the selection states for the buckets.
- selection_index_list: A JAX array of the selection index for the buckets.
- msm_length: The length of the MSM trace.
- result: The final elliptic curve.
- rns_mat: The lazy matrix used for padding and doubling.
- """
-
- def __init__(self, slice_length):
- self.coordinate_num = util.COORDINATE_NUM
-
- self.slice_length = slice_length
- self.window_num = int(math.ceil(254 / self.slice_length)) #
- self.bucket_num_per_window = 2**self.slice_length
- self.slice_mask = self.bucket_num_per_window - 1
- self.blank_point = util.int_list_to_array(
- [0, 0, 0, 0], util.BASE, util.NUM_MODULI
- ).reshape(self.coordinate_num, 1, util.NUM_MODULI)
-
- self.all_buckets = jnp.broadcast_to(
- self.blank_point.reshape(
- 1, self.coordinate_num, 1, util.NUM_MODULI
- ).transpose(1, 0, 2, 3),
- (
- self.coordinate_num,
- self.window_num,
- self.bucket_num_per_window,
- util.NUM_MODULI,
- ),
- )
-
- self.window_sum: jnp.ndarray
-
- self.msm_length = 0
- self.bucket_zero_states: jnp.ndarray
- self.temp_sum_zero_states: jnp.ndarray
- self.window_sum_zero_states: jnp.ndarray
- self.zero_states_list: jnp.ndarray
- self.selection_list: jnp.ndarray
- self.selection_index_list: jnp.ndarray
- self.all_points: jnp.ndarray
-
- self.scalars: List[int] = [] # Orignal scalar from the trace
- # [Points, Points, ..., Points]
- self.points: List[jnp.ndarray] = [] # Orignal points from the trace
- self.rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT)
-
- self.result = None
-
- def initialize(self, scalars, points):
- """Initialize the Pippenger algorithm.
-
- Args:
- scalars: A list of integers, where each integer represents an Orignal
- scalar from the trace.
- points: A list of JAX arrays, where each array represents an Orignal point
- from the trace.
- """
- # Initial internal selection from the scalar
- self.scalars = scalars
- self.msm_length = len(scalars)
-
- # Convert high-precision points into a vector of low-precision chunks
- self.points = [
- util.int_list_to_array_rns(coordinates + [1, 1])
- for coordinates in points
- ] # pytype: disable=container-type-mismatch
-
- self.all_points = jnp.array(self.points).astype(jnp.uint16)
-
- # For BA
- zero_states_pylist, selection_pylist, selection_index_pylist = (
- self.construct_ba_zero_states_and_selection()
- )
- self.selection_list = jnp.array(selection_pylist, dtype=jnp.uint8).reshape(
- (-1, self.window_num * self.bucket_num_per_window)
- )
- # For index selection version BA
- self.zero_states_list = jnp.array(zero_states_pylist, dtype=jnp.uint8)
- self.selection_index_list = jnp.array(
- selection_index_pylist, dtype=jnp.uint32
- )
-
- # For BR
- (
- bucket_zero_states_py,
- temp_sum_zero_states_py,
- window_sum_zero_states_py,
- ) = self.construct_br_zero_states(
- zero_states_pylist[len(zero_states_pylist) - 1]
- )
- self.bucket_zero_states = jnp.array(bucket_zero_states_py, dtype=jnp.uint8)
- self.temp_sum_zero_states = jnp.array(
- temp_sum_zero_states_py, dtype=jnp.uint8
- )
- self.window_sum_zero_states = jnp.array(
- window_sum_zero_states_py, dtype=jnp.uint8
- )
-
- def bucket_accumulation(self, bucket_accumulation_index_func):
- """BA index selection version."""
- self.all_buckets = bucket_accumulation_index_func(
- self.all_buckets,
- self.all_points,
- self.selection_index_list,
- self.zero_states_list[: self.msm_length],
- )
-
- return self.all_buckets
-
- def bucket_reduction(self, bucket_reduction_func):
- """Reduce the buckets to a single point for each window."""
- temp_sum = jnp.broadcast_to(
- self.blank_point,
- (
- self.coordinate_num,
- self.window_num,
- util.NUM_MODULI,
- ),
- )
- window_sum = jnp.broadcast_to(
- self.blank_point,
- (
- self.coordinate_num,
- self.window_num,
- util.NUM_MODULI,
- ),
- )
- self.window_sum = bucket_reduction_func(
- self.all_buckets,
- temp_sum,
- window_sum,
- self.bucket_zero_states[: self.bucket_num_per_window],
- self.temp_sum_zero_states[: self.bucket_num_per_window],
- self.window_sum_zero_states[: self.bucket_num_per_window],
- )
- return self.window_sum
-
- def window_merge(self, window_merge_func):
- """Merge the windows to form the final elliptic curve.
-
- Args:
- window_merge_func: The function to merge the windows.
-
- Returns:
- The final elliptic curve.
- """
- self.result = window_merge_func(self.window_sum)
- return self.result
-
- def construct_ba_zero_states_and_selection(self):
- """Construct the zero states and selection for the bucket accumulation (BA) step.
-
- Returns:
- A tuple of two lists: the zero states for the bucket accumulation, and the
- selection for the bucket accumulation.
- """
- zero_states = [
- deepcopy([1] * self.bucket_num_per_window)
- for _ in range(self.window_num)
- ]
- zero_states_list = []
- zero_states_list.append(deepcopy(zero_states))
- selection_list = []
- selection_index_list = [] # Used for index selection
- for scalar in self.scalars:
- # Compute the zero states for each scalar by time dependence
- selection = [
- deepcopy(([0] * self.bucket_num_per_window))
- for _ in range(self.window_num)
- ]
- selection_index = []
- for w in range(self.window_num):
- slice_index = (scalar >> (w * self.slice_length)) & self.slice_mask
- zero_states[w][slice_index] = 0
- selection[w][slice_index] = 1
- selection_index.append(slice_index)
-
- selection_list.append(deepcopy(selection))
- zero_states_list.append(deepcopy(zero_states))
- selection_index_list.append(deepcopy(selection_index))
- return zero_states_list, selection_list, selection_index_list
-
- def construct_br_zero_states(self, bucket_zero_states):
- """Construct the zero states for the bucket reduction (BR) step.
-
- Args:
- bucket_zero_states: The zero states of the buckets.
-
- Returns:
- A tuple of three lists: the zero states for the bucket reduction, the zero
- states for the temporary sum, and the zero states for the window sum.
- """
- temp_sum_zero_states = np.array([1] * self.window_num)
- window_sum_zero_states = np.array([1] * self.window_num)
- temp_sum_zero_states_list = []
- window_sum_zero_states_list = []
- temp_sum_zero_states_list.append(temp_sum_zero_states)
- window_sum_zero_states_list.append(window_sum_zero_states)
- bucket_zero_states_list = np.flip(
- np.array(bucket_zero_states).transpose(1, 0), axis=0
- )
- for b in range(self.bucket_num_per_window):
- next_temp_sum_zero_states = (
- temp_sum_zero_states_list[b] & bucket_zero_states_list[b]
- )
- next_window_sum_zero_states = (
- window_sum_zero_states_list[b] & next_temp_sum_zero_states
- )
- temp_sum_zero_states_list.append(next_temp_sum_zero_states)
- window_sum_zero_states_list.append(next_window_sum_zero_states)
- return (
- bucket_zero_states_list,
- temp_sum_zero_states_list,
- window_sum_zero_states_list,
- )
-
-
-#########################
-# Functions for twisted curve
-#########################
-
-
-def padd(partial_sum, single_point):
- return jec.padd_rns_twisted_pack(partial_sum, single_point)
-
-
-def padd_with_pdul_check(partial_sum, single_point):
- # coordinate_dim, batch_dim, precision_dim = partial_sum.shape
- _, batch_dim, _ = partial_sum.shape
- new_partial_sum = jec.padd_rns_twisted_pack(partial_sum, single_point)
- double_partial_sum = jec.pdul_rns_twisted_pack(partial_sum)
- cond_equal = jnp.all(partial_sum == single_point, axis=(0, 2)).reshape(
- 1, batch_dim, 1
- )
- return jnp.where(cond_equal, double_partial_sum, new_partial_sum)
-
-
-def bucket_accumulation_index_scan_parallel_algorithm_twisted(
- all_buckets: jnp.ndarray,
- all_points: jnp.ndarray,
- selection_index_list: jnp.ndarray,
- msm_length: int,
-):
- """Scan version BA with index selection."""
- # buckets_dim is not used in the algorithm, changed it into _.
- coordinate_dim, batch_window_dim, _, precision_dim = all_buckets.shape
- _, _, parallel_dim, _ = (
- all_points.shape
- ) # (serial_dim, coordinate_dim, parallel_dim, precision_dim)
- single_window_dim = batch_window_dim // parallel_dim
-
- def scan_body(buckets, point_with_cond_pack):
- point, selection_index = point_with_cond_pack
- point = jax.lax.broadcast_in_dim(
- point,
- (coordinate_dim, parallel_dim, single_window_dim, precision_dim),
- (0, 1, 3),
- )
- point = point.reshape((coordinate_dim, batch_window_dim, precision_dim))
- selective_buckets = buckets[
- :, jnp.arange(batch_window_dim), selection_index, :
- ]
- selective_update = padd(selective_buckets, point)
- return (
- buckets.at[:, jnp.arange(batch_window_dim), selection_index, :].set(
- selective_update
- ),
- None,
- )
-
- all_buckets, _ = jax.lax.scan(
- scan_body,
- all_buckets,
- (all_points, selection_index_list),
- length=msm_length,
- )
- return all_buckets
-
-
-def bucket_reduction_scan_algorithm_twisted(
- all_buckets: jnp.ndarray,
- temp_sum: jnp.ndarray,
- window_sum: jnp.ndarray,
- bucket_num_in_window: int,
-):
- """Scan version BR."""
- all_buckets = jnp.flip(all_buckets.transpose(2, 0, 1, 3), axis=0)
-
- def scan_body(temp_and_window_sum_pack, buckets):
- temp_sum, window_sum = temp_and_window_sum_pack
- temp_sum = padd(temp_sum, buckets)
- window_sum = padd_with_pdul_check(window_sum, temp_sum)
- return (temp_sum, window_sum), None
-
- (_, window_sum), _ = jax.lax.scan(
- scan_body,
- (temp_sum, window_sum),
- all_buckets[:bucket_num_in_window],
- length=bucket_num_in_window,
- )
-
- return window_sum
-
-
-def batch_window_summation_algorithm_twisted(
- batch_window_sum: jnp.ndarray,
- all_window_sum: jnp.ndarray,
- point_parallel: int,
-):
- """Batch window summation algorithm for twisted curve."""
- # batch_window_dim is not used in the algorithm, changed it into _.
- coordinate_dim, _, precision_dim = all_window_sum.shape
- all_window_sum = all_window_sum.reshape(
- (coordinate_dim, point_parallel, -1, precision_dim)
- ).transpose(1, 0, 2, 3)
-
- def scan_body(batch_window_sum, single_window_sum):
- batch_window_sum = padd(batch_window_sum, single_window_sum)
- return batch_window_sum, None
-
- batch_window_sum, _ = jax.lax.scan(
- scan_body, batch_window_sum, all_window_sum, length=point_parallel
- )
- return batch_window_sum
-
-
-def window_merge_scan_algorithm_twisted(
- window_sum: jnp.ndarray, slice_length: int
-):
- """Scan version WM."""
- coordinate_dim, window_dim, precision_dim = window_sum.shape
- window_sum = window_sum.transpose(1, 0, 2)
- result = window_sum[window_dim - 1, :, :].reshape(
- (coordinate_dim, 1, precision_dim)
- )
-
- def fori_loop_body(_, result):
- result = jec.pdul_rns_twisted_pack(result)
- return result
-
- def scan_body(result, window_sum):
- result = jax.lax.fori_loop(0, slice_length, fori_loop_body, result)
- result = jec.padd_rns_twisted_pack(
- result, window_sum.reshape((coordinate_dim, 1, util.NUM_MODULI))
- )
- return result, None
-
- result, _ = jax.lax.scan(
- scan_body,
- result,
- window_sum[: window_dim - 1, :, :],
- reverse=True,
- length=window_dim - 1,
- )
- result = result.reshape((coordinate_dim, precision_dim))
- return result
-
-
-class MSMPippengerTwisted:
- """Pippenger algorithm for elliptic curves with twisted points.
-
- Attributes:
- coordinate_num: The number of coordinates in the elliptic curve.
- slice_length: The length of each slice in the elliptic curve.
- point_parallel: The number of parallel points in the elliptic curve.
- window_num: The number of windows in the elliptic curve.
- batch_window_num: The number of batch windows in the elliptic curve.
- bucket_num_per_window: The number of buckets in each window.
- slice_mask: The mask for the slices in the elliptic curve.
- blank_point: A JAX array of zeros, used to initialize the buckets.
- all_buckets: A JAX array of all the buckets in the elliptic curve.
- points: A list of JAX arrays, where each array represents an Orignal point
- from the trace.
- scalars: A list of integers, where each integer represents an Orignal scalar
- from the trace.
- all_points: A JAX array of all the points in the elliptic curve. from the
- trace.
- window_sum: A JAX array of the window sum.
- br_temp_sum: A JAX array of the temp sum for bucket reduction.
- batch_window_sum: A JAX array of the batch window sum.
- selection_index_list: A JAX array of the selection index for the buckets.
- msm_length: The length of the MSM trace.
- result: The final elliptic curve.
- rns_mat: The lazy matrix used for padding and doubling.
- """
-
- def __init__(self, slice_length: int, point_parallel: int):
-
- self.coordinate_num = util.COORDINATE_NUM
-
- self.slice_length = slice_length
- self.point_parallel = point_parallel
- self.window_num = int(math.ceil(254 / self.slice_length)) #
- self.batch_window_num = self.window_num * self.point_parallel
- self.bucket_num_per_window = (
- 2**self.slice_length - 1
- ) # Note: here remove the bucket_0
- self.slice_mask = 2**self.slice_length - 1
- self.blank_point = (
- util.int_list_to_array_rns([0, 1, 1, 0])
- .reshape(self.coordinate_num, 1, util.NUM_MODULI)
- .astype(jnp.uint16)
- )
-
- self.all_buckets = jnp.broadcast_to(
- self.blank_point.reshape(
- 1, self.coordinate_num, 1, util.NUM_MODULI
- ).transpose(1, 0, 2, 3),
- (
- self.coordinate_num,
- self.window_num,
- self.bucket_num_per_window,
- util.NUM_MODULI,
- ),
- )
-
- self.all_buckets = jnp.tile(
- self.all_buckets, (1, self.point_parallel, 1, 1)
- )
-
- self.window_sum: jnp.ndarray
- self.br_temp_sum: jnp.ndarray
- self.batch_window_sum: jnp.ndarray
-
- self.msm_length = 0
-
- self.selection_index_list: jnp.ndarray
- self.all_points: jnp.ndarray
-
- self.scalars: List[int] = [] # Orignal scalar from the trace
- # [Points, Points, ..., Points]
- self.points: List[jnp.ndarray] = [] # Orignal points from the trace
- self.rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT)
-
- self.result = None
-
- def initialize(self, scalars, points):
- """Initialize the Pippenger algorithm.
-
- Args:
- scalars: A list of integers, where each integer represents an Orignal
- scalar from the trace.
- points: A list of JAX arrays, where each array represents an Orignal point
- from the trace.
- """
- # Initial internal selection from the scalar
- self.scalars = scalars
- self.msm_length = len(scalars)
-
- # Convert high-precision points into a vector of low-precision chunks
- self.points = [
- util.int_list_to_array_rns(coordinates) for coordinates in points
- ] # pytype: disable=container-type-mismatch
-
- self.all_points = jnp.array(self.points).astype(jnp.uint16)
- _, coordinate_dim, precision_dim = self.all_points.shape
-
- # For BA
- selection_index_pylist = self.construct_ba_selection()
- # Note: it contains uint(-1) for the bucket_0.
- # In BA, it may cause some undefined behavior when do bucket selection
- # it is correct now, because when setting buckets after the computation,
- # jax.numpy will ignore the index with uint(-1) out of index.
- self.selection_index_list = jnp.array(selection_index_pylist).astype(
- jnp.uint32
- )
- _, window_dim = self.selection_index_list.shape
-
- # Batch construction
- self.all_points = self.all_points.reshape(
- (-1, self.point_parallel, coordinate_dim, precision_dim)
- ).transpose(0, 2, 1, 3)
- self.selection_index_list = self.selection_index_list.reshape(
- (-1, window_dim * self.point_parallel)
- )
- self.br_temp_sum = jnp.broadcast_to(
- self.blank_point,
- (
- self.coordinate_num,
- self.batch_window_num,
- util.NUM_MODULI,
- ),
- )
- self.window_sum = jnp.broadcast_to(
- self.blank_point,
- (
- self.coordinate_num,
- self.batch_window_num,
- util.NUM_MODULI,
- ),
- )
- self.batch_window_sum = jnp.broadcast_to(
- self.blank_point,
- (
- self.coordinate_num,
- self.window_num,
- util.NUM_MODULI,
- ),
- )
-
- def bucket_accumulation(self, bucket_accumulation_index_func):
- """BA index selection version."""
- self.all_buckets = bucket_accumulation_index_func(
- self.all_buckets, self.all_points, self.selection_index_list
- )
- return self.all_buckets
-
- def bucket_reduction(self, bucket_reduction_func):
- """Reduce the buckets to a single point for each window."""
- temp_sum = jnp.broadcast_to(
- self.blank_point,
- (
- self.coordinate_num,
- self.batch_window_num,
- util.NUM_MODULI,
- ),
- )
- window_sum = jnp.broadcast_to(
- self.blank_point,
- (
- self.coordinate_num,
- self.batch_window_num,
- util.NUM_MODULI,
- ),
- )
- self.window_sum = bucket_reduction_func(
- self.all_buckets, temp_sum, window_sum
- )
- return self.window_sum
-
- def batch_window_summation(self, batch_window_summation_func):
- """Sum the batch windows to form the final window sum."""
- batch_window_sum = jnp.broadcast_to(
- self.blank_point,
- (
- self.coordinate_num,
- self.window_num,
- util.NUM_MODULI,
- ),
- )
- self.window_sum = batch_window_summation_func(
- batch_window_sum, self.window_sum
- )
- return self.window_sum
-
- def window_merge(self, window_merge_func):
- """Merge the windows to form the final elliptic curve."""
- self.result = window_merge_func(self.window_sum)
- return self.result
-
- def construct_ba_selection(self):
- selection_index_list = [] # Used for index selection
- for scalar in self.scalars:
- # Compute the zero states for each scalar by time dependence
- selection_index = []
- for w in range(self.window_num):
- slice_index = (
- (scalar >> (w * self.slice_length)) & self.slice_mask
- ) - 1
- selection_index.append(slice_index)
- selection_index_list.append(deepcopy(selection_index))
- return selection_index_list
diff --git a/jaxite_ec/profiler.py b/jaxite_ec/profiler.py
new file mode 100644
index 0000000..85713d7
--- /dev/null
+++ b/jaxite_ec/profiler.py
@@ -0,0 +1,1067 @@
+import csv
+import gzip
+import json
+import os
+import statistics
+from typing import Any, Callable, Dict, List, Optional, Tuple, cast
+import warnings
+import jax
+import jax.numpy as jnp
+import pandas as pd
+
+
+class DataFrameGenerator:
+ """A utility class for building pandas DataFrames from column data."""
+
+ def __init__(self):
+ """Initialize an empty DataFrameGenerator."""
+ self.data: Dict[str, List[Any]] = {}
+
+ def add_data(self, column_name: str, values: List[Any]) -> None:
+ """Add data to a specific column.
+
+ Args:
+ column_name: Name of the column to add data to
+ values: List of values to add to the column
+ """
+ if not isinstance(column_name, str):
+ raise ValueError("column_name must be a string")
+ if not isinstance(values, list):
+ raise ValueError("values must be a list")
+
+ if column_name not in self.data:
+ self.data[column_name] = []
+ self.data[column_name].extend(values)
+
+ def add_single_value(self, column_name: str, value: Any) -> None:
+ """Add a single value to a specific column.
+
+ Args:
+ column_name: Name of the column to add data to
+ value: Single value to add to the column
+ """
+ self.add_data(column_name, [value])
+
+ def get_column_lengths(self) -> Dict[str, int]:
+ """Get the length of each column.
+
+ Returns:
+ Dictionary mapping column names to their lengths
+ """
+ return {col: len(values) for col, values in self.data.items()}
+
+ def is_balanced(self) -> bool:
+ """Check if all columns have the same length.
+
+ Returns:
+ True if all columns have the same length, False otherwise
+ """
+ if not self.data:
+ return True
+ lengths = set(len(col) for col in self.data.values())
+ return len(lengths) == 1
+
+ def to_dataframe(self, auto_balance: bool = True) -> pd.DataFrame:
+ """Convert the stored data to a pandas DataFrame.
+
+ Args:
+ auto_balance: If True, automatically trim columns to the minimum length.
+ If False, raise an error if columns have different lengths.
+
+ Returns:
+ pandas DataFrame with the stored data
+
+ Raises:
+ ValueError: If auto_balance is False and columns have different lengths
+ """
+ if not self.data:
+ return pd.DataFrame()
+
+ if not auto_balance and not self.is_balanced():
+ lengths = self.get_column_lengths()
+ raise ValueError(f"Columns have different lengths: {lengths}")
+
+ # Find the minimum length among all columns
+ min_len = min(len(col) for col in self.data.values())
+
+ # Trim each column to the minimum length
+ trimmed_data = {k: v[:min_len] for k, v in self.data.items()}
+
+ return pd.DataFrame(trimmed_data)
+
+ def clear(self) -> None:
+ """Clear all stored data."""
+ self.data.clear()
+
+ def get_column_names(self) -> List[str]:
+ """Get the names of all columns.
+
+ Returns:
+ List of column names
+ """
+ return list(self.data.keys())
+
+ def has_column(self, column_name: str) -> bool:
+ """Check if a column exists.
+
+ Args:
+ column_name: Name of the column to check
+
+ Returns:
+ True if the column exists, False otherwise
+ """
+ return column_name in self.data
+
+ def merge(self, other_dataframe_generator: "DataFrameGenerator"):
+ """Merge the stored data with another DataFrameGenerator.
+
+ Args:
+ other_dataframe_generator: Another DataFrameGenerator to merge with
+
+ Returns:
+ Merged DataFrameGenerator
+ """
+ if not isinstance(other_dataframe_generator, DataFrameGenerator):
+ raise ValueError("other_dataframe_generator must be a DataFrameGenerator")
+ # Check if this DataFrameGenerator is empty
+ if not self.data:
+ self.data = other_dataframe_generator.data
+ return
+ # Check if the other DataFrameGenerator has the same column names
+ if not set(self.get_column_names()) == set(
+ other_dataframe_generator.get_column_names()
+ ):
+ print("The two DataFrameGenerators have different column names")
+ return
+ # raise ValueError("The two DataFrameGenerators have different column names")
+ # Merge the data
+ for column_name in other_dataframe_generator.get_column_names():
+ self.add_data(column_name, other_dataframe_generator.data[column_name])
+
+ def get_header(self) -> List[str]:
+ """Get the header of the DataFrameGenerator.
+
+ Returns:
+ List of column names
+ """
+ return list(self.data.keys())
+
+ def get_row_dict(self, index: int) -> Dict[str, Any]:
+ """Get a row of the DataFrameGenerator.
+
+ Returns:
+ Dictionary of column names and values
+ """
+ return {
+ column_name: self.data[column_name][index]
+ for column_name in self.get_column_names()
+ }
+
+
+class TraceParser:
+
+ def __init__(self, trace_dir: str):
+ self.trace_dir = trace_dir
+
+ def set_trace_dir(self, new_dir: str):
+ """Set a new trace directory for the parser."""
+ self.trace_dir = new_dir
+
+ def find_trace_file(self):
+ """Recursively search for the latest .trace.json.gz file in the trace_dir.
+
+ Returns the full path to the file, or None if not found.
+ """
+ trace_files = []
+ for root, _, files in os.walk(self.trace_dir):
+ for file in files:
+ if file.endswith(".trace.json.gz"):
+ trace_files.append(os.path.join(root, file))
+
+ if not trace_files:
+ return None
+
+ # Return the most recently modified file
+ return max(trace_files, key=os.path.getmtime)
+
+ def read_trace_json(self):
+ """Finds, unzips, and reads the JSON content from the trace file.
+
+ Returns the loaded JSON object, or None if not found or error.
+ """
+ trace_file = self.find_trace_file()
+ if trace_file is None:
+ print("No trace file found.")
+ return None
+ try:
+ with gzip.open(trace_file, "rt", encoding="utf-8") as f:
+ data = json.load(f)
+ return data
+ except Exception as e:
+ print(f"Error reading trace file: {e}")
+ return None
+
+ def parse_trace_csv(self):
+ """Parses the trace CSV file and returns a list of trace events."""
+ csv_file = os.path.join(self.trace_dir, "trace_events.csv")
+
+ # Read the trace JSON data
+ trace_data = self.read_trace_json()
+ if trace_data is None:
+ print("Failed to read trace data")
+ return None
+
+ # Extract trace events
+ trace_events = trace_data.get("traceEvents", [])
+ if not trace_events:
+ print("No trace events found in the data")
+ return None
+
+ headers = ["pid", "tid", "ts", "dur", "ph", "name", "args"]
+ # Write to CSV directly
+ with open(csv_file, "w", newline="", encoding="utf-8") as f:
+ writer = csv.DictWriter(f, fieldnames=headers)
+ writer.writeheader()
+ for event in trace_events:
+ # Convert args dictionary to string if it exists
+ if "args" in event:
+ event["args"] = json.dumps(event["args"])
+ else:
+ event["args"] = ""
+
+ # Write the event
+ writer.writerow(event)
+ print(f"Trace events written to: {csv_file}")
+
+
+def calculate_statistics(data: List[Any]) -> Dict[str, Any]:
+ """Calculate the statistics of the data.
+
+ Args:
+ data: List of data
+
+ Returns:
+ Dictionary containing the statistics
+ """
+ mean_value = statistics.mean(data)
+ if len(data) == 1:
+ std_value = 0
+ else:
+ std_value = statistics.stdev(data)
+ min_value = min(data)
+ max_value = max(data)
+ median_value = statistics.median(data)
+ return {
+ "mean": mean_value,
+ "std": std_value,
+ "min": min_value,
+ "max": max_value,
+ "median": median_value,
+ }
+
+
+def analyze_trace_bottlenecks(
+ trace_folder: str,
+ top_k: int = 10,
+ device_pid: int = 3,
+) -> List[Dict[str, Any]]:
+ """Analyze a single-iteration TPU trace and return the top-K bottleneck ops.
+
+ Each returned entry contains the op name, duration, source file/line,
+ abbreviated source stack, HLO category, estimated FLOPs, and bytes
+ accessed — everything needed to identify where kernel time is spent
+ and which Python line produced the HLO op.
+
+ Args:
+ trace_folder: Path to an iteration's trace folder (contains
+ ``plugins/profile//*.trace.json.gz``).
+ top_k: Number of top ops to return, sorted by duration descending.
+ device_pid: The ``pid`` of the target TPU device in the trace (default 3 =
+ ``/device:TPU:0``).
+
+ Returns:
+ A list of dicts, each with keys:
+ ``name``, ``dur_us``, ``pct``, ``source``, ``source_stack``,
+ ``hlo_category``, ``model_flops``, ``bytes_accessed``.
+ """
+ parser = TraceParser(trace_folder)
+ trace_json = parser.read_trace_json()
+ if trace_json is None:
+ return []
+ events = trace_json.get("traceEvents", [])
+
+ # Keep only device-level duration events for the target PID.
+ dev_events = [
+ e
+ for e in events
+ if e.get("pid") == device_pid
+ and e.get("ph") == "X"
+ and e.get("dur") is not None
+ ]
+ if not dev_events:
+ return []
+
+ total_us = sum(e["dur"] for e in dev_events)
+ dev_events.sort(key=lambda e: -e["dur"])
+
+ results = []
+ for e in dev_events[:top_k]:
+ args = e.get("args") or {}
+ source = args.get("source", "")
+ stack_raw = args.get("source_stack", "")
+ # Trim the stack to just the user-code frames (skip jax internals).
+ stack_lines = []
+ if stack_raw:
+ for line in stack_raw.strip().split("\n"):
+ line = line.strip()
+ if not line:
+ continue
+ # Keep lines from the user's project; skip jax/site-packages.
+ if "site-packages" not in line:
+ stack_lines.append(line)
+ results.append({
+ "name": e.get("name", ""),
+ "dur_us": e["dur"],
+ "pct": 100.0 * e["dur"] / total_us if total_us else 0,
+ "source": source,
+ "source_stack": stack_lines,
+ "hlo_category": args.get("hlo_category", ""),
+ "model_flops": args.get("model_flops", 0),
+ "bytes_accessed": args.get("bytes_accessed", 0),
+ })
+ return results
+
+
+def print_trace_bottlenecks(
+ trace_folder: str,
+ top_k: int = 10,
+ device_pid: int = 3,
+) -> None:
+ """Pretty-print the top-K bottleneck ops from a single-iteration trace.
+
+ Convenience wrapper around :func:`analyze_trace_bottlenecks`.
+ """
+ bottlenecks = analyze_trace_bottlenecks(trace_folder, top_k, device_pid)
+ if not bottlenecks:
+ print(" (no device events found)")
+ return
+ print(
+ f" {'Rank':<5} {'Dur (ms)':>10} {'%':>6} {'Category':<12}"
+ f" {'Op Name':<40} {'Source'}"
+ )
+ print(" " + "-" * 110)
+ for i, b in enumerate(bottlenecks):
+ print(
+ f" {i+1:<5} {b['dur_us']/1000:>10.2f} {b['pct']:>5.1f}%"
+ f" {b['hlo_category']:<12} {b['name'][:40]:<40} {b['source']}"
+ )
+ if b["source_stack"]:
+ for frame in b["source_stack"][:3]:
+ print(f" {'':>24} ↳ {frame}")
+
+
+def list_add(list1: List[Any], list2: List[Any]) -> List[Any]:
+ """Sum two lists element-wise.
+
+ Args:
+ list1: First list to sum
+ list2: Second list to sum
+
+ Returns:
+ List of the sum of the two lists
+ """
+ assert len(list1) == len(list2), "The two lists must have the same length"
+ return [e1 + e2 for e1, e2 in zip(list1, list2)]
+
+
+class KernelWrapperBase:
+ """Unified interface consumed by ``Profiler``.
+
+ Subclasses expose a callable (``get_compiled_function``) plus the metadata
+ the profiler needs to feed and account for it. The profiler never branches
+ on the concrete wrapper type — it only calls methods defined here.
+ """
+
+ def __init__(
+ self,
+ kernel_name: str,
+ input_structs: List[Tuple[Tuple[int, ...], jnp.dtype]],
+ mesh: Optional[jax.sharding.Mesh] = None,
+ input_shardings: Optional[Tuple[jax.sharding.Sharding, ...]] = None,
+ output_sharding: Optional[jax.sharding.Sharding] = None,
+ enable_sharding: bool = False,
+ ):
+ self.kernel_name = kernel_name
+ self.input_structs = input_structs
+ self.mesh = mesh
+ self.input_shardings = input_shardings
+ self.output_sharding = output_sharding
+ self.enable_sharding = enable_sharding
+ self.callable_function_name = None
+ self.jit_function_name = None
+
+ # ---- Profiler-facing unified API ----
+ def get_compiled_function(self) -> Callable[..., jnp.ndarray]:
+ raise NotImplementedError
+
+ def get_input_structs(self) -> List[Tuple[Tuple[int, ...], jnp.dtype]]:
+ return self.input_structs
+
+ def get_kernel_name(self) -> str:
+ return self.kernel_name
+
+ def get_input_arrays(self) -> Optional[List[jnp.ndarray]]:
+ """Return concrete inputs, or None to let Profiler fabricate random ones."""
+ return None
+
+ def shard_inputs(self, input_arrays: List[jnp.ndarray]) -> List[jnp.ndarray]:
+ if self.enable_sharding and self.input_shardings:
+ return [
+ jax.device_put(arr, s)
+ for arr, s in zip(input_arrays, self.input_shardings)
+ ]
+ return input_arrays
+
+
+class KernelWrapper(KernelWrapperBase):
+
+ def __init__(
+ self,
+ kernel_name: str,
+ function_to_wrap: Callable,
+ input_structs: List[Tuple[Tuple[int, ...], jnp.dtype]],
+ mesh: Optional[jax.sharding.Mesh] = None,
+ input_shardings: Optional[Tuple[jax.sharding.Sharding, ...]] = None,
+ output_sharding: Optional[jax.sharding.Sharding] = None,
+ parameters: Optional[Dict[str, Any]] = {},
+ enable_sharding: bool = False,
+ ):
+ super().__init__(
+ kernel_name=kernel_name,
+ input_structs=input_structs,
+ mesh=mesh,
+ input_shardings=input_shardings,
+ output_sharding=output_sharding,
+ enable_sharding=enable_sharding,
+ )
+ self.callable_function = function_to_wrap
+ self.parameters = parameters
+
+ self.jit_lower = None
+ self.jit_compiled_function = None
+
+ # Compile immediately upon initialization
+ self._compile()
+
+ def _compile(self):
+ jax_input_structs = []
+ if self.enable_sharding and self.input_shardings:
+ for (shape, dtype), sharding in zip(
+ self.input_structs, self.input_shardings
+ ):
+ jax_input_structs.append(
+ jax.ShapeDtypeStruct(shape, dtype, sharding=sharding)
+ )
+ else:
+ for shape, dtype in self.input_structs:
+ jax_input_structs.append(jax.ShapeDtypeStruct(shape, dtype))
+
+ # NOTE: Do not change the name of the function, it is used for profiling
+ if self.parameters:
+
+ def compiled_kernel_function(*jax_array_inputs):
+ return self.callable_function(
+ *jax_array_inputs, parameters=self.parameters
+ )
+
+ else:
+
+ def compiled_kernel_function(*jax_array_inputs):
+ return self.callable_function(*jax_array_inputs)
+
+ if self.enable_sharding and self.mesh:
+ with self.mesh:
+ self.jit_lower = jax.jit(
+ jax.named_call(compiled_kernel_function, name=self.kernel_name),
+ in_shardings=self.input_shardings,
+ out_shardings=self.output_sharding,
+ ).lower(*jax_input_structs)
+ else:
+ self.jit_lower = jax.jit(
+ jax.named_call(compiled_kernel_function, name=self.kernel_name)
+ ).lower(*jax_input_structs)
+
+ self.jit_compiled_function = self.jit_lower.compile()
+
+ def get_compiled_function(self) -> Callable[..., jnp.ndarray]:
+ assert self.jit_compiled_function is not None, "Kernel not compiled"
+ if self.enable_sharding and self.mesh:
+
+ def compiled_with_mesh(*jax_array_inputs):
+ if self.mesh is not None:
+ with self.mesh:
+ return self.jit_compiled_function(*jax_array_inputs)
+ return self.jit_compiled_function(*jax_array_inputs)
+
+ return compiled_with_mesh
+ return self.jit_compiled_function
+
+
+class PrecompiledKernelWrapper(KernelWrapperBase):
+ """Wrapper around an already-compiled/externally-managed callable.
+
+ Unlike ``KernelWrapper``, this does NOT re-trace the function under
+ ``jax.jit``. Use it when the callable already dispatches to pre-compiled
+ executables or carries internal state that cannot be safely re-jitted
+ (e.g. a stateful context method like ``FusionMSMContext.multiscalar_multiply``
+ which internally looks up pre-compiled, pre-sharded kernels).
+
+ The wrapper also takes concrete input arrays directly — Profiler will use
+ them as-is instead of fabricating random inputs from shape/dtype, which is
+ essential when the inputs must match a specific layout (sharding, padding)
+ baked into the compiled kernel.
+ """
+
+ def __init__(
+ self,
+ kernel_name: str,
+ callable_function: Callable,
+ input_arrays: List[jnp.ndarray],
+ enable_sharding: bool = False,
+ callable_function_name: Optional[str] = None,
+ ):
+ input_arrays = list(input_arrays)
+ super().__init__(
+ kernel_name=kernel_name,
+ input_structs=[(a.shape, a.dtype) for a in input_arrays],
+ enable_sharding=enable_sharding,
+ )
+ self.callable_function = callable_function
+ self._input_arrays = input_arrays
+ self.callable_function_name = callable_function_name
+ self.jit_function_name = (
+ f"jit_{callable_function_name}" if callable_function_name else None
+ )
+
+ def get_compiled_function(self) -> Callable[..., jnp.ndarray]:
+ return self.callable_function
+
+ def get_input_arrays(self) -> List[jnp.ndarray]:
+ """Concrete inputs preserved as-is (sharding, layout, device placement)."""
+ return self._input_arrays
+
+ def shard_inputs(self, input_arrays: List[jnp.ndarray]) -> List[jnp.ndarray]:
+ # Inputs were provided pre-placed; nothing to do.
+ return input_arrays
+
+
+class Profiler:
+
+ def __init__(
+ self,
+ output_trace_path: str,
+ profile_naming: str,
+ configuration: Optional[Dict[str, Any]] = None,
+ ):
+ self.trace_dir = output_trace_path
+ self.profiler_name = profile_naming
+ self.profile_dir = os.path.join(self.trace_dir, self.profiler_name)
+ if not os.path.exists(self.profile_dir):
+ os.makedirs(self.profile_dir)
+
+ self.configuration = configuration or {}
+ self.random_seed = self.configuration.get("random_seed", 0)
+ self.iterations = self.configuration.get("iterations", 1)
+ self.save_to_file = self.configuration.get("save_to_file", True)
+ self.enable_sharding = self.configuration.get("enable_sharding", False)
+
+ self.profiles: List[Dict[str, Any]] = []
+ self.profile_name_list: List[str] = []
+
+ # Storage for results
+ self.storage_file = os.path.join(
+ self.profile_dir, f"{self.profiler_name}_results.csv"
+ )
+
+ def add_profile(
+ self,
+ name: str,
+ kernel_wrapper: "KernelWrapperBase",
+ kernel_setting_cols: Optional[Dict[str, Any]] = None,
+ ):
+ if kernel_setting_cols is None:
+ kernel_setting_cols = {}
+ if name in self.profile_name_list:
+ raise ValueError(f"Profiler name {name} already exists")
+
+ self.profile_name_list.append(name)
+
+ profile_folder = os.path.join(self.profile_dir, name)
+ if not os.path.exists(profile_folder):
+ os.makedirs(profile_folder)
+
+ self.profiles.append({
+ "name": name,
+ "wrapper": kernel_wrapper,
+ "settings": kernel_setting_cols,
+ "folder": profile_folder,
+ "failed": False,
+ "trace_events": None,
+ "filtered_events": None,
+ "stats": None,
+ })
+
+ def _get_input_arrays(self, kernel_wrapper: "KernelWrapperBase"):
+ # If the wrapper carries concrete inputs (e.g. pre-sharded, pre-padded
+ # arrays bound to an already-compiled kernel), use them verbatim.
+ provided = kernel_wrapper.get_input_arrays()
+ if provided is not None:
+ for arr in provided:
+ arr.block_until_ready()
+ return provided
+
+ def get_max_value(dtype):
+ if dtype == jnp.uint8:
+ return 128
+ elif dtype == jnp.uint16:
+ return 32768
+ elif dtype == jnp.uint32:
+ return 4294967295
+ elif dtype == jnp.uint64:
+ return 4294967295
+ raise ValueError(f"Unsupported dtype: {dtype}")
+
+ random_key = jax.random.key(self.random_seed)
+ input_arrays = []
+ for shape, dtype in kernel_wrapper.get_input_structs():
+ if jnp.issubdtype(dtype, jnp.floating):
+ input_arrays.append(jax.random.uniform(random_key, shape, dtype))
+ elif jnp.issubdtype(dtype, jnp.integer):
+ input_arrays.append(
+ jax.random.randint(
+ random_key, shape, 0, get_max_value(dtype), dtype
+ )
+ )
+ elif jnp.issubdtype(dtype, jnp.bool_):
+ input_arrays.append(jax.random.bernoulli(random_key, 0.5, shape=shape))
+ else:
+ raise ValueError(f"Unsupported dtype: {dtype}")
+ for input_array in input_arrays:
+ input_array.block_until_ready()
+
+ if self.enable_sharding:
+ input_arrays = kernel_wrapper.shard_inputs(input_arrays)
+
+ return input_arrays
+
+ def profile_all_profilers(self):
+ for profile in self.profiles:
+ print(f"Profiling {profile['name']}")
+ try:
+ # Kernel wrapper is already compiled in its init
+ wrapper = cast(KernelWrapperBase, profile["wrapper"])
+ compiled_function = wrapper.get_compiled_function()
+ input_arrays = self._get_input_arrays(wrapper)
+
+ # Run each iteration in its own trace window so that the TPU
+ # trace buffer is not shared across iterations to avoid buffer
+ # overflow problem for kernel with long running time.
+ iter_folders = []
+ for i in range(self.iterations):
+ iter_folder = os.path.join(profile["folder"], f"iter_{i}")
+ os.makedirs(iter_folder, exist_ok=True)
+ with jax.profiler.trace(iter_folder):
+ compiled_function(*input_arrays).block_until_ready()
+ iter_folders.append(iter_folder)
+ profile["iter_folders"] = iter_folders
+ except Exception as e:
+ print(f"Error profiling {profile['name']}:\n {e}")
+ profile["failed"] = True
+
+ def _parse_json_trace(self, profile):
+ # With per-iteration trace windows, merge trace events from all
+ # iteration sub-folders into one combined list.
+ iter_folders = profile.get("iter_folders", [])
+ if iter_folders:
+ all_trace_events = []
+ trace_file_path = None
+ for folder in iter_folders:
+ parser = TraceParser(folder)
+ tfp = parser.find_trace_file()
+ if tfp is None:
+ continue
+ if trace_file_path is None:
+ trace_file_path = tfp
+ pjson = parser.read_trace_json()
+ if pjson is not None:
+ all_trace_events.extend(pjson.get("traceEvents", []))
+ trace_events = all_trace_events
+ else:
+ # Legacy path: single trace folder.
+ trace_parser = TraceParser(profile["folder"])
+ trace_file_path = trace_parser.find_trace_file()
+ profile_json = trace_parser.read_trace_json()
+ if profile_json is None:
+ warnings.warn(
+ f"{profile['name']}: No trace events found in the data", UserWarning
+ )
+ profile["failed"] = True
+ return None
+ trace_events = profile_json.get("traceEvents", [])
+
+ if not trace_events:
+ warnings.warn(
+ f"{profile['name']}: No trace events found in the data", UserWarning
+ )
+ profile["failed"] = True
+ return None
+ if self.save_to_file:
+ # Save into the same folder as the raw trace file
+ output_dir = os.path.dirname(trace_file_path)
+ profile["output_folder"] = output_dir
+ with open(os.path.join(output_dir, "trace_events.json"), "w") as f:
+ json.dump(trace_events, f, indent=2)
+ profile["trace_events"] = trace_events
+ return trace_events
+
+ def _filter_trace_events(self, profile):
+ trace_events = profile["trace_events"]
+ if trace_events is None:
+ return None
+
+ def merge_filtered_events_by_name(filtered_events):
+ grouped = {}
+ for event in filtered_events:
+ event_name = event.get("name", "unknown")
+ if (
+ "args" in event.keys()
+ and "deduplicated_name" in event["args"].keys()
+ ):
+ event_name += "_" + event["args"]["deduplicated_name"]
+ elif (
+ "custom-call" in event["name"]
+ and "args" in event.keys()
+ and "tf_op" in event["args"].keys()
+ ):
+ event_name += "_" + event["args"]["tf_op"]
+ if event_name not in grouped:
+ grouped[event_name] = []
+ grouped[event_name].append(event)
+
+ merged_filtered_events = {}
+ for event_name, events in grouped.items():
+ merged = events[0].copy()
+ merged["dur"] = [e.get("dur") for e in events if "dur" in e]
+ merged["ts"] = [e.get("ts") for e in events if "ts" in e]
+ merged["repeat_count"] = len(events)
+ merged_filtered_events[event_name] = merged
+ return merged_filtered_events
+
+ filtered_events_list = []
+ # Check if NVIDIA is in device kind OR CPU is used as a fallback if
+ # explicit check needed
+ # But generally JAX trace events differ by backend.
+ # Assuming typical CPU/GPU separation.
+ device_kind = jax.devices()[0].device_kind
+
+ if "NVIDIA" in device_kind:
+ for e in trace_events:
+ if "args" in e and "tf_op" in e["args"]:
+ # Loosen the check for compiled_kernel_function as it might be nested differently or named differently
+ if "compiled_kernel_function" in e["args"].get(
+ "hlo_module", ""
+ ) or "compiled_kernel_function" in e["args"].get("long_name", ""):
+ merged_event = False
+ # Try to merge with existing events
+ for f in filtered_events_list:
+ # Check if correlation_id exists before accessing it
+ if (
+ "correlation_id" in f["args"]
+ and "correlation_id" in e["args"]
+ and f["args"]["correlation_id"] == e["args"]["correlation_id"]
+ and f["name"] == e["name"]
+ ):
+ f["dur"] = f["dur"] + e["dur"]
+ merged_event = True
+ if not merged_event:
+ filtered_events_list.append(e)
+ profile["filtered_events"] = merge_filtered_events_by_name(
+ filtered_events_list
+ )
+
+ elif "TPU" in device_kind:
+ wrapper = cast(KernelWrapperBase, profile["wrapper"])
+ jit_function_name = wrapper.jit_function_name
+ for event in trace_events:
+ if (
+ "pid" not in event.keys() or event["pid"] != 3
+ ): # ToDo: change it into automatic PID detection based on "TPU:0".
+ continue
+ if (
+ "name" in event.keys()
+ and "compiled_kernel_function" in event["name"]
+ and "args" in event.keys()
+ ):
+ filtered_events_list.append(event)
+ elif (
+ "name" in event.keys()
+ and jit_function_name is not None
+ and jit_function_name in event["name"]
+ ):
+ # print(f"Found jit function name {jit_function_name} in event {event['name']}")
+ filtered_events_list.append(event)
+ elif "args" in event.keys() and "long_name" in event["args"].keys():
+ filtered_events_list.append(event)
+ else:
+ continue
+ profile["filtered_events"] = merge_filtered_events_by_name(
+ filtered_events_list
+ )
+ else:
+ # Fallback for CPU or other devices
+ # CPU traces might be different. Let's try to capture events related to our kernel.
+ for event in trace_events:
+ if "name" in event and "compiled_kernel_function" in event["name"]:
+ filtered_events_list.append(event)
+ profile["filtered_events"] = merge_filtered_events_by_name(
+ filtered_events_list
+ )
+
+ # Always save filtered events if we have any
+ if self.save_to_file:
+ # Make sure we don't crash if profile['filtered_events'] is None
+ events_to_dump = (
+ profile["filtered_events"]
+ if profile["filtered_events"] is not None
+ else {}
+ )
+ with open(
+ os.path.join(profile["output_folder"], "filtered_events.json"), "w"
+ ) as f:
+ json.dump(events_to_dump, f, indent=2)
+
+ def _calculate_profiling_statistics(self, profile):
+ if profile["filtered_events"] is None:
+ return
+
+ repeat_count = self.iterations
+ kernel_duration = [0] * repeat_count
+
+ device_kind = jax.devices()[0].device_kind
+
+ if "NVIDIA" in device_kind:
+ filtered_events = cast(Dict[str, Any], profile["filtered_events"])
+ for event in filtered_events.values():
+ if "compiled_kernel_function" in event["args"].get("hlo_module", ""):
+ durations = event["dur"]
+ if not isinstance(durations, list):
+ durations = [durations]
+
+ if len(durations) == repeat_count:
+ kernel_duration = list_add(kernel_duration, durations)
+ elif (
+ len(durations) > repeat_count
+ and len(durations) % repeat_count == 0
+ ):
+ # Assume sequential execution of kernels within one iteration
+ chunk_size = len(durations) // repeat_count
+ aggregated_durations = [
+ sum(durations[i * chunk_size : (i + 1) * chunk_size])
+ for i in range(repeat_count)
+ ]
+ kernel_duration = list_add(kernel_duration, aggregated_durations)
+ else:
+ # Fallback: just take first N or handle mismatch.
+ # For now, adopting CPU strategy of taking first N but this is likely under-reporting.
+ # Ideally log a warning.
+ kernel_duration = list_add(
+ kernel_duration, durations[:repeat_count]
+ )
+ elif "TPU" in device_kind:
+ wrapper = cast(KernelWrapperBase, profile["wrapper"])
+ jit_function_name = wrapper.jit_function_name
+ filtered_events = cast(Dict[str, Any], profile["filtered_events"])
+ for event in filtered_events.values():
+ matches = "compiled_kernel_function" in event["name"] or (
+ jit_function_name is not None and jit_function_name in event["name"]
+ )
+ if not matches:
+ continue
+ durations = event["dur"]
+ if not isinstance(durations, list):
+ durations = [durations]
+ # At larger problem sizes XLA can coalesce per-device event replays
+ # unevenly, producing either a multiple of ``repeat_count`` (one
+ # kernel invocation expanded into several sub-events per iteration)
+ # or fewer samples than ``repeat_count`` (some replays elided).
+ # Mirror the NVIDIA branch's bucket-and-fallback logic instead of
+ # blindly zip-adding, which used to assert out and abort the
+ # post-processing before the CSV was written.
+ if len(durations) == repeat_count:
+ kernel_duration = list_add(kernel_duration, durations)
+ elif (
+ len(durations) > repeat_count and len(durations) % repeat_count == 0
+ ):
+ chunk_size = len(durations) // repeat_count
+ aggregated = [
+ sum(durations[i * chunk_size : (i + 1) * chunk_size])
+ for i in range(repeat_count)
+ ]
+ kernel_duration = list_add(kernel_duration, aggregated)
+ elif len(durations) < repeat_count:
+ padded = durations + [0] * (repeat_count - len(durations))
+ kernel_duration = list_add(kernel_duration, padded)
+ else:
+ kernel_duration = list_add(kernel_duration, durations[:repeat_count])
+ else:
+ # CPU logic - assuming direct name match from filtered events
+ filtered_events = cast(Dict[str, Any], profile["filtered_events"])
+ for event in filtered_events.values():
+ # On CPU, events might be simpler
+ if "compiled_kernel_function" in event.get("name", ""):
+ # DUR might be a single value or list depending on how it was merged
+ durations = event["dur"]
+ if not isinstance(durations, list):
+ durations = [durations]
+
+ # If we have less durations than repeat_count, we might need to pad
+ # or it's a mismatch
+ # For now, let's just add what we have, assuming 1-to-1 or aggregated
+ if len(durations) == repeat_count:
+ kernel_duration = list_add(kernel_duration, durations)
+ elif len(durations) > repeat_count:
+ # Take first N
+ kernel_duration = list_add(
+ kernel_duration, durations[:repeat_count]
+ )
+ else:
+ # Append 0s? Or just take what we have
+ padded = durations + [0] * (repeat_count - len(durations))
+ kernel_duration = list_add(kernel_duration, padded)
+
+ profile["stats"] = {
+ "kernel_all": kernel_duration,
+ }
+
+ def post_process_all_profilers(self):
+ """Post-process trace events and calculate statistics for all profiles."""
+ for profile in self.profiles:
+ if profile["failed"]:
+ continue
+
+ events = self._parse_json_trace(profile)
+ if events is None:
+ continue
+
+ self._filter_trace_events(profile)
+ self._calculate_profiling_statistics(profile)
+
+ self.write_results()
+
+ def get_profiling_dataframe_generator_all_profilers(self):
+ df_generator = DataFrameGenerator()
+ for profile in self.profiles:
+ if profile["failed"] or profile["stats"] is None:
+ continue
+
+ p_df_gen = DataFrameGenerator()
+ p_df_gen.add_single_value(
+ "operation_name", profile["wrapper"].get_kernel_name()
+ )
+
+ for key, value in profile["settings"].items():
+ p_df_gen.add_single_value(key, value)
+
+ all_kernel_duration = profile["stats"]["kernel_all"]
+ for i, duration in enumerate(all_kernel_duration):
+ p_df_gen.add_single_value(f"sample_{i}", duration)
+
+ df_generator.merge(p_df_gen)
+ return df_generator
+
+ def write_results(self):
+ """Write profiling results to a CSV file and print to stdout."""
+ storage_dataframe_generator = (
+ self.get_profiling_dataframe_generator_all_profilers()
+ )
+ # Check if file exists to determine if we need to write header
+ file_exists = os.path.exists(self.storage_file)
+ mode = "a" if file_exists else "w"
+ header = not file_exists
+ storage_dataframe_generator.to_dataframe().to_csv(
+ self.storage_file, mode=mode, header=header, index=False
+ )
+ print(
+ storage_dataframe_generator.to_dataframe().to_csv()
+ ) # Need to see the content of the file in terminal as Google does not have file system
+ print(f"Results written to: {self.storage_file}")
+
+
+def collect_logs(root_dir=".", output_csv_name="all_logs_collected"):
+ """Collects all CSV files found under directories named 'log'
+
+ and aggregates them into a single CSV file.
+ Handles varying headers by taking the union of all found columns.
+ """
+ all_files = []
+
+ # Fieldnames set to collect all unique columns
+ all_fieldnames = set()
+ # To preserve some order, we can use a list and add new ones as we see them
+ ordered_fieldnames = []
+
+ # First pass: identify files and collect all possible fieldnames
+ for dirpath, dirnames, filenames in os.walk(root_dir):
+ path_parts = dirpath.split(os.sep)
+ if "log" in path_parts:
+ for file in filenames:
+ if file.endswith(".csv"):
+ full_path = os.path.join(dirpath, file)
+ all_files.append(full_path)
+ try:
+ with open(full_path, "r", newline="") as csvfile:
+ reader = csv.reader(csvfile)
+ try:
+ header = next(reader)
+ for h in header:
+ if h not in all_fieldnames:
+ all_fieldnames.add(h)
+ ordered_fieldnames.append(h)
+ except StopIteration:
+ # Empty file
+ pass
+ except Exception as e:
+ print(f"Error reading header of {full_path}: {e}")
+
+ if not all_files:
+ print("No CSV files found.")
+ return
+
+ print(f"Found {len(all_files)} CSV files.")
+ print(f"Unified collected columns: {ordered_fieldnames}")
+
+ output_file = os.path.join(root_dir, f"{output_csv_name}.csv")
+ total_rows = 0
+
+ try:
+ with open(output_file, "w", newline="") as csvfile:
+ writer = csv.DictWriter(csvfile, fieldnames=ordered_fieldnames)
+ writer.writeheader()
+
+ for full_path in all_files:
+ try:
+ with open(full_path, "r", newline="") as infile:
+ reader = csv.DictReader(infile)
+ # The DictReader uses the file's own header mapping
+ # We just iterate and write to the master dict writer
+ for row in reader:
+ writer.writerow(row)
+ total_rows += 1
+ except Exception as e:
+ print(f"Error processing {full_path}: {e}")
+
+ print(f"Saved aggregated logs to {os.path.abspath(output_file)}")
+ print(f"Total rows collected: {total_rows}")
+
+ except Exception as e:
+ print(f"Error writing output file: {e}")
diff --git a/jaxite_ec/test_case/t1/zprize_msm_curve_377_bases_dim_1_seed_0.csv b/jaxite_ec/test_case/t1/zprize_msm_curve_377_bases_dim_1_seed_0.csv
deleted file mode 100644
index 9ee1257..0000000
--- a/jaxite_ec/test_case/t1/zprize_msm_curve_377_bases_dim_1_seed_0.csv
+++ /dev/null
@@ -1 +0,0 @@
-0, x, 004cd4683316487f, b89b952325a20f91, 293cd479875ceb77, 2af0a9c29bacf6f9, f8673c91ec915534, aff8c262353486d7, zz: Fp384 "(003F62945CE627D1ECB0F654CD9E80F82DC0E62E484AE9CED02FC5DC5658CD2D6C3AB04E97AD20BD970B5C7E2178709C)", y, 00aa7ad6de759cb7, 270923c09134b2c6, 7f88b0b899bceace, 22071c8c7a09df51, 7f08b237fda96e55, b4d65d75663afe99, zz: Fp384 "(016113C0C182762A49721F7EAF286AFD6B0FA13DD9CE68CE5119C574EDF947D9AEBF6D36C26C3C208AB518E3079AC6FA)"
\ No newline at end of file
diff --git a/jaxite_ec/test_case/t1/zprize_msm_curve_377_res_dim_1_seed_0.csv b/jaxite_ec/test_case/t1/zprize_msm_curve_377_res_dim_1_seed_0.csv
deleted file mode 100644
index 80821f2..0000000
--- a/jaxite_ec/test_case/t1/zprize_msm_curve_377_res_dim_1_seed_0.csv
+++ /dev/null
@@ -1 +0,0 @@
-x, 009976f628cc7c42, d55f93b8c7dfbc61, 86a374e1e5a2fafb, a3e0ba0828a88f69, 9f20ac113f371d74, 77a360edb4c28207, zz: Fp384 "(0134A9468957E7FDC88CF9D69A4E7B3AE93432AE7FFFE93F6BE4F6F287BCA1880E185D98E304467DEC0709D9F7BC7AF0)", y, 01a7e5e4f8816331, 28dcee5c74675219, e54fd183a73529c5, 776dbcf5725a252e, 39fdb795acda186b, 0eadea2601fc454d, zz: Fp384 "(004A4A15E00C14B99875584249FCAF18B0AAA567BF67863D6DCF47CDC39A227154C3369DF6D9A63EECDB3BAA19FD2B58)"
\ No newline at end of file
diff --git a/jaxite_ec/test_case/t1/zprize_msm_curve_377_scalars_dim_1_seed_0.csv b/jaxite_ec/test_case/t1/zprize_msm_curve_377_scalars_dim_1_seed_0.csv
deleted file mode 100644
index 73ac78f..0000000
--- a/jaxite_ec/test_case/t1/zprize_msm_curve_377_scalars_dim_1_seed_0.csv
+++ /dev/null
@@ -1 +0,0 @@
-0, 0f923fffd2a6f534, dc5b6a6901840fc0, fb65827e6efd22a8, 063cded681f5f7b2, zz: Fp256 "(11057EEFA14A393F5EBD62FD9747A76B77CDFC90E07F0BC6D8C87440D758452D)"
\ No newline at end of file
diff --git a/jaxite_ec/test_case/t2/zprize_msm_curve_377_bases_dim_2_seed_0.csv b/jaxite_ec/test_case/t2/zprize_msm_curve_377_bases_dim_2_seed_0.csv
deleted file mode 100644
index d61bb34..0000000
--- a/jaxite_ec/test_case/t2/zprize_msm_curve_377_bases_dim_2_seed_0.csv
+++ /dev/null
@@ -1,2 +0,0 @@
-0, x, 0024f0169fe2cfbb, 2dcc78c274a772fd, a0bf1eb986baf6d8, 0f9d3669482b16c0, e9d8bb1ef8cbea3e, 02f2a0491d3f7f41, zz: Fp384 "(01AC3A384FC584EFD3E7F2C5A2927E7D454875C874A051027B9E7363D08942533EDE85DAE295D8CAB2751085206BCA76)", y, 002da6d2840f4ce6, 1514adf2c130292c, 86aa9bea8cf858ba, 53b98e247e9dc552, 36abb853a15bff31, 803b023ee6aaf142, zz: Fp384 "(011DB83AEC88460820F4868A73B12309EE2E910526E62DB4ACCB303ABF50F86C3985A072ED07A4B81FFB82D8DD247283)"
-1, x, 00b3b2f197a3b127, cde0e1d3d77e7d89, 49d6a6a3dac76741, fc4e1568a3b90dcc, f690d49a6208e299, fcf3195b0a1534a6, zz: Fp384 "(01546AF2ABB4E189E9BBC412FDBF2A8E5EC6E4A3B0AF132E21EE9CEC3EF5E226490FB98D662670FA3CFB3948B7E2A48C)", y, 012e6fd70da14f01, 3841e339306efa6e, c845ceff3aa42d46, 5d869b1205ec884c, 77718567b7535dd9, 1bc1fdfd9ce464a3, zz: Fp384 "(002961A558A885DF227FDB09F8BDF57AF179CB9437FF8828F13E9DF01AE55502F409AAF5058B88F2F7CCC7BC0676A5D4)"
\ No newline at end of file
diff --git a/jaxite_ec/test_case/t2/zprize_msm_curve_377_res_dim_2_seed_0.csv b/jaxite_ec/test_case/t2/zprize_msm_curve_377_res_dim_2_seed_0.csv
deleted file mode 100644
index b28b820..0000000
--- a/jaxite_ec/test_case/t2/zprize_msm_curve_377_res_dim_2_seed_0.csv
+++ /dev/null
@@ -1 +0,0 @@
-x, 001ba77010560355, d4a083cd44ea3b27, 47e444e832f177a9, e5a2cc6f64868bf3, 34875bb04ff9a153, 29a12ab6ca346603, zz: Fp384 "(001E9324B117F5E1E1EA2D63635906F087DD18C9C3840E28CD24FD6960DC4BAC7659B648FA29E89DA0DCDA4B9677254F)", y, 0038529273debc84, 567b59a70a96e53f, f2421d1eeaa395eb, e3d065d79346fc9b, 20a9f718531bd65d, 29662f4a78afce5d, zz: Fp384 "(010FDD33E114E8A119E7855FC6E457A0B2F1FC0436076E9557302DD86CD11EF6148721B9E4E63A01693378E985496444)"
\ No newline at end of file
diff --git a/jaxite_ec/test_case/t2/zprize_msm_curve_377_scalars_dim_2_seed_0.csv b/jaxite_ec/test_case/t2/zprize_msm_curve_377_scalars_dim_2_seed_0.csv
deleted file mode 100644
index f6344a1..0000000
--- a/jaxite_ec/test_case/t2/zprize_msm_curve_377_scalars_dim_2_seed_0.csv
+++ /dev/null
@@ -1,2 +0,0 @@
-0, 0f923fffd2a6f534, dc5b6a6901840fc0, fb65827e6efd22a8, 063cded681f5f7b2, zz: Fp256 "(11057EEFA14A393F5EBD62FD9747A76B77CDFC90E07F0BC6D8C87440D758452D)"
-1, 002e9d0d87c0a600, 1c9a1f731ec9a8d0, 3ca0557886321ce6, e5716b57188ca258, zz: Fp256 "(0C736B82B849EB8E4D8F81136A30803766F90E66AB059D12A6EDD2577E605788)"
\ No newline at end of file
diff --git a/jaxite_ec/test_case/t4/zprize_msm_curve_377_bases_dim_4_seed_0.csv b/jaxite_ec/test_case/t4/zprize_msm_curve_377_bases_dim_4_seed_0.csv
deleted file mode 100644
index 8896d35..0000000
--- a/jaxite_ec/test_case/t4/zprize_msm_curve_377_bases_dim_4_seed_0.csv
+++ /dev/null
@@ -1,4 +0,0 @@
-0, x, 0054dadbda56447b, 1edf94ecc29bac9d, be41bffb21bbab70, 0d13b7683a5604e7, ae2f805045068087, 3bd87a659df906e5, zz: Fp384 "(0132A41F10B5AAE35C21740C1742AA70445C0DAF97DC1F92B7B51023BCC4ED52DE46227513DCFBCE518FD1AEEC21B83B)", y, 013bed8689b68cbd, 11eb13b43aa44895, 86836736dcaf43a5, a9675efb170a9c2c, be3e2a2333c55942, 31edbd28217dbb58, zz: Fp384 "(004F55B794051CB39882288086577AD6E7D199297A9DA267666D6FEC28F3D8781C95781639307F2A9B9C2522FCBD040B)"
-1, x, 002f3f48c6b8f5fa, 720aa4decad34840, a94743660088442e, 712af8cd5eb5d3d1, 8da9cd5acf6bff08, a003e41885a627d0, zz: Fp384 "(012668C2A40C279AD883A98302D2BA3542871C8B48B21AF176A8C704DED10DB7E4753B830250EDDFA47FCD7B163343DC)", y, 0125bd5ea0a1989d, 3cad55ebd9ef75b3, 6f167f9389f8360d, 87346eda19e3e3b8, 50d092aa5b40cec1, daf39285f5fe38d8, zz: Fp384 "(01A4904AB5F9B7C2FBD154025161FEC3414AF1D3B764377D50741A06FC86742B9EC862BFCD8507486CFE67995CEDCDBD)"
-2, x, 0024be0cb95af7cb, 212a56e5183de381, a551cc6f82fb71e2, 1e3996681d117eb1, f9278d11382dfdec, 01a164e98bc73565, zz: Fp384 "(0145535A3BF5EF85CBA87B177AE7672B3ABD8C9A0C35108D9C92201AAFF925B3678D58F150CD1EA5395AC4D141CCDF28)", y, 012530da4d102355, 59cd4f2fb4b622cc, fad0061949e42051, 38bbcb79a57c7fd0, 67da66ec382efb95, ecf7cfddb7cfc0c9, zz: Fp384 "(007028D05152FD66FEA7662953D4C0B6F36891B2276EF5B7F48CA78EFF4892C3ACEB3FEC17D45C780506A27A18521DD6)"
-3, x, 016efb97086a0478, 0ad22f31d98ac0b5, 934213bb4a45b710, bb75e577d1bfc091, 0880564021f4eb65, 58a3d749c1d1b878, zz: Fp384 "(0035B2C1EC332094CE8D30528D3E5E373021DB3BC54A97675AE885C574D1DD0084CE476EFD85EEBD4109A2D2339E8824)", y, 011f2bb8975d0cfe, 19132fe5974c13ea, eb279bb047b1abe8, e9ac98f26391a563, 0101680172684959, e9e6e94a59e90ffa, zz: Fp384 "(00CFEC2077BD515E1AA9B72FCC2AC8A79D4B7F333E9D139C2E9482905CE5C4294261C9ADF1EA8E05540D4834C2461536)"
\ No newline at end of file
diff --git a/jaxite_ec/test_case/t4/zprize_msm_curve_377_res_dim_4_seed_0.csv b/jaxite_ec/test_case/t4/zprize_msm_curve_377_res_dim_4_seed_0.csv
deleted file mode 100644
index ce32abf..0000000
--- a/jaxite_ec/test_case/t4/zprize_msm_curve_377_res_dim_4_seed_0.csv
+++ /dev/null
@@ -1 +0,0 @@
-x, 014d7f8695379436, 16965d5ab3d91b3a, be23f9940b3b2857, 6924eb4496b95388, 2d22b77f18c3d89e, 2feb08e40a36bb91, zz: Fp384 "(0105E8EEFA0C3DE4D1E65A63489065FB0F8C66F1CBFF2422FE246E3CD1443DC699711B81B7790B7653547310240B2AF0)", y, 015845f50dda6c32, 87e7e0cbd525e428, ea1cd4b49a411788, eb7226604c36c531, d4d3a8e5c5d6d925, 0906ee99df45c728, zz: Fp384 "(00D7A61C6E85442A8DE1B4B14D18BD07FDC08AEFD9C71CAD928D404B6FD3F5C347057AE82ED6581F71EBF6407A1E07EE)"
\ No newline at end of file
diff --git a/jaxite_ec/test_case/t4/zprize_msm_curve_377_scalars_dim_4_seed_0.csv b/jaxite_ec/test_case/t4/zprize_msm_curve_377_scalars_dim_4_seed_0.csv
deleted file mode 100644
index 8b24b20..0000000
--- a/jaxite_ec/test_case/t4/zprize_msm_curve_377_scalars_dim_4_seed_0.csv
+++ /dev/null
@@ -1,4 +0,0 @@
-0, 0f923fffd2a6f534, dc5b6a6901840fc0, fb65827e6efd22a8, 063cded681f5f7b2, zz: Fp256 "(11057EEFA14A393F5EBD62FD9747A76B77CDFC90E07F0BC6D8C87440D758452D)"
-1, 002e9d0d87c0a600, 1c9a1f731ec9a8d0, 3ca0557886321ce6, e5716b57188ca258, zz: Fp256 "(0C736B82B849EB8E4D8F81136A30803766F90E66AB059D12A6EDD2577E605788)"
-2, 0e09b93519940300, 0a4867513ba6edb7, 530cce8b6a69232d, 0ecdd08a922d30e3, zz: Fp256 "(0738C6A142F9D317BFD26C26E9629C54E3E28C0CD3B100A6311C83A0BA678390)"
-3, 0a85273b8e6bb8ed, dc0d90c24a4f5c72, 0f57216e312a2f32, 6ef1359bfc3df47f, zz: Fp256 "(04C4AFFA315644F295F99B40E9A87D32380603D675FBA350A0BB852DFEFE7878)"
\ No newline at end of file
diff --git a/jaxite_ec/test_case/t8/zprize_msm_curve_377_bases_dim_8_seed_0.csv b/jaxite_ec/test_case/t8/zprize_msm_curve_377_bases_dim_8_seed_0.csv
deleted file mode 100644
index a2a3b4e..0000000
--- a/jaxite_ec/test_case/t8/zprize_msm_curve_377_bases_dim_8_seed_0.csv
+++ /dev/null
@@ -1,8 +0,0 @@
-0, x, 018df51d28d7bb33, 12125d3dc2345560, dc6f8aa4f602ac7a, 9060043fe5dbc225, f9c8eba7ee4b72f0, 4054363299acd090, zz: Fp384 "(00FAF6788CB83075B17A9E3C865EA5DAA699646B8C42C7B0CCCF7D1306DB76C14D9AE857115C728EAA5053B4DD7E8907)", y, 00600f6808058b1d, 93e9bef7c2902d8c, 9480c5e3a177149d, 8fcd6015fc212417, 5073eadb38affa0a, abbd2a0e0b6894be, zz: Fp384 "(0167A1F5704269F768539A23A97A44611791593BD8556BE1718B8A08B54549698604AABA33DD7EBD8E935596E50380EA)"
-1, x, 00daceeae593097d, e7dff89b697bff72, b4f1081a7c7da896, e16b722bb6c4d340, e91940d6dd520c64, a549a1cc3d4185d6, zz: Fp384 "(0005AA5A34376BA2D6D2982F5EADBCACEFA1FA2999B381AEAAF51F8EE541CBB439C56420ED43DC6F137A4116F42018D0)", y, 01373aad56ca01f0, 92e5fb3bd9b1a845, 7231a14c2005c972, 8cb4c2d02007c407, 7319d9315a06327c, a080605f19f5a0af, zz: Fp384 "(01936530B0244C450ADEB2A2F03119D090FE80E3E4C1878B6066E42A3BFF646C503872BDE9A0ACBDCDFA8DF81FFCF13D)"
-2, x, 002ecd4e42ae033a, f64d0e37eabca682, dca865eb442d4487, 390bb206301b7dc9, e6e5addb49897da3, 885c7248fc2da32c, zz: Fp384 "(00099F98BF058B8521588A6D521FF5D95F0A4E6EC9D9F0658FD513BC654C8E2DA2475CA572C5FA1C90EC09D4A3ACC452)", y, 00497ee9a51a8c00, 2e6865603a44a22c, f36f2ee7608b7ebf, 7211ab8e2046edbc, 00c91718b1a340d2, a5086213363a7ff9, zz: Fp384 "(00E8AA0A224C09333B12378C512414EAD8439E8F15F520096F97B3BD45BDCFE06D67E2E881F5AD575D9AAF49A0C6B7A4)"
-3, x, 003c2546fd254a90, 0a3e18aab430f4ad, dad00063a6701bac, e9b3c577880ecf51, c9fbb89016ce87fe, b125886dfaa66b55, zz: Fp384 "(0164DDDBF27670CE389E2992C0E7DAB7741F1B925EDBDC254D2BC0830BAF8E0B186F80F0DD4DE0F0EA6176E55934D45B)", y, 0087161d71559880, 7f6dbe70a20e787d, 996d9a6282b303bc, 66e92f1cbb9bd9f9, e3b3bb47a7de8d87, 2fbb78e7edf485df, zz: Fp384 "(01908E9D77A0F8AD89AC41441F74248704E756BC59C38920617F51BFCDB738EE5B123876D489D09C9EB904A321A336EC)"
-4, x, 00373631a6b80a93, b33760134d4de6be, 56ed27d5df9baba3, 755b753d0162f9b8, 1ba8a82579296bdb, 9aafe4d1b780ac70, zz: Fp384 "(0066CB93F6DBF96A94EBD6874DBD7B0E51D67A924C5AB5F96C2A45AF5EC75EA093843107E2EB1A45D645310AC3984B24)", y, 00a5af9f753d8472, b96fcf742d8981cd, 792ff8a32fab8a6a, 41d6fc8217f9f9ef, cf392541d749bf16, ef7345fe8183a959, zz: Fp384 "(0192AF6092F295756DDEDB7A77E7DB52B9FFCCB9D00EB6B61ED0E9E6B795484287CD9DFADD3B30784390941F86BD0FC2)"
-5, x, 00ea9dc132d85e66, 87fda87d39f0499a, 17b4277d88cdbd0a, 27b8b9dcfea6be53, 34c23351b99f097d, b3cf7ac7ffd6ed2f, zz: Fp384 "(0084801FF559F4882F1180089D4FC1C81A0017932824F2ADAB09956C79AEF5219AAC66D806A086C838DB2424E67B1039)", y, 011f02cfd6c5d920, b577b59e650d501d, 66326514f4c44ebf, 7fc342da9c7ae7b6, 30325f5017e975e8, cbbf2d840d9c64e5, zz: Fp384 "(009B77B4C386398F4CB40B86E74BF5B310ADDC286678166AB660FDEB4A0FD5F1DD080FDB4EE3D8BC4553F5FA529A74DD)"
-6, x, 019f02ca76047ec7, dc062d90f205d06c, 7dfbdc07f954e9d7, 85c4a3e6845721a1, 30f1080faa1484d8, 493b00c9cd7be16e, zz: Fp384 "(00B0630E7F192D20443A93860275447074CE77DF559907FA1900F378D4674649BF25F85C893E2A1916B1DA57594F2E17)", y, 006b2c62488dc2ba, 833a6b9f0c8f808b, 98c825e9e1583848, fc8b9db4aeb70f49, e5840cf05def3735, 88f30a7a997ca239, zz: Fp384 "(01ACC84F362CF60A265C011F0FE4360A15F51BECF7E2C3923FE07C66D5D113104B56E8486C64204A2A9ECD75BA0C41A7)"
-7, x, 0048c29459872078, eb2d1fcf4a139b51, 67da2ccdbaa8f2e3, 2a03ceed817240a7, fe43123affec9af4, c8486205c2ebdfba, zz: Fp384 "(018C09BCC580517E8C93CECCA83E200690FD66AD94DE753023401E3E337938715885C26FC631FF1833618D053E0872B0)", y, 00dbef2406460af5, 382cc49017f942af, 4a56c1a953eedf43, a1aaaa31c829cfd3, 96d012bb8bedbcc2, b31c1ced77585736, zz: Fp384 "(011080F0E32525EBB7F919B908C9F6F7B20E5B46A8CAEF059CE14E0739C9F164570264F91ECA4DE80E45520ED61B881C)"
\ No newline at end of file
diff --git a/jaxite_ec/test_case/t8/zprize_msm_curve_377_res_dim_8_seed_0.csv b/jaxite_ec/test_case/t8/zprize_msm_curve_377_res_dim_8_seed_0.csv
deleted file mode 100644
index dc0e754..0000000
--- a/jaxite_ec/test_case/t8/zprize_msm_curve_377_res_dim_8_seed_0.csv
+++ /dev/null
@@ -1 +0,0 @@
-x, 000e569ef8591fd6, 6cfc4f3224e47b68, bc4e8df117d07814, 160ba23f3fcab777, b1fc73a8395568c4, d21d407e354f8f52, zz: Fp384 "(008BFE20A05721F86AFEAEB8A94FCE08847074F2F555A3281E72545045976C72B33A10C9E62EC748C87087274D63ED22)", y, 00a952354041686e, dd824e49cb9926ab, 1d6a3f89b5ddb42d, a64339e196f523d8, 1ecff81311076324, 7ec7773b01f43018, zz: Fp384 "(01803103D87B815018D87204D4AC96B3BD49DC467802B7C8B7E6FC7E1FE0B20443DA072921D521F7B8BA7823C8B76048)"
\ No newline at end of file
diff --git a/jaxite_ec/test_case/t8/zprize_msm_curve_377_scalars_dim_8_seed_0.csv b/jaxite_ec/test_case/t8/zprize_msm_curve_377_scalars_dim_8_seed_0.csv
deleted file mode 100644
index 19f34d8..0000000
--- a/jaxite_ec/test_case/t8/zprize_msm_curve_377_scalars_dim_8_seed_0.csv
+++ /dev/null
@@ -1,8 +0,0 @@
-0, 0f923fffd2a6f534, dc5b6a6901840fc0, fb65827e6efd22a8, 063cded681f5f7b2, zz: Fp256 "(11057EEFA14A393F5EBD62FD9747A76B77CDFC90E07F0BC6D8C87440D758452D)"
-1, 002e9d0d87c0a600, 1c9a1f731ec9a8d0, 3ca0557886321ce6, e5716b57188ca258, zz: Fp256 "(0C736B82B849EB8E4D8F81136A30803766F90E66AB059D12A6EDD2577E605788)"
-2, 0e09b93519940300, 0a4867513ba6edb7, 530cce8b6a69232d, 0ecdd08a922d30e3, zz: Fp256 "(0738C6A142F9D317BFD26C26E9629C54E3E28C0CD3B100A6311C83A0BA678390)"
-3, 0a85273b8e6bb8ed, dc0d90c24a4f5c72, 0f57216e312a2f32, 6ef1359bfc3df47f, zz: Fp256 "(04C4AFFA315644F295F99B40E9A87D32380603D675FBA350A0BB852DFEFE7878)"
-4, 0ed8f6c7a5b1a650, 031ebc9b7a93492e, 89f282d49e7d2560, 7a5693b3d8ae2e87, zz: Fp256 "(0D19ECF1BAF67C8B31F01A16F3EA15BDF78812F5A00857F90C61C66F3445F9EC)"
-5, 09014a7226acc95a, 4835d93dae8844b6, cda8ebe010d04060, bce87d13ab77d2f7, zz: Fp256 "(0136D026DE3C43F05C8035F0932506241268E9FA45095A2FA17C72C6FD1076C6)"
-6, 0d40d252d29285e5, b39c381c97f2caed, 963b0d72b9280ba1, ee92ff87524aa9f5, zz: Fp256 "(05963104A4318BC9978D3B7746B2D8E3E9E36DC4CB751A4406B76D04BD4C3DB8)"
-7, 060bae0e7184f261, 1e89f2311ccfb076, 9149396b939c46ae, 4a5c432085813bb4, zz: Fp256 "(0215F60E325265030BA4B084E9C465B1A22969C73E357B37F215AA734C66D6FA)"
\ No newline at end of file
diff --git a/jaxite_ec/util.py b/jaxite_ec/util.py
deleted file mode 100644
index 17afce2..0000000
--- a/jaxite_ec/util.py
+++ /dev/null
@@ -1,628 +0,0 @@
-"""Utility functions for jaxite_ec.
-
-Note that: All functions that directly take Python int as input cannot be
-jitted.
-"""
-
-import csv
-import json
-import math
-from typing import Any, Callable, List, Tuple
-
-import jax
-import jax.numpy as jnp
-import numpy as np
-
-# copybara: from google3.perftools.accelerators.xprof.api.python import xprof_analysis_client
-# copybara: from google3.perftools.accelerators.xprof.api.python import xprof_session
-
-gcd = math.gcd
-
-
-####################################
-# BLS12-377 Curve Configurations
-####################################
-
-MODULUS_377_INT = 0x01AE3A4617C510EAC63B05C06CA1493B1A22D9F300F5138F1EF3622FBA094800170B5D44300000008508C00000000001
-MU_377_INT = 0x98542343310183A5DB0F28160BBD3DCEEEB43799DDAC681ABCB52236169B40B43B5A1DE2710A9647E7F56317936BFF32
-TWIST_D_INT = 122268283598675559488486339158635529096981886914877139579534153582033676785385790730042363341236035746924960903179
-
-
-####################################
-# Global Configurations
-####################################
-
-BASE = 16
-BASE_TYPE = jnp.uint16 # this type must match the BASE, i.e. jnp.uint
-U16_MASK = 0xFFFF
-U32_MASK = 0xFFFFFFFF
-
-U8_CHUNK_NUM = 48
-U16_CHUNK_NUM = 24
-U32_CHUNK_NUM = 12
-U16_CHUNK_SHIFT_BITS = 16
-U32_CHUNK_SHIFT_BITS = 32
-
-U16_EXT_CHUNK_NUM = 25
-
-
-BARRETT_SHIFT_U8 = 95 # BARRETT Params for k = 380
-CHUNK_PRECISION = 8
-
-# Lazy Reduction Logics
-MODULUS_377_S16_INT = MODULUS_377_INT << 16
-
-# Pippenger Logics
-COORDINATE_NUM = 4
-
-# RNS Reduction Logics
-# Hardware friendly moduli factors are 2**16 - v for v in the following list
-RNS_MODULI_T = (
- 0,
- 1,
- 3,
- 5,
- 9,
- 15,
- 17,
- 27,
- 33,
- 39,
- 45,
- 47,
- 57,
- 59,
- 63,
- 77,
- 87,
- 89,
- 99,
- 105,
- 113,
- 117,
- 123,
- 125,
- 129,
- 143,
- 153,
- 155,
- 165,
- 167,
- 173,
- 179,
- 183,
- 189,
- 197,
- 209,
- 213,
- 215,
- 225,
- 227,
- 243,
- 249,
- 14,
- 38,
- 50,
- 54,
- 98,
- 102,
- 110,
- 122,
-)
-
-MODULI = tuple([
- 2**16 if i == 0 else 2**16 - int(i) if i % 2 == 1 else 2**15 - (int(i) // 2)
- for i in RNS_MODULI_T
-])
-
-
-RNS_PRECISION = 16
-NUM_MODULI = len(RNS_MODULI_T)
-# Maximum number of consecutive additions/subtractions
-ADDITION_BOUND = 4
-
-# Warning: specific to target modulus and addition bound
-MODULI_SUB = tuple([
- ((512 * NUM_MODULI * MODULUS_377_INT * ADDITION_BOUND) - 2**16) % m
- for m in MODULI
-])
-TWIST_D_RNS = tuple([TWIST_D_INT % MODULI[i] for i in range(len(MODULI))])
-
-
-####################################
-# Utility Functions
-####################################
-
-
-def print_hex_values(int_list):
- hex_values = " ".join((hex(value)) for value in int_list)
- print(hex_values)
-
-
-def array_to_int(jax_array: jax.Array, base) -> int:
- """Converts a JAX array to a single Python integer."""
- result = 0
-
- for i, elem in enumerate(jax_array):
- result |= int(elem) << (i * base)
-
- return result
-
-
-def int_to_array(
- python_int, base=BASE, dtype=jnp.uint16, array_size=U16_CHUNK_NUM
-):
- """Converts a Python integer to a JAX array."""
- mask = (1 << base) - 1
-
- elements = []
- while python_int > 0:
- elements.append(python_int & mask) # Extract the lower bits
- python_int >>= base # Shift to remove the extracted bits
-
- # we pad or trim the result to match the desired size
- if array_size is not None:
- assert array_size >= len(elements)
- elements = elements[:array_size] + [0] * (array_size - len(elements))
-
- return jnp.array(elements, dtype=dtype)
-
-
-def array_to_int_list(jax_array, base):
- """Converts JAX array to single integer."""
- result_list = []
- for i in range(jax_array.shape[0]):
- value_vector = jax_array[i]
- value_int = array_to_int(value_vector, base)
- result_list.append(value_int)
- return result_list
-
-
-def int_list_to_array(int_list, base=BASE, array_size=U16_CHUNK_NUM):
- """Converts a list of integers to a JAX array."""
- chunked_arrays = []
- for int_value in int_list:
- chunked_arrays.append(int_to_array(int_value, base, array_size=array_size))
- return jnp.array(chunked_arrays)
-
-
-def int_point_to_jax_point_pack(
- coordinates: List[int], base=BASE, chunk_num=U16_CHUNK_NUM
-):
- result = []
- for i in range(len(coordinates)):
- result.append(int_to_array(coordinates[i], base, array_size=chunk_num))
- return jnp.array(result)
-
-
-def jax_point_pack_to_int_point(point: jax.Array):
- coordinate_num = point.shape[0]
- coordinates = []
- for i in range(coordinate_num):
- c = array_to_int(point[i], BASE)
- coordinates.append(c)
- return coordinates
-
-
-# RNS related data formal conversion
-def int_to_array_rns(x):
- return [x % m for m in MODULI]
-
-
-def array_rns_to_int(residues):
- rns_precompute_values = rns_precompute(MODULI)
- return rns_reconstruct(residues, MODULI, rns_precompute_values)
-
-
-def int_list_to_array_rns(int_list) -> jnp.ndarray:
- """Converts a list of integers to a JAX array."""
- limbs = []
- for int_value in int_list:
- limbs.append(int_to_array_rns(int_value))
- return jnp.array(limbs)
-
-
-def array_rns_to_int_list(jax_array):
- """Converts JAX array to single integer."""
- result_list = []
- for i in range(jax_array.shape[0]):
- value_vector = jax_array[i]
- value_int = array_rns_to_int(value_vector)
- result_list.append(value_int)
- return result_list
-
-
-def int_point_to_jax_rns_point_pack(coordinates: List[int]):
- result = []
- for i in range(len(coordinates)):
- result.append(int_to_array_rns(coordinates[i]))
- return jnp.array(result)
-
-
-def jax_rns_point_pack_to_int_point(point: jax.Array):
- coordinate_num = point.shape[0]
- coordinates = []
- for i in range(coordinate_num):
- c = array_rns_to_int(point[i])
- coordinates.append(c)
- return coordinates
-
-
-def int_point_batch_to_jax_point_pack(
- points: List[List[int]], base=BASE, chunk_num=U16_CHUNK_NUM
-):
- result = []
- for i in range(len(points)):
- result.append(int_point_to_jax_point_pack(points[i], base, chunk_num))
- return jnp.transpose(jnp.array(result), (1, 0, 2))
-
-
-def jax_point_pack_to_int_point_batch(point_pack: jnp.ndarray, base=BASE):
- points = jnp.transpose(point_pack, (1, 0, 2))
- results = []
- for i in range(len(points)):
- results.append(array_to_int_list(points[i], base))
- return results
-
-
-def int_point_batch_to_jax_rns_point_pack(points: List[List[int]]):
- result = []
- for i in range(len(points)):
- result.append(int_point_to_jax_rns_point_pack(points[i]))
- return jnp.transpose(jnp.array(result), (1, 0, 2))
-
-
-def jax_rns_point_pack_to_int_point_batch(point_pack: jnp.ndarray):
- points = jnp.transpose(point_pack, (1, 0, 2))
- results = []
- for i in range(len(points)):
- results.append(array_rns_to_int_list(points[i]))
- return results
-
-
-# RNS helpers
-def total_modulus(moduli):
- modulus = 1 # Compute the big modulus
- for m in moduli:
- modulus *= m
- return modulus
-
-
-def rns_precompute(moduli):
- modulus = total_modulus(moduli)
- precomputed = []
- for m in moduli:
- rest = modulus // m # 0 mod all the other moduli
- inverse = pow(rest % m, -1, m) # factor to make 1 mod this moduli
- icrt_val = (rest * inverse) % modulus # combine
- precomputed.append(icrt_val)
- return precomputed
-
-
-def rns_reconstruct(residues, moduli, precomputed):
- assert len(residues) == len(moduli)
- assert len(moduli) == len(precomputed)
- output = 0
- for i, r in enumerate(residues):
- output += precomputed[i] * int(r)
- return output % total_modulus(moduli)
-
-
-def to_rns(x, moduli):
- assert x < total_modulus(moduli)
- return [x % m for m in moduli]
-
-
-def to_tuple(a):
- """Create to convert numpy array into tuple."""
- try:
- return tuple(to_tuple(i) for i in a)
- except TypeError:
- return a
-
-
-# The following function achieves the same function as int_to_array, but it
-# can be pre-run (Google restriction), and returns a tuple.
-def int_to_precomputed_array(
- python_int, base=BASE, dtype=jnp.uint16, array_size=U16_CHUNK_NUM
-):
- """Converts a Python integer to a JAX array."""
- mask = (1 << base) - 1
-
- elements = []
- while python_int > 0:
- elements.append(python_int & mask) # Extract the lower bits
- python_int >>= base # Shift to remove the extracted bits
-
- # we pad or trim the result to match the desired size
- if array_size is not None:
- assert array_size >= len(elements)
- elements = elements[:array_size] + [0] * (array_size - len(elements))
-
- return to_tuple(np.array(elements, dtype=dtype).tolist())
-
-
-####################################
-# Performance Profiler Functions (Google Internal)
-####################################
-
-
-def profile_jax_functions(
- tasks: List[Tuple[Callable[..., Any], Tuple[Any, ...]]],
- profile_name: str = "jax_profile",
-):
- """Profiles a list of JAX functions.
-
- Args:
- tasks: A list of tuples, where each tuple contains a JAX function and its
- arguments.
- profile_name: The name of the profile.
-
- Usage:
- tasks = [
- (jit_pdul_barrett_xyzz_pack, (point_a_jax,)),
- ]
- profile_name = "jit_pdul_barrett_xyzz_pack"
- profile_jax_functions(tasks, profile_name)
- """
- session_id = None
-
- # copybara: session = xprof_session.XprofSession()
- # copybara: session.start_session()
- try:
- # Launch all JAX computations
- results = []
- for func, args_tuple in tasks:
- result = func(*args_tuple)
- results.append(result)
-
- # Wait for all computations launched in the loop to complete
- if results:
- jax.block_until_ready(results)
-
- except Exception as e: # pylint: disable=broad-exception-caught
- print(f"Error type: {type(e).__name__}")
- print(f"Error details: {e}")
- # Attempt to end the session even if there was an error
- # copybara: session_id = session.end_session_and_get_session_id()
- print("Xprof session ended due to error.")
- if session_id:
- print(f"{profile_name}: http://xprof/?session_id={session_id}")
- finally:
- if session_id is None:
- # copybara: session_id = session.end_session_and_get_session_id()
- print(f"{profile_name}: http://xprof/?session_id={session_id}")
- # copybara: client = xprof_analysis_client.XprofAnalysisClient()
- trace = (
- client.get_profile_data("trace_viewer.json", session_id)
- if client
- else None
- )
- jtrace = json.loads(trace[1]) if trace else None
- if jtrace:
- for e in jtrace["traceEvents"]:
- if profile_name in e["name"]:
- print(f"{profile_name} latency: {e['dur']}\n")
-
-
-####################################
-# Lazy Reduction -- Offline Precompute
-####################################
-
-
-def construct_lazy_matrix(p, chunk_precision=8, chunk_num_u8=U8_CHUNK_NUM):
- """Construct the lazy matrix.
-
- Args:
- p: The modulus.
- chunk_precision: The chunk precision.
- chunk_num_u8: The number of chunks in the u8 value.
-
- Returns:
- lazy_mat: The lazy matrix.
-
- Note that: this function runs on CPU of the TPU-VM, which cannot be jitted.
- """
- lazy_mat_list = []
- for i in range(chunk_num_u8 + 4):
- val = int(int(256) ** (chunk_num_u8 + i)) % p
- lazy_mat_list.append(
- int_to_precomputed_array(val, chunk_precision, array_size=chunk_num_u8)
- )
- return to_tuple(lazy_mat_list)
-
-
-MODULUS_377_LAZY_MAT = construct_lazy_matrix(MODULUS_377_INT)
-
-
-####################################
-# RNS Reduction -- Offline Precompute
-####################################
-
-
-def find_moduli(total_modulus, precision):
- """Finds a list of moduli close to the given precision.
-
- Args:
- total_modulus: The target modulus.
- precision: The desired precision of the moduli.
-
- Returns:
- A tuple containing two lists:
- - overall_moduli: A list of moduli close to the given precision.
- - overall_constant_offset: A list of constant offsets for the moduli.
- """
- initial_moduli = 2**precision
- overall_moduli = []
- overall_constant_offset = []
- overall_modulus = 1
- for i in range(2 ** (precision >> 1) - 1):
- cur_moduli = initial_moduli - i
- if math.gcd(cur_moduli, overall_modulus) == 1:
- overall_moduli.append(cur_moduli)
- overall_constant_offset.append(i)
- overall_modulus *= cur_moduli
- if overall_modulus > total_modulus:
- return to_tuple(overall_moduli), to_tuple(overall_constant_offset)
-
- # Find 2**15 - v too
- initial_moduli = 2 ** (precision - 1)
- if overall_modulus < total_modulus:
- for i in range(2 ** (precision >> 1) - 1):
- cur_moduli = initial_moduli - i
- if math.gcd(cur_moduli, overall_modulus) == 1:
- overall_moduli.append(cur_moduli)
- overall_constant_offset.append(i << 1)
- overall_modulus *= cur_moduli
- if overall_modulus > total_modulus:
- return to_tuple(overall_moduli), to_tuple(overall_constant_offset)
-
- return to_tuple(overall_moduli), to_tuple(overall_constant_offset)
-
-
-def rns_icrt_factors_compute(modulus, moduli):
- precomputed = []
- for m in moduli:
- rest = modulus // m # 0 mod all the other moduli
- inverse = pow(rest % m, -1, m) # factor to make 1 mod this moduli
- icrt_val = (rest * inverse) % modulus # combine
- precomputed.append(icrt_val)
- return precomputed
-
-
-def rns_coefficients_precompute(
- icrt_factors,
- overall_moduli,
- num_bytes,
- moduli_precision,
- overall_modulus,
- q,
-):
- """Precompute RNS coefficients.
-
- Args:
- icrt_factors: Precomputed inverse CRT factors.
- overall_moduli: Array of moduli.
- num_bytes: Number of bytes.
- moduli_precision: Precision of the moduli.
- overall_modulus: Overall modulus.
- q: Target modulus.
-
- Returns:
- Precomputed RNS coefficients and correction coefficients.
- """
- num_residues = len(overall_moduli)
- # icrt_factors_byteshifted -- (num_residues, num_bytes)
- icrt_factors_byteshifted = [
- [
- (((1 << (8 * pre_id)) * factor) % overall_modulus)
- for pre_id in range(num_bytes)
- ]
- for factor in icrt_factors
- ]
- # icrt_factors_byteshifted_modq -- (num_residues, num_bytes)
- icrt_factors_byteshifted_modq = [
- [(chunk % q) for chunk in factors] for factors in icrt_factors_byteshifted
- ]
- # icrt_factors_byteshifted_modq_rns
- # (num_residues, num_bytes, num_residues) [Convert each byte range into RNS]
- icrt_factors_byteshifted_modq_rns = [
- [to_rns(chunk, overall_moduli) for chunk in factors]
- for factors in icrt_factors_byteshifted_modq
- ]
-
- rns_mat = np.array(
- icrt_factors_byteshifted_modq_rns, dtype=np.uint16
- ).reshape(-1, num_residues)
-
- # calculate quotient estimation
- fix_point = 1 << moduli_precision
-
- shifted_quotient_estimations = []
- for factors in icrt_factors_byteshifted:
- for chunk in factors:
- shifted_quotient_estimations.append(
- [math.ceil((chunk * fix_point) / overall_modulus)]
- )
- sqe_mat = np.array(shifted_quotient_estimations, dtype=np.uint16)
-
- cor_mat = np.array(
- [to_rns(-overall_modulus % q, overall_moduli)], dtype=np.uint16
- )
-
- # Convert rns_mat and sqe_mat into various bytes.
- # Version 1: split precision into different chunks.
- # rns_mat_u8 = rns_mat.view(np.uint8).reshape(*rns_mat.shape, num_bytes)
- # seq_mat_u8 = sqe_mat.view(np.uint8).reshape(*sqe_mat.shape, num_bytes)
- # rns_stack_mat_u8 = np.hstack((
- # rns_mat_u8[..., 0],
- # seq_mat_u8[..., 0],
- # rns_mat_u8[..., 1],
- # seq_mat_u8[..., 1],
- # ))
- # Version 2: interleave precision -- tested to be faster.
- rns_stack_mat_u8 = np.hstack(
- (rns_mat.view(jnp.uint8), sqe_mat.view(jnp.uint8))
- )
- return to_tuple(rns_stack_mat_u8.tolist()), to_tuple(cor_mat.tolist())
-
-
-def get_parts(u16mat):
- assert u16mat.dtype == np.uint16
- u16bytes = u16mat.view(np.uint8)
- return [u16bytes[:, ::2], u16bytes[:, 1::2]]
-
-
-M = MODULUS_377_INT * MODULUS_377_INT * 256 * 256 * 50 * 50 * 4 * 2
-moduli_precision = 16
-num_bytes = moduli_precision // 8 # 2
-# hardware friendly moduli is 2**precision - t
-# overall_moduli is the jax.array of "2**precision - t"
-# overall_constant_offset is the jax.array of "t"
-overall_moduli, overall_constant_offset = find_moduli(M, moduli_precision)
-M = 1
-for moduli in overall_moduli:
- M *= moduli
-M = int(M)
-assert len(overall_moduli) == (
- (M.bit_length() + moduli_precision - 1) // moduli_precision
-)
-
-icrt_factors = rns_icrt_factors_compute(M, overall_moduli)
-
-RNS_STACK_MAT_NEW, COR_MAT_NEW = rns_coefficients_precompute(
- icrt_factors,
- overall_moduli,
- num_bytes,
- moduli_precision,
- M,
- MODULUS_377_INT,
-)
-
-
-def construct_rns_matrix(q):
- return rns_coefficients_precompute(
- icrt_factors, overall_moduli, num_bytes, moduli_precision, M, q
- )
-
-
-###############################
-# Break High-precision Integer into Chunkcs
-###############################
-
-
-MODULI = overall_moduli
-RNS_MODULI_T = overall_constant_offset
-RNS_MAT = (RNS_STACK_MAT_NEW, COR_MAT_NEW)
-MODULUS_377_INT_CHUNK = int_to_precomputed_array(
- MODULUS_377_INT, base=BASE, array_size=U16_CHUNK_NUM
-)
-MU_377_INT_CHUNK = int_to_precomputed_array(
- MU_377_INT, base=BASE, array_size=U16_CHUNK_NUM
-)
-TWIST_D_INT_CHUNK = int_to_precomputed_array(
- TWIST_D_INT, base=BASE, array_size=U16_EXT_CHUNK_NUM
-)
-MODULUS_377_S16_INT_CHUNK = int_to_precomputed_array(
- MODULUS_377_S16_INT, base=BASE, array_size=U16_EXT_CHUNK_NUM
-)
diff --git a/jaxite_ec/utils.py b/jaxite_ec/utils.py
new file mode 100644
index 0000000..21cabb6
--- /dev/null
+++ b/jaxite_ec/utils.py
@@ -0,0 +1,804 @@
+import csv
+from functools import lru_cache
+import hashlib
+import math
+import os
+import pickle
+from typing import Any, Callable, List, Optional
+import warnings
+import jax
+from jax import export
+from jax import sharding as shd
+import jax.numpy as jnp
+import numpy as np
+import toml
+
+# =============================================================================
+# Load configurations
+# =============================================================================
+_config_path = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), "configurations.toml"
+)
+with open(_config_path, "r", encoding="utf-8") as _f:
+ _config = toml.load(_f)
+_serialized_jax_kernel_dir = _config.get(
+ "serialized_jax_kernel_dir", "./deployments/"
+)
+if "TEST_TMPDIR" in os.environ:
+ _serialized_jax_kernel_dir = os.path.join(
+ os.environ["TEST_TMPDIR"], "deployments"
+ )
+
+if not os.path.exists(_serialized_jax_kernel_dir):
+ os.makedirs(_serialized_jax_kernel_dir)
+_hash_length = _config.get("hash_length", 8)
+
+
+# =============================================================================
+# Number theory transform related helper functions
+# =============================================================================
+def gen_twiddle_matrix(rows, cols, q, omega):
+ """Precompute the twiddle matrix T of shape (rows, cols), where T[r, c] = omega^(r*c) mod q.
+
+ Args:
+ rows: The number of rows in the matrix.
+ cols: The number of columns in the matrix.
+ q: The modulus.
+ omega: The primitive root of unity.
+
+ Returns:
+ The twiddle matrix.
+ """
+ warnings.warn(
+ "gen_twiddle_matrix is deprecated. Use"
+ " jaxite_ntt.number_theory_transform.utils.gen_twiddle_matrix instead.",
+ UserWarning,
+ stacklevel=2,
+ )
+ twiddle_matrix = np.zeros((rows, cols), dtype=int)
+ for r in range(rows):
+ for c in range(cols):
+ twiddle_matrix[r, c] = pow(int(omega), int(r * c), int(q))
+ return twiddle_matrix
+
+
+def gen_twiddle_matrix_inv(rows, cols, q, omega):
+ """Precompute the inverse twiddle matrix T_inv of shape (rows, cols).
+
+ T_inv[r, c] = omega^{- (r*c)} mod q.
+
+ Args:
+ rows: The number of rows in the matrix.
+ cols: The number of columns in the matrix.
+ q: The modulus.
+ omega: The primitive root of unity.
+
+ Returns:
+ The inverse twiddle matrix.
+ """
+ warnings.warn(
+ "gen_twiddle_matrix_inv is deprecated. Use"
+ " jaxite_ntt.number_theory_transform.utils.gen_twiddle_matrix_inv"
+ " instead.",
+ UserWarning,
+ stacklevel=2,
+ )
+ twiddle_matrix_inv = np.zeros((rows, cols), dtype=int)
+ for r in range(rows):
+ for c in range(cols):
+ twiddle_matrix_inv[r, c] = pow(int(omega), int(-r * c), int(q))
+ return twiddle_matrix_inv
+
+
+def prime_factors(n):
+ """Return the set of prime factors of n."""
+ factors = set()
+ # Divide out factors of 2
+ while n % 2 == 0:
+ factors.add(2)
+ n //= 2
+ # Check odd factors from 3 to sqrt(n)
+ p = 3
+ while p**2 <= n:
+ while n % p == 0:
+ factors.add(p)
+ n //= p
+ p += 2
+ if n > 1:
+ factors.add(n)
+ return factors
+
+
+def find_generator(q):
+ """Find a primitive root modulo q.
+
+ Args:
+ q (int): The prime modulus.
+
+ Returns:
+ A generator of GF(q)^*.
+
+ Raises:
+ ValueError: If no generator is found, indicating q is not prime.
+ """
+ phi = q - 1
+ factors = prime_factors(phi)
+
+ # Test candidates from 2 to q-1.
+ for g in range(2, q):
+ is_generator = all(pow(g, phi // p, q) != 1 for p in factors)
+ if is_generator:
+ return g
+ raise ValueError("No generator found, check that q is prime.")
+
+
+def root_of_unity(m: int, q: int) -> int:
+ """Canonical primitive m-th root of unity modulo q that **works with NTT**.
+
+ Args:
+ m (int): The order of the root of unity.
+ q (int): The prime modulus.
+
+ Returns:
+ int: The canonical primitive m-th root of unity modulo q.
+
+ Usage:
+ root_of_unity(16, 134219681) # This works with NTT.
+ computed_psi = [root_of_unity(m, q) for q in original_modulus]
+ """
+ assert (q - 1) % m == 0, "q-1 must be divisible by m"
+ # Step 1: multiplicative generator of Z_q^*
+ g = find_generator(q)
+ # Step 2: raise to (q-1)/m to get an m-th root candidate
+ r = pow(g, (q - 1) // m, q)
+ # Step 3: among r^k with gcd(k,m)=1, pick the minimal value whose order is exactly m
+ # For m=2^t, order check is psi^(m/2) == q-1 (i.e., == -1 mod q)
+ candidates = []
+ half = m // 2
+ for k in range(1, m):
+ if math.gcd(k, m) != 1:
+ continue
+ psi = pow(r, k, q)
+ if pow(psi, half, q) == q - 1 and pow(psi, m, q) == 1:
+ candidates.append(psi)
+ assert candidates, "No primitive m-th root found"
+ return int(min(candidates))
+
+
+# =============================================================================
+# Modular arithmetic related helper functions
+# =============================================================================
+def modular_inverse(a: int, m: int):
+ t, new_t = 0, 1
+ r, new_r = m, a
+
+ while new_r != 0:
+ quotient = r // new_r
+ t, new_t = new_t, t - quotient * new_t
+ r, new_r = new_r, r - quotient * new_r
+
+ if r > 1:
+ raise ValueError(f"{a} is not invertible modulo {m}")
+ if t < 0:
+ t += m
+
+ return t
+
+
+def compute_crt_factors(moduli):
+ modular = math.prod(moduli)
+ ms = [modular // m for m in moduli]
+ ms_inv = [modular_inverse(ms[i], moduli[i]) for i in range(len(moduli))]
+ return [(ms[i] * ms_inv[i]) % modular for i in range(len(moduli))]
+
+
+def to_rns(x, moduli):
+ return [x % m for m in moduli]
+
+
+def rns_reconstruct(residues, moduli, crt_factors):
+ return sum(
+ [residues[i] * crt_factors[i] for i in range(len(residues))]
+ ) % math.prod(moduli)
+
+
+def find_moduli_specified_number(total_number, precision):
+ """Finds a list of moduli close to the given precision.
+
+ The moduli are all odd and coprime.
+
+ Args:
+ total_number: The total number of moduli requirement.
+ precision: The desired precision of the moduli.
+
+ Returns:
+ A tuple containing two lists:
+ - overall_moduli: A list of moduli close to the given precision.
+ """
+ initial_moduli = 2**precision
+ overall_moduli = []
+ overall_modulus = 1
+ for i in range(1, 2 ** (precision >> 1) - 1):
+ cur_moduli = initial_moduli - i
+ if cur_moduli % 2 == 1 and math.gcd(cur_moduli, overall_modulus) == 1:
+ overall_moduli.append(cur_moduli)
+ overall_modulus *= cur_moduli
+ if len(overall_moduli) >= total_number:
+ return to_tuple(overall_moduli)
+
+ # Find 2**31 - v
+ initial_moduli = 2 ** (precision - 1)
+ if len(overall_moduli) < total_number:
+ for i in range(1, 2 ** (precision >> 1) - 1):
+ cur_moduli = initial_moduli - i
+ if cur_moduli % 2 == 1 and math.gcd(cur_moduli, overall_modulus) == 1:
+ overall_moduli.append(cur_moduli)
+ overall_modulus *= cur_moduli
+ if len(overall_moduli) >= total_number:
+ return to_tuple(overall_moduli)
+
+ return to_tuple(overall_moduli)
+
+
+def find_primes_with_bits(number, bits):
+ """Returns a list of 'number' prime numbers, each with exactly 'bits' bits.
+
+ Args:
+ number (int): The number of primes to find.
+ bits (int): The bit length of each prime.
+
+ Returns:
+ List[int]: List of prime numbers with the specified bit length.
+ """
+
+ def is_prime(n):
+ if n < 2:
+ return False
+ if n == 2 or n == 3:
+ return True
+ if n % 2 == 0 or n % 3 == 0:
+ return False
+ i = 5
+ w = 2
+ while i * i <= n:
+ if n % i == 0:
+ return False
+ i += w
+ w = 6 - w
+ return True
+
+ primes = []
+ lower = 1 << (bits - 1)
+ upper = (1 << bits) - 1
+ candidate = lower | 1 # ensure odd
+
+ while candidate <= upper and len(primes) < number:
+ if is_prime(candidate):
+ primes.append(candidate)
+ candidate += 2 # only check odd numbers
+
+ return primes
+
+
+def modular_matrix_np_u32_to_u8_bat_4d(matrix: np.ndarray, modulus: int):
+ rows, cols = matrix.shape
+ assert modulus <= 2**31
+ matrix_u64 = matrix.astype(np.uint64)
+ matrix_u64_byteshifted = np.array(
+ [matrix_u64 << (8 * byte_idx) for byte_idx in range(4)], dtype=np.uint64
+ )
+ # shape is (4, rows, cols)
+ matrix_u64_byteshifted = matrix_u64_byteshifted.transpose(1, 0, 2)
+ matrix_u64_byteshifted_mod_modulus = (
+ matrix_u64_byteshifted % modulus
+ ).astype(np.uint32)
+ matrix_u8 = matrix_u64_byteshifted_mod_modulus.view(np.uint8).reshape(
+ rows, 4, cols, 4
+ )
+ return matrix_u8
+
+
+# =============================================================================
+# MSM related helper functions
+# =============================================================================
+def read_external_msm_file(path, type: str):
+ if type == "scalars":
+ scalars = []
+ with open(
+ path, "r", newline="", encoding="utf-8"
+ ) as csvfile: # Handle potential encoding issues
+ csv_reader = csv.reader(csvfile)
+ for row in csv_reader:
+ scalars.append(int(row[-1][13:-2], 16))
+ return scalars
+
+ elif type == "points":
+ points = []
+ with open(
+ path, "r", newline="", encoding="utf-8"
+ ) as csvfile: # Handle potential encoding issues
+ csv_reader = csv.reader(csvfile)
+ for row in csv_reader:
+ points.append([int(row[8][13:-2], 16), int(row[-1][13:-2], 16)])
+ return points
+
+ elif type == "result_ref":
+ result_ref = []
+ with open(
+ path, "r", newline="", encoding="utf-8"
+ ) as csvfile: # Handle potential encoding issues
+ csv_reader = csv.reader(csvfile)
+ for row in csv_reader:
+ result_ref.append(int(row[7][13:-2], 16))
+ result_ref.append(int(row[-1][13:-2], 16))
+ return result_ref
+
+
+def split_list(lst, chunk_size):
+ """Splits a list into equal-sized chunks."""
+ return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]
+
+
+def slice_scalars(
+ scalars: List, scalar_bits: int, slice_length: int
+) -> List[List]:
+ window_num = int(math.ceil(scalar_bits / slice_length))
+ mask = (1 << slice_length) - 1
+ slices_list = [[] for _ in range(window_num)]
+ for scalar in scalars:
+ for i in range(window_num):
+ slices_list[i].append(scalar & mask)
+ scalar >>= slice_length
+ return slices_list
+
+
+# =============================================================================
+# shape and tuple related helper functions
+# =============================================================================
+def nested_list_depth(x: Any) -> int:
+ """Return the nesting depth of a list.
+
+ Examples:
+ nested_list_depth([1, 2, 3]) -> 1
+ nested_list_depth([[1, 2], [3, 4]]) -> 2
+ nested_list_depth([[[1], [2]]]) -> 3
+
+ Args:
+ x: The value to inspect.
+
+ Returns:
+ 0 if x is not a list, otherwise 1 + the maximum depth of its elements.
+ """
+ if not isinstance(x, list):
+ return 0
+ if len(x) == 0:
+ return 1
+ return 1 + max(nested_list_depth(item) for item in x)
+
+
+def to_tuple(a):
+ """Create to convert numpy array into tuple."""
+ if isinstance(a, (list, tuple, np.ndarray)):
+ return tuple(to_tuple(i) for i in a)
+ return int(a)
+
+
+def pad_jax_array(array: jnp.ndarray, target_shape: tuple) -> jax.Array:
+ if array.shape == target_shape:
+ return array
+ assert len(array.shape) == len(
+ target_shape
+ ), "array and target_shape must have the same number of dimensions"
+ pad_width = []
+ for cur, tgt in zip(array.shape, target_shape):
+ assert tgt >= cur, f"target size {tgt} is smaller than current size {cur}"
+ pad_width.append((0, tgt - cur))
+ return jnp.pad(array, pad_width, mode="constant", constant_values=0)
+
+
+# =============================================================================
+# Code structure
+# =============================================================================
+class JaxParameters:
+ word_bits: Any = None
+ rns_moduli_inv_word: Any = None
+ word_mask: Any = None
+ half_word_mask: Any = None
+ half_word_bits: Any = None
+ rns_moduli_high: Any = None
+ rns_moduli_low: Any = None
+ rns_moduli: Any = None
+ crns_precision: Any = None
+ crns_vector_g: Any = None
+ crns_stacked_mat_E_with_f_T: Any = None
+ rns_moduli_negate: Any = None
+ rns_moduli_sub: Any = None
+ twist_d: Any = None
+
+ def __init__(self, **kwargs):
+ for k, v in kwargs.items():
+ setattr(self, k, v)
+
+ def set_parameter(self, **kwargs):
+ for k, v in kwargs.items():
+ setattr(self, k, v)
+
+
+class JaxKernelContextBase:
+
+ def __init__(self, use_compiled_kernels: bool = False):
+ self.compiled_kernels = {}
+ self.use_compiled_kernels = use_compiled_kernels
+ self.use_sharding = False
+ self.sharding_mesh = None
+ self.mesh_axes = ()
+
+ def set_use_compiled_kernels(self, use_compiled_kernels: bool):
+ self.use_compiled_kernels = use_compiled_kernels
+
+ def serialize(self, parameters) -> Any:
+ pass
+
+ def compile(self, parameters) -> Any:
+ pass
+
+ def context_hash(self) -> str:
+ raise NotImplementedError("Subclasses must implement context_hash")
+
+ def set_use_sharding(self, use_sharding: bool):
+ if use_sharding:
+ self.use_sharding = True
+ mesh, partition_spec = create_sharding()
+ axis_names = mesh.axis_names
+ partition = axis_names if len(axis_names) > 1 else axis_names[0]
+ self.sharding_mesh = mesh
+ self.sharding_partition = partition
+ self.sharding_partition_spec = partition_spec
+ self.mesh_axes = tuple(mesh.axis_names)
+ else:
+ self.use_sharding = False
+ self.sharding_mesh = None
+ self.sharding_partition = None
+ self.sharding_partition_spec = None
+ self.mesh_axes = ()
+
+ def make_named_sharding(self, spec) -> "jax.sharding.NamedSharding | None":
+ """Wrap a PartitionSpec in a NamedSharding bound to the current mesh.
+
+ Returns None if sharding is disabled.
+ """
+ mesh = getattr(self, "sharding_mesh", None)
+ if mesh is None:
+ return None
+ return jax.sharding.NamedSharding(mesh, spec)
+
+ def shard_constraint(self, x, spec):
+ """Apply ``with_sharding_constraint(x, NamedSharding(mesh, spec))``.
+
+ No-op when sharding is disabled. Intended for use inside jitted
+ kernels to pin intermediate layouts.
+ """
+ mesh = getattr(self, "sharding_mesh", None)
+ if mesh is None:
+ return x
+ return jax.lax.with_sharding_constraint(
+ x, jax.sharding.NamedSharding(mesh, spec)
+ )
+
+ def create_named_sharding(
+ self, shape: tuple, axes: list[int]
+ ) -> tuple[jax.sharding.NamedSharding, tuple]:
+ """Create an efficient NamedSharding for the given shape and shard axes.
+
+ (Generated by Claude)
+
+ For 1 axis: shards across all mesh devices on that axis. Skips sharding
+ (returns replicated) if the axis is too small for the mesh size. Pads
+ the axis so that shape[axis] % total_devices == 0.
+
+ For 2 axes: tries both mesh-to-data-axis mappings and picks the one
+ with less padding waste. Pads each sharded axis as needed.
+
+ Args:
+ shape: Array shape.
+ axes: 1 or 2 axis indices to shard along. >=3 is not allowed.
+
+ Returns:
+ (NamedSharding, padded_shape) — the sharding and the shape after
+ padding (same as input shape if no padding was needed).
+ """
+ assert (
+ hasattr(self, "sharding_mesh") and self.sharding_mesh is not None
+ ), "Sharding must be enabled first via enable_sharding()"
+
+ if len(axes) == 0:
+ spec = [None] * len(shape)
+ partition_spec = jax.sharding.PartitionSpec(*spec)
+ return (
+ jax.sharding.NamedSharding(self.sharding_mesh, partition_spec),
+ shape,
+ )
+
+ assert len(axes) >= 1, "Must specify at least 1 axis"
+ assert len(axes) <= 2, "Sharding along >= 3 axes is not supported"
+
+ mesh = self.sharding_mesh
+ mesh_axis_names = mesh.axis_names # e.g. ('x', 'y')
+ mesh_axis_sizes = {name: mesh.shape[name] for name in mesh_axis_names}
+ padded_shape = list(shape)
+
+ # Minimum elements per device below which sharding is not worthwhile
+ _MIN_ELEMS_PER_DEVICE = 1
+
+ if len(axes) == 1:
+ axis = axes[0]
+ total_devices = math.prod(mesh_axis_sizes.values())
+
+ # Small axis: replicate instead of sharding
+ if shape[axis] < total_devices * _MIN_ELEMS_PER_DEVICE:
+ spec = [None] * len(shape)
+ partition_spec = jax.sharding.PartitionSpec(*spec)
+ warnings.warn(
+ f"Shape {shape} is too small for sharding, replicating the axis"
+ f" {axis}",
+ stacklevel=2,
+ )
+ return jax.sharding.NamedSharding(mesh, partition_spec), tuple(shape)
+
+ # Pad so shape[axis] is divisible by total_devices
+ if shape[axis] % total_devices != 0:
+ warnings.warn(
+ f"Shape {shape} is not divisible by total_devices {total_devices},"
+ f" padding the axis {axis} to"
+ f" {math.ceil(shape[axis] / total_devices) * total_devices}",
+ stacklevel=2,
+ )
+ padded_shape[axis] = (
+ math.ceil(shape[axis] / total_devices) * total_devices
+ )
+
+ # Shard this single data axis across all mesh axes
+ spec = [None] * len(shape)
+ spec[axis] = tuple(mesh_axis_names)
+ partition_spec = jax.sharding.PartitionSpec(*spec)
+
+ else: # len(axes) == 2
+ axis0, axis1 = axes[0], axes[1]
+ total_devices = math.prod(mesh_axis_sizes.values())
+
+ def _compute_candidate(spec_list, shard_plan):
+ """Compute (waste, spec, padded) for a shard_plan: list of (mesh_label, data_axis).
+
+ mesh_label is either a single mesh axis name or a tuple of all mesh axis
+ names.
+ """
+ p = list(shape)
+ waste = 0
+ for mesh_label, d_axis in shard_plan:
+ if isinstance(mesh_label, tuple):
+ ms = math.prod(mesh_axis_sizes[n] for n in mesh_label)
+ else:
+ ms = mesh_axis_sizes[mesh_label]
+ s = shape[d_axis]
+ if s % ms != 0:
+ waste += (math.ceil(s / ms) * ms) - s
+ p[d_axis] = math.ceil(s / ms) * ms
+ return waste, spec_list, tuple(p)
+
+ # Candidate A: mesh_x -> axis0, mesh_y -> axis1
+ spec_a = [None] * len(shape)
+ spec_a[axis0] = mesh_axis_names[0]
+ spec_a[axis1] = mesh_axis_names[1]
+ cand_a = _compute_candidate(
+ spec_a, [(mesh_axis_names[0], axis0), (mesh_axis_names[1], axis1)]
+ )
+
+ # Candidate B: mesh_x -> axis1, mesh_y -> axis0
+ spec_b = [None] * len(shape)
+ spec_b[axis1] = mesh_axis_names[0]
+ spec_b[axis0] = mesh_axis_names[1]
+ cand_b = _compute_candidate(
+ spec_b, [(mesh_axis_names[0], axis1), (mesh_axis_names[1], axis0)]
+ )
+
+ # Candidate C: all mesh axes -> axis0 only
+ spec_c = [None] * len(shape)
+ spec_c[axis0] = tuple(mesh_axis_names)
+ cand_c = _compute_candidate(spec_c, [(tuple(mesh_axis_names), axis0)])
+
+ # Candidate D: all mesh axes -> axis1 only
+ spec_d = [None] * len(shape)
+ spec_d[axis1] = tuple(mesh_axis_names)
+ cand_d = _compute_candidate(spec_d, [(tuple(mesh_axis_names), axis1)])
+
+ candidates = [cand_a, cand_b, cand_c, cand_d]
+
+ # Filter out candidates where a sharded axis is too small
+ def _is_viable(spec_list):
+ for i, s in enumerate(spec_list):
+ if s is None:
+ continue
+ if isinstance(s, tuple):
+ ms = math.prod(mesh_axis_sizes[n] for n in s)
+ else:
+ ms = mesh_axis_sizes[s]
+ if shape[i] < ms * _MIN_ELEMS_PER_DEVICE:
+ return False
+ return True
+
+ viable = [(w, sp, pp) for w, sp, pp in candidates if _is_viable(sp)]
+ if not viable:
+ # All too small — replicate
+ spec = [None] * len(shape)
+ partition_spec = jax.sharding.PartitionSpec(*spec)
+ return jax.sharding.NamedSharding(mesh, partition_spec), tuple(shape)
+
+ best_waste, best_spec, best_padded = min(viable, key=lambda x: x[0])
+ spec = best_spec
+ padded_shape = list(best_padded)
+
+ if best_waste > 0:
+ warnings.warn(
+ f"Shape {shape} requires padding for sharding, padded to"
+ f" {tuple(padded_shape)}"
+ )
+
+ partition_spec = jax.sharding.PartitionSpec(*spec)
+
+ return jax.sharding.NamedSharding(mesh, partition_spec), tuple(padded_shape)
+
+
+def jax_jit_lower_compile(func: Callable, *args, **kwargs) -> Callable:
+ # in_shardings = tuple(
+ # arg.sharding if isinstance(arg, jax.ShapeDtypeStruct) and getattr(arg, "sharding", None) is not None
+ # else None
+ # for arg in args
+ # )
+ # if any(s is not None for s in in_shardings):
+ # assert
+ # return jax.jit(func, in_shardings=in_shardings).lower(*args).compile()
+ return jax.jit(func).lower(*args).compile()
+
+
+def store_jax_kernel(func: Callable, *args, **kwargs) -> None:
+ name = kwargs.get("name")
+ path = os.path.join(_serialized_jax_kernel_dir, f"{name}.jax")
+
+ serialized_jax_kernel = export.export(jax.jit(func))(*args).serialize()
+ with open(path, "wb") as f:
+ f.write(serialized_jax_kernel)
+ print(f"stored jax kernel to {path}")
+
+
+def load_jax_kernel(
+ name: str, alternative_callable: Optional[Callable] = None
+) -> Optional[Callable]:
+ path = os.path.join(_serialized_jax_kernel_dir, f"{name}.jax")
+ if not os.path.exists(path):
+ return alternative_callable
+ with open(path, "rb") as f:
+ serialized_jax_kernel = export.deserialize(bytearray(f.read()))
+ print(f"loaded jax kernel from {path}")
+ return serialized_jax_kernel.call
+
+
+def store_jax_executable(func: Callable, *args, **kwargs) -> None:
+ return store_jax_kernel(func, *args, **kwargs)
+
+
+def load_jax_executable(name: str) -> Optional[Callable]:
+ return load_jax_kernel(name)
+
+
+# =============================================================================
+# Hashing utilities
+# =============================================================================
+
+# 62-character alphabet: digits + lowercase + uppercase
+_HASH_CHARS = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
+_HASH_BASE = len(_HASH_CHARS) # 62
+
+
+@lru_cache(maxsize=1)
+def _get_hash_length() -> int:
+ """Read hash_length from configurations.toml, defaulting to 16."""
+ import toml
+
+ config_path = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), "configurations.toml"
+ )
+ with open(config_path, "r", encoding="utf-8") as f:
+ config = toml.load(f)
+ return int(config.get("hash_length", 16))
+
+
+def _serialize(arg: Any) -> str:
+ """Produce a deterministic, type-tagged string for any common Python value.
+
+ Handles arbitrarily nested lists and tuples recursively.
+ """
+ if isinstance(arg, bool):
+ return f"bool:{arg}"
+ if isinstance(arg, int):
+ return f"int:{arg}"
+ if isinstance(arg, float):
+ return f"float:{repr(arg)}"
+ if isinstance(arg, str):
+ # embed length so "ab","c" != "a","bc"
+ return f"str{len(arg)}:{arg}"
+ if isinstance(arg, bytes):
+ return f"bytes:{arg.hex()}"
+ if isinstance(arg, (list, tuple)):
+ tag = "list" if isinstance(arg, list) else "tuple"
+ inner = ",".join(_serialize(x) for x in arg)
+ return f"{tag}[{len(arg)}:{inner}]"
+ if isinstance(arg, dict):
+ items = ";".join(
+ f"{_serialize(k)}->{_serialize(v)}"
+ for k, v in sorted(arg.items(), key=lambda kv: repr(kv[0]))
+ )
+ return f"dict{{{len(arg)}:{items}}}"
+ # Fallback for other types (e.g. numpy scalars, custom objects)
+ return f"{type(arg).__name__}:{repr(arg)}"
+
+
+def hash_args(*args: Any) -> str:
+ """Hash any number of Python values into a fixed-length alphanumeric string.
+
+ The output length is controlled by ``hash_length`` in ``configurations.toml``.
+ Characters are drawn from ``[0-9a-zA-Z]`` (base-62), giving 62^length possible
+ values — e.g. 16 characters yield ~4.7 × 10²⁸ distinct hashes. The result is
+ safe to embed directly in a file name.
+
+ Supported argument types: ``int``, ``float``, ``bool``, ``str``, ``bytes``,
+ ``list``, ``tuple`` (arbitrarily nested), ``dict``, and any type with a
+ stable ``repr``.
+
+ Args:
+ *args: Values to hash.
+
+ Returns:
+ A fixed-length alphanumeric hash string.
+
+ Example:
+ >>> hash_args(42, "hello", [1, 2, 3])
+ 'aB3x9Kp2mNqR7tYz'
+ """
+ length = _hash_length
+ payload = "|".join(_serialize(a) for a in args)
+ digest = hashlib.blake2b(payload.encode("utf-8")).digest() # 64 bytes
+
+ # Base-62 encode the big-endian integer
+ num = int.from_bytes(digest, "big")
+ chars: list[str] = []
+ while num:
+ chars.append(_HASH_CHARS[num % _HASH_BASE])
+ num //= _HASH_BASE
+
+ # blake2b gives ~86 base-62 digits from 64 bytes — always enough for any
+ # reasonable hash_length without padding
+ return "".join(reversed(chars))[:length]
+
+
+# =============================================================================
+# Sharding utilities
+# =============================================================================
+def create_sharding():
+ """Create default batch and replicated shardings for the current device mesh."""
+ available_devices = jax.devices()
+ if not available_devices:
+ raise RuntimeError("No devices available for sharding test.")
+ if len(available_devices) == 8:
+ mesh_shape = (2, 4)
+ elif len(available_devices) == 4:
+ mesh_shape = (2, 2)
+ elif len(available_devices) == 2:
+ mesh_shape = (2, 1)
+ else:
+ mesh_shape = (1, 1)
+
+ mesh = jax.make_mesh(mesh_shape, ("x", "y"))
+ shd.set_mesh(mesh)
+
+ partition_spec = jax.sharding.PartitionSpec
+ return mesh, partition_spec