diff --git a/.github/workflows/build_and_test.yml b/.github/workflows/build_and_test.yml index 7dbb06f..20daa31 100644 --- a/.github/workflows/build_and_test.yml +++ b/.github/workflows/build_and_test.yml @@ -27,13 +27,14 @@ jobs: - name: "Run `bazel build`" run: | - bazel build //... + bazel --host_jvm_args=-Xmx32g build --jobs=8 //... - name: "Run `bazel test`" run: | - bazel test \ + bazel --host_jvm_args=-Xmx32g test \ --test_output=errors \ --test_size_filters=small \ --test_timeout=1800 \ --experimental_ui_max_stdouterr_bytes=10485760 \ + --jobs=8 \ //... diff --git a/.github/workflows/periodic_tests.yml b/.github/workflows/periodic_tests.yml index a9eef38..935e159 100644 --- a/.github/workflows/periodic_tests.yml +++ b/.github/workflows/periodic_tests.yml @@ -14,13 +14,14 @@ jobs: - name: "Run `bazel build`" run: | - bazel build //... + bazel --host_jvm_args=-Xmx32g build --jobs=8 //... - name: "Run `bazel test`" run: | - bazel test \ + bazel --host_jvm_args=-Xmx32g test \ --test_output=errors \ --test_size_filters=medium,large \ --test_timeout=3600 \ --experimental_ui_max_stdouterr_bytes=10485760 \ + --jobs=8 \ //... diff --git a/BUILD b/BUILD index 7ce857c..82fa23a 100644 --- a/BUILD +++ b/BUILD @@ -78,6 +78,10 @@ py_library( "jaxite_ckks/*", ], ), + data = [ + "jaxite_ec/configurations.toml", + # "@jaxite//jaxite_ec/c_kernels:distribution.so", + ], visibility = [":internal"], deps = [ ":jaxite_ckks", @@ -88,7 +92,10 @@ py_library( "@jaxite_deps//jaxlib", # copybara: jax/experimental:pallas_lib # copybara: jax/experimental:pallas_tpu + "@jaxite//jaxite_ec/c_kernels:build", "@jaxite_deps//numpy", + # copybara: pandas + # copybara: toml ], ) @@ -142,13 +149,21 @@ tpu_test( ) tpu_test( - name = "jaxite_ec_finite_field_test", - size = "large", - timeout = "moderate", + name = "ec_finite_field_test", srcs = ["jaxite_ec/finite_field_test.py"], - python_version = "PY3", - shard_count = 3, - srcs_version = "PY3ONLY", + deps = [ + ":jaxite", + "@abseil-py//absl/testing:absltest", + "@abseil-py//absl/testing:parameterized", + "@jaxite_deps//jax", + "@jaxite_deps//jaxlib", + "@jaxite_deps//numpy", + ], +) + +tpu_test( + name = "ec_finite_field_perf_test", + srcs = ["jaxite_ec/finite_field_perf_test.py"], deps = [ ":jaxite", # copybara: xprof_analysis_client # buildcleaner: keep @@ -158,57 +173,44 @@ tpu_test( "@jaxite_deps//jax", "@jaxite_deps//jaxlib", "@jaxite_deps//numpy", + # copybara: toml ], ) tpu_test( - name = "msm_test", - size = "large", - timeout = "eternal", - srcs = [ - "jaxite_ec/msm_test.py", + name = "elliptic_curve_test", + srcs = ["jaxite_ec/elliptic_curve_test.py"], + deps = [ + ":jaxite", + "@abseil-py//absl/testing:absltest", + "@abseil-py//absl/testing:parameterized", + "@jaxite_deps//jax", + "@jaxite_deps//jaxlib", + "@jaxite_deps//numpy", + # copybara: toml ], - data = [ - "jaxite_ec/test_case/t1/zprize_msm_curve_377_bases_dim_1_seed_0.csv", - "jaxite_ec/test_case/t1/zprize_msm_curve_377_res_dim_1_seed_0.csv", - "jaxite_ec/test_case/t1/zprize_msm_curve_377_scalars_dim_1_seed_0.csv", - "jaxite_ec/test_case/t1024/zprize_msm_curve_377_bases_dim_1024_seed_0.csv", - "jaxite_ec/test_case/t1024/zprize_msm_curve_377_res_dim_1024_seed_0.csv", - "jaxite_ec/test_case/t1024/zprize_msm_curve_377_scalars_dim_1024_seed_0.csv", - "jaxite_ec/test_case/t2/zprize_msm_curve_377_bases_dim_2_seed_0.csv", - "jaxite_ec/test_case/t2/zprize_msm_curve_377_res_dim_2_seed_0.csv", - "jaxite_ec/test_case/t2/zprize_msm_curve_377_scalars_dim_2_seed_0.csv", - "jaxite_ec/test_case/t4/zprize_msm_curve_377_bases_dim_4_seed_0.csv", - "jaxite_ec/test_case/t4/zprize_msm_curve_377_res_dim_4_seed_0.csv", - "jaxite_ec/test_case/t4/zprize_msm_curve_377_scalars_dim_4_seed_0.csv", - "jaxite_ec/test_case/t8/zprize_msm_curve_377_bases_dim_8_seed_0.csv", - "jaxite_ec/test_case/t8/zprize_msm_curve_377_res_dim_8_seed_0.csv", - "jaxite_ec/test_case/t8/zprize_msm_curve_377_scalars_dim_8_seed_0.csv", - ], - python_version = "PY3", - shard_count = 3, - srcs_version = "PY3ONLY", +) + +tpu_test( + name = "elliptic_curve_perf_test", + srcs = ["jaxite_ec/elliptic_curve_perf_test.py"], deps = [ ":jaxite", # copybara: xprof_analysis_client # buildcleaner: keep # copybara: xprof_session # buildcleaner: keep - # copybara: resources "@abseil-py//absl/testing:absltest", "@abseil-py//absl/testing:parameterized", "@jaxite_deps//jax", "@jaxite_deps//jaxlib", "@jaxite_deps//numpy", + # copybara: toml ], ) tpu_test( - name = "elliptic_curve_test", - size = "large", - timeout = "long", - srcs = ["jaxite_ec/elliptic_curve_test.py"], - python_version = "PY3", - shard_count = 16, - srcs_version = "PY3ONLY", + name = "multiscalar_multiplication_test", + srcs = ["jaxite_ec/multiscalar_multiplication_test.py"], + data = glob(["jaxite_ec/data/t1024/*.csv"]), deps = [ ":jaxite", # copybara: xprof_analysis_client # buildcleaner: keep @@ -218,6 +220,50 @@ tpu_test( "@jaxite_deps//jax", "@jaxite_deps//jaxlib", "@jaxite_deps//numpy", + # copybara: toml + ], +) + +tpu_test( + name = "multiscalar_multiplication_perf_test", + srcs = ["jaxite_ec/multiscalar_multiplication_perf_test.py"], + deps = [ + ":jaxite", + # copybara: xprof_analysis_client # buildcleaner: keep + # copybara: xprof_session # buildcleaner: keep + "@abseil-py//absl/testing:absltest", + "@abseil-py//absl/testing:parameterized", + "@jaxite_deps//jax", + "@jaxite_deps//jaxlib", + "@jaxite_deps//numpy", + ], +) + +tpu_test( + name = "number_theory_transform_test", + srcs = ["jaxite_ec/number_theory_transform_test.py"], + deps = [ + ":jaxite", + "@abseil-py//absl/testing:absltest", + "@abseil-py//absl/testing:parameterized", + "@jaxite_deps//jax", + "@jaxite_deps//jaxlib", + "@jaxite_deps//numpy", + ], +) + +tpu_test( + name = "number_theory_transform_perf_test", + size = "large", + timeout = "eternal", + srcs = ["jaxite_ec/number_theory_transform_perf_test.py"], + deps = [ + ":jaxite", + "@abseil-py//absl/testing:absltest", + "@abseil-py//absl/testing:parameterized", + "@jaxite_deps//jax", + "@jaxite_deps//jaxlib", + "@jaxite_deps//numpy", ], ) diff --git a/MODULE.bazel b/MODULE.bazel index 3919bc0..9605595 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -4,6 +4,7 @@ bazel_dep(name = "bazel_skylib", version = "1.9.0") bazel_dep(name = "rules_license", version = "1.0.0") bazel_dep(name = "abseil-py", version = "2.1.0") bazel_dep(name = "rules_python", version = "1.5.1") +bazel_dep(name = "rules_cc", version = "0.0.17") # Hermetic python setup python = use_extension("@rules_python//python/extensions:python.bzl", "python") diff --git a/jaxite_ec/Morph_logo.png b/jaxite_ec/Morph_logo.png new file mode 100644 index 0000000..5890424 Binary files /dev/null and b/jaxite_ec/Morph_logo.png differ diff --git a/jaxite_ec/README.md b/jaxite_ec/README.md new file mode 100644 index 0000000..02a5d07 --- /dev/null +++ b/jaxite_ec/README.md @@ -0,0 +1,227 @@ + +

+ FEATHER +

+ + +

+TPU-accelerated, free, immediate, fast, and cheap HE serving for everyone +

+

+| paper | +code | +tutorial | +

+ +🔥 We have delivered a tutorial at ASPLOS'26 to help you get started with MORPH. Please visit [CPA_tutorial](https://efficientppml.github.io/CROSS_Tutorial/) to learn more. +For questions, please drop an email to our community [email](cpacommunity@googlegroups.com). + +--- + +# MORPH: Enable AI Accelerator for Zero Knowledge Proof +[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](./LICENSE) + +# 1. What is MORPH? +MORPH is the first project to enable AI Accelerator, such as Google TPUs, to accelerate Zero Knowledge Proof Primitives (Multi-scalar Multiplication and Number Theory Transformation) and achieves the State-of-the-art (SotA) throughput and energy efficiency (performance per watt). Together with [CROSS](https://github.com/EfficientPPML/CROSS), they enable AI ASICs to be SotA throughput machine for cryptography primitive with wide-range precision. + + + +It features +- MXU Lazy Modular Reduction: bringing quadratic high-precision modular reduction down to linear operation. + + + +- dataflow optimization for MSM and NTT. Details in the [paper](https://arxiv.org/abs/2604.17808). + +This branch (`asplos`) contains demo scripts for profiling and comparing the two core workloads. + +## Project Structure + +``` +├── finite_field_context.py # Finite field arithmetic (MORPH & CROSS backends) +├── elliptic_curve_context.py # Elliptic curve point arithmetic +├── multiscalar_multiplication_context.py # Multi-scalar multiplication (MSM) +├── number_theory_transform_context.py # Number Theoretic Transform (NTT) +├── utils.py # JAX kernel utilities, number theory helpers +├── profiler.py # Trace parsing and kernel profiling +├── configurations.toml # Curve parameters (BLS12-377) +├── c_kernels/ # Custom C kernels for TPU acceleration +├── deployments/ # Serialized compiled JAX kernels +``` +All functions have `_test.py` and `_perf_test.py` for correctness and performance testing. + + +## Key Concepts + +| Concept | Description | +|---------|-------------| +| **DRNS (Double RNS)** | Residue Number System representation enabling efficient large-integer modular arithmetic on TPU | +| **MORPH** | Alternative modular multiplication backend using chunk-based representation | +| **MSM** | Multi-scalar multiplication — computing $\sum_i s_i \cdot P_i$ over elliptic curve points | +| **Bucket Accumulation** | MSM decomposition strategy: scalars are sliced into windows, points accumulated into buckets per window | +| **Compiled Kernels** | Pre-compiled JAX/C kernels stored in `deployments/` for fast TPU execution | +| **Sharding** | Distribution of computation aMORPH TPU cores | + + +# 2. Environment Setup + +Inside TPU VM, please do following setup to configure the environment. + +Step 1: install miniconda +``` +wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh +chmod +x ./Miniconda3-latest-Linux-x86_64.sh +./Miniconda3-latest-Linux-x86_64.sh +``` + +Step 2: create environment and install required packages +``` +source ~/.bashrc +conda create --name jaxite python=3.13 +conda activate jaxite +pip install -U "jax[tpu]" +pip install xprof +pip install absl-py +pip install toml +pip install gdown +pip install pandas +pip install gmpy2 +``` + +Step 3: Install the C++ toolchain for the MSM C kernel. + +The MSM path uses a CPU C kernel (`c_kernels/distribution.cpp`) that is +compiled on the first import of `multiscalar_multiplication_context`. You need +a host `g++` with OpenMP, and the conda env's bundled `libstdc++` must be +recent enough to satisfy the symbols emitted by that compiler. On modern Ubuntu +(g++ 13+) this means `GLIBCXX_3.4.32`, which the conda default +`libstdcxx-ng 11.2.0` does **not** ship — so install a newer one from +`conda-forge`: +``` +sudo apt-get install -y g++ # skip if already installed +conda install -n jaxite -c conda-forge 'libstdcxx-ng>=13' 'libgcc-ng>=13' +``` +If you see `OSError: ... libstdc++.so.6: version 'GLIBCXX_3.4.XX' not found` +when importing `multiscalar_multiplication_context`, the conda libstdc++ is +older than your system `g++` — re-run the `conda install` line above (or +`LD_PRELOAD=/usr/lib/x86_64-linux-gnu/libstdc++.so.6` for a one-shot +workaround). + +Step 4: Download the reference data +``` +mkdir -p data && gdown 1aJhANlS8hWrjSt9j0nBKoFRBoZh0W1aa -O data/data.tar.gz && cd data && tar -xvf data.tar.gz +``` + +Step 5 (optional): Pre-build the MSM C kernel. + +The C kernel is compiled automatically by `c_kernels/build.py` on the first +import of `multiscalar_multiplication_context`, so no separate build step is +required. To pre-build (or force a rebuild) ahead of time: +``` +python -m c_kernels.build # build if missing or stale +python -m c_kernels.build --force # always rebuild +``` +The compiler defaults to `g++` with `-std=c++17 -fopenmp -O2 -fPIC -shared` +plus `-I/include`. Override via the `CXX` and `CXXFLAGS` env vars +(e.g. point `CXX` at conda's `gxx_linux-64` to keep the build inside the env). + +# 3. TPU Setup +The code is optimized for TPU execution, but it also runs on NVIDIA GPU and CPU for functional preview (not optimized for these devices). + +- Step 1: Create a Google Project [tutorial](https://cloud.google.com/appengine/docs/standard/nodejs/building-app/creating-project). + +Obtain the name of the project as and **Google Project ID** from the created project. + +- Step 2: Apply for the Tree-tier TPU trail for 30 days[TRC](https://sites.research.google/trc/about/) + +Once submitted the request, an email will be shot to you within one day, where there is a link to fill in a survey with your **Google project ID**. + +- Step 3: Launch TPU VM. +You could do it over GUI or gcloud cli (in your local machine) to create a TPU VM. I give the gcloud cli as it works for all generations (>=v4) of TPUs. + +For TPUv6e, +```bash +gcloud config set project +gcloud config set compute/zone us-east1-d +gcloud alpha compute tpus queued-resources create --node-id= \ + --zone=us-east1-d \ + --accelerator-type=v6e-1 \ + --runtime-version=v2-alpha-tpuv6e \ + --provisioning-model=spot +``` + +Note that TPUv5e and TPUv6e could only work with provisioning-model as spot, because they are popular resources, and Google cloud can preempt it if there are tasks with higher priority requiring these resources. But you could get a long-term active TPUv4 VM as it's less demanding by other tasks. + +- Step 4: Setup Remote SSH (VSCode or Cursor) to TPU VM +Once the requested TPU vm is up and running as shown in Google console, you could use gcloud to forward the SSH port of the remote machine to a port of local machine and setup VSCode remote ssh. + +You need to first setup local ssh key to Google's compute engine, following [link](https://cloud.google.com/compute/docs/connect/create-ssh-keys#gcloud). After your follow the instructions on the page, the ssh key will be dumped here `/.ssh/google_compute_engine`. + + +```bash +gcloud compute tpus tpu-vm ssh @ -- -L 9009:localhost:22 +``` +Where 9009 is the port of local machine, while 22 is the SSH port of the TPU vm. + +After you set it up, you could configure VSCode to use the remote SSH package [link](https://code.visualstudio.com/docs/remote/ssh) to remotely access into TPUvm. +```bash +Host tpu-vm + User + HostName localhost + Port 9009 + IdentityFile /.ssh/google_compute_engine +``` + +After this, you should follow the steps on [link](https://code.visualstudio.com/docs/remote/ssh) to log into TPU VM. + +# 4. Ready to Play? + +Run functional correctness tests for both NTT and MSM: +``` +python3 number_theory_transform_test.py +python3 multiscalar_multiplication_test.py +``` + +Run performance tests for both NTT and MSM: + +``` +python3 number_theory_transform_perf_test.py +python3 multiscalar_multiplication_perf_test.py +``` + +Notes: +- The first MSM test run auto-compiles `c_kernels/distribution.cpp` into + `c_kernels/distribution.so` (a few seconds). Subsequent runs reuse the cached + `.so` and rebuild only when the source is newer. +- The first run of each test also JIT-compiles JAX kernels; expect a longer + first iteration that is then cached under `deployments/`. +- Performance tests assume the reference data from Step 4 is present under + `./data/`. + +# 5. Call for Actions +Our mission is to build an open-sourced SoTA library for the community. +- If you find this repository helpful, please consider giving it a star :) +- For any questions, please feel free to open an issue. +- For any suggestions or new features, please feel free to open a pull request. + +# Contact +- Jianming Tong, Georgia Institute of Technology / Google, jianming.tong@gatech.edu/jianmingt@google.com +- Jingtian Dang, Georgia Institute of Technology, dangjingtian@gatech.edu +- Tushar Krishna, Georgia Institute of Technology, tushar@ece.gatech.edu + + +# Citation + +``` +@inproceedings{tong2025MORPH, +author = {Jianming Tong and Jingtian Dang and Simon Langowski and Tianhao Huang and Asra Ali and Jeremy Kun and Srini Devadas and Tushar Krishna}, +title = {MORPH: Enabling AI ASICs for Zero Knowledge Proof}, +year = {2026}, +publisher = {IEEE Press}, +booktitle = {Proceedings of the 63nd Annual ACM/IEEE Design Automation Conference}, +location = {Los Angeles, California, United States}, +series = {DAC '26} +} +``` + +Enjoy! :D diff --git a/jaxite_ec/__init__.py b/jaxite_ec/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/jaxite_ec/algorithm/big_integer.py b/jaxite_ec/algorithm/big_integer.py deleted file mode 100644 index d79b763..0000000 --- a/jaxite_ec/algorithm/big_integer.py +++ /dev/null @@ -1,169 +0,0 @@ -"""Big integer classes for jaxite_ec.""" - -import gmpy2 - - -class GMPBigInteger: - """A class representing a big integer using gmpy2. - - This class provides basic arithmetic operations for big integers using gmpy2. - """ - - def __init__(self, value) -> None: - if isinstance(value, (int, gmpy2.mpz)): - self.value = gmpy2.mpz(value) - elif isinstance(value, GMPBigInteger): - self.value = value.value - else: - raise TypeError("Unsupported type for GMPBigInteger initialization") - - def __add__(self, other): - if isinstance(other, (GMPBigInteger, int)): - return GMPBigInteger( - self.value - + gmpy2.mpz( - other.value if isinstance(other, GMPBigInteger) else other - ) - ) - return NotImplemented - - def __sub__(self, other): - if isinstance(other, (GMPBigInteger, int)): - return GMPBigInteger( - self.value - - gmpy2.mpz( - other.value if isinstance(other, GMPBigInteger) else other - ) - ) - return NotImplemented - - def __mul__(self, other): - if isinstance(other, (GMPBigInteger, int)): - return GMPBigInteger( - self.value - * gmpy2.mpz( - other.value if isinstance(other, GMPBigInteger) else other - ) - ) - return NotImplemented - - def __truediv__(self, other): - if isinstance(other, (GMPBigInteger, int)): - if ( - gmpy2.mpz(other.value if isinstance(other, GMPBigInteger) else other) - == 0 - ): - raise ZeroDivisionError("division by zero") - return GMPBigInteger( - self.value - // gmpy2.mpz( - other.value if isinstance(other, GMPBigInteger) else other - ) - ) - return NotImplemented - - def __mod__(self, other): - if isinstance(other, (GMPBigInteger, int)): - return GMPBigInteger( - self.value - % gmpy2.mpz( - other.value if isinstance(other, GMPBigInteger) else other - ) - ) - return NotImplemented - - def __eq__(self, other): - if isinstance(other, (GMPBigInteger, int)): - return self.value == gmpy2.mpz( - other.value if isinstance(other, GMPBigInteger) else other - ) - return NotImplemented - - def __ne__(self, other): - if isinstance(other, (GMPBigInteger, int)): - return self.value != gmpy2.mpz( - other.value if isinstance(other, GMPBigInteger) else other - ) - return NotImplemented - - def __lt__(self, other): - if isinstance(other, (GMPBigInteger, int)): - return self.value < gmpy2.mpz( - other.value if isinstance(other, GMPBigInteger) else other - ) - return NotImplemented - - def __le__(self, other): - if isinstance(other, (GMPBigInteger, int)): - return self.value <= gmpy2.mpz( - other.value if isinstance(other, GMPBigInteger) else other - ) - return NotImplemented - - def __gt__(self, other): - if isinstance(other, (GMPBigInteger, int)): - return self.value > gmpy2.mpz( - other.value if isinstance(other, GMPBigInteger) else other - ) - return NotImplemented - - def __ge__(self, other): - if isinstance(other, (GMPBigInteger, int)): - return self.value >= gmpy2.mpz( - other.value if isinstance(other, GMPBigInteger) else other - ) - return NotImplemented - - def __pow__(self, exponent, modulus=None): - if isinstance(exponent, (GMPBigInteger, int, gmpy2.mpz)): - if isinstance(exponent, GMPBigInteger): - exponent = gmpy2.mpz(exponent.value) - if isinstance(modulus, GMPBigInteger): - modulus = gmpy2.mpz(modulus.value) - if modulus is None: - return GMPBigInteger(self.value**exponent) - else: - return GMPBigInteger(gmpy2.powmod(self.value, exponent, modulus)) - else: - print(type(exponent)) - raise TypeError("Exponent must be an integer") - - def __lshift__(self, shift): - """Left shift operator (<<).""" - if isinstance(shift, GMPBigInteger): - shift = shift.value - return GMPBigInteger(self.value << shift) - - def __rshift__(self, shift): - """Right shift operator (>>).""" - if isinstance(shift, GMPBigInteger): - shift = shift.value - return GMPBigInteger(self.value >> shift) - - def __and__(self, other): - if isinstance(other, (GMPBigInteger, int)): - return GMPBigInteger( - self.value - & gmpy2.mpz( - other.value if isinstance(other, GMPBigInteger) else other - ) - ) - return NotImplemented - - def ceil_log2(self): - """Calculate the base-2 logarithm of the GMPBigInteger.""" - if self.value <= 0: - raise ValueError("log2 is only defined for positive integers") - return GMPBigInteger(gmpy2.ceil(gmpy2.log2(self.value))) - - def __int__(self): - return int(self.value) - - def __str__(self): - return str(self.value) - - def __repr__(self): - return f"GMPBigInteger({self.value})" - - def hex_value_str(self) -> str: - return hex(self.value) diff --git a/jaxite_ec/algorithm/config_file.py b/jaxite_ec/algorithm/config_file.py deleted file mode 100644 index 0233520..0000000 --- a/jaxite_ec/algorithm/config_file.py +++ /dev/null @@ -1,46 +0,0 @@ -"""Jaxite EC algorithm configuration file.""" - -config_BLS12_377 = { - # A small prime field for simplicity - 'prime': 0x01AE3A4617C510EAC63B05C06CA1493B1A22D9F300F5138F1EF3622FBA094800170B5D44300000008508C00000000001, - 'order': 0x12AB655E9A2CA55660B44D1E5C37B00159AA76FED00000010A11800000000001, - 'a': 0, # Coefficient a = 0 - 'b': 1, # Coefficient b = 1 - 'generator': [ - 0x008848DEFE740A67C8FC6225BF87FF5485951E2CAA9D41BB188282C8BD37CB5CD5481512FFCD394EEAB9B16EB21BE9EF, - 0x01914A69C5102EFF1F674F5D30AFEEC4BD7FB348CA3E52D96D182AD44FB82305C2FE3D3634A9591AFD82DE55559C8EA6, - ], -} - -# https://docs.rs/ark-ed-on-bls12-377/latest/ark_ed_on_bls12_377/index.html -config_BLS12_377_t = { - 'prime': ( - 258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177 - ), - 'order': ( - 8444461749428370424248824938781546531375899335154063827935233455917409239041 - ), - 'a': -1, - 'd': ( - 122268283598675559488486339158635529096981886914877139579534153582033676785385790730042363341236035746924960903179 - ), - 'alpha': -1, - 'b': 1, - 's': ( - 10189023633222963290707194929886294091415157242906428298294512798502806398782149227503530278436336312243746741931 - ), - 'MA': ( - 228097355113300204138531148905234651262148041026195375645000724271212049151994375092458297304264351187709081232384 - ), - 'MB': ( - 10189023633222963290707194929886294091415157242906428298294512798502806398782149227503530278436336312243746741931 - ), - 't': ( - 23560188534917577818843641916571445935985386319233886518929971599490231428764380923487987729215299304184915158756 - ), - # 't': 235104237478051516191809091778322087600408126435680774020954291067230236919576441851480900410358060820255406299421, - 'generator': [ - 71222569531709137229370268896323705690285216175189308202338047559628438110820800641278662592954630774340654489393, - 6177051365529633638563236407038680211609544222665285371549726196884440490905471891908272386851767077598415378235, - ], -} diff --git a/jaxite_ec/algorithm/elliptic_curve.py b/jaxite_ec/algorithm/elliptic_curve.py deleted file mode 100644 index 6b649d9..0000000 --- a/jaxite_ec/algorithm/elliptic_curve.py +++ /dev/null @@ -1,775 +0,0 @@ -"""Elliptic curve coordinate systems and points.""" - -import abc -import copy -import enum -from typing import Dict, Generic, Iterable, List, Optional, TypeVar, Union - -from jaxite.jaxite_ec.algorithm import big_integer -from jaxite.jaxite_ec.algorithm import finite_field - - -BigInt = big_integer.GMPBigInteger - -abstractmethod = abc.abstractmethod -ABC = abc.ABC -Auto = enum.auto -Enum = enum.Enum -T = TypeVar('T') -FieldEle = finite_field.FiniteFieldElement - - -class CoordinateSystemType(Enum): - """Enum to represent different types of coordinate systems for elliptic curves. - - AFFINE: The affine coordinate system, the standard coordinate system for - elliptic curves: https://www.hyperelliptic.org/EFD/g1p/auto-shortw.html - - PROJECTIVE: The projective coordinate system, - https://www.hyperelliptic.org/EFD/g1p/auto-shortw-projective.html#addition-madd-1998-cmo - - XYZZ: The XYZZ coordinate system: - https://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html - """ - - NONE = Auto() - WEIERSTRASS_AFFINE = Auto() - WEIERSTRASS_PROJECTIVE = Auto() - WEIERSTRASS_XYZZ = Auto() - - -class ECPoint(Generic[T]): - """Represents a point in an elliptic curve coordinate system.""" - - def __init__( - self, - coordinates: List[T], - coordinate_system: 'EllipticCurveCoordinateSystem', - zero: bool = False, - ) -> None: - self.coordinate_system = coordinate_system - self.zero = zero - self.coordinates = self.coordinate_system.generate_formal_coordinates( - coordinates, zero - ) - self.type = coordinate_system.get_type() - - def __getitem__(self, index: Union[int, slice]): - return self.coordinates[index] - - def __setitem__( - self, index: Union[int, slice], value: Union[T, Iterable[T]] - ) -> None: - self.coordinates[index] = value # type: ignore - - def __eq__(self, other: 'ECPoint') -> bool: - if not isinstance(other, ECPoint): - return NotImplemented - - return self.coordinates == other.coordinates - - def __add__(self, other: 'ECPoint') -> 'ECPoint': - return self.coordinate_system.point_add(self, other) - - def __lshift__(self, shift) -> 'ECPoint': - if self.zero: - return self.copy() - return self.coordinate_system.point_lshift(self, shift) - - def is_zero(self): - return self.zero - - def get_type(self) -> CoordinateSystemType: - return self.type - - def set_type(self, cs_type: CoordinateSystemType): - self.type = cs_type - - def set_coordinate_system( - self, coordinate_system: 'EllipticCurveCoordinateSystem' - ): - self.coordinate_system = coordinate_system - self.type = coordinate_system.get_type() - - def append(self, coordinate: T): - assert self.coordinates is not None - self.coordinates.append(coordinate) - - def copy(self): - obj = copy.copy(self) - if not self.is_zero(): - obj.coordinates = self.coordinate_system.generate_formal_coordinates( - self.coordinates, self.zero - ) - return obj - - def convert_to_affine(self): - return self.coordinate_system.convert_to_affine(self) - - def __str__(self) -> str: - if self.is_zero(): - return 'Point, O' - coord_strs = [coord.hex_value_str() for coord in self.coordinates] # pytype: disable=attribute-error - return 'Point, ' + ', '.join(coord_strs) - - -class EllipticCurveCoordinateSystem(ABC, Generic[T]): - """Abstract base class for elliptic curve coordinate systems.""" - - def __init__(self, config: Dict[str, Union[int, List[int]]]) -> None: - """Initialize the coordinate system. - - Args: - config: A dictionary containing the configuration of the coordinate - system. The dictionary should contain the following keys: - ff_zero: The - value "zero" in the finite field. - ff_one: The value "one" in the - finite field. - type: The type of the coordinate system. - """ - super().__init__() - self.config = config - self.ff_zero: FieldEle = FieldEle(0, config['prime']) - self.ff_one: FieldEle = FieldEle(1, config['prime']) - self.type = CoordinateSystemType.NONE - - @abstractmethod - def generate_formal_coordinates( - self, coordinates: List[Union[int, FieldEle]], zero: bool - ) -> List[T]: - pass - - def get_type(self): - return self.type - - def generate_point( - self, - coordinates: Optional[Union[List[T], List[int]]] = None, - zero: bool = False, - ) -> ECPoint[T]: - return ECPoint[T](coordinates, self, zero) - - @abstractmethod - def point_add(self, point_a: ECPoint[T], point_b: ECPoint[T]): - pass - - @abstractmethod - def point_lshift(self, point_a: ECPoint[T], index): - pass - - @abstractmethod - def convert_to_affine(self, point_a: ECPoint[T]): - pass - - -class ECCSWeierstrass(EllipticCurveCoordinateSystem[FieldEle]): - """Base class for Weierstrass coordinate systems.""" - - def __init__(self, config: Dict[str, Union[int, List[int]]]) -> None: - """Initialize the EC Weierstrass coordinate systems. - - Weierstrass: y^2 = x^3 + a*x + b - - Args: - config: A dictionary containing the configuration of the coordinate - system. The dictionary should contain the following keys: - generator: A - special point "G" in the Elliptic Curve. - order: order * generator == 0 - in the Elliptic Curve. - prime: The modulus of the Elliptic Curve. - a: - The parameter a of the Elliptic Curve. - b: The parameter b of the - Elliptic Curve. - """ - super().__init__(config) - self.prime = BigInt(config['prime']) - self.order = BigInt(config['order']) - - self.generator: List[FieldEle] = [] - for coordinate in config['generator']: - element = self.ff_zero.copy(coordinate) - self.generator.append(element) - self.a = self.ff_zero.copy(config['a']) - self.b = self.ff_zero.copy(config['b']) - - def generate_formal_coordinates( - self, coordinates: List[Union[int, FieldEle]], zero: bool = False - ): - if zero: - return None - formal_coordinate: List[FieldEle] = [] - for coordinate in coordinates: - if isinstance(coordinate, FieldEle): - formal_coordinate.append(coordinate) - else: - formal_coordinate.append(self.ff_zero.copy(coordinate)) - return formal_coordinate - - -class ECCSWeierstrassAffine(ECCSWeierstrass): - """Weierstrass affine coordinate system. - - It has unique point add calculation. Details pleaase refer to: - AFFINE: The affine coordinate system, the standard coordinate system for - elliptic curves: https://www.hyperelliptic.org/EFD/g1p/auto-shortw.html - """ - - def __init__(self, config: Dict[str, Union[int, List[int]]]) -> None: - super().__init__(config) - self.type = CoordinateSystemType.WEIERSTRASS_AFFINE - - def add_general(self, point_a: ECPoint[FieldEle], point_b: ECPoint[FieldEle]): - slope = (point_b[1] - point_a[1]) / (point_b[0] - point_a[0]) - cx = (slope * slope) - point_a[0] - point_b[0] - cy = slope * (point_a[0] - cx) - point_a[1] - return ECPoint[FieldEle]([cx, cy], self) - - def double_general(self, point_a: ECPoint[FieldEle]): - x1 = point_a[0] - y1 = point_a[1] - slope = (x1 * x1 * 3 + self.a) / (y1 * 2) - cx = slope * slope - x1 - x1 - cy = slope * (x1 - cx) - y1 - return ECPoint[FieldEle]([cx, cy], self) - - def point_add(self, point_a: ECPoint[FieldEle], point_b: ECPoint[FieldEle]): - if point_b.is_zero(): - return point_a.copy() - elif point_a.is_zero(): - return point_b.copy() - - if point_a == point_b: - result = self.double_general(point_a) - else: - result = self.add_general(point_a, point_b) - - return result - - def point_lshift(self, point_a: ECPoint[FieldEle], shift: int): - if point_a.is_zero(): - return point_a.copy() - - for _ in range(shift): - point_a = self.double_general(point_a) - return point_a - - def convert_to_affine(self, point_a: ECPoint[FieldEle]) -> ECPoint[FieldEle]: - return point_a - - def generate_formal_coordinates( - self, coordinates: List[Union[int, FieldEle]], zero: bool = False - ) -> Optional[List[FieldEle]]: - if zero: - return None - coordinate_length = len(coordinates) - assert coordinate_length == 2 or coordinate_length == 3 - formal_coordinates: List[FieldEle] = [] - for coordinate in coordinates: - if isinstance(coordinate, FieldEle): - formal_coordinates.append(coordinate) - else: - formal_coordinates.append(self.ff_zero.copy(coordinate)) - if coordinate_length == 2: - assert len(formal_coordinates) == 3 - return formal_coordinates - - -class ECCSWeierstrassProjective(ECCSWeierstrass): - """Weierstrass projective coordinate system. - - PROJECTIVE: The projective coordinate system, - https://www.hyperelliptic.org/EFD/g1p/auto-shortw-projective.html#addition-madd-1998-cmo - """ - - def __init__(self, config: Dict[str, Union[int, List[int]]]) -> None: - super().__init__(config) - self.prime = BigInt(config['prime']) - self.order = BigInt(config['order']) - self.type = CoordinateSystemType.WEIERSTRASS_PROJECTIVE - - self.generator: List[FieldEle] = [] - for coordinate in config['generator']: - element = self.ff_zero.copy(coordinate) - self.generator.append(element) - - self.a = self.ff_zero.copy(config['a']) - self.b = self.ff_zero.copy(config['b']) - - def add_z2_eq_1(self, point_a: ECPoint[FieldEle], point_b: ECPoint[FieldEle]): - if point_b[2] == self.ff_one: - x1, y1, z1 = point_a - x2, y2, _ = point_b - elif point_a[2] == self.ff_one: - x1, y1, z1 = point_b - x2, y2, _ = point_a - else: - raise NotImplementedError - - u = y2 * z1 - y1 - uu = u * u - v = x2 * z1 - x1 - vv = v * v - vvv = v * vv - r = vv * x1 - a = uu * z1 - vvv - (r + r) - x3 = v * a - y3 = u * (r - a) - vvv * y1 - z3 = vvv * z1 - return ECPoint[FieldEle]([x3, y3, z3], self) - - def add_general(self, point_a: ECPoint[FieldEle], point_b: ECPoint[FieldEle]): - x1, y1, z1 = point_a - x2, y2, z2 = point_b - - b3 = self.b * self.ff_zero.copy(3) - a = self.a - - # Perform the operations - t0 = x1 * x2 - t1 = y1 * y2 - t2 = z1 * z2 - t3 = (x1 + y1) * (x2 + y2) - t4 = t0 + t1 - t3 = t3 - t4 - t4 = (x1 + z1) * (x2 + z2) - t5 = t0 + t2 - t4 = t4 - t5 - t5 = (y1 + z1) * (y2 + z2) - x3 = t1 + t2 - t5 = t5 - x3 - z3 = a * t4 - x3 = b3 * t2 - z3 = x3 + z3 - x3 = t1 - z3 - z3 = t1 + z3 - y3 = x3 * z3 - t1 = t0 + t0 - t1 = t1 + t0 - t2 = a * t2 - t4 = b3 * t4 - t1 = t1 + t2 - t2 = t0 - t2 - t2 = a * t2 - t4 = t4 + t2 - t0 = t1 * t4 - y3 = y3 + t0 - t0 = t5 * t4 - x3 = t3 * x3 - x3 = x3 - t0 - t0 = t3 * t1 - z3 = t5 * z3 - z3 = z3 + t0 - - return ECPoint[FieldEle]([x3, y3, z3], self) - - def double_general(self, point_a: ECPoint): - x1, y1, z1 = point_a - a = self.a - ff2 = self.ff_zero.copy(2) - - # Perform the operations based on the pseudocode - xx = x1 * x1 - zz = z1 * z1 - w = a * zz + xx * self.ff_zero.copy(3) - s = y1 * z1 * ff2 - ss = s * s - sss = s * ss - r = y1 * s - rr = r * r - b = (x1 + r) * (x1 + r) - xx - rr - h = w * w - b * ff2 - x3 = h * s - y3 = w * (b - h) - rr * ff2 - z3 = sss - - return ECPoint([x3, y3, z3], self) - - def point_add( - self, point_a: ECPoint[FieldEle], point_b: ECPoint[FieldEle] - ) -> ECPoint[FieldEle]: - if point_b.is_zero(): - return point_a.copy() - elif point_a.is_zero(): - return point_b.copy() - - if point_a == point_b: - result = self.double_general(point_a) - elif point_a[2] == self.ff_one or point_b[2] == self.ff_one: - result = self.add_z2_eq_1(point_a, point_b) - else: - result = self.add_general(point_a, point_b) - return result - - def point_lshift(self, point: ECPoint[FieldEle], shift: int): - if point.is_zero(): - return point.copy() - - for _ in range(shift): - point = self.double_general(point) - return point - - def convert_from_affine( - self, point_a: ECPoint[FieldEle] - ) -> ECPoint[FieldEle]: - assert point_a.get_type() == self.type - new_point = point_a.copy() - if new_point.coordinates is not None: - new_point.coordinates.clear() - z_invert = self.ff_one / point_a[2] - new_point.append(point_a[0] * z_invert) - new_point.append(point_a[1] * z_invert) - new_point.append(self.ff_one) - new_point.set_type(CoordinateSystemType.WEIERSTRASS_AFFINE) - return new_point - else: - return new_point - - def convert_to_affine(self, point: ECPoint) -> ECPoint: - assert point.get_type() == self.type - new_point = point.copy() - if new_point.coordinates is not None: - new_point.coordinates.clear() - z_invert = self.ff_one / point[2] - new_point.append(point[0] * z_invert) - new_point.append(point[1] * z_invert) - new_point.append(self.ff_one) - new_point.set_type(CoordinateSystemType.WEIERSTRASS_AFFINE) - return new_point - else: - return new_point - - -class ECCSWeierstrassXYZZ(ECCSWeierstrass): - """Weierstrass XYZZ coordinate system.""" - - def __init__(self, config: Dict[str, Union[int, List[int]]]) -> None: - super().__init__(config) - self.type = CoordinateSystemType.WEIERSTRASS_XYZZ - - def generate_formal_coordinates( - self, coordinates: List[Union[int, FieldEle]], zero: bool = False - ) -> Optional[List[FieldEle]]: - if zero: - return None - coordinate_length = len(coordinates) - assert coordinate_length == 2 or coordinate_length == 4 - formal_coordinates: List[FieldEle] = [] - for coordinate in coordinates: - if isinstance(coordinate, FieldEle): - formal_coordinates.append(coordinate) - else: - formal_coordinates.append(self.ff_zero.copy(coordinate)) - if coordinate_length == 2: - formal_coordinates.append(self.ff_one.copy()) - formal_coordinates.append(self.ff_one.copy()) - assert len(formal_coordinates) == 4 - return formal_coordinates - - def add_z2_eq_1(self, point_a: ECPoint[FieldEle], point_b: ECPoint[FieldEle]): - """Add the general coordinates of two points in the XYZZ coordinate system. - - This function is not implemented for the XYZZ coordinate system. - - Args: - point_a: The first point. - point_b: The second point. - - Returns: - The added general coordinates of the two points. - """ - raise NotImplementedError - # return ECPoint[FieldEle]([], self) - - def add_general(self, point_a: ECPoint[FieldEle], point_b: ECPoint[FieldEle]): - x1, y1, zz1, zzz1 = point_a - x2, y2, zz2, zzz2 = point_b - - u1 = x1 * zz2 - u2 = x2 * zz1 - s1 = y1 * zzz2 - s2 = y2 * zzz1 - p = u2 - u1 - r = s2 - s1 - pp = p * p - ppp = p * pp - q = u1 * pp - x3 = r * r - ppp - (q + q) - y3 = r * (q - x3) - s1 * ppp - zz3 = zz1 * zz2 * pp - zzz3 = zzz1 * zzz2 * ppp - - return ECPoint[FieldEle]([x3, y3, zz3, zzz3], self) - - def double_general(self, point_a: ECPoint[FieldEle]): - """Double the general coordinates of a point in the XYZZ coordinate system. - - Args: - point_a: The point to double the general coordinates of. - - Returns: - The doubled general coordinates of the point. - """ - x1, y1, zz1, zzz1 = point_a - a = self.a - - # Perform the operations based on the pseudocode - u = y1 + y1 - v = u * u - w = u * v - s = x1 * v - x1_sq = x1 * x1 - m = (x1_sq + x1_sq + x1_sq) + a * (zz1 * zz1) - x3 = (m * m) - (s + s) - y3 = m * (s - x3) - w * y1 - zz3 = v * zz1 - zzz3 = w * zzz1 - - return ECPoint[FieldEle]([x3, y3, zz3, zzz3], self) - - def point_add( - self, point_a: ECPoint[FieldEle], point_b: ECPoint[FieldEle] - ) -> ECPoint[FieldEle]: - if point_b.is_zero(): - return point_a.copy() - elif point_a.is_zero(): - return point_b.copy() - - if point_a == point_b: - result = self.double_general(point_a) - else: - result = self.add_general(point_a, point_b) - return result - - def point_lshift(self, point_a: ECPoint[FieldEle], shift: int): - if point_a.is_zero(): - return point_a.copy() - - for _ in range(shift): - point_a = self.double_general(point_a) - return point_a - - def convert_from_affine( - self, point_a: ECPoint[FieldEle] - ) -> ECPoint[FieldEle]: - assert ( - point_a.coordinate_system.get_type() - == CoordinateSystemType.WEIERSTRASS_AFFINE - ) - new_point = point_a.copy() - new_point.append(self.ff_one) - new_point.append(self.ff_one) - new_point.set_coordinate_system(self) - return new_point - - def convert_to_affine(self, point_a: ECPoint[FieldEle]) -> ECPoint[FieldEle]: - assert point_a.get_type() == self.type - new_point = point_a.copy() - if new_point.coordinates is not None: - new_point.coordinates.clear() - a = self.ff_one / point_a[3] - b = point_a[2] * a - b_sq = b * b - new_point.append(point_a[0] * b_sq) - new_point.append(point_a[1] * a) - new_point.append(self.ff_one) - new_point.append(self.ff_one) - new_point.set_type(CoordinateSystemType.WEIERSTRASS_AFFINE) - return new_point - else: - return new_point - - -class ECCSTwistedEdwardsExtended(ECCSWeierstrass): - """Twisted Edwards Extended coordinate system.""" - - def __init__(self, config: Dict[str, Union[int, List[int]]]) -> None: - super().__init__(config) - self.type = CoordinateSystemType.WEIERSTRASS_PROJECTIVE - self.a = self.ff_zero.copy(config['a']) - self.d = self.ff_zero.copy(config['d']) - self.alpha = self.ff_zero.copy(config['alpha']) - self.s = self.ff_zero.copy(config['s']) - self.ma = self.ff_zero.copy(config['MA']) - self.mb = self.ff_zero.copy(config['MB']) - self.t = self.ff_zero.copy(config['t']) - self.k = self.d + self.d - - def generate_point( - self, - coordinates: Optional[Union[List[T], List[int]]] = None, - zero: bool = False, - twist: bool = True, - ) -> ECPoint[T]: - if twist: - ff_coordinates = [] - for coordinate in coordinates: - if isinstance(coordinate, FieldEle): - ff_coordinates.append(coordinate) - else: - ff_coordinates.append(self.ff_zero.copy(coordinate)) - coordinates = self.twist(ff_coordinates) - return ECPoint[T](coordinates, self, zero) - - def generate_formal_coordinates( - self, - coordinates: List[Union[int, FieldEle]], - zero: bool = False, - ) -> Optional[List[FieldEle]]: - if zero: - coordinates = self.zero() - coordinate_length = len(coordinates) - assert coordinate_length in {2, 4} - formal_coordinates: List[FieldEle] = [] - for coordinate in coordinates: - if isinstance(coordinate, FieldEle): - formal_coordinates.append(coordinate) - else: - formal_coordinates.append(self.ff_zero.copy(coordinate)) - - if coordinate_length == 2: - formal_coordinates.append(self.ff_one.copy()) - formal_coordinates.append(formal_coordinates[0] * formal_coordinates[1]) - - return formal_coordinates - - def point_add(self, point_a: ECPoint[FieldEle], point_b: ECPoint[FieldEle]): - # https://www.hyperelliptic.org/EFD/g1p/data/twisted/extended/addition/madd-2008-hwcd - # copied from arkworks_bls12_377 - x1, y1, z1, t1 = point_a - x2, y2, z2, t2 = point_b # 0, 1, 1, 0 - a = x1 * x2 # 0 - b = y1 * y2 # y1 - c = self.d * t1 * t2 # 0 - d = z1 * z2 # z1 - # h = b - (self.a * a) # self.a = -1 in standard form - h = b + a # y1 - # karatsuba - e = (x1 + y1) * (x2 + y2) - h # (x1 + y1) - y1 = x1 - f = d - c # z1 - g = d + c # z1 - x3 = e * f # x1 * z1 - y3 = g * h # y1 * z1 - z3 = f * g # z1 * z1 - t3 = e * h # x1 * y1 - - return ECPoint[FieldEle]([x3, y3, z3, t3], self) - - def double_general(self, point_a: ECPoint[FieldEle]): - # https://www.hyperelliptic.org/EFD/g1p/data/twisted/extended/doubling/dbl-2008-hwcd - # copied from arkworks_bls12_377 - x1, y1, z1, _ = point_a - a = x1 * x1 - b = y1 * y1 - ct = z1 * z1 - # d = self.a * a - # d = self.ff_zero - a - h = self.ff_zero - a - b - et = x1 + y1 - e = (et * et) + h - g = b - a - f = g - ct - ct - x3 = e * f - y3 = g * h - z3 = f * g - t3 = e * h - return ECPoint[FieldEle]([x3, y3, z3, t3], self) - - def point_lshift(self, point_a: ECPoint[FieldEle], shift: int): - if point_a.is_zero(): - return point_a.copy() - - for _ in range(shift): - point_a = self.double_general(point_a) - return point_a - - def twist(self, point_a: List[FieldEle]): - # https://en.wikipedia.org/wiki/Montgomery_curve - x, y = point_a - # Convert to montgomery - xm = self.s * (x - self.alpha) - ym = self.s * y - # Convert to edwards - if ym == self.ff_zero: - return None - xt = xm / ym - - yt_denom = xm + self.ff_one - if yt_denom == self.ff_zero: - return None - yt = (xm - self.ff_one) / (yt_denom) - - xt = xt * self.t - return [xt, yt] - - def untwist(self, point_a: List[FieldEle]): - xt, yt = point_a - xt = xt / self.t - # print("ut", xt, yt) - # Convert to montgomery - xm = (self.ff_one + yt) / ((self.ff_one - yt)) - ym = (self.ff_one + yt) / ((self.ff_one - yt) * xt) - # print("um", xm, ym) - # Convert to weierstrass - three = self.ff_zero.copy(3) - x = (xm / self.mb) + (self.ma / (three * self.mb)) - y = ym / self.mb - # print("uw", x, y) - return [x, y] - - def convert_from_affine(self, point_a): - # apply twist? - return NotImplementedError - - def convert_to_twisted_edwards_affine( - self, point_a: ECPoint[FieldEle] - ) -> ECPoint[FieldEle]: - new_point = point_a.copy() - if new_point.coordinates is not None: - new_point.coordinates.clear() - x, y, z = point_a[:3] - inv_z = self.ff_one / z - x2 = x * inv_z - y2 = y * inv_z - new_point.append(x2) - new_point.append(y2) - new_point.append(self.ff_one) - new_point.append(self.ff_one) - return new_point - - def convert_to_affine(self, point_a: ECPoint[FieldEle]) -> ECPoint[FieldEle]: - new_point = self.convert_to_twisted_edwards_affine(point_a) - twisted_coords = new_point[:2] - new_point[:2] = self.untwist(twisted_coords) - return new_point - - def zero(self): - return [self.ff_zero, self.ff_one, self.ff_one, self.ff_zero] - - def twist_int_coordinates(self, coordinates: List[int]) -> List[int]: - ff_coordinates = [ - self.ff_zero.copy(coordinate) for coordinate in coordinates - ] - ff_coordinates = self.twist(ff_coordinates) - if ff_coordinates is None: - return [0, 1, 1, 0] # zero - ff_coordinates.append(self.ff_one.copy()) - ff_coordinates.append(ff_coordinates[0] * ff_coordinates[1]) - return [int(coord.value) for coord in ff_coordinates] - - def twist_projective(self, point_a): - x, y, z = point_a - xm = self.s * (x - self.alpha) - ym = self.s * y - # Convert to edwards - # Common denominator - zt = z * ym * (xm + self.ff_one) - xt = xm * (xm + self.ff_one) - yt = (xm - self.ff_one) * ym - - xt = xt * self.t - return [xt, yt, zt] - - -class TwistedInvertedXYZ(ECCSWeierstrass): - - def convert_from_affine(self, point_a): - x, y = point_a - return [y, x, x * y] - - # Looks like the cheapest strongly unified three coordinate option - # https://www.hyperelliptic.org/EFD/g1p/data/twisted/inverted/addition/add-2008-bbjlp diff --git a/jaxite_ec/algorithm/finite_field.py b/jaxite_ec/algorithm/finite_field.py deleted file mode 100644 index 6a3ea36..0000000 --- a/jaxite_ec/algorithm/finite_field.py +++ /dev/null @@ -1,304 +0,0 @@ -"""Finite field elements for elliptic curve cryptography.""" - -import copy -import math - -from jaxite.jaxite_ec.algorithm import big_integer - -BigInt = big_integer.GMPBigInteger - - -class FiniteFieldElement: - """Finite field element for elliptic curve cryptography.""" - - def __init__(self, value, prime): - if isinstance(value, BigInt): - self.value = value - else: - self.value = BigInt(value) - - if isinstance(prime, BigInt): - self.prime = prime - else: - self.prime = BigInt(prime) - - if self.value < 0 or self.value >= self.prime: - raise ValueError(f"Value {self.value} not in range 0 to {self.prime - 1}") - - def set_value(self, value): - """Set the value of the finite field element, with validation.""" - if isinstance(value, BigInt): - new_value = value - else: - new_value = BigInt(value) - - if new_value < BigInt(0) or new_value >= self.prime: - raise ValueError(f"Value {new_value} not in range 0 to {self.prime - 1}") - - self.value = new_value - - def get_value(self) -> int: - return int(self.value) - - def get_prime(self): - return copy.deepcopy(self.prime) - - def copy(self, value=None, transform=False, reduction=False): - """Create a deep copy of the current finite field element. - - transform and reduction are only place holders for unified API. - - Args: - value (int): The value of the new finite field element. If None, the - value of the current finite field element is used. - transform (bool): Whether to transform the value of the new finite field - element. - reduction (bool): Whether to reduce the value of the new finite field - element. - - Returns: - FiniteFieldElement: A deep copy of the current finite field element. - """ - obj = copy.copy(self) - if not value: - obj.value = BigInt(int(self.value)) - else: - obj.value = BigInt(value) - return obj - - def __add__(self, other): - if self.prime != other.prime: - raise ValueError("Cannot add two numbers in different Fields") - result = (self.value + other.value) % self.prime - return FiniteFieldElement(result, self.prime.value) - - def __sub__(self, other): - if self.prime != other.prime: - raise ValueError("Cannot subtract two numbers in different Fields") - result = (self.value - other.value) % self.prime - return FiniteFieldElement(result, self.prime.value) - - def __mul__(self, other): - if self.prime != other.prime: - raise ValueError("Cannot multiply two numbers in different Fields") - result = (self.value * other.value) % self.prime - return FiniteFieldElement(result, self.prime.value) - - def __truediv__(self, other): - if self.prime != other.prime: - raise ValueError("Cannot divide two numbers in different Fields") - # Use Fermat's Little Theorem to find the inverse: a^(p-1) ≡ 1 (mod p) - inverse = other.value.__pow__(self.prime.value - 2, self.prime.value) - result = (self.value * inverse) % self.prime - return FiniteFieldElement(result, self.prime.value) - - def __pow__(self, exponent): - result = self.value.__pow__(exponent, self.prime.value) - return FiniteFieldElement(result.value, self.prime.value) - - def __eq__(self, other): - return self.value == other.value and self.prime == other.prime - - def __str__(self): - return f"FieldElement_{self.prime.value}({self.value.value})" - - def __repr__(self): - return ( - f"FiniteFieldElement(value={self.value.value}," - f" prime={self.prime.value})" - ) - - def __hex__(self): - return hex(int(self.value.value)) - - def hex_value_str(self) -> str: - return self.value.hex_value_str() - - -class FiniteFieldElementBarrett(FiniteFieldElement): - """Finite field element for elliptic curve cryptography using Barrett reduction.""" - - def __init__(self, value, prime, k=None): - super().__init__(value, prime) - if k == None: - if isinstance(prime, BigInt): - self.two_k = prime.ceil_log2() * 2 - else: - self.two_k = BigInt(math.ceil(math.log2(prime))) * 2 - else: - self.two_k = BigInt(2 * k) - self.mu = BigInt(2) ** self.two_k / prime - - def barrett_reduction(self, x): - # q = (x * mu) >> 2k - q = (x * self.mu) >> self.two_k - # r = x - q * prime - r = x - q * self.prime - if r >= self.prime: - r -= self.prime - return r - - def __add__(self, other): - if self.prime != other.prime: - raise ValueError("Cannot add two numbers in different Fields") - result = self.value + other.value - if result > self.prime: - result -= self.prime - new_instance = self.copy() - new_instance.value = result - return new_instance - - def __sub__(self, other): - if self.prime != other.prime: - raise ValueError("Cannot subtract two numbers in different Fields") - if self.value < other.value: - result = self.value + self.prime - other.value - else: - result = self.value - other.value - new_instance = self.copy() - new_instance.value = result - return new_instance - - def __mul__(self, other): - if self.prime != other.prime: - raise ValueError("Cannot multiply two numbers in different Fields") - result = self.value * other.value - reduced_result = self.barrett_reduction(result) - new_instance = self.copy() - new_instance.value = reduced_result - return new_instance - - def __truediv__(self, other): - if self.prime != other.prime: - raise ValueError("Cannot divide two numbers in different Fields") - inverse = other.value.__pow__(self.prime.value - 2, self.prime.value) - result = self.value * inverse - reduced_result = self.barrett_reduction(result) - new_instance = self.copy() - new_instance.value = reduced_result - return new_instance - - def copy(self, value=None, transform=False, reduction=False): - """Create a deep copy of the current finite field element.""" - obj = copy.copy(self) - if not value: - obj.value = BigInt(self.value) - else: - if reduction: - value = self.barrett_reduction(BigInt(value)) - obj.value = BigInt(value) - return obj - - -class FiniteFieldElementMontgomery(FiniteFieldElement): - """Finite field element for elliptic curve cryptography using Montgomery reduction.""" - - def __init__(self, value, prime, k=None): - super().__init__(value, prime) - if not k: - if isinstance(prime, BigInt): - self.k = prime.ceil_log2() - else: - self.k = BigInt(math.ceil(math.log2(prime))) - else: - self.k = BigInt(k) - - self.r = BigInt(2) ** self.k - # self.r_inverse = (self.r ** (self.prime - 2)) % self.prime - self.r_inverse = self.r.__pow__(self.prime - 2, self.prime) - self.n_prime = (self.r * self.r_inverse - 1) / self.prime - self.r_mask = self.r - 1 - self.value = self.montgomeryize(self.value) - self.montgomeryized = True - self.one_bar = self.montgomeryize(BigInt(1)) - - def montgomery_reduction(self, x): - m = ((x & self.r_mask) * self.n_prime) & self.r_mask - u = (x + m * self.prime) >> self.k - if u >= self.prime: - u -= self.prime - return u - - def montgomeryize(self, x): - x_bar = (x * self.r) % self.prime - return x_bar - - def de_montgomeryize(self, x_bar): - x = self.montgomery_reduction(x_bar) - return x - - def change_montgomery_form(self): - if self.montgomeryized: - self.value = self.de_montgomeryize(self.value) - self.montgomeryized = False - else: - self.value = self.montgomeryize(self.value) - self.montgomeryized = True - return self - - def __add__(self, other): - if self.prime != other.prime: - raise ValueError("Cannot add two numbers in different Fields") - assert other.montgomeryized - result = self.value + other.value - if result > self.prime: - result -= self.prime - new_instance = self.copy() - new_instance.value = result - return new_instance - - def __sub__(self, other): - if self.prime != other.prime: - raise ValueError("Cannot subtract two numbers in different Fields") - assert other.montgomeryized - if self.value < other.value: - result = self.value + self.prime - other.value - else: - result = self.value - other.value - new_instance = self.copy() - new_instance.value = result - return new_instance - - def __mul__(self, other): - if self.prime != other.prime: - raise ValueError("Cannot multiply two numbers in different Fields") - assert other.montgomeryized - result = self.value * other.value - reduced_result = self.montgomery_reduction(result) - new_instance = self.copy() - new_instance.value = reduced_result - return new_instance - - def __truediv__(self, other): - if self.prime != other.prime: - raise ValueError("Cannot divide two numbers in different Fields") - - if other.montgomeryized: - other_value = self.de_montgomeryize(other.value) - else: - other_value = other.value - - inverse = other_value.__pow__(self.prime.value - 2, self.prime.value) - - inverse_bar = self.montgomeryize(inverse) - - result = self.value * inverse_bar - reduced_result = self.montgomery_reduction(result) - new_instance = self.copy() - new_instance.value = reduced_result - return new_instance - - def copy(self, value=None, transform=False, reduction=False): - """Create a deep copy of the current finite field element.""" - obj = copy.copy(self) - if not value: - obj.value = BigInt(self.value) - else: - if reduction: - obj.value = self.montgomery_reduction(BigInt(value)) - elif transform: - obj.value = self.montgomeryize(BigInt(value)) - else: - obj.value = BigInt(value) - - return obj diff --git a/jaxite_ec/algorithm/lazy_reduction.py b/jaxite_ec/algorithm/lazy_reduction.py deleted file mode 100644 index a13b093..0000000 --- a/jaxite_ec/algorithm/lazy_reduction.py +++ /dev/null @@ -1,91 +0,0 @@ -"""This file implements the lazy reduction algorithm.""" - -import random - -import numpy as np - -randint = random.randint - - -def byte_len(x: int): - return (x.bit_length() + 7) // 8 - - -def chunk_decomposition(x, length=-1): - if length == -1: - length = byte_len(x) - return [(x >> (8 * i)) & 255 for i in range(length)] - - -# Input: a_list: (t+2) byte numbers for a -# b_list: (t+2) byte numbers for b -# p: modulus -# Output: c_list: (t_2) byte numbers for (ab % p) -def lazy_reduction_via_matrix(a_list, b_list, p): - """Performs lazy reduction via matrix multiplication. - - Args: - a_list: (t+2) byte numbers for a - b_list: (t+2) byte numbers for b - p: modulus - - Returns: - c_list: (t_2) byte numbers for (ab % p) - """ - t = len(chunk_decomposition(p)) - # Precomputation - lazy_mat = np.zeros((t + 4, t), dtype=np.uint8) - for i in range(t + 4): - val = (256 ** (t + i)) % p - lazy_mat[i, :] = chunk_decomposition(val, t) - - # Begin computation - assert len(a_list) == len(b_list) - batch_size = len(a_list) - batch_mat = np.zeros((batch_size, t + 4), dtype=np.uint8) - standard_product = [a_list[i] * b_list[i] for i in range(batch_size)] - standard_product_low = [s & (256**t - 1) for s in standard_product] - standard_product_high = [s >> (8 * t) for s in standard_product] - # Matrix packing - for i in range(batch_size): - batch_mat[i] = chunk_decomposition(standard_product_high[i], t + 4) - # Matrix product - # Upcast to get proper accumulators - reduced = np.matmul(batch_mat.astype(np.uint32), lazy_mat.astype(np.uint32)) - c_list = [] - # Recombine into integers - for i in range(batch_size): - val = 0 - for j in reversed(range(t)): - val *= 256 - val += int(reduced[i][j]) - c_list.append(val) - # Add in standard_product_low; this could be done in u8 form before the - # carry-add chain above, perhaps from the raw toeplitz output. - # Since we are using a redundant form the upper part doesn't have to be - # accurate. - c_list = [c_list[i] + standard_product_low[i] for i in range(batch_size)] - return c_list - - -def main(): - """This test case check the functionality of the lazy reduction algorithm.""" - p = randint(2**381, 2**384) - batch_size = 16 - bound = p * 256 * 256 - # a and b are both < bound - a_list = [randint(0, bound) for _ in range(batch_size)] - b_list = [randint(0, bound) for _ in range(batch_size)] - c_list = lazy_reduction_via_matrix(a_list, b_list, p) - for i in range(batch_size): - # each output is congruent modulo p - assert c_list[i] % p == (a_list[i] * b_list[i]) % p - # each output is < bound - assert c_list[i] < bound - print(a_list) - print(b_list) - print(c_list) - - -if __name__ == "__main__": - main() diff --git a/jaxite_ec/algorithm/rns_mod_reduce.py b/jaxite_ec/algorithm/rns_mod_reduce.py deleted file mode 100644 index 910d312..0000000 --- a/jaxite_ec/algorithm/rns_mod_reduce.py +++ /dev/null @@ -1,419 +0,0 @@ -"""Implementation of RNS modular reduction.""" - -import math -import random - -import jax -import jax.numpy as jnp - -randint = random.randint - - -jax.config.update("jax_enable_x64", True) -chunk_dtype = jnp.uint16 -mul_res_dtype = jnp.uint32 - - -def to_tuple(a): - """Create to convert numpy array into tuple.""" - try: - return tuple(to_tuple(i) for i in a) - except TypeError: - return a - - -def find_moduli(total_modulus, precision): - """Find moduli for RNS.""" - initial_moduli = 2**precision - overall_moduli = [] - constant_offset_list = [] - overall_modulus = 1 - for i in range(2 ** (precision >> 1) - 1): - cur_moduli = initial_moduli - i - if math.gcd(cur_moduli, overall_modulus) == 1: - overall_moduli.append(cur_moduli) - constant_offset_list.append(i) - overall_modulus *= cur_moduli - if overall_modulus > total_modulus: - return overall_moduli, constant_offset_list - - # Find 2**15 - v too - initial_moduli = 2 ** (precision - 1) - if overall_modulus < total_modulus: - for i in range(2 ** (precision >> 1) - 1): - cur_moduli = initial_moduli - i - if math.gcd(cur_moduli, overall_modulus) == 1: - overall_moduli.append(cur_moduli) - constant_offset_list.append(i << 1) - overall_modulus *= cur_moduli - if overall_modulus > total_modulus: - return overall_moduli, constant_offset_list - - return overall_moduli, constant_offset_list - - -def hardware_friendly_mod_reduce(x, moduli_t): - """Convert input value x into the RNS form. - - Args: - x: Input value as a jnp.uint32 array. - moduli_t: List of hardware friendly moduli_t. - - Returns: - The RNS representation of x as a jnp.uint16 array. - """ - assert x.dtype == jnp.uint32 - x_h = (x >> 16) & 0xFFFF - x_l = x & 0xFFFF - x_reduce = x_h * moduli_t + x_l - - x_h_sec = (x_reduce >> 16) & 0xFFFF - x_l_sec = x_reduce & 0xFFFF - x_reduce_sec = x_h_sec * moduli_t + x_l_sec - - x_h_third = (x_reduce_sec >> 16) & 0xFFFF - x_l_third = x_reduce_sec & 0xFFFF - x_reduce_third = x_h_third * moduli_t + x_l_third - return x_reduce_third.astype(jnp.uint16) - - -def to_rns(x, moduli): - return [x % m for m in moduli] - - -def rns_reconstruct(x, overall_moduli, icrt_factors): - x_reconstruct = [] - for value_rns in x: - big_int = 0 - for idx, tower in enumerate(value_rns): - big_int += int(tower) * int(icrt_factors[idx]) - x_reconstruct.append(big_int % overall_moduli) - return x_reconstruct - - -def rns_icrt_factors_compute(modulus, moduli): - precomputed = [] - for m in moduli: - rest = modulus // m # 0 mod all the other moduli - inverse = pow(rest % m, -1, m) # factor to make 1 mod this moduli - icrt_val = (rest * inverse) % modulus # combine - precomputed.append(icrt_val) - return precomputed - - -def rns_coefficients_precompute( - icrt_factors, - overall_moduli, - num_bytes, - moduli_precision, - overall_modulus, - q, -): - """Precompute RNS coefficients. - - Args: - icrt_factors: Precomputed inverse CRT factors. - overall_moduli: Array of moduli. - num_bytes: Number of bytes. - moduli_precision: Precision of the moduli. - overall_modulus: Overall modulus. - q: Target modulus. - - Returns: - Precomputed RNS coefficients and correction coefficients. - """ - num_residues = len(overall_moduli) - # icrt_factors_byteshifted -- (num_residues, num_bytes) - icrt_factors_byteshifted = [ - [ - (((1 << (8 * pre_id)) * factor) % overall_modulus) - for pre_id in range(num_bytes) - ] - for factor in icrt_factors - ] - # icrt_factors_byteshifted_modq -- (num_residues, num_bytes) - icrt_factors_byteshifted_modq = [ - [(chunk % q) for chunk in factors] for factors in icrt_factors_byteshifted - ] - # icrt_factors_byteshifted_modq_rns - # (num_residues, num_bytes, num_residues) [Convert each byte range into RNS] - icrt_factors_byteshifted_modq_rns = [ - [to_rns(chunk, overall_moduli) for chunk in factors] - for factors in icrt_factors_byteshifted_modq - ] - - rns_mat = jnp.array( - icrt_factors_byteshifted_modq_rns, dtype=jnp.uint16 - ).reshape(-1, num_residues) - - # calculate quotient estimation - fix_point = 1 << moduli_precision - - shifted_quotient_estimations = [] - for factors in icrt_factors_byteshifted: - for chunk in factors: - shifted_quotient_estimations.append( - [math.ceil((chunk * fix_point) / overall_modulus)] - ) - sqe_mat = jnp.array(shifted_quotient_estimations, dtype=jnp.uint16) - - cor_mat = jnp.array( - [to_rns(-overall_modulus % q, overall_moduli)], dtype=jnp.uint16 - ) - - # Convert rns_mat and sqe_mat into various bytes. - rns_mat_u8 = jax.lax.bitcast_convert_type(rns_mat, jnp.uint8) - seq_mat_u8 = jax.lax.bitcast_convert_type(sqe_mat, jnp.uint8) - rns_stack_mat_u8 = jnp.hstack(( - rns_mat_u8[..., 0], - seq_mat_u8[..., 0], - rns_mat_u8[..., 1], - seq_mat_u8[..., 1], - )) - return to_tuple(rns_stack_mat_u8.tolist()), to_tuple(cor_mat.tolist()) - - -def rns_mod_reduce( - data_a_rns, - data_b_rns, - moduli, - moduli_t, - rns_stack_mat_u8, - cor_mat, - icrt_factors, - overall_modulus, -): - """Performs RNS modular reduction. - - Args: - data_a_rns: First input in RNS form. - data_b_rns: Second input in RNS form. - moduli: Array of moduli. - moduli_t: Array of constant offsets. - rns_stack_mat_u8: Precomputed RNS coefficients. - cor_mat: Precomputed correction coefficients. - icrt_factors: Precomputed inverse CRT factors. - overall_modulus: Overall modulus. - - Returns: - Result of the modular reduction. - """ - num_residues = moduli.shape[0] - mul_res = jnp.multiply( - data_a_rns.astype(mul_res_dtype), data_b_rns.astype(mul_res_dtype) - ) - mul_res_tower_red = hardware_friendly_mod_reduce(mul_res, moduli_t) - - # Global Modular reduction - mul_res_glb_red = jnp.matmul( - mul_res_tower_red.view(jnp.uint8), - rns_stack_mat_u8, - preferred_element_type=mul_res_dtype, - ) - - mul_res_glb_red_u32_l, mul_res_glb_red_u32_h = jnp.split( - mul_res_glb_red, [num_residues + 1], axis=1 - ) - mul_res_glb_red_u32 = mul_res_glb_red_u32_l + (mul_res_glb_red_u32_h << 8) - rns_reduce_u32, qe_u32 = jnp.split( - mul_res_glb_red_u32, [num_residues], axis=1 - ) - - # obtain the high 16 bits from the quotient estimation results qe_u32 - c_corrected = rns_reduce_u32 + jnp.matmul( - qe_u32 >> moduli_precision, cor_mat, preferred_element_type=mul_res_dtype - ) - c_corrected_reduce = hardware_friendly_mod_reduce(c_corrected, moduli_t) - - result_rns = rns_reconstruct( - c_corrected_reduce.tolist(), overall_modulus, icrt_factors - ) - - return result_rns - - -if __name__ == "__main__": - ########################### - # User Configured Input - ########################### - q = 0x01AE3A4617C510EAC63B05C06CA1493B1A22D9F300F5138F1EF3622FBA094800170B5D44300000008508C00000000001 - moduli_precision = 16 - extra_bit_to_avoid_addition_overflow = 4 - num_bytes = moduli_precision // 8 - num_residues_for_q = ( - q.bit_length() + moduli_precision - 1 - ) // moduli_precision + 1 - overall_modulus = ( - ((q * 256 * num_bytes * num_residues_for_q) ** 2) - * extra_bit_to_avoid_addition_overflow - * 2 - ) - assert overall_modulus == q * q * 256 * 256 * 50 * 50 * 4 * 2 - num_elements = 1024 - - ########################### - # Offline Precompute - ########################### - - overall_moduli, constant_offset_list = find_moduli( - overall_modulus, moduli_precision - ) - assert overall_moduli == [ - 65536, - 65535, - 65533, - 65531, - 65527, - 65521, - 65519, - 65509, - 65503, - 65497, - 65491, - 65489, - 65479, - 65477, - 65473, - 65459, - 65449, - 65447, - 65437, - 65431, - 65423, - 65419, - 65413, - 65411, - 65407, - 65393, - 65383, - 65381, - 65371, - 65369, - 65363, - 65357, - 65353, - 65347, - 65339, - 65327, - 65323, - 65321, - 65311, - 65309, - 65293, - 65287, - 32761, - 32749, - 32743, - 32741, - 32719, - 32717, - 32713, - 32707, - ] - overall_modulus = 1 - for moduli in overall_moduli: - overall_modulus *= moduli - overall_modulus = int(overall_modulus) - assert len(overall_moduli) == ( - (overall_modulus.bit_length() + moduli_precision - 1) // moduli_precision - ) - icrt_factors = rns_icrt_factors_compute(overall_modulus, overall_moduli) - - # hardware friendly moduli is 2**precision - t - # moduli is the jax.array of "2**precision - t" - # moduli_t is the jax.array of "t" - rns_stack_mat_u8, cor_mat = rns_coefficients_precompute( - icrt_factors, - overall_moduli, - num_bytes, - moduli_precision, - overall_modulus, - q, - ) - assert cor_mat == ( - ( - 57491, - 26379, - 24673, - 4733, - 47122, - 11996, - 11119, - 12151, - 45048, - 10179, - 3397, - 45514, - 12274, - 62018, - 4316, - 141, - 20271, - 17626, - 20758, - 57875, - 41612, - 44321, - 30081, - 6090, - 16501, - 13984, - 14909, - 14581, - 47918, - 44932, - 34016, - 7605, - 33574, - 30236, - 15843, - 26521, - 52723, - 28347, - 32242, - 11676, - 31854, - 34463, - 30291, - 29806, - 1344, - 25148, - 23069, - 4869, - 6178, - 32502, - ), - ) - rns_stack_mat_u8 = jnp.array(rns_stack_mat_u8, dtype=jnp.uint8) - cor_mat = jnp.array(cor_mat, dtype=jnp.uint16) - - moduli = jnp.array(overall_moduli, dtype=mul_res_dtype) - moduli_t = jnp.array(constant_offset_list, dtype=chunk_dtype) - - ########################### - # Generate Random Data - ########################### - - random_data = [randint(0, q) for _ in range(num_elements)] - result_ref = [val * val % q for val in random_data] - data_rns = jnp.array( - [to_rns(ele, overall_moduli) for ele in random_data], jnp.uint32 - ) - - ########################### - # Compute Modular Reduction in RNS Form - ########################### - # Limb-wise modular multiplication - # (num_elements, num_residue) - result_rns = rns_mod_reduce( - data_rns, - data_rns, - moduli, - moduli_t, - rns_stack_mat_u8, - cor_mat, - icrt_factors, - overall_modulus, - ) - result_mod_q = [val % q for val in result_rns] - assert result_ref == result_mod_q diff --git a/jaxite_ec/algorithm/util.py b/jaxite_ec/algorithm/util.py deleted file mode 100644 index cfbef87..0000000 --- a/jaxite_ec/algorithm/util.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Utility functions for the Bukect distribution algorithm.""" - -import numpy as np - - -def bits_to_numpy_dtype(bits): - if bits == 8: - return np.uint8 - elif bits == 16: - return np.uint16 - elif bits == 32: - return np.uint32 - elif bits == 64: - return np.uint64 - else: - raise ValueError("Unsupported bit size. Use 8, 16, 32, or 64.") - - -def int_to_array(python_int, base, array_size=None) -> np.ndarray: - """Converts a Python integer to a JAX array. - - Args: - python_int: The Python integer to convert. - base: The number of bits per element in the integer. - array_size: The desired size of the resulting JAX array. If provided, the - integer will be padded or trimmed to match this size. - - Returns: - A JAX array containing the elements of the Python integer. - """ - chunks = [] - mask = ( - 1 << base - ) - 1 # Mask to extract the lower bits (e.g., 32 bits -> 0xFFFFFFFF) - dtype = bits_to_numpy_dtype(base) - # Extract each element from the integer - while python_int > 0: - chunks.append(python_int & mask) # Extract the lower bits - python_int >>= base # Shift to remove the extracted bits - - # If array_size is provided, we pad or trim the result to match the - # desired size - if array_size is not None: - assert array_size >= len(chunks) - chunks = chunks[:array_size] + [0] * (array_size - len(chunks)) - - # Convert the list to a JAX array - return np.array(chunks, dtype=dtype) - - -def array_to_int(np_array: np.ndarray, base): - # Initialize the result as a Python integer - result = 0 - - # Iterate over the elements in the array - for i, elem in enumerate(np_array): - # Convert each element to an int and shift it by the appropriate - # number of bits - result |= int(elem) << (i * base) - - return result - - -def int_list_to_array(int_list, base, array_size=None) -> np.ndarray: - chunked_arrays = [ - int_to_array(int_value, base, array_size) for int_value in int_list - ] - return np.array(chunked_arrays) diff --git a/jaxite_ec/c_kernels/BUILD b/jaxite_ec/c_kernels/BUILD new file mode 100644 index 0000000..f762ed2 --- /dev/null +++ b/jaxite_ec/c_kernels/BUILD @@ -0,0 +1,24 @@ +load("@rules_python//python:defs.bzl", "py_library") +# load("//third_party/bazel_rules/rules_cc/cc:cc_binary.bzl", "cc_binary") + +package( + default_applicable_licenses = ["@jaxite//:license"], + default_visibility = ["//visibility:public"], +) + +# cc_binary( +# name = "distribution.so", +# srcs = ["distribution.cpp"], +# features = ["-layering_check"], +# linkshared = 1, +# deps = [ +# "//third_party/tensorflow/compiler/xla/ffi/api:c_api", +# "//third_party/tensorflow/compiler/xla/ffi/api:ffi", +# ], +# ) + +py_library( + name = "build", + srcs = ["build.py"], + visibility = ["//visibility:public"], +) diff --git a/jaxite_ec/c_kernels/build.py b/jaxite_ec/c_kernels/build.py new file mode 100644 index 0000000..9b44d73 --- /dev/null +++ b/jaxite_ec/c_kernels/build.py @@ -0,0 +1,103 @@ +"""Python-side build driver for the distribution C kernel. + +:func:`ensure_distribution_kernel` compiles ``distribution.cpp`` (or returns +the cached build) and yields the absolute path of the resulting shared +library, so callers can ``ctypes.cdll.LoadLibrary`` it without a separate +build step. + +Default flags: ``-std=c++17 -fopenmp -O2 -fPIC -shared -I/include``. +``CXX`` and ``CXXFLAGS`` env vars override the compiler and flags. + +Run as a script (``python -m c_kernels.build``) to pre-build without importing +the rest of the package. +""" + +from __future__ import annotations + +import os +import pathlib +import shlex +import subprocess +import sys +import threading + +_KERNEL_DIR = pathlib.Path(__file__).parent +_SRC_PATH = _KERNEL_DIR / "distribution.cpp" +_LIB_PATH = _KERNEL_DIR / "distribution.so" + +_DEFAULT_CXXFLAGS = ("-std=c++17", "-fopenmp", "-O2", "-fPIC", "-shared") + +_BUILD_LOCK = threading.Lock() + + +def _jaxlib_include_dir() -> pathlib.Path: + import jaxlib # pytype: disable=import-error + + return pathlib.Path(jaxlib.__file__).resolve().parent / "include" + + +def _needs_rebuild() -> bool: + if not _LIB_PATH.exists(): + return True + if not _SRC_PATH.exists(): + return False + return _SRC_PATH.stat().st_mtime > _LIB_PATH.stat().st_mtime + + +def _build() -> None: + cxx = os.environ.get("CXX", "g++") + flags_env = os.environ.get("CXXFLAGS") + cxxflags = shlex.split(flags_env) if flags_env else list(_DEFAULT_CXXFLAGS) + cmd = [ + cxx, + *cxxflags, + f"-I{_jaxlib_include_dir()}", + str(_SRC_PATH), + "-o", + str(_LIB_PATH), + ] + print( + f"[c_kernels.build] compiling: {' '.join(shlex.quote(c) for c in cmd)}", + file=sys.stderr, + ) + subprocess.run(cmd, check=True, cwd=_KERNEL_DIR) + + +def ensure_distribution_kernel(force: bool = False) -> str: + """Compile :file:`distribution.cpp` on demand and return the .so path. + + Args: + force: rebuild even if the cached library is up to date. + + Returns: + Absolute path to ``distribution.so``. + + Raises: + FileNotFoundError: if pre-built kernel not found in test environment. + """ + if "TEST_SRCDIR" in os.environ: + runfiles_dir = pathlib.Path(os.environ["TEST_SRCDIR"]) + lib_path = ( + runfiles_dir + / "google3/third_party/py/jaxite/jaxite_ec/c_kernels/distribution.so" + ) + if lib_path.exists(): + return str(lib_path) + raise FileNotFoundError( + f"Pre-built kernel not found in runfiles at {lib_path}" + ) + + if "UNITTEST_ON_FORGE" in os.environ: + if _LIB_PATH.exists(): + return str(_LIB_PATH) + raise FileNotFoundError(f"Pre-built kernel not found at {_LIB_PATH}") + + with _BUILD_LOCK: + if force or _needs_rebuild(): + _build() + return str(_LIB_PATH) + + +if __name__ == "__main__": + path = ensure_distribution_kernel(force="--force" in sys.argv[1:]) + print(path) diff --git a/jaxite_ec/c_kernels/distribution.cpp b/jaxite_ec/c_kernels/distribution.cpp new file mode 100644 index 0000000..e5644ba --- /dev/null +++ b/jaxite_ec/c_kernels/distribution.cpp @@ -0,0 +1,1151 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef _OPENMP +#include +#endif +#include + +#include "third_party/tensorflow/compiler/xla/ffi/api/c_api.h" +#include "third_party/tensorflow/compiler/xla/ffi/api/ffi.h" + +namespace ffi = xla::ffi; +#define COORDINATE_NUM 4 +#define CHUNK_NUM 32 +#define RESERVE_RATIO 2.0 +#define SCALAR_BITS 253 +#define ORDER_HIGH 0x12AB655E +#define ORDER_LOW_BITS 224 +// #define PROFILE + +typedef struct { + uint32_t chunks[CHUNK_NUM * COORDINATE_NUM]; +} point_t; + +typedef struct { + // Parameters for the distributor + uint32_t window_num; + uint32_t regular_bucket_num; + uint32_t special_bucket_num; + uint32_t msm_length; + uint32_t fixed_regular_padding_size; + uint32_t fixed_special_padding_size; + + // Inputs buffers + uint32_t* zero; + int32_t* slices_list; + uint32_t* points_list; + uint32_t* neg_points_list; + + // Output buffers + uint32_t* regular_buckets; + uint32_t* special_buckets; +} distributor_params_t; + +class Distributor { + public: + Distributor(uint32_t window_num, uint32_t regular_bucket_num, + uint32_t special_bucket_num, uint32_t msm_length) + : window_num_(window_num), + regular_bucket_num_(regular_bucket_num), + special_bucket_num_(special_bucket_num), + msm_length_(msm_length) { + this->reserve_ratio_ = RESERVE_RATIO; + this->windows_.resize(window_num); + this->fixed_regular_padding_size_ = 0; + this->fixed_special_padding_size_ = 0; + this->truncated = 0; + } + + void set_slices_list(int32_t* slices_list) { + this->slices_list_pointers_.resize(this->window_num_); + for (uint32_t i = 0; i < this->window_num_; ++i) { + slices_list_pointers_[i] = slices_list + i * this->msm_length_; + } + } + + void set_output_buffers(uint32_t* regular_buckets, + uint32_t* special_buckets) { + this->regular_buckets_ = regular_buckets; + this->special_buckets_ = special_buckets; + this->bucket_pointers_.resize(window_num_); + this->bucket_sizes_.resize(window_num_); + uint32_t regular_bucket_capacity_in_bytes = + fixed_regular_padding_size_ * sizeof(point_t); + uint32_t special_bucket_capacity_in_bytes = + fixed_special_padding_size_ * sizeof(point_t); + uint32_t regular_window_capacity_in_bytes = + regular_bucket_num_ * regular_bucket_capacity_in_bytes; + // uint32_t special_window_capacity_in_bytes = special_bucket_num_ * + // special_bucket_capacity_in_bytes; + + // Initialize the bucket pointers and sizes + for (uint32_t i = 0; i < window_num_ - 1; ++i) { + this->bucket_pointers_[i].resize(regular_bucket_num_); + this->bucket_sizes_[i].resize(regular_bucket_num_); + for (uint32_t j = 0; j < this->bucket_pointers_[i].size(); ++j) { + this->bucket_sizes_[i][j] = 0; + this->bucket_pointers_[i][j] = reinterpret_cast( + reinterpret_cast(regular_buckets_) + + i * regular_window_capacity_in_bytes + + j * regular_bucket_capacity_in_bytes); + } + } + // For the last window, which has special buckets + this->bucket_pointers_[window_num_ - 1].resize(special_bucket_num_); + this->bucket_sizes_[window_num_ - 1].resize(special_bucket_num_); + for (uint32_t j = 0; j < this->bucket_pointers_[window_num_ - 1].size(); + ++j) { + this->bucket_sizes_[window_num_ - 1][j] = 0; + this->bucket_pointers_[window_num_ - 1][j] = reinterpret_cast( + reinterpret_cast(special_buckets_) + + j * special_bucket_capacity_in_bytes); + } + } + + void set_points_list(uint32_t* points_list) { + this->points_list_ = reinterpret_cast(points_list); + } + + void set_neg_points_list(uint32_t* neg_points_list) { + this->neg_points_list_ = reinterpret_cast(neg_points_list); + } + + void set_zeros(uint32_t* zero) { + this->zero_ = reinterpret_cast(zero); + } + + void set_fixed_padding_size(uint32_t regular_size, uint32_t special_size) { + this->fixed_regular_padding_size_ = regular_size; + this->fixed_special_padding_size_ = special_size; + } + + void set_tile_length(uint32_t tile_length) { + this->tile_length_ = tile_length; + assert(msm_length_ % tile_length == 0); + this->tile_num_ = msm_length_ / tile_length; + } + + void distribute_a_window(int32_t* slices, point_t* points, + uint32_t bucket_num, + std::vector>& buckets, + uint32_t& max_size) { + // Implement the distribution logic for a single window here + buckets.resize(bucket_num); + uint32_t reserve_num = + static_cast(msm_length_ * reserve_ratio_ / bucket_num); + for (uint32_t i = 0; i < bucket_num; ++i) { + buckets[i].reserve(reserve_num); + } + + for (uint32_t i = 0; i < msm_length_; ++i) { + if (slices[i] == 0) { + continue; + } + uint32_t bucket_index = slices[i] - 1; + buckets[bucket_index].push_back(points[i]); + } + // max_size = 0; + for (uint32_t i = 0; i < bucket_num; ++i) { + if (buckets[i].size() > max_size) { + max_size = buckets[i].size(); + } + } + } + + void distribute_a_window_to_buffer(int32_t* slices, point_t* points, + uint32_t tile_length, + uint32_t max_bucket_size, + std::vector& buckets, + std::vector& bucket_sizes, + uint32_t& overflow) { + for (uint32_t i = 0; i < tile_length; ++i) { + uint32_t prefetch_distance = 32; // Adjust prefetch distance as needed + if (i + prefetch_distance < tile_length) { + __builtin_prefetch(&slices[i + prefetch_distance], 0, 1); + __builtin_prefetch(&points[i + prefetch_distance], 0, 1); + } + if (slices[i] == 0) { + continue; + } + int32_t bucket_index = slices[i] - 1; + if (bucket_index >= buckets.size()) { + std::cerr << "Error: bucket index out of range." << std::endl; + continue; + } + if (bucket_sizes[bucket_index] < max_bucket_size) { + buckets[bucket_index][bucket_sizes[bucket_index]] = points[i]; + bucket_sizes[bucket_index]++; + } else { + overflow++; + } + } + } + + void distribute_a_window_to_buffer_with_sign( + int32_t* slices, point_t* points, point_t* neg_points, + uint32_t tile_length, uint32_t max_bucket_size, + std::vector& buckets, std::vector& bucket_sizes, + uint32_t& overflow) { + for (uint32_t i = 0; i < tile_length; ++i) { + if (slices[i] == 0) { + continue; + } + point_t* points_to_use = (slices[i] > 0) ? &points[i] : &neg_points[i]; + int32_t bucket_index = abs(slices[i]) - 1; + if (bucket_index >= buckets.size()) { + std::cerr << "Error: bucket index out of range." << std::endl; + continue; + } + if (bucket_sizes[bucket_index] < max_bucket_size) { + buckets[bucket_index][bucket_sizes[bucket_index]] = *points_to_use; + bucket_sizes[bucket_index]++; + } else { + overflow++; + } + } + } + + void pad_a_window(point_t* zero, uint32_t target_size, + std::vector>& buckets) { + for (uint32_t i = 0; i < buckets.size(); ++i) { + uint32_t current_size = buckets[i].size(); + // assert(current_size <= target_size); + if (current_size <= target_size) { + uint32_t pad_num = target_size - current_size; + buckets[i].insert(buckets[i].end(), pad_num, *zero); + } else { + // If the current size exceeds the target size, truncate it + buckets[i].resize(target_size); + } + } + } + + void pad_a_window_to_buffer(point_t* zero, uint32_t target_size, + std::vector& buckets, + std::vector& bucket_sizes) { + for (uint32_t i = 0; i < buckets.size(); ++i) { + uint32_t current_size = bucket_sizes[i]; + // assert(current_size <= target_size); + if (current_size < target_size) { + uint32_t pad_num = target_size - current_size; + for (uint32_t j = 0; j < pad_num; ++j) { + buckets[i][current_size + j] = *zero; + } + bucket_sizes[i] += pad_num; + } + } + } + + void distribute() { + uint32_t max_regular_bucket_size = 0; + auto start = std::chrono::high_resolution_clock::now(); + auto end = std::chrono::high_resolution_clock::now(); + auto duration_us = + std::chrono::duration_cast(end - start) + .count(); + start = std::chrono::high_resolution_clock::now(); + for (uint32_t i = 0; i < window_num_ - 1; ++i) { + int32_t* slices = slices_list_pointers_[i]; + distribute_a_window(slices, points_list_, regular_bucket_num_, + windows_[i], max_regular_bucket_size); + } + uint32_t max_special_bucket_size = 0; + auto end2 = std::chrono::high_resolution_clock::now(); + distribute_a_window(slices_list_pointers_[window_num_ - 1], points_list_, + special_bucket_num_, windows_[window_num_ - 1], + max_special_bucket_size); + end = std::chrono::high_resolution_clock::now(); + + duration_us = + std::chrono::duration_cast(end - start) + .count(); + auto duration2_us = + std::chrono::duration_cast(end2 - start) + .count(); + std::cout << "dist_r time: " << duration2_us << " us" << std::endl; + std::cout << "dist time: " << duration_us << " us" << std::endl; + + this->real_regular_padding_size_ = max_regular_bucket_size; + this->real_special_padding_size_ = max_special_bucket_size; + + if (this->fixed_regular_padding_size_ > 0) { + if (this->fixed_regular_padding_size_ < max_regular_bucket_size) { + this->truncated = 1; + } + // assert(this->fixed_regular_padding_size_ >= max_regular_bucket_size); + max_regular_bucket_size = this->fixed_regular_padding_size_; + } + if (this->fixed_special_padding_size_ > 0) { + if (this->fixed_special_padding_size_ < max_special_bucket_size) { + this->truncated = 1; + } + // assert(this->fixed_special_padding_size_ >= max_special_bucket_size); + max_special_bucket_size = this->fixed_special_padding_size_; + } + + start = std::chrono::high_resolution_clock::now(); + for (uint32_t i = 0; i < window_num_; ++i) { + uint32_t target_size = (i == window_num_ - 1) ? max_special_bucket_size + : max_regular_bucket_size; + pad_a_window(zero_, target_size, windows_[i]); + } + end = std::chrono::high_resolution_clock::now(); + duration_us = + std::chrono::duration_cast(end - start) + .count(); + std::cout << "pad time: " << duration_us << " us" << std::endl; + } + + void distribute_parallel_v1() { + uint32_t max_regular_bucket_size = 0; + std::vector max_regular_bucket_sizes(window_num_ - 1, 0); + auto start = std::chrono::high_resolution_clock::now(); + auto end = std::chrono::high_resolution_clock::now(); + auto duration_us = + std::chrono::duration_cast(end - start) + .count(); + start = std::chrono::high_resolution_clock::now(); +#pragma omp parallel for + for (uint32_t i = 0; i < window_num_ - 1; ++i) { + int32_t* slices = slices_list_pointers_[i]; + distribute_a_window(slices, points_list_, regular_bucket_num_, + windows_[i], max_regular_bucket_sizes[i]); + } + auto end2 = std::chrono::high_resolution_clock::now(); + for (uint32_t i = 0; i < window_num_ - 1; ++i) { + if (max_regular_bucket_sizes[i] > max_regular_bucket_size) { + max_regular_bucket_size = max_regular_bucket_sizes[i]; + } + } + uint32_t max_special_bucket_size = 0; + + distribute_a_window(slices_list_pointers_[window_num_ - 1], points_list_, + special_bucket_num_, windows_[window_num_ - 1], + max_special_bucket_size); + this->real_regular_padding_size_ = max_regular_bucket_size; + this->real_special_padding_size_ = max_special_bucket_size; + end = std::chrono::high_resolution_clock::now(); + + duration_us = + std::chrono::duration_cast(end - start) + .count(); + auto duration2_us = + std::chrono::duration_cast(end2 - start) + .count(); + std::cout << "dist_r time: " << duration2_us << " us" << std::endl; + std::cout << "dist time: " << duration_us << " us" << std::endl; + + if (this->fixed_regular_padding_size_ > 0) { + if (this->fixed_regular_padding_size_ < max_regular_bucket_size) { + this->truncated = 1; + } + // assert(this->fixed_regular_padding_size_ >= max_regular_bucket_size); + max_regular_bucket_size = this->fixed_regular_padding_size_; + } + if (this->fixed_special_padding_size_ > 0) { + if (this->fixed_special_padding_size_ < max_special_bucket_size) { + this->truncated = 1; + } + // assert(this->fixed_special_padding_size_ >= + // max_special_bucket_size); + max_special_bucket_size = this->fixed_special_padding_size_; + } + + start = std::chrono::high_resolution_clock::now(); +#pragma omp parallel for + for (uint32_t i = 0; i < window_num_; ++i) { + uint32_t target_size = (i == window_num_ - 1) ? max_special_bucket_size + : max_regular_bucket_size; + pad_a_window(zero_, target_size, windows_[i]); + } + end = std::chrono::high_resolution_clock::now(); + duration_us = + std::chrono::duration_cast(end - start) + .count(); + std::cout << "pad time: " << duration_us << " us" << std::endl; + } + + void distribute_a_window_to_buffer_zero_indexed( + int32_t* slices, point_t* points, uint32_t tile_length, + uint32_t max_bucket_size, std::vector& buckets, + std::vector& bucket_sizes, uint32_t& overflow) { + for (uint32_t i = 0; i < tile_length; ++i) { + uint32_t prefetch_distance = 32; + if (i + prefetch_distance < tile_length) { + __builtin_prefetch(&slices[i + prefetch_distance], 0, 1); + __builtin_prefetch(&points[i + prefetch_distance], 0, 1); + } + if (slices[i] <= 0) { + continue; + } + int32_t bucket_index = slices[i]; + if (bucket_index >= static_cast(buckets.size())) { + std::cerr << "Error: bucket index out of range." << std::endl; + continue; + } + if (bucket_sizes[bucket_index] < max_bucket_size) { + buckets[bucket_index][bucket_sizes[bucket_index]] = points[i]; + bucket_sizes[bucket_index]++; + } else { + overflow++; + } + } + } + + void distribute_to_buffer_zero_indexed_parallel() { + std::vector bucket_overflow(window_num_, 0); +#ifdef PROFILE + auto start = std::chrono::high_resolution_clock::now(); + auto end = std::chrono::high_resolution_clock::now(); + auto duration_us = + std::chrono::duration_cast(end - start) + .count(); + start = std::chrono::high_resolution_clock::now(); +#endif +#pragma omp parallel for + for (uint32_t i = 0; i < window_num_ - 1; ++i) { + int32_t* slices = slices_list_pointers_[i]; + point_t* points = points_list_; + std::vector& buckets = bucket_pointers_[i]; + std::vector& bucket_sizes = bucket_sizes_[i]; + uint32_t max_bucket_size = this->fixed_regular_padding_size_; + distribute_a_window_to_buffer_zero_indexed( + slices, points, msm_length_, max_bucket_size, buckets, bucket_sizes, + bucket_overflow[i]); + } +#ifdef PROFILE + auto end2 = std::chrono::high_resolution_clock::now(); +#endif + + distribute_a_window_to_buffer_zero_indexed( + slices_list_pointers_[window_num_ - 1], points_list_, msm_length_, + this->fixed_special_padding_size_, bucket_pointers_[window_num_ - 1], + bucket_sizes_[window_num_ - 1], bucket_overflow[window_num_ - 1]); + +#ifdef PROFILE + end = std::chrono::high_resolution_clock::now(); + duration_us = + std::chrono::duration_cast(end - start) + .count(); + auto duration2_us = + std::chrono::duration_cast(end2 - start) + .count(); + std::cout << "dist_r time: " << duration2_us << " us" << std::endl; + std::cout << "dist time: " << duration_us << " us" << std::endl; +#endif + + for (uint32_t i = 0; i < window_num_; ++i) { + if (bucket_overflow[i] > 0) { + this->truncated = 1; + break; + } + } + +#ifdef PROFILE + start = std::chrono::high_resolution_clock::now(); +#endif +#pragma omp parallel for + for (uint32_t i = 0; i < window_num_; ++i) { + uint32_t target_size = (i == window_num_ - 1) + ? this->fixed_special_padding_size_ + : this->fixed_regular_padding_size_; + pad_a_window_to_buffer(zero_, target_size, bucket_pointers_[i], + bucket_sizes_[i]); + } +#ifdef PROFILE + end = std::chrono::high_resolution_clock::now(); + duration_us = + std::chrono::duration_cast(end - start) + .count(); + std::cout << "pad time: " << duration_us << " us" << std::endl; +#endif + } + + void distribute_to_buffer_parallel() { + std::vector bucket_overflow(window_num_, 0); +#ifdef PROFILE + auto start = std::chrono::high_resolution_clock::now(); + auto end = std::chrono::high_resolution_clock::now(); + auto duration_us = + std::chrono::duration_cast(end - start) + .count(); + start = std::chrono::high_resolution_clock::now(); +#endif +#pragma omp parallel for + for (uint32_t i = 0; i < window_num_ - 1; ++i) { + int32_t* slices = slices_list_pointers_[i]; + point_t* points = points_list_; + std::vector& buckets = bucket_pointers_[i]; + std::vector& bucket_sizes = bucket_sizes_[i]; + uint32_t max_bucket_size = this->fixed_regular_padding_size_; + distribute_a_window_to_buffer(slices, points, msm_length_, + max_bucket_size, buckets, bucket_sizes, + bucket_overflow[i]); + } +#ifdef PROFILE + auto end2 = std::chrono::high_resolution_clock::now(); +#endif + + distribute_a_window_to_buffer( + slices_list_pointers_[window_num_ - 1], points_list_, msm_length_, + this->fixed_special_padding_size_, bucket_pointers_[window_num_ - 1], + bucket_sizes_[window_num_ - 1], bucket_overflow[window_num_ - 1]); + +#ifdef PROFILE + end = std::chrono::high_resolution_clock::now(); + duration_us = + std::chrono::duration_cast(end - start) + .count(); + auto duration2_us = + std::chrono::duration_cast(end2 - start) + .count(); + std::cout << "dist_r time: " << duration2_us << " us" << std::endl; + std::cout << "dist time: " << duration_us << " us" << std::endl; +#endif + + for (uint32_t i = 0; i < window_num_; ++i) { + if (bucket_overflow[i] > 0) { + this->truncated = 1; + break; + } + } + +#ifdef PROFILE + start = std::chrono::high_resolution_clock::now(); +#endif +#pragma omp parallel for + for (uint32_t i = 0; i < window_num_; ++i) { + uint32_t target_size = (i == window_num_ - 1) + ? this->fixed_special_padding_size_ + : this->fixed_regular_padding_size_; + pad_a_window_to_buffer(zero_, target_size, bucket_pointers_[i], + bucket_sizes_[i]); + } +#ifdef PROFILE + end = std::chrono::high_resolution_clock::now(); + duration_us = + std::chrono::duration_cast(end - start) + .count(); + std::cout << "pad time: " << duration_us << " us" << std::endl; +#endif + } + + void distribute_to_buffer_signed_parallel() { + std::vector bucket_overflow(window_num_, 0); +#ifdef PROFILE + auto start = std::chrono::high_resolution_clock::now(); + auto end = std::chrono::high_resolution_clock::now(); + auto duration_us = + std::chrono::duration_cast(end - start) + .count(); + start = std::chrono::high_resolution_clock::now(); +#endif +#pragma omp parallel for + for (uint32_t i = 0; i < window_num_ - 1; ++i) { + int32_t* slices = slices_list_pointers_[i]; + point_t* points = points_list_; + point_t* neg_points = neg_points_list_; + std::vector& buckets = bucket_pointers_[i]; + std::vector& bucket_sizes = bucket_sizes_[i]; + uint32_t max_bucket_size = this->fixed_regular_padding_size_; + distribute_a_window_to_buffer_with_sign( + slices, points, neg_points, msm_length_, max_bucket_size, buckets, + bucket_sizes, bucket_overflow[i]); + } +#ifdef PROFILE + auto end2 = std::chrono::high_resolution_clock::now(); +#endif + + distribute_a_window_to_buffer_with_sign( + slices_list_pointers_[window_num_ - 1], points_list_, neg_points_list_, + msm_length_, this->fixed_special_padding_size_, + bucket_pointers_[window_num_ - 1], bucket_sizes_[window_num_ - 1], + bucket_overflow[window_num_ - 1]); + +#ifdef PROFILE + end = std::chrono::high_resolution_clock::now(); + duration_us = + std::chrono::duration_cast(end - start) + .count(); + auto duration2_us = + std::chrono::duration_cast(end2 - start) + .count(); + std::cout << "dist_r time: " << duration2_us << " us" << std::endl; + std::cout << "dist time: " << duration_us << " us" << std::endl; +#endif + + for (uint32_t i = 0; i < window_num_; ++i) { + if (bucket_overflow[i] > 0) { + this->truncated = 1; + break; + } + } + +#ifdef PROFILE + start = std::chrono::high_resolution_clock::now(); +#endif +#pragma omp parallel for + for (uint32_t i = 0; i < window_num_; ++i) { + uint32_t target_size = (i == window_num_ - 1) + ? this->fixed_special_padding_size_ + : this->fixed_regular_padding_size_; + pad_a_window_to_buffer(zero_, target_size, bucket_pointers_[i], + bucket_sizes_[i]); + } +#ifdef PROFILE + end = std::chrono::high_resolution_clock::now(); + duration_us = + std::chrono::duration_cast(end - start) + .count(); + std::cout << "pad time: " << duration_us << " us" << std::endl; +#endif + } + + void distribute_to_buffer_parallel_v2() { + std::vector bucket_overflow(window_num_, 0); +#ifdef PROFILE + auto start = std::chrono::high_resolution_clock::now(); + auto end = std::chrono::high_resolution_clock::now(); + auto duration_us = + std::chrono::duration_cast(end - start) + .count(); + start = std::chrono::high_resolution_clock::now(); +#endif + for (uint32_t t = 0; t < tile_num_; ++t) { +// Distribute each tile to the corresponding window +#pragma omp parallel for + for (uint32_t i = 0; i < window_num_ - 1; ++i) { + int32_t* slices = slices_list_pointers_[i] + t * tile_length_; + point_t* points = points_list_ + t * tile_length_; + std::vector& buckets = bucket_pointers_[i]; + std::vector& bucket_sizes = bucket_sizes_[i]; + uint32_t max_bucket_size = this->fixed_regular_padding_size_; + distribute_a_window_to_buffer(slices, points, tile_length_, + max_bucket_size, buckets, bucket_sizes, + bucket_overflow[i]); + } + } +#ifdef PROFILE + auto end2 = std::chrono::high_resolution_clock::now(); +#endif + + distribute_a_window_to_buffer( + slices_list_pointers_[window_num_ - 1], points_list_, msm_length_, + this->real_special_padding_size_, bucket_pointers_[window_num_ - 1], + bucket_sizes_[window_num_ - 1], bucket_overflow[window_num_ - 1]); + +#ifdef PROFILE + end = std::chrono::high_resolution_clock::now(); + duration_us = + std::chrono::duration_cast(end - start) + .count(); + auto duration2_us = + std::chrono::duration_cast(end2 - start) + .count(); + std::cout << "dist_r time: " << duration2_us << " us" << std::endl; + std::cout << "dist time: " << duration_us << " us" << std::endl; +#endif + + for (uint32_t i = 0; i < window_num_; ++i) { + if (bucket_overflow[i] > 0) { + this->truncated = 1; + break; + } + } + +#ifdef PROFILE + start = std::chrono::high_resolution_clock::now(); +#endif +#pragma omp parallel for + for (uint32_t i = 0; i < window_num_; ++i) { + uint32_t target_size = (i == window_num_ - 1) + ? this->fixed_special_padding_size_ + : this->fixed_regular_padding_size_; + pad_a_window_to_buffer(zero_, target_size, bucket_pointers_[i], + bucket_sizes_[i]); + } +#ifdef PROFILE + end = std::chrono::high_resolution_clock::now(); + duration_us = + std::chrono::duration_cast(end - start) + .count(); + std::cout << "pad time: " << duration_us << " us" << std::endl; +#endif + } + + void get_merged_regular_buckets(uint32_t* regular_buckets) { + uint32_t offset = 0; + for (uint32_t i = 0; i < window_num_ - 1; ++i) { + uint32_t bucket_size = windows_[i].size(); + for (uint32_t j = 0; j < bucket_size; ++j) { + std::memcpy(regular_buckets + offset, windows_[i][j].data(), + windows_[i][j].size() * sizeof(point_t)); + offset += windows_[i][j].size() * CHUNK_NUM * COORDINATE_NUM; + } + } + } + + void get_merged_special_buckets(uint32_t* special_buckets) { + uint32_t offset = 0; + uint32_t window_idx = window_num_ - 1; + uint32_t bucket_size = windows_[window_num_ - 1].size(); + for (uint32_t j = 0; j < bucket_size; ++j) { + std::memcpy(special_buckets + offset, windows_[window_idx][j].data(), + windows_[window_idx][j].size() * sizeof(point_t)); + offset += windows_[window_idx][j].size() * CHUNK_NUM * COORDINATE_NUM; + } + } + + void get_merged_metadata(uint32_t* metadata) { + metadata[0] = real_regular_padding_size_; + metadata[1] = real_special_padding_size_; + } + + private: + uint32_t window_num_; + uint32_t regular_bucket_num_; + uint32_t special_bucket_num_; + uint32_t msm_length_; + + uint32_t tile_length_; + uint32_t tile_num_; + float reserve_ratio_; + + std::vector slices_list_pointers_; + point_t* points_list_; + point_t* neg_points_list_; + point_t* zero_; + std::vector>> windows_; + + uint32_t fixed_regular_padding_size_; + uint32_t real_regular_padding_size_; + uint32_t fixed_special_padding_size_; + uint32_t real_special_padding_size_; + + uint32_t* regular_buckets_; + uint32_t* special_buckets_; + // windows + std::vector> bucket_pointers_; + std::vector> bucket_sizes_; + + /*When truncated is true, the reuslt will be incorrect. + It is only for perfmance profiling goal.*/ + uint32_t truncated; + + /* For profiling*/ +}; + +distributor_params_t init_distributor_param(uint32_t slice_length, + uint32_t msm_length, + double buf_extend_ratio, + bool signed_bucket) { + // General parameter initialization + uint32_t window_num = (SCALAR_BITS + slice_length - 1) / slice_length; + uint32_t regular_bucket_num = (1 << slice_length) - 1; // 2^slice_length - 1 + uint32_t shift_bits = ((window_num - 1) * slice_length) - ORDER_LOW_BITS; + assert(shift_bits >= 0); + uint32_t special_bucket_num = ORDER_HIGH >> shift_bits; + if (signed_bucket) { + regular_bucket_num = 1 << (slice_length - 1); // 2^(slice_length - 1) + special_bucket_num++; + assert(special_bucket_num <= regular_bucket_num); + } + + // Special bucket optoimization parameter initialization + uint32_t log_special_duplication_ratio = static_cast(std::ceil( + std::log2(static_cast(regular_bucket_num) / special_bucket_num))); + uint32_t special_duplication_ratio = 1U << log_special_duplication_ratio; + uint32_t bucket_num_duplication = + special_bucket_num * special_duplication_ratio; + + // Output buffer parameter initialization + double expected_regular_bucket_size = + static_cast(msm_length) / (regular_bucket_num + 1); + double expected_special_bucket_size = + static_cast(msm_length) / bucket_num_duplication; + uint32_t regular_bucket_size = + static_cast(expected_regular_bucket_size * buf_extend_ratio); + uint32_t special_bucket_size = + static_cast(expected_special_bucket_size * buf_extend_ratio); + uint32_t regular_buffer_size_in_U32 = (window_num - 1) * regular_bucket_num * + regular_bucket_size * COORDINATE_NUM * + CHUNK_NUM; + uint32_t special_buffer_size_in_U32 = + special_bucket_num * special_duplication_ratio * special_bucket_size * + COORDINATE_NUM * CHUNK_NUM; + + distributor_params_t params; + params.window_num = window_num; + params.regular_bucket_num = regular_bucket_num; + params.special_bucket_num = special_bucket_num; + params.msm_length = msm_length; + params.fixed_regular_padding_size = regular_bucket_size; + params.fixed_special_padding_size = + special_bucket_size * special_duplication_ratio; + params.zero = + new uint32_t[COORDINATE_NUM * CHUNK_NUM](); // Initialize to zero + params.slices_list = new int32_t[window_num * msm_length]; + params.points_list = new uint32_t[msm_length * COORDINATE_NUM * CHUNK_NUM]; + params.neg_points_list = + new uint32_t[msm_length * COORDINATE_NUM * CHUNK_NUM]; + params.regular_buckets = new uint32_t[regular_buffer_size_in_U32]; + params.special_buckets = new uint32_t[special_buffer_size_in_U32]; + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution regular_dist(0, regular_bucket_num); + std::uniform_int_distribution special_dist(0, special_bucket_num); + // Initialize slices_list with random values + for (uint32_t w = 0; w < window_num - 1; ++w) { + for (uint32_t i = 0; i < msm_length; ++i) { + params.slices_list[w * msm_length + i] = regular_dist(gen); + } + } + for (uint32_t i = 0; i < msm_length; ++i) { + params.slices_list[(window_num - 1) * msm_length + i] = special_dist(gen); + } + + return params; +} + +void free_distributor_param(distributor_params_t& params) { + delete[] params.zero; + delete[] params.slices_list; + delete[] params.points_list; + delete[] params.neg_points_list; + delete[] params.regular_buckets; + delete[] params.special_buckets; +} + +int main_old() { + // Example usage of the Distributor class + uint32_t slice_length = 10; // Example slice length + uint32_t msm_length = 1 << 16; // Example MSM length + double buf_extend_ratio = 1.1; // Example buffer extend ratio + bool signed_bucket = false; // Example signed bucket flag + uint32_t repeat = 50; + double duration_us = 0; + // printing the parameters + std::cout << "Slice Length: " << slice_length << std::endl; + std::cout << "MSM Length: " << msm_length << std::endl; + std::cout << "Buffer Extend Ratio: " << buf_extend_ratio << std::endl; + std::cout << "Signed Bucket: " << (signed_bucket ? "true" : "false") + << std::endl; + + for (uint32_t i = 0; i < repeat; ++i) { + std::cout << "Run " << i + 1 << " of " << repeat << std::endl; + distributor_params_t params = init_distributor_param( + slice_length, msm_length, buf_extend_ratio, signed_bucket); + + auto start_time = std::chrono::high_resolution_clock::now(); + + Distributor* distributor = + new Distributor(params.window_num, params.regular_bucket_num, + params.special_bucket_num, params.msm_length); + + distributor->set_fixed_padding_size(params.fixed_regular_padding_size, + params.fixed_special_padding_size); + distributor->set_slices_list(params.slices_list); + distributor->set_points_list(params.points_list); + distributor->set_neg_points_list(params.neg_points_list); + distributor->set_zeros(params.zero); + distributor->set_output_buffers(params.regular_buckets, + params.special_buckets); + + distributor->distribute_to_buffer_parallel(); + auto end_time = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast( + end_time - start_time) + .count(); + std::cout << "Distribution time: " << duration << " us" << std::endl; + duration_us += duration; + + // Free allocated memory + free_distributor_param(params); + delete distributor; + } + std::cout << "Average distribution time: " << (duration_us / repeat) << " us" + << std::endl; + + return 0; +} + +// Dummy performance test function (replace with your real test) +double run_performance_test(uint32_t slice_length, uint32_t msm_length, + double buf_extend_ratio, bool signed_bucket) { + distributor_params_t params = init_distributor_param( + slice_length, msm_length, buf_extend_ratio, signed_bucket); + + auto start_time = std::chrono::high_resolution_clock::now(); + + Distributor distributor = + Distributor(params.window_num, params.regular_bucket_num, + params.special_bucket_num, params.msm_length); + + distributor.set_fixed_padding_size(params.fixed_regular_padding_size, + params.fixed_special_padding_size); + distributor.set_slices_list(params.slices_list); + distributor.set_points_list(params.points_list); + distributor.set_neg_points_list(params.neg_points_list); + distributor.set_zeros(params.zero); + distributor.set_output_buffers(params.regular_buckets, + params.special_buckets); + + distributor.distribute_to_buffer_parallel(); + auto end_time = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast( + end_time - start_time) + .count(); + // std::cout << "Distribution time: " << duration << " us" << std::endl; + // Free allocated memory + free_distributor_param(params); + return static_cast(duration); +} + +// Compute statistics +void compute_stats(const std::vector& data, double& min, double& max, + double& avg, double& stddev) { + if (data.empty()) { + min = max = avg = stddev = 0.0; + return; + } + min = *std::min_element(data.begin(), data.end()); + max = *std::max_element(data.begin(), data.end()); + avg = std::accumulate(data.begin(), data.end(), 0.0) / data.size(); + double sum_sq = 0.0; + for (double v : data) sum_sq += (v - avg) * (v - avg); + stddev = std::sqrt(sum_sq / data.size()); +} + +int main() { + std::ofstream csv("C_kernel_perf_in_C.csv"); + int trial_num = 50; + + csv << "slice_length,log2_msm_length,msm_length,buf_extend_ratio,signed_" + "bucket,min,max,avg(us),stddev\n"; + + std::vector slice_lengths = {6, 8, 10, 12, 14}; + std::vector log2_msm_lengths = {10, 12, 14, + 16, 18, 20}; // log2(msm_length) + std::vector buf_extend_ratios = {1.1}; + std::vector signed_buckets = {false, true}; + + for (uint32_t slice_length : slice_lengths) { + for (uint32_t log2_msm_length : log2_msm_lengths) { + uint32_t msm_length = 1U << log2_msm_length; + for (double buf_extend_ratio : buf_extend_ratios) { + for (bool signed_bucket : signed_buckets) { + if (slice_length >= log2_msm_length) { + // std::cerr << "Error: slice_length must be less than + // log2_msm_length." << std::endl; + continue; + } + std::cout << "Testing slice_length: " << slice_length + << ", log2_msm_length: " << log2_msm_length + << ", buf_extend_ratio: " << buf_extend_ratio + << ", signed_bucket: " << (signed_bucket ? "true" : "false") + << std::endl; + // Run the test multiple times for statistics + std::vector results; + for (int trial = 0; trial < trial_num; ++trial) { + double perf = run_performance_test(slice_length, msm_length, + buf_extend_ratio, signed_bucket); + results.push_back(perf); + } + double min, max, avg, stddev; + compute_stats(results, min, max, avg, stddev); + + csv << slice_length << "," << log2_msm_length << "," << msm_length + << "," << buf_extend_ratio << "," << signed_bucket << "," << min + << "," << max << "," << avg << "," << stddev << "\n"; + } + } + } + } + + csv.close(); + std::cout << "Performance results written to performance_results.csv\n"; + return 0; +} + +namespace ffi = xla::ffi; +ffi::Error DistributeImpl(uint32_t window_num, uint32_t regular_bucket_num, + uint32_t special_bucket_num, uint32_t msm_length, + uint32_t fixed_regular_padding_size, + uint32_t fixed_special_padding_size, + ffi::Buffer slices_list, + ffi::Buffer points_list, + ffi::Buffer zero, + ffi::ResultBuffer regular_buckets, + ffi::ResultBuffer special_buckets, + ffi::ResultBuffer metadata) { + // Create a Distributor object + Distributor distributor(window_num, regular_bucket_num, special_bucket_num, + msm_length); + distributor.set_fixed_padding_size(fixed_regular_padding_size, + fixed_special_padding_size); + distributor.set_slices_list(slices_list.typed_data()); + distributor.set_points_list(points_list.typed_data()); + distributor.set_zeros(zero.typed_data()); + + distributor.distribute(); + + distributor.get_merged_regular_buckets(regular_buckets->typed_data()); + distributor.get_merged_special_buckets(special_buckets->typed_data()); + distributor.get_merged_metadata(metadata->typed_data()); + // Return success + return xla::ffi::Error::Success(); +} + +ffi::Error DistributeBufImpl(uint32_t window_num, uint32_t regular_bucket_num, + uint32_t special_bucket_num, uint32_t msm_length, + uint32_t fixed_regular_padding_size, + uint32_t fixed_special_padding_size, + ffi::Buffer slices_list, + ffi::Buffer points_list, + ffi::Buffer zero, + ffi::ResultBuffer regular_buckets, + ffi::ResultBuffer special_buckets, + ffi::ResultBuffer metadata) { + // Create a Distributor object + Distributor distributor(window_num, regular_bucket_num, special_bucket_num, + msm_length); + distributor.set_fixed_padding_size(fixed_regular_padding_size, + fixed_special_padding_size); + distributor.set_slices_list(slices_list.typed_data()); + distributor.set_points_list(points_list.typed_data()); + distributor.set_zeros(zero.typed_data()); + distributor.set_output_buffers(regular_buckets->typed_data(), + special_buckets->typed_data()); + + distributor.distribute_to_buffer_parallel(); + + // Return success + return ffi::Error::Success(); +} + +ffi::Error DistributeBufSignedImpl( + uint32_t window_num, uint32_t regular_bucket_num, + uint32_t special_bucket_num, uint32_t msm_length, + uint32_t fixed_regular_padding_size, uint32_t fixed_special_padding_size, + ffi::Buffer slices_list, ffi::Buffer points_list, + ffi::Buffer neg_points_list, ffi::Buffer zero, + ffi::ResultBuffer regular_buckets, + ffi::ResultBuffer special_buckets, + ffi::ResultBuffer metadata) { + // Create a Distributor object + Distributor distributor(window_num, regular_bucket_num, special_bucket_num, + msm_length); + distributor.set_fixed_padding_size(fixed_regular_padding_size, + fixed_special_padding_size); + distributor.set_slices_list(slices_list.typed_data()); + distributor.set_points_list(points_list.typed_data()); + distributor.set_neg_points_list(neg_points_list.typed_data()); + distributor.set_zeros(zero.typed_data()); + distributor.set_output_buffers(regular_buckets->typed_data(), + special_buckets->typed_data()); + + distributor.distribute_to_buffer_signed_parallel(); + + // Return success + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + Distribute, DistributeImpl, + ffi::Ffi::Bind() + .Attr("window_num") + .Attr("regular_bucket_num") + .Attr("special_bucket_num") + .Attr("msm_length") + .Attr("fixed_regular_padding_size") + .Attr("fixed_special_padding_size") + .Arg>() // slices_list + .Arg>() // points_list + .Arg>() // zero + .Ret>() // regular_buckets + .Ret>() // special_buckets + .Ret>() // metadata +); + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + DistributeBuf, DistributeBufImpl, + ffi::Ffi::Bind() + .Attr("window_num") + .Attr("regular_bucket_num") + .Attr("special_bucket_num") + .Attr("msm_length") + .Attr("fixed_regular_padding_size") + .Attr("fixed_special_padding_size") + .Arg>() // slices_list + .Arg>() // points_list + .Arg>() // zero + .Ret>() // regular_buckets + .Ret>() // special_buckets + .Ret>() // metadata +); + +ffi::Error DistributeBufZeroImpl( + uint32_t window_num, uint32_t regular_bucket_num, + uint32_t special_bucket_num, uint32_t msm_length, + uint32_t fixed_regular_padding_size, uint32_t fixed_special_padding_size, + ffi::Buffer slices_list, ffi::Buffer points_list, + ffi::Buffer zero, ffi::ResultBuffer regular_buckets, + ffi::ResultBuffer special_buckets, + ffi::ResultBuffer metadata) { + Distributor distributor(window_num, regular_bucket_num, special_bucket_num, + msm_length); + distributor.set_fixed_padding_size(fixed_regular_padding_size, + fixed_special_padding_size); + distributor.set_slices_list(slices_list.typed_data()); + distributor.set_points_list(points_list.typed_data()); + distributor.set_zeros(zero.typed_data()); + distributor.set_output_buffers(regular_buckets->typed_data(), + special_buckets->typed_data()); + + distributor.distribute_to_buffer_zero_indexed_parallel(); + + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + DistributeBufZero, DistributeBufZeroImpl, + ffi::Ffi::Bind() + .Attr("window_num") + .Attr("regular_bucket_num") + .Attr("special_bucket_num") + .Attr("msm_length") + .Attr("fixed_regular_padding_size") + .Attr("fixed_special_padding_size") + .Arg>() // slices_list + .Arg>() // points_list + .Arg>() // zero + .Ret>() // regular_buckets + .Ret>() // special_buckets + .Ret>() // metadata +); + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + DistributeBufSigned, DistributeBufSignedImpl, + ffi::Ffi::Bind() + .Attr("window_num") + .Attr("regular_bucket_num") + .Attr("special_bucket_num") + .Attr("msm_length") + .Attr("fixed_regular_padding_size") + .Attr("fixed_special_padding_size") + .Arg>() // slices_list + .Arg>() // points_list + .Arg>() // neg_points_list + .Arg>() // zero + .Ret>() // regular_buckets + .Ret>() // special_buckets + .Ret>() // metadata +); diff --git a/jaxite_ec/configurations.toml b/jaxite_ec/configurations.toml new file mode 100644 index 0000000..bec8438 --- /dev/null +++ b/jaxite_ec/configurations.toml @@ -0,0 +1,30 @@ +######################################################## +# General +######################################################## +serialized_jax_kernel_dir = "./deployments/" +hash_length = 4 # length of hash strings used in file names (alphanumeric, base-62) + + + +######################################################## +# Elliptic Curve Parameters +######################################################## +[ec_parameters_bls12_377_affine] +prime = 0x01AE3A4617C510EAC63B05C06CA1493B1A22D9F300F5138F1EF3622FBA094800170B5D44300000008508C00000000001 +order = 0x12AB655E9A2CA55660B44D1E5C37B00159AA76FED00000010A11800000000001 +a = 0 +b = 1 +generator = [0x008848DEFE740A67C8FC6225BF87FF5485951E2CAA9D41BB188282C8BD37CB5CD5481512FFCD394EEAB9B16EB21BE9EF, 0x01914A69C5102EFF1F674F5D30AFEEC4BD7FB348CA3E52D96D182AD44FB82305C2FE3D3634A9591AFD82DE55559C8EA6] + +[ec_parameters_bls12_377_extended_twisted_edwards] +prime = 258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177 +order = 8444461749428370424248824938781546531375899335154063827935233455917409239041 +a = -1 +d = 122268283598675559488486339158635529096981886914877139579534153582033676785385790730042363341236035746924960903179 +alpha = -1 +b = 1 +s = 10189023633222963290707194929886294091415157242906428298294512798502806398782149227503530278436336312243746741931 +MA = 228097355113300204138531148905234651262148041026195375645000724271212049151994375092458297304264351187709081232384 +MB = 10189023633222963290707194929886294091415157242906428298294512798502806398782149227503530278436336312243746741931 +t = 23560188534917577818843641916571445935985386319233886518929971599490231428764380923487987729215299304184915158756 +generator = [71222569531709137229370268896323705690285216175189308202338047559628438110820800641278662592954630774340654489393, 6177051365529633638563236407038680211609544222665285371549726196884440490905471891908272386851767077598415378235] \ No newline at end of file diff --git a/jaxite_ec/test_case/t1024/zprize_msm_curve_377_bases_dim_1024_seed_0.csv b/jaxite_ec/data/t1024/zprize_msm_curve_377_bases_dim_1024_seed_0.csv similarity index 99% rename from jaxite_ec/test_case/t1024/zprize_msm_curve_377_bases_dim_1024_seed_0.csv rename to jaxite_ec/data/t1024/zprize_msm_curve_377_bases_dim_1024_seed_0.csv index 1affbca..4d5efb5 100644 --- a/jaxite_ec/test_case/t1024/zprize_msm_curve_377_bases_dim_1024_seed_0.csv +++ b/jaxite_ec/data/t1024/zprize_msm_curve_377_bases_dim_1024_seed_0.csv @@ -1021,4 +1021,4 @@ 1020, x, 012c98c9f7a2a5d1, ea874762b1e8c4a1, 35fbe0fb5894980c, 44f8e64cd96d180c, 85ba13911acb084f, b3b26ad0609738fe, zz: Fp384 "(0188570868A82EB0E0DDD97F55DAD85FF37B7637DAD6A12A80F9C86AB30F2CE853CCB5833FC3394CB83B187B9913FE7E)", y, 015b19304056067c, 908e5c19fe908420, e0f26a2c8fd72c93, 0840903a30687f8a, bccf44f4ea625176, 501ba7f12adc27a8, zz: Fp384 "(017ED2A2320E71894D670B2226830E2411AC8A717D1768A4CBC6B353571220C96496127BC7362B6B9D91373E8CD2681B)" 1021, x, 00f2048bceae505c, fe87dff235238fb0, 4ded13aa23396358, 84d5a9690203a9b5, d0ac634f189abad8, 4cdbb5659e84a52c, zz: Fp384 "(010AC4C7D2713EEC119C408FFAE8D3384DF4E7C4BCB12E001FDC8A7ADE7FC248FCCB7D919281EA13BBB7872B4D95044C)", y, 0098bdfaec3f29a0, 2e58bfc101ea8d74, a557b0596c9a161c, 735751360474b38c, 921c1d87ec7be426, 3d63974b26fa7705, zz: Fp384 "(0145B8E22FD6E243A109F4D2EED6B6E4AC4E59EB640A2494209777B50D5321A186F0A301EB762D5DB91401FF0C3D46DB)" 1022, x, 00356c16958fb477, 8ae26c11a4828367, 34db5eb3749d3968, 17041b9560c36b01, d5b3a6e348b8f643, c55b73c307c9329b, zz: Fp384 "(010E4E39A08185356BFCB7A9A2A35B9920389D094C35A80CD139B6E9ACDE311D73FBF42845B9322A6C6D9B52B711522D)", y, 0088810ee70cf6b5, 9d2312a04fa2cc6c, 71a6895532c12549, 4dbb59f3b1f28742, 5de8969809d4048b, 4245fa5364a4b35d, zz: Fp384 "(00C00B0BDF9032FBE826E6F890EAA7FDA47C5165934F6CED340B9C3012636D780E3C7DE8DA62FBC4863EC79E1936F8F5)" -1023, x, 0133064a88647c9a, eef3ffdd08549f02, fcd5b94f6e966386, 96f66ce8af9b507d, 2525c4c66cdde115, 1e52e82abddda9e9, zz: Fp384 "(0103B2D12B9ECAF92133B192BF3DD477160BF1F637EFDC393C31BDB9579F196B1FAD1E4C5AF9FF565AAAACD3A5867659)", y, 0098742883b34a4a, 6290ee8ddccf15b5, 2790a01ac2896430, 5aafe8673f723240, b2149c9d67dced62, 7160bb12c0fc2da8, zz: Fp384 "(015155F210905A6D7700FCC1098F6A9DD895BCADE864008EDA180DB8C32934CAEF5BBD97CC1A49C52E5F593FF31E50A8)" \ No newline at end of file +1023, x, 0133064a88647c9a, eef3ffdd08549f02, fcd5b94f6e966386, 96f66ce8af9b507d, 2525c4c66cdde115, 1e52e82abddda9e9, zz: Fp384 "(0103B2D12B9ECAF92133B192BF3DD477160BF1F637EFDC393C31BDB9579F196B1FAD1E4C5AF9FF565AAAACD3A5867659)", y, 0098742883b34a4a, 6290ee8ddccf15b5, 2790a01ac2896430, 5aafe8673f723240, b2149c9d67dced62, 7160bb12c0fc2da8, zz: Fp384 "(015155F210905A6D7700FCC1098F6A9DD895BCADE864008EDA180DB8C32934CAEF5BBD97CC1A49C52E5F593FF31E50A8)" diff --git a/jaxite_ec/test_case/t1024/zprize_msm_curve_377_res_dim_1024_seed_0.csv b/jaxite_ec/data/t1024/zprize_msm_curve_377_res_dim_1024_seed_0.csv similarity index 86% rename from jaxite_ec/test_case/t1024/zprize_msm_curve_377_res_dim_1024_seed_0.csv rename to jaxite_ec/data/t1024/zprize_msm_curve_377_res_dim_1024_seed_0.csv index d77e7dd..363b1a8 100644 --- a/jaxite_ec/test_case/t1024/zprize_msm_curve_377_res_dim_1024_seed_0.csv +++ b/jaxite_ec/data/t1024/zprize_msm_curve_377_res_dim_1024_seed_0.csv @@ -1 +1 @@ -x, 01543e9a8f6cd81c, 2c5c084d7e5a0ff0, 408dc7edb8382513, 8ddfcfd7249d9b0c, d8e6f994e24cbb4d, 77682f864dc1ce64, zz: Fp384 "(00B5A8DC793FF0EBBD3A0E7170D83EA721FA80ABA6E306D0AF2703EF9762DFB47A80BE839B97E72DAE7C2D0B6757106C)", y, 01729e1ff869e64f, 9ad1b3c12f3f5481, e8f1f4e5a91bed5d, a874614f38e0ded6, 31531869aa1785fb, 4d8e2298435ee696, zz: Fp384 "(00F9A8D085C36B7CD26A6AD1C3DFF2DDA1FE33D4828A527AE6475E1863AE685762B6B9CD59B8FB67C58703D16B340F09)" \ No newline at end of file +x, 01543e9a8f6cd81c, 2c5c084d7e5a0ff0, 408dc7edb8382513, 8ddfcfd7249d9b0c, d8e6f994e24cbb4d, 77682f864dc1ce64, zz: Fp384 "(00B5A8DC793FF0EBBD3A0E7170D83EA721FA80ABA6E306D0AF2703EF9762DFB47A80BE839B97E72DAE7C2D0B6757106C)", y, 01729e1ff869e64f, 9ad1b3c12f3f5481, e8f1f4e5a91bed5d, a874614f38e0ded6, 31531869aa1785fb, 4d8e2298435ee696, zz: Fp384 "(00F9A8D085C36B7CD26A6AD1C3DFF2DDA1FE33D4828A527AE6475E1863AE685762B6B9CD59B8FB67C58703D16B340F09)" diff --git a/jaxite_ec/test_case/t1024/zprize_msm_curve_377_scalars_dim_1024_seed_0.csv b/jaxite_ec/data/t1024/zprize_msm_curve_377_scalars_dim_1024_seed_0.csv similarity index 99% rename from jaxite_ec/test_case/t1024/zprize_msm_curve_377_scalars_dim_1024_seed_0.csv rename to jaxite_ec/data/t1024/zprize_msm_curve_377_scalars_dim_1024_seed_0.csv index bec8106..fa3fff7 100644 --- a/jaxite_ec/test_case/t1024/zprize_msm_curve_377_scalars_dim_1024_seed_0.csv +++ b/jaxite_ec/data/t1024/zprize_msm_curve_377_scalars_dim_1024_seed_0.csv @@ -1021,4 +1021,4 @@ 1020, 0ecccbe98754c222, 03a75f6652abc4b2, 943173a9fdfec684, eda2e4c7a7bf0baf, zz: Fp256 "(105BE4B38C035D1848D9A285173F8716BAB21F59E0371871E048033B5980544F)" 1021, 0c08c6b9a73f213b, 895b7d0f13cd478b, 1ee44d69464688cc, 509d47f371f7185d, zz: Fp256 "(02CA79487BE3365732763ACF0E818B6552FC45A786480EC7BBA27DC26EE93858)" 1022, 0855c8b333f7c1c0, 1b6f4de421fb8a6d, 64d1e921237109a7, ef194ae414988390, zz: Fp256 "(09B817A72336D6D78EBD429DAA73F178A0C10AAE103DEC47401A8ED4F089C9FE)" -1023, 0d30713603e6c901, 56e77dbc8c902e50, 88c1dbdc29e27097, c1516a88192dcbae, zz: Fp256 "(103A3BF5E3AAC3D6C62F18DEC884B7A7CBAF9C7BCEB861C64C5441B2D210BE76)" \ No newline at end of file +1023, 0d30713603e6c901, 56e77dbc8c902e50, 88c1dbdc29e27097, c1516a88192dcbae, zz: Fp256 "(103A3BF5E3AAC3D6C62F18DEC884B7A7CBAF9C7BCEB861C64C5441B2D210BE76)" diff --git a/jaxite_ec/elliptic_curve.py b/jaxite_ec/elliptic_curve.py deleted file mode 100644 index bf846d2..0000000 --- a/jaxite_ec/elliptic_curve.py +++ /dev/null @@ -1,896 +0,0 @@ -"""The jaxite_ec implementation of the Elliptic curve operations on TPU. - -Detailed algorithms come from the following papers: -xyzz: https://www.hyperelliptic.org/EFD/g1p/auto-shortw-xyzz.html -affine: https://www.hyperelliptic.org/EFD/g1p/auto-shortw.html -projective: -https://www.hyperelliptic.org/EFD/g1p/auto-shortw-projective.html#addition-madd-1998-cmo - -A non-TPU version of the same functions can be found in -jaxite_ec/algorithm/elliptic_curve.py - -To test the functionalities of this library, please refer to -jaxite_ec/elliptic_curve_test.py -""" - -import functools - -import jax -import jax.numpy as jnp -from jaxite.jaxite_ec import finite_field -from jaxite.jaxite_ec import util - - -add_3u16 = finite_field.add_3u16 -add_2u16 = finite_field.add_2u16 -sub_2u16 = finite_field.sub_2u16 -cond_sub_2u16 = finite_field.cond_sub_2u16 -cond_sub_mod_u16 = finite_field.cond_sub_mod_u16 -mod_mul_barrett_2u16 = finite_field.mod_mul_barrett_2u16 -mod_mul_lazy_2u16 = finite_field.mod_mul_lazy_2u16 -mod_mul_rns_2u16 = finite_field.mod_mul_rns_2u16 -add_rns_2u16 = finite_field.add_rns_2u16 -add_rns_3u16 = finite_field.add_rns_3u16 -add_sub_rns_var = finite_field.add_sub_rns_var -negate_rns_for_var_add = finite_field.negate_rns_for_var_add -negate_rns_for_var_add_zero_check = ( - finite_field.negate_rns_for_var_add_zero_check -) -rns_constant = finite_field.rns_constant - - -# Barrett Reduction Based Functions -@jax.named_call -def padd_barret_xyzz( - x1: jax.Array, - y1: jax.Array, - zz1: jax.Array, - zzz1: jax.Array, - x2: jax.Array, - y2: jax.Array, - zz2: jax.Array, - zzz2: jax.Array, -): - """PADD-BARRETT elliptic curve operation with packed arguments. - - As for the algorithm, pls refer to - jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassXYZZ::add_general - - This function implements the PADD-BARRETT elliptic curve operation with packed - arguments, which is used to compute the elliptic curve points of a given - group. - - Args: - x1: The first generator element. - y1: The second generator element. - zz1: The third generator element. - zzz1: The third generator element. - x2: The first generator element. - y2: The second generator element. - zz2: The third generator element. - zzz2: The third generator element. - - Returns: - A tuple containing the third generator element and the elliptic curve points - of the group. - """ - u1 = mod_mul_barrett_2u16(x1, zz2) - u2 = mod_mul_barrett_2u16(x2, zz1) - s1 = mod_mul_barrett_2u16(y1, zzz2) - s2 = mod_mul_barrett_2u16(y2, zzz1) - zz1_zz2 = mod_mul_barrett_2u16(zz1, zz2) - zzz1_zzz2 = mod_mul_barrett_2u16(zzz1, zzz2) - - p = cond_sub_2u16(u2, u1) - r = cond_sub_2u16(s2, s1) - - pp = mod_mul_barrett_2u16(p, p) - rr = mod_mul_barrett_2u16(r, r) - - ppp = mod_mul_barrett_2u16(pp, p) - q = mod_mul_barrett_2u16(u1, pp) - zz3 = mod_mul_barrett_2u16(zz1_zz2, pp) - - ppp_q_2 = add_3u16(ppp, q, q) - ppp_q_2 = cond_sub_mod_u16(ppp_q_2) - ppp_q_2 = cond_sub_mod_u16(ppp_q_2) - - x3 = cond_sub_2u16(rr, ppp_q_2) - - q_x3 = cond_sub_2u16(q, x3) - s1_ppp = mod_mul_barrett_2u16(s1, ppp) - zzz3 = mod_mul_barrett_2u16(zzz1_zzz2, ppp) - - y3 = mod_mul_barrett_2u16(r, q_x3) - y3 = cond_sub_2u16(y3, s1_ppp) - - return jnp.array([x3, y3, zz3, zzz3]) - - -@jax.named_call -def pdul_barret_xyzz( - x1: jax.Array, y1: jax.Array, zz1: jax.Array, zzz1: jax.Array -): - """PDUL-BARRET elliptic curve operation with packed arguments. - - As for the algorithm, pls refer to - jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassXYZZ::double_general - - This function implements the PDUL-BARRET elliptic curve operation with packed - arguments, which is used to compute the elliptic curve points of a given - group. - - Args: - x1: The first generator element. - y1: The second generator element. - zz1: The third generator element. - zzz1: The third generator element. - - Returns: - A tuple containing the third generator element and the elliptic curve points - of the group. - """ - u = add_2u16(y1, y1) - u = cond_sub_mod_u16(u) - - x1x1 = mod_mul_barrett_2u16(x1, x1) - v = mod_mul_barrett_2u16(u, u) - - w = mod_mul_barrett_2u16(u, v) - s = mod_mul_barrett_2u16(x1, v) - - s_2 = add_2u16(s, s) - s_2 = cond_sub_mod_u16(s_2) - - m = add_3u16(x1x1, x1x1, x1x1) - m = cond_sub_mod_u16(m) - m = cond_sub_mod_u16(m) - - mm = mod_mul_barrett_2u16(m, m) - w_y1 = mod_mul_barrett_2u16(w, y1) - zz3 = mod_mul_barrett_2u16(v, zz1) - zzz3 = mod_mul_barrett_2u16(w, zzz1) - - x3 = cond_sub_2u16(mm, s_2) - - s_x3 = cond_sub_2u16(s, x3) - - y3 = mod_mul_barrett_2u16(m, s_x3) - y3 = cond_sub_2u16(y3, w_y1) - - return jnp.array([x3, y3, zz3, zzz3]) - - -@jax.named_call -def pdul_barrett_xyzz_pack(x1_y1_zz1_zzz1: jax.Array): - return pdul_barret_xyzz( - x1_y1_zz1_zzz1[0], x1_y1_zz1_zzz1[1], x1_y1_zz1_zzz1[2], x1_y1_zz1_zzz1[3] - ) - - -@jax.named_call -def padd_barrett_xyzz_pack( - x1_y1_zz1_zzz1: jax.Array, x2_y2_zz2_zzz2: jax.Array -): - return padd_barret_xyzz( - x1_y1_zz1_zzz1[0], - x1_y1_zz1_zzz1[1], - x1_y1_zz1_zzz1[2], - x1_y1_zz1_zzz1[3], - x2_y2_zz2_zzz2[0], - x2_y2_zz2_zzz2[1], - x2_y2_zz2_zzz2[2], - x2_y2_zz2_zzz2[3], - ) - - -@jax.named_call -def pdul_barrett_xyzz_pack_batch_first( - x1_y1_zz1_zzz1: jax.Array, transpose=(0, 1, 2) -): - return pdul_barret_xyzz( - x1_y1_zz1_zzz1[:, 0], - x1_y1_zz1_zzz1[:, 1], - x1_y1_zz1_zzz1[:, 2], - x1_y1_zz1_zzz1[:, 3], - ).transpose(transpose[0], transpose[1], transpose[2]) - - -@jax.named_call -def padd_barrett_xyzz_pack_batch_first( - x1_y1_zz1_zzz1: jax.Array, x2_y2_zz2_zzz2: jax.Array, transpose=(0, 1, 2) -): - return padd_barret_xyzz( - x1_y1_zz1_zzz1[:, 0], - x1_y1_zz1_zzz1[:, 1], - x1_y1_zz1_zzz1[:, 2], - x1_y1_zz1_zzz1[:, 3], - x2_y2_zz2_zzz2[:, 0], - x2_y2_zz2_zzz2[:, 1], - x2_y2_zz2_zzz2[:, 2], - x2_y2_zz2_zzz2[:, 3], - ).transpose(transpose[0], transpose[1], transpose[2]) - - -# Lazy Reduction Based Functions -@jax.named_call -def padd_lazy_xyzz( - x1: jax.Array, - y1: jax.Array, - zz1: jax.Array, - zzz1: jax.Array, - x2: jax.Array, - y2: jax.Array, - zz2: jax.Array, - zzz2: jax.Array, -): - """PADD-LAZY elliptic curve operation with packed arguments. - - As for the algorithm, pls refer to - jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassXYZZ::add_general - - This function implements the PADD-LAZY elliptic curve operation with packed - arguments, which is used to compute the elliptic curve points of a given - group. - - Args: - x1: The first generator element. - y1: The second generator element. - zz1: The third generator element. - zzz1: The third generator element. - x2: The first generator element. - y2: The second generator element. - zz2: The third generator element. - zzz2: The third generator element. - - Returns: - A tuple containing the third generator element and the elliptic curve points - of the group. - """ - cond_sub_2u16_ext = functools.partial( - cond_sub_2u16, - modulus_377_int_chunk=util.MODULUS_377_S16_INT_CHUNK, - chunk_num_u16=util.U16_EXT_CHUNK_NUM, - ) - cond_sub_mod_u16_ext = functools.partial( - cond_sub_mod_u16, - modulus_377_int_chunk=util.MODULUS_377_S16_INT_CHUNK, - chunk_num_u16=util.U16_EXT_CHUNK_NUM, - ) - - u1 = mod_mul_lazy_2u16(x1, zz2) - u2 = mod_mul_lazy_2u16(x2, zz1) - s1 = mod_mul_lazy_2u16(y1, zzz2) - s2 = mod_mul_lazy_2u16(y2, zzz1) - zz1_zz2 = mod_mul_lazy_2u16(zz1, zz2) - zzz1_zzz2 = mod_mul_lazy_2u16(zzz1, zzz2) - - p = cond_sub_2u16_ext(u2, u1) - r = cond_sub_2u16_ext(s2, s1) - - pp = mod_mul_lazy_2u16(p, p) - rr = mod_mul_lazy_2u16(r, r) - - ppp = mod_mul_lazy_2u16(pp, p) - q = mod_mul_lazy_2u16(u1, pp) - zz3 = mod_mul_lazy_2u16(zz1_zz2, pp) - - # Can be replaced by mod_add_lazy. - ppp_q_2 = add_3u16(ppp, q, q) - ppp_q_2 = cond_sub_mod_u16_ext(ppp_q_2) - ppp_q_2 = cond_sub_mod_u16_ext(ppp_q_2) - - x3 = cond_sub_2u16_ext(rr, ppp_q_2) - - q_x3 = cond_sub_2u16_ext(q, x3) - s1_ppp = mod_mul_lazy_2u16(s1, ppp) - zzz3 = mod_mul_lazy_2u16(zzz1_zzz2, ppp) - - y3 = mod_mul_lazy_2u16(r, q_x3) - y3 = cond_sub_2u16_ext(y3, s1_ppp) - - return jnp.array([x3, y3, zz3, zzz3]) - - -@jax.named_call -def pdul_lazy_xyzz( - x1: jax.Array, - y1: jax.Array, - zz1: jax.Array, - zzz1: jax.Array, -): - """PDUL-BARRET elliptic curve operation with packed arguments. - - As for the algorithm, pls refer to - jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassXYZZ::double_general - - This function implements the PDUL-BARRET elliptic curve operation with packed - arguments, which is used to compute the elliptic curve points of a given - group. - - Args: - x1: The first generator element. - y1: The second generator element. - zz1: The third generator element. - zzz1: The third generator element. - - Returns: - A tuple containing the third generator element and the elliptic curve points - of the group. - """ - cond_sub_2u16_ext = functools.partial( - cond_sub_2u16, - modulus_377_int_chunk=util.MODULUS_377_S16_INT_CHUNK, - chunk_num_u16=util.U16_EXT_CHUNK_NUM, - ) - cond_sub_mod_u16_ext = functools.partial( - cond_sub_mod_u16, - modulus_377_int_chunk=util.MODULUS_377_S16_INT_CHUNK, - chunk_num_u16=util.U16_EXT_CHUNK_NUM, - ) - u = add_2u16(y1, y1) - u = cond_sub_mod_u16_ext(u) - - x1x1 = mod_mul_lazy_2u16(x1, x1) - v = mod_mul_lazy_2u16(u, u) - - w = mod_mul_lazy_2u16(u, v) - s = mod_mul_lazy_2u16(x1, v) - - s_2 = add_2u16(s, s) - s_2 = cond_sub_mod_u16_ext(s_2) - - m = add_3u16(x1x1, x1x1, x1x1) - m = cond_sub_mod_u16_ext(m) - m = cond_sub_mod_u16_ext(m) - - mm = mod_mul_lazy_2u16(m, m) - w_y1 = mod_mul_lazy_2u16(w, y1) - zz3 = mod_mul_lazy_2u16(v, zz1) - zzz3 = mod_mul_lazy_2u16(w, zzz1) - - x3 = cond_sub_2u16_ext(mm, s_2) - - s_x3 = cond_sub_2u16_ext(s, x3) - - y3 = mod_mul_lazy_2u16(m, s_x3) - y3 = cond_sub_2u16_ext(y3, w_y1) - - return jnp.array([x3, y3, zz3, zzz3]) - - -@jax.named_call -def padd_lazy_xyzz_pack(x1_y1_zz1_zzz1: jax.Array, x2_y2_zz2_zzz2: jax.Array): - return padd_lazy_xyzz( - x1_y1_zz1_zzz1[0], - x1_y1_zz1_zzz1[1], - x1_y1_zz1_zzz1[2], - x1_y1_zz1_zzz1[3], - x2_y2_zz2_zzz2[0], - x2_y2_zz2_zzz2[1], - x2_y2_zz2_zzz2[2], - x2_y2_zz2_zzz2[3], - ) - - -@jax.named_call -def pdul_lazy_xyzz_pack(x1_y1_zz1_zzz1: jax.Array): - return pdul_lazy_xyzz( - x1_y1_zz1_zzz1[0], - x1_y1_zz1_zzz1[1], - x1_y1_zz1_zzz1[2], - x1_y1_zz1_zzz1[3], - ) - - -# Lazy Reduction Based Function -@jax.named_call -@functools.partial(jax.jit, static_argnames="twisted_d_chunk") -def padd_lazy_twisted( - x1: jax.Array, - y1: jax.Array, - z1: jax.Array, - t1: jax.Array, - x2: jax.Array, - y2: jax.Array, - z2: jax.Array, - t2: jax.Array, - twisted_d_chunk=util.TWIST_D_INT_CHUNK, -): - """PADD-LAZY elliptic curve operation with packed arguments. - - As for the algorithm, pls refer to - jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassTwisted::add_general - - This function implements the PADD-LAZY elliptic curve operation with packed - arguments, which is used to compute the elliptic curve points of a given - group. - - Args: - x1: The first generator element. - y1: The second generator element. - z1: The third generator element. - t1: The fourth generator element. - x2: The first generator element. - y2: The second generator element. - z2: The third generator element. - t2: The fourth generator element. - twisted_d_chunk: The twisted d parameter. - - Returns: - A tuple containing the third generator element and the elliptic curve points - of the group. - """ - cond_sub_2u16_ext = functools.partial( - cond_sub_2u16, - modulus_377_int_chunk=util.MODULUS_377_S16_INT_CHUNK, - chunk_num_u16=util.U16_EXT_CHUNK_NUM, - ) - cond_sub_mod_u16_ext = functools.partial( - cond_sub_mod_u16, - modulus_377_int_chunk=util.MODULUS_377_S16_INT_CHUNK, - chunk_num_u16=util.U16_EXT_CHUNK_NUM, - ) - - twisted_d = jnp.asarray(twisted_d_chunk, dtype=jnp.uint16) - twisted_d = jax.lax.broadcast(twisted_d, [x1.shape[0]]) - - a = mod_mul_lazy_2u16(x1, x2) - b = mod_mul_lazy_2u16(y1, y2) - d = mod_mul_lazy_2u16(z1, z2) - c = mod_mul_lazy_2u16(t1, t2) - c = mod_mul_lazy_2u16(c, twisted_d) - - h = add_2u16(a, b) - h = cond_sub_mod_u16_ext(h) - e1 = add_2u16(x1, y1) - e1 = cond_sub_mod_u16_ext(e1) - e2 = add_2u16(x2, y2) - e2 = cond_sub_mod_u16_ext(e2) - e = mod_mul_lazy_2u16(e1, e2) - - e = cond_sub_2u16_ext(e, h) - - f = cond_sub_2u16_ext(d, c) - g = add_2u16(d, c) - g = cond_sub_mod_u16_ext(g) - - x3 = mod_mul_lazy_2u16(e, f) - y3 = mod_mul_lazy_2u16(g, h) - z3 = mod_mul_lazy_2u16(f, g) - t3 = mod_mul_lazy_2u16(e, h) - - return jnp.array([x3, y3, z3, t3]) - - -def pdul_lazy_twisted( - x1: jax.Array, - y1: jax.Array, - z1: jax.Array, - t1: jax.Array, -): - """PDUL-LAZY elliptic curve operation with packed arguments. - - As for the algorithm, pls refer to - jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassTwisted::double_general - - This function implements the PDUL-LAZY elliptic curve operation with packed - arguments, which is used to compute the elliptic curve points of a given - group. - - Args: - x1: The first generator element. - y1: The second generator element. - z1: The third generator element. - t1: The fourth generator element. - - Returns: - A tuple containing the third generator element and the elliptic curve points - of the group. - """ - cond_sub_2u16_ext = functools.partial( - cond_sub_2u16, - modulus_377_int_chunk=util.MODULUS_377_S16_INT_CHUNK, - chunk_num_u16=util.U16_EXT_CHUNK_NUM, - ) - cond_sub_mod_u16_ext = functools.partial( - cond_sub_mod_u16, - modulus_377_int_chunk=util.MODULUS_377_S16_INT_CHUNK, - chunk_num_u16=util.U16_EXT_CHUNK_NUM, - ) - modulus_377_int_array = jnp.asarray( - util.MODULUS_377_S16_INT_CHUNK, jnp.uint16 - ) - - a = mod_mul_lazy_2u16(x1, x1) - b = mod_mul_lazy_2u16(y1, y1) - - ct = mod_mul_lazy_2u16(z1, z1) - ct2 = add_2u16(ct, ct) # - ct2 = cond_sub_mod_u16_ext(ct2) # - - h = add_2u16(a, b) - h = cond_sub_2u16_ext(modulus_377_int_array, h) # - - et = add_2u16(x1, y1) # - et = cond_sub_mod_u16_ext(et) # - e = mod_mul_lazy_2u16(et, et) # - e = add_2u16(e, h) # - e = cond_sub_mod_u16_ext(e) # - - g = cond_sub_2u16_ext(b, a) # - f = cond_sub_2u16_ext(g, ct2) # - x3 = mod_mul_lazy_2u16(e, f) # - y3 = mod_mul_lazy_2u16(g, h) - z3 = mod_mul_lazy_2u16(f, g) - t3 = mod_mul_lazy_2u16(e, h) - return jnp.array([x3, y3, z3, t3]) - - -def pneg_lazy_twisted( - x1: jax.Array, - y1: jax.Array, - z1: jax.Array, - t1: jax.Array, -): - """PDUL-LAZY elliptic curve operation with packed arguments. - - As for the algorithm, pls refer to - jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassTwisted::double_general - - This function implements the PDUL-LAZY elliptic curve operation with packed - arguments, which is used to compute the elliptic curve points of a given - group. - - Args: - x1: The first generator element. - y1: The second generator element. - z1: The third generator element. - t1: The fourth generator element. - - Returns: - A tuple containing the third generator element and the elliptic curve points - of the group. - """ - - modulus_377_int_array = jnp.asarray( - util.MODULUS_377_S16_INT_CHUNK, jnp.uint16 - ) - sub_2u16_ext = functools.partial( - sub_2u16, - chunk_num_u16=util.U16_EXT_CHUNK_NUM, - ) - - x3 = sub_2u16_ext(modulus_377_int_array, x1) - y3 = y1 - z3 = z1 - t3 = sub_2u16_ext(modulus_377_int_array, t1) - - return jnp.array([x3, y3, z3, t3]) - - -def padd_lazy_twisted_pack(x1_y1_z1_t1: jax.Array, x2_y2_z2_t2: jax.Array): - return padd_lazy_twisted( - x1_y1_z1_t1[0], - x1_y1_z1_t1[1], - x1_y1_z1_t1[2], - x1_y1_z1_t1[3], - x2_y2_z2_t2[0], - x2_y2_z2_t2[1], - x2_y2_z2_t2[2], - x2_y2_z2_t2[3], - ) - - -def pdul_lazy_twisted_pack(x1_y1_z1_t1: jax.Array): - return pdul_lazy_twisted( - x1_y1_z1_t1[0], - x1_y1_z1_t1[1], - x1_y1_z1_t1[2], - x1_y1_z1_t1[3], - ) - - -def pneg_lazy_twisted_pack(x1_y1_z1_t1: jax.Array): - return pneg_lazy_twisted( - x1_y1_z1_t1[0], - x1_y1_z1_t1[1], - x1_y1_z1_t1[2], - x1_y1_z1_t1[3], - ) - - -# RNS Based Functions -@jax.named_call -@functools.partial(jax.jit, static_argnames="rns_mat") -def padd_rns_xyzz_pack( - x1_y1_zz1_zzz1: jax.Array, - x2_y2_zz2_zzz2: jax.Array, - rns_mat=util.RNS_MAT, -): - """PADD-RNS elliptic curve operation with packed arguments. - - As for the algorithm, pls refer to - jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassXYZZ::add_general - - This function implements the PADD-RNS elliptic curve operation with packed - arguments, which is used to compute the elliptic curve points of a given - group. - - Args: - x1_y1_zz1_zzz1: The first point. - x2_y2_zz2_zzz2: The second point. - rns_mat: The RNS matrix. - - Returns: - A tuple containing the third generator element and the elliptic curve points - of the group. - """ - - # u1 = x1 * zz2 - # u2 = x2 * zz1 - # s1 = y1 * zzz2 - # s2 = y2 * zzz1 - # zz1_zz2 = zz1 * zz2 - # zzz1_zzz2 = zzz1 * zzz2 - num_moduli = x1_y1_zz1_zzz1.shape[-1] - inputsl = jnp.vstack((x1_y1_zz1_zzz1, x1_y1_zz1_zzz1[2:])).reshape( - -1, num_moduli - ) - inputsr = jnp.vstack( - (x2_y2_zz2_zzz2[2:], x2_y2_zz2_zzz2[2:], x2_y2_zz2_zzz2[:2]) - ).reshape(-1, num_moduli) - outputs = mod_mul_rns_2u16(inputsl, inputsr, rns_mat) - u1, s1, zz1_zz2, zzz1_zzz2, u2, s2 = jnp.vsplit(outputs, 6) - - # p = u2 - u1 - # r = s2 - s1 - p = add_sub_rns_var(u2, negate_rns_for_var_add(u1)) - r = add_sub_rns_var(s2, negate_rns_for_var_add(s1)) - - # pp = p * p - # rr = r * r - pp = mod_mul_rns_2u16(p, p, rns_mat) - rr = mod_mul_rns_2u16(r, r, rns_mat) - - # ppp = p * pp - # q = u1 * pp - # zz3 = zz1_zz2 * pp - inputsl = jnp.vstack((p, u1, zz1_zz2)) - inputsr = jnp.vstack((pp, pp, pp)) - outputs = mod_mul_rns_2u16(inputsl, inputsr, rns_mat) - ppp, q, zz3 = jnp.vsplit(outputs, 3) - - # x3 = r * r - ppp - (q + q) - # q_x3 = q - x3 - x3 = add_sub_rns_var( - rr, - negate_rns_for_var_add(ppp), - negate_rns_for_var_add(q), - negate_rns_for_var_add(q), - ) - q_x3 = add_sub_rns_var(q, negate_rns_for_var_add(x3)) - - # s1_ppp = s1 * ppp - # r_q_x3 = r * q_x3 - # zzz3 = zzz1_zzz2 * ppp - # y3 = r_q_x3 - s1_ppp - inputsl = jnp.vstack((s1, zzz1_zzz2, r)) - inputsr = jnp.vstack((ppp, ppp, q_x3)) - outputs = mod_mul_rns_2u16(inputsl, inputsr, rns_mat) - s1_ppp, zzz3, r_q_x3 = jnp.vsplit(outputs, 3) - - y3 = add_sub_rns_var(r_q_x3, negate_rns_for_var_add(s1_ppp)) - - return jnp.array([x3, y3, zz3, zzz3]) - - -@jax.named_call -@functools.partial(jax.jit, static_argnames="rns_mat") -def pdul_rns_xyzz( - x1: jax.Array, - y1: jax.Array, - zz1: jax.Array, - zzz1: jax.Array, - rns_mat=util.RNS_MAT, -): - """PDUL-RNS elliptic curve operation with packed arguments. - - As for the algorithm, pls refer to - jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassXYZZ::double_general - - This function implements the PDUL-RNS elliptic curve operation with packed - arguments, which is used to compute the elliptic curve points of a given - group. - - Args: - x1: The first generator element. - y1: The second generator element. - zz1: The third generator element. - zzz1: The third generator element. - rns_mat: The RNS matrix. - - Returns: - A tuple containing the third generator element and the elliptic curve points - of the group. - """ - # u = y1 + y1 - u = add_rns_2u16(y1, y1) - - # v = u * u - v = mod_mul_rns_2u16(u, u, rns_mat) - - # x1x1 = x1 * x1 - # w = u * v - # s = x1 * v - inputsl = jnp.vstack((x1, u, x1)) - inputsr = jnp.vstack((x1, v, v)) - output = mod_mul_rns_2u16(inputsl, inputsr, rns_mat) - x1x1, w, s = jnp.vsplit(output, 3) - - # m = (x1x1 + x1x1 + x1x1) + a * (zz1 * zz1), Note: a = 0 - m = add_rns_3u16(x1x1, x1x1, x1x1) - - # mm = m * m - # w_y1 = w * y1 - # zz3 = v * zz1 - # zzz3 = w * zzz1 - inputsl = jnp.vstack((m, w, v, w)) - inputsr = jnp.vstack((m, y1, zz1, zzz1)) - output = mod_mul_rns_2u16(inputsl, inputsr, rns_mat) - mm, w_y1, zz3, zzz3 = jnp.vsplit(output, 4) - - # x3 = mm - (s + s) - x3 = add_sub_rns_var(mm, negate_rns_for_var_add(s), negate_rns_for_var_add(s)) - - # s_x3 = s - x3 - s_x3 = add_sub_rns_var(s, negate_rns_for_var_add(mm), s, s) - - # m_s_x3 = m * s_x3 - w_y1 - m_s_x3 = mod_mul_rns_2u16(m, s_x3, rns_mat) - - # y3 = m_s_x3 - w_y1 - y3 = add_sub_rns_var(m_s_x3, negate_rns_for_var_add(w_y1)) - - return jnp.array([x3, y3, zz3, zzz3]) - - -@jax.named_call -@functools.partial(jax.jit, static_argnames="rns_mat") -def pdul_rns_xyzz_pack(x1_y1_zz1_zzz1: jax.Array, rns_mat=util.RNS_MAT): - return pdul_rns_xyzz( - x1_y1_zz1_zzz1[0], - x1_y1_zz1_zzz1[1], - x1_y1_zz1_zzz1[2], - x1_y1_zz1_zzz1[3], - rns_mat, - ) - - -# RNS Based Functions -@jax.named_call -@functools.partial(jax.jit, static_argnames=("rns_mat", "twist_d")) -def padd_rns_twisted_pack( - x1_y1_zz1_zzz1: jax.Array, - x2_y2_zz2_zzz2: jax.Array, - rns_mat=util.RNS_MAT, - twist_d=util.TWIST_D_RNS, -): - """PADD-RNS elliptic curve operation with packed arguments. - - As for the algorithm, pls refer to - jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassXYZZ::add_general - - This function implements the PADD-RNS elliptic curve operation with packed - arguments, which is used to compute the elliptic curve points of a given - group. - - Args: - x1_y1_zz1_zzz1: The first point. - x2_y2_zz2_zzz2: The second point. - rns_mat: The RNS matrix. - twist_d: curve parameter. - - Returns: - A tuple containing the third generator element and the elliptic curve points - of the group. - """ - twist_d = jnp.array(twist_d, dtype=jnp.uint16) - - inputsl = jnp.vstack(x1_y1_zz1_zzz1) - inputsr = jnp.vstack(x2_y2_zz2_zzz2) - outputs = mod_mul_rns_2u16(inputsl, inputsr, rns_mat) - a, b, d, c = jnp.vsplit(outputs, 4) - - e1 = add_rns_2u16(x1_y1_zz1_zzz1[0], x1_y1_zz1_zzz1[1]) - e2 = add_rns_2u16(x2_y2_zz2_zzz2[0], x2_y2_zz2_zzz2[1]) - twist_d_here = jnp.broadcast_to( - twist_d.reshape(-1, twist_d.shape[0]), c.shape - ) - inputsl = jnp.vstack((e1, c)) - inputsr = jnp.vstack((e2, twist_d_here)) - outputs = mod_mul_rns_2u16(inputsl, inputsr, rns_mat) - e3, c = jnp.vsplit(outputs, 2) - - # Issue happens here - e = add_sub_rns_var( - e3, - negate_rns_for_var_add_zero_check(a), - negate_rns_for_var_add_zero_check(b), - ) - f = add_sub_rns_var(d, negate_rns_for_var_add_zero_check(c)) - g = add_rns_2u16(d, c) - h = add_rns_2u16(a, b) - - inputsl = jnp.vstack((e, g, f, e)) - inputsr = jnp.vstack((f, h, g, h)) - outputs = mod_mul_rns_2u16(inputsl, inputsr, rns_mat) - x3, y3, z3, t3 = jnp.vsplit(outputs, 4) - - return jnp.array([x3, y3, z3, t3]) - - -@jax.named_call -@functools.partial(jax.jit, static_argnames="rns_mat") -def pdul_rns_twisted( - x1: jax.Array, - y1: jax.Array, - z1: jax.Array, - t1: jax.Array, - rns_mat=util.RNS_MAT, -): - """PDUL-RNS elliptic curve operation with packed arguments. - - As for the algorithm, pls refer to - jaxite_ec/algorithm/elliptic_curve.py::ECCSWeierstrassXYZZ::double_general - - This function implements the PDUL-RNS elliptic curve operation with packed - arguments, which is used to compute the elliptic curve points of a given - group. - - Args: - x1: The first generator element. - y1: The second generator element. - z1: The third generator element. - t1: The third generator element. - rns_mat: The RNS matrix. - - Returns: - A tuple containing the third generator element and the elliptic curve points - of the group. - """ - et = add_rns_2u16(x1, y1) - inputsl = jnp.vstack((x1, y1, z1, et)) - outputs = mod_mul_rns_2u16(inputsl, inputsl, rns_mat) - a, b, ct, et2 = jnp.vsplit(outputs, 4) - - e = add_sub_rns_var(et2, negate_rns_for_var_add(a), negate_rns_for_var_add(b)) - g = add_sub_rns_var(b, negate_rns_for_var_add(a)) - f = add_sub_rns_var(g, negate_rns_for_var_add(ct), negate_rns_for_var_add(ct)) - h = add_sub_rns_var(negate_rns_for_var_add(a), negate_rns_for_var_add(b)) - - inputsl = jnp.vstack((e, g, f, e)) - inputsr = jnp.vstack((f, h, g, h)) - outputs = mod_mul_rns_2u16(inputsl, inputsr, rns_mat) - x3, y3, z3, t3 = jnp.vsplit(outputs, 4) - - return jnp.array([x3, y3, z3, t3]) - - -@jax.named_call -@functools.partial(jax.jit, static_argnames="rns_mat") -def pdul_rns_twisted_pack(x1_y1_zz1_zzz1: jax.Array, rns_mat=util.RNS_MAT): - return pdul_rns_twisted( - x1_y1_zz1_zzz1[0], - x1_y1_zz1_zzz1[1], - x1_y1_zz1_zzz1[2], - x1_y1_zz1_zzz1[3], - rns_mat, - ) - - -@jax.named_call -def rns_twist_zero(): - return jnp.array( - [rns_constant(0), rns_constant(1), rns_constant(1), rns_constant(0)] - ) diff --git a/jaxite_ec/elliptic_curve_context.py b/jaxite_ec/elliptic_curve_context.py new file mode 100644 index 0000000..bda2862 --- /dev/null +++ b/jaxite_ec/elliptic_curve_context.py @@ -0,0 +1,818 @@ +import abc +from concurrent.futures import ProcessPoolExecutor +import functools +import multiprocessing +import os +import warnings + +import jax +import jax.numpy as jnp +from jaxite.jaxite_ec import finite_field_context +from jaxite.jaxite_ec import utils +import numpy as np + +FiniteFieldContextBase = finite_field_context.FiniteFieldContextBase +abstractmethod = abc.abstractmethod +ABC = abc.ABC + + +# Use 'forkserver' to avoid JAX multithreading + fork deadlock +_MP_CONTEXT = multiprocessing.get_context("forkserver") + +JaxParameters = utils.JaxParameters +JaxKernelContextBase = utils.JaxKernelContextBase +hash_args = utils.hash_args +pad_jax_array = utils.pad_jax_array +store_jax_executable = utils.store_jax_kernel +load_jax_executable = utils.load_jax_kernel +jax_jit_lower_compile = utils.jax_jit_lower_compile +jax.config.update("jax_enable_x64", True) + + +class EllipticCurveContextBase(ABC): + """Abstract base class defining the interface for finite field operations. + + Subclasses must implement all abstract methods to provide concrete + finite field arithmetic operations. + """ + + @abstractmethod + def __init__(self, parameters: dict): + """Initialize the finite field context. + + Args: + parameters: Configuration dictionary containing field parameters. + """ + self.parameters = parameters + self.prime = parameters.get("prime", None) + assert self.prime is not None, "prime must be provided" + self.zero_point = None + ff_ctx_class = parameters.get("finite_field_context_class", None) + assert ( + ff_ctx_class is not None + ), "finite_field_context_class must be provided" + self.ff_ctx: FiniteFieldContextBase = ff_ctx_class( + parameters.get("finite_field_parameters", {}) + ) + + @abstractmethod + def to_computational_format(self, a) -> jnp.ndarray: + """Convert input to the internal computational representation. + + Args: + a: Input value in standard format. + + Returns: + Value converted to computational format (e.g., Montgomery form). + """ + pass + + @abstractmethod + def to_original_format(self, a): + """Convert from computational format back to standard representation. + + Args: + a: Value in computational format. + + Returns: + Value in standard integer representation. + """ + pass + + @abstractmethod + def point_add(self, a: jnp.ndarray, b: jnp.ndarray): + """Perform point addition: (a + b) + + Args: + a: First operand in computational format. + b: Second operand in computational format. + + Returns: + Product in computational format. + """ + pass + + @abstractmethod + def point_double(self, a: jnp.ndarray): + """Perform point doubling: (2 * a) + + Args: + a: First operand in computational format. + + Returns: + Product in computational format. + """ + pass + + def get_finite_field_context(self) -> FiniteFieldContextBase: + return self.ff_ctx + + def _modular_multiply(self, a: int, b: int) -> int: + return (a * b) % self.prime + + def _modular_reduce(self, a: int) -> int: + return a % self.prime + + def _modular_divide(self, a: int, b: int) -> int: + assert b != 0, "ec divide: b is zero" + b_inv = pow(b, self.prime - 2, self.prime) + return (a * b_inv) % self.prime + + +class CPUWeierstrassAffineContext(EllipticCurveContextBase): + """CPU implementation of Weierstrass affine curve operations. + + This class provides CPU-based implementations for point addition and doubling + on a Weierstrass curve in affine coordinates. + This class is only for private functional testing, not for production use. + """ + + def __init__(self, parameters: dict): + super().__init__(parameters) + # warnings.warn("CPUWeierstrassAffineContext is only for private functional testing, not for production use",UserWarning, stacklevel=2) + + # Curve configuration + self.a = parameters["a"] + self.b = parameters["b"] + + def point_add(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + raise NotImplementedError( + "CPUWeierstrassAffineContext: point_add is not implemented" + ) + + def point_double(self, point: jnp.ndarray) -> jnp.ndarray: + raise NotImplementedError( + "CPUWeierstrassAffineContext: point_double is not implemented" + ) + + def to_computational_format(self, a: list) -> jnp.ndarray: + raise NotImplementedError( + "CPUWeierstrassAffineContext: to_computational_format is not" + " implemented" + ) + + def to_original_format(self, a: jnp.ndarray) -> list: + raise NotImplementedError( + "CPUWeierstrassAffineContext: to_original_format is not implemented" + ) + + def _point_add(self, point_a: list, point_b: list) -> list: + def single_point_add(point_a: list, point_b: list) -> list[int]: + x1, y1 = point_a + x2, y2 = point_b + slope = self._modular_divide( + self._modular_reduce(y2 - y1), self._modular_reduce(x2 - x1) + ) + x3 = self._modular_reduce(self._modular_multiply(slope, slope) - x1 - x2) + y3 = self._modular_reduce( + self._modular_multiply(slope, self._modular_reduce(x1 - x3)) - y1 + ) + return [x3, y3] + + list_depth = utils.nested_list_depth(point_a) + if list_depth == 1: + return single_point_add(point_a, point_b) + elif list_depth == 2: + return [ + single_point_add(point_a_i, point_b_i) + for point_a_i, point_b_i in zip(point_a, point_b) + ] + else: + raise ValueError( + f"Invalid list depth {list_depth} of input for point addition" + ) + + def _point_double(self, point: list) -> list: + def single_point_double(point: list) -> list[int]: + x, y = point + slope = self._modular_divide( + self._modular_reduce(3 * x * x + self.a), self._modular_reduce(2 * y) + ) + x3 = self._modular_reduce(self._modular_multiply(slope, slope) - 2 * x) + y3 = self._modular_reduce(self._modular_multiply(slope, x - x3) - y) + return [x3, y3] + + list_depth = utils.nested_list_depth(point) + if list_depth == 1: + return single_point_double(point) + elif list_depth == 2: + return [single_point_double(point_i) for point_i in point] + else: + raise ValueError("Invalid list depth of input for point doubling") + + +class ExtendedTwistedEdwardsContextBase(EllipticCurveContextBase): + + def __init__(self, parameters: dict): + super().__init__(parameters) + + # Curve configuration + self.a = parameters["a"] + self.twist_d = parameters["twist_d"] + self.alpha = parameters["alpha"] + self.s = parameters["s"] + self.A = parameters["MA"] + self.B = parameters["MB"] + self.t = parameters["t"] + self.k = self.twist_d + self.twist_d + self.zero_point = [0, 1, 1, 0] + + def _twist(self, coordinates: list[int]) -> list[int]: + assert ( + len(coordinates) == 2 + ), "Twisted Edwards coordinates must be of length 2" + x, y = coordinates + # Convert to montgomery (Notel it is ec montgomery not field montgomery) + xm = self._modular_reduce(self.s * (x - self.alpha)) + ym = self._modular_reduce(self.s * y) + # Convert to edwards + if ym == 0: + raise ValueError("ec twist: ym is zero") + xt = self._modular_divide(xm, ym) + + yt_denom = xm + 1 + if yt_denom == 0: + raise ValueError("ec twist: yt_denom is zero") + yt = self._modular_divide(xm - 1, yt_denom) + + xt = self._modular_multiply(xt, self.t) + return [xt, yt] + + def _untwist(self, coordinates: list[int]) -> list[int]: + assert ( + len(coordinates) == 2 + ), "Twisted Edwards coordinates must be of length 2" + xt, yt = coordinates + xt = self._modular_divide(xt, self.t) + # Convert to montgomery + xm = self._modular_divide((1 + yt), (1 - yt)) + ym = self._modular_divide((1 + yt), self._modular_multiply((1 - yt), xt)) + # Convert to weierstrass + x = self._modular_reduce( + self._modular_divide(xm, self.B) + + self._modular_divide(self.A, self._modular_multiply(3, self.B)) + ) + y = self._modular_divide(ym, self.B) + return [x, y] + + def _convert_to_edwards_affine(self, coordinates: list[int]) -> list[int]: + assert ( + len(coordinates) == 4 + ), "Twisted Edwards coordinates must be of length 2" + x, y, z, t = coordinates + z_inv = self._modular_divide(1, z) + x = self._modular_multiply(x, z_inv) + y = self._modular_multiply(y, z_inv) + return [x, y] + + def _convert_to_extended_twisted_edwards( + self, coordinates: list[int] + ) -> list[int]: + assert ( + len(coordinates) == 2 + ), "Twisted Edwards coordinates must be of length 2" + xt, yt = self._twist(coordinates) + return [xt, yt, 1, self._modular_multiply(xt, yt)] + + def _convert_to_weierstrass_affine(self, coordinates: list[int]) -> list[int]: + assert ( + len(coordinates) == 4 + ), "Extended Twisted Edwards coordinates must be of length 4" + affine_coords = self._convert_to_edwards_affine(coordinates) + untwisted_coords = self._untwist(affine_coords) + return untwisted_coords + + +class ExtendedTwistedEdwardsContext( + ExtendedTwistedEdwardsContextBase, JaxKernelContextBase +): + + def __init__(self, parameters: dict): + super().__init__(parameters) + JaxKernelContextBase.__init__(self) + self.jax_parameters = JaxParameters() + self._init_jax_parameters() + + def to_computational_format(self, a: list) -> jnp.ndarray: + list_depth = utils.nested_list_depth(a) + # NOTE: the dimension is (batch, coordinates) + if list_depth == 1: + twisted_coords = self._convert_to_extended_twisted_edwards(a) + elif list_depth == 2: + twisted_coords = [ + self._convert_to_extended_twisted_edwards(a_i) for a_i in a + ] + else: + raise ValueError( + "Invalid list depth of input for converting to extended twisted" + " edwards coordinates" + ) + result = self.ff_ctx.to_computational_format(twisted_coords) + if list_depth == 1: + result = jnp.broadcast_to(result, (result.shape[0], 1, result.shape[1])) + elif list_depth == 2: + result = result.transpose(1, 0, 2) + # NOTE: the computational format dim is (coordinates, batch, precision) + if self.use_sharding: + named_sharding, padded_shape = self.create_named_sharding( + shape=result.shape, axes=[1] + ) + result = pad_jax_array(result, padded_shape) + return result.to_device(named_sharding) + else: + return result.to_device(jax.devices()[0]) + + def to_original_format(self, a: jnp.ndarray) -> list: + dim = a.ndim + # NOTE: the computational format dim is (coordinates, batch, precision) + if dim == 3: + a = a.transpose( + 1, 0, 2 + ) # (coordinates, batch, precision) -> (batch, coordinates, precision) + a = self.ff_ctx.to_original_format(a) + if dim == 2: + affine_coords = self._convert_to_weierstrass_affine(a) + elif dim == 3: + affine_coords = [self._convert_to_weierstrass_affine(a_i) for a_i in a] + else: + raise ValueError( + "Invalid dimension of input for converting to weierstrass affine" + " coordinates" + ) + return affine_coords + + def _init_jax_parameters(self): + self.jax_parameters.set_parameter( + twist_d=self.ff_ctx.to_computational_format(self.twist_d), + ) + + def _point_add( + self, point_a: jnp.ndarray, point_b: jnp.ndarray + ) -> jnp.ndarray: + twist_d = self.jax_parameters.twist_d + + inputsl = point_a + inputsr = point_b + outputs = self.ff_ctx._modular_multiply(inputsl, inputsr) + a, b, d, c = jnp.vsplit(outputs, 4) + # print(a.shape, b.shape, d.shape, c.shape) + + pax, pay, _, _ = jnp.vsplit(point_a, 4) + pbx, pby, _, _ = jnp.vsplit(point_b, 4) + + e1 = self.ff_ctx._modular_add(pax, pay) + e2 = self.ff_ctx._modular_add(pbx, pby) + twist_d_here = jnp.broadcast_to( + twist_d.reshape(-1, twist_d.shape[0]), c.shape + ) + if self.use_sharding: + twist_d_here = jax.sharding.reshard(twist_d_here, jax.typeof(e2).sharding) + inputsl = jnp.concatenate((e1, c), axis=0) + inputsr = jnp.concatenate((e2, twist_d_here), axis=0) + outputs = self.ff_ctx._modular_multiply(inputsl, inputsr) + e3, c = jnp.vsplit(outputs, 2) + + e = self.ff_ctx._modular_subtract(self.ff_ctx._modular_subtract(e3, a), b) + f = self.ff_ctx._modular_subtract(d, c) + g = self.ff_ctx._modular_add(d, c) + h = self.ff_ctx._modular_add(a, b) + + inputsl = jnp.concatenate((e, g, f, e), axis=0) + inputsr = jnp.concatenate((f, h, g, h), axis=0) + # print(inputsl.shape, inputsr.shape) + outputs = self.ff_ctx._modular_multiply(inputsl, inputsr) + return outputs.reshape(4, -1, outputs.shape[-1]) + + def _point_double(self, point: jnp.ndarray) -> jnp.ndarray: + x, y, z, t = point + + et = self.ff_ctx._modular_multiply(x, y) + inputsl = jnp.vstack((x, y, z, et)) + outputs = self.ff_ctx._modular_multiply(inputsl, inputsl) + a, b, ct, et2 = jnp.vsplit(outputs, 4) + + h = self.ff_ctx._modular_negate(self.ff_ctx._modular_add(a, b)) + e = self.ff_ctx._modular_add(et2, h) + g = self.ff_ctx._modular_subtract(b, a) + f = self.ff_ctx._modular_subtract(g, self.ff_ctx._modular_add(ct, ct)) + + inputsl = jnp.vstack((e, g, f, e)) + inputsr = jnp.vstack((f, h, g, h)) + outputs = self.ff_ctx._modular_multiply(inputsl, inputsr) + return outputs.reshape(4, -1, outputs.shape[-1]) + + def point_add(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + + if self.use_compiled_kernels: + kernel_hash = hash_args(a.shape, a.dtype.__str__()) + return self.compiled_kernels[kernel_hash]["point_add"](a, b) + else: + return self._point_add(a, b) + + # TODO: fix the functional bug in point_double + def point_double(self, point: jnp.ndarray) -> jnp.ndarray: + raise ValueError("point_double has logic bug") + shape_dtype_struct = jax.ShapeDtypeStruct(point.shape, point.dtype) + if self.use_compiled_kernels: + return self.compiled_kernels[shape_dtype_struct.__hash__()][ + "point_double" + ](point) + else: + return self._point_double(point) + + def _get_shape_dtype_structs( + self, parameters: dict + ) -> list[jax.ShapeDtypeStruct]: + batch_size = parameters["batch_size"] + num_moduli = self.jax_parameters.twist_d.shape[0] + point_shape = (4, batch_size, num_moduli) + if self.use_sharding: + named_sharding, padded_shape = self.create_named_sharding( + shape=point_shape, axes=[1] + ) + return [ + jax.ShapeDtypeStruct( + padded_shape, jnp.uint32, sharding=named_sharding + ) + ] + return [jax.ShapeDtypeStruct(point_shape, jnp.uint32)] + + def context_hash(self) -> str: + return hash_args( + self.__class__.__name__, + self.ff_ctx.context_hash(), + self.a, + self.twist_d, + self.alpha, + self.s, + self.A, + self.B, + self.t, + self.use_sharding, + ) + + def serialize(self, parameters: dict): + shape_dtype_structs = self._get_shape_dtype_structs(parameters) + kernel_hash = hash_args(self.context_hash(), parameters) + class_name = self.__class__.__name__ + + store_jax_executable( + self._point_add, + shape_dtype_structs[0], + shape_dtype_structs[0], + name=f"{class_name}_point_add_{kernel_hash}", + ) + store_jax_executable( + self._point_double, + shape_dtype_structs[0], + name=f"{class_name}_point_double_{kernel_hash}", + ) + + def compile(self, parameters: dict): + shape_dtype_structs = self._get_shape_dtype_structs(parameters) + kernel_hash = hash_args(self.context_hash(), parameters) + class_name = self.__class__.__name__ + + point_add_kernel = load_jax_executable( + f"{class_name}_point_add_{kernel_hash}" + ) + point_double_kernel = load_jax_executable( + f"{class_name}_point_double_{kernel_hash}" + ) + + if None in [point_add_kernel, point_double_kernel]: + warnings.warn( + f"Not found stored serialized compiled kernels, compiling...", + UserWarning, + stacklevel=2, + ) + + kernel_hash = hash_args( + shape_dtype_structs[0].shape, shape_dtype_structs[0].dtype.__str__() + ) + self.compiled_kernels[kernel_hash] = { + "point_add": ( + point_add_kernel + if point_add_kernel is not None + else jax_jit_lower_compile( + self._point_add, shape_dtype_structs[0], shape_dtype_structs[0] + ) + ), + "point_double": ( + point_double_kernel + if point_double_kernel is not None + else jax_jit_lower_compile( + self._point_double, shape_dtype_structs[0] + ) + ), + } + self.use_compiled_kernels = True + + +def _twist_extend_and_rns_worker( + point, s, alpha, prime, prime_m2, t, rns_moduli, radix_bits +): + """Module-level worker: twist + extend + RNS conversion using gmpy2 (must be picklable).""" + import gmpy2 + + x, y = gmpy2.mpz(point[0]), gmpy2.mpz(point[1]) + xm = (s * (x - alpha)) % prime + ym = (s * y) % prime + if ym == 0: + raise ValueError("ec twist: ym is zero") + xt = (xm * gmpy2.powmod(ym, prime_m2, prime)) % prime + yt_denom = xm + 1 + if yt_denom == 0: + raise ValueError("ec twist: yt_denom is zero") + yt = ((xm - 1) * gmpy2.powmod(yt_denom, prime_m2, prime)) % prime + xt = (xt * t) % prime + coords = [int(xt), int(yt), 1, int((xt * yt) % prime)] + # RNS conversion inline: (a % m) << radix_bits) % m for each coordinate + return [[(((c % m) << radix_bits) % m) for m in rns_moduli] for c in coords] + + +class ExtendedTwistedEdwardsNDContext( + ExtendedTwistedEdwardsContextBase, JaxKernelContextBase +): + """Extended Twisted Edwards context supporting arbitrary batch dimensions. + + Computational format layout: (coordinates=4, *batch_dims, precision) + where batch_dims can be any number of dimensions, e.g.: + - (4, batch, precision) -- 1D batch + - (4, batch1, batch2, precision) -- 2D batch + - (4, batch1, batch2, batch3, precision) -- 3D batch + + Input points are nested lists of [x, y] affine Weierstrass coordinates. + Nesting depth determines batch dimensions: + - [x, y] → (4, 1, precision) + - [[x,y], ...] → (4, batch1, precision) + - [[[x,y], ...], ...] → (4, batch1, batch2, precision) + """ + + def __init__(self, parameters: dict): + super().__init__(parameters) + JaxKernelContextBase.__init__(self) + self.jax_parameters = JaxParameters() + self._init_jax_parameters() + + def to_computational_format(self, a: list) -> jnp.ndarray: + list_depth = utils.nested_list_depth(a) + if list_depth < 1: + raise ValueError(f"Invalid list depth {list_depth} for point conversion") + + # Use parallel processing with gmpy2 for large flat batches of points (depth==2) + # Fuses twist + extend + RNS conversion into one parallel step + _PARALLEL_THRESHOLD = 2048 + if list_depth == 2 and len(a) >= _PARALLEL_THRESHOLD: + import gmpy2 + + ff_ctx = self.ff_ctx + worker = functools.partial( + _twist_extend_and_rns_worker, + s=gmpy2.mpz(self.s), + alpha=gmpy2.mpz(self.alpha), + prime=gmpy2.mpz(self.prime), + prime_m2=gmpy2.mpz(self.prime - 2), + t=gmpy2.mpz(self.t), + rns_moduli=ff_ctx.rns_moduli, + radix_bits=ff_ctx.radix_bits, + ) + num_workers = min(64, os.cpu_count() or 1, max(1, len(a) // 256)) + with ProcessPoolExecutor( + max_workers=num_workers, mp_context=_MP_CONTEXT + ) as pool: + rns_coords = list( + pool.map(worker, a, chunksize=max(1, len(a) // num_workers)) + ) + # rns_coords: list of (4, moduli_num) per point → (N, 4, moduli_num) array + result = jnp.array( + np.array(rns_coords, dtype=np.uint32), dtype=jnp.uint32 + ) + # (N, 4, moduli_num) → (4, N, moduli_num) + result = result.transpose(1, 0, 2) + else: + + def recursive_twist(lst, depth): + if depth == 1: + return self._convert_to_extended_twisted_edwards(lst) + return [recursive_twist(item, depth - 1) for item in lst] + + twisted_coords = recursive_twist(a, list_depth) + + result = self.ff_ctx.to_computational_format(twisted_coords) + + if list_depth == 1: + # (4, precision) → (4, 1, precision) + result = jnp.expand_dims(result, axis=1) + else: + # (*batch_dims, 4, precision) → (4, *batch_dims, precision) + ndim = result.ndim + perm = (ndim - 2,) + tuple(range(ndim - 2)) + (ndim - 1,) + result = result.transpose(perm) + + if self.use_sharding: + shard_axes = list(range(1, min(3, result.ndim - 1))) + if not shard_axes: + shard_axes = [1] + named_sharding, padded_shape = self.create_named_sharding( + shape=result.shape, axes=shard_axes + ) + result = pad_jax_array(result, padded_shape) + return result.to_device(named_sharding) + else: + return result.to_device(jax.devices()[0]) + + def to_original_format(self, a: jnp.ndarray) -> list: + ndim = a.ndim + if ndim < 2: + raise ValueError(f"Expected at least 2D array, got {ndim}D") + + if ndim == 2: + a_orig = self.ff_ctx.to_original_format(a) + return self._convert_to_weierstrass_affine(a_orig) + + # (4, *batch_dims, precision) → (*batch_dims, 4, precision) + perm = tuple(range(1, ndim - 1)) + (0, ndim - 1) + a = a.transpose(perm) + a_orig = self.ff_ctx.to_original_format(a) + + batch_depth = ndim - 2 + + def recursive_untwist(lst, depth): + if depth == 0: + return self._convert_to_weierstrass_affine(lst) + return [recursive_untwist(item, depth - 1) for item in lst] + + return recursive_untwist(a_orig, batch_depth) + + def _init_jax_parameters(self): + self.jax_parameters.set_parameter( + twist_d=self.ff_ctx.to_computational_format(self.twist_d), + ) + + def _point_add( + self, point_a: jnp.ndarray, point_b: jnp.ndarray + ) -> jnp.ndarray: + twist_d = self.jax_parameters.twist_d + + inputsl = point_a + inputsr = point_b + outputs = self.ff_ctx._modular_multiply(inputsl, inputsr) + a, b, d, c = jnp.vsplit(outputs, 4) + + pax, pay, _, _ = jnp.vsplit(point_a, 4) + pbx, pby, _, _ = jnp.vsplit(point_b, 4) + + e1 = self.ff_ctx._modular_add(pax, pay) + e2 = self.ff_ctx._modular_add(pbx, pby) + twist_d_here = jnp.broadcast_to(twist_d, c.shape) + try: + _e2_sh = jax.typeof(e2).sharding + if _e2_sh is not None: + twist_d_here = jax.sharding.reshard(twist_d_here, _e2_sh) + except Exception: + pass + inputsl = jnp.concatenate((e1, c), axis=0) + inputsr = jnp.concatenate((e2, twist_d_here), axis=0) + outputs = self.ff_ctx._modular_multiply(inputsl, inputsr) + e3, c = jnp.vsplit(outputs, 2) + + e = self.ff_ctx._modular_subtract(self.ff_ctx._modular_subtract(e3, a), b) + f = self.ff_ctx._modular_subtract(d, c) + g = self.ff_ctx._modular_add(d, c) + h = self.ff_ctx._modular_add(a, b) + + inputsl = jnp.concatenate((e, g, f, e), axis=0) + inputsr = jnp.concatenate((f, h, g, h), axis=0) + outputs = self.ff_ctx._modular_multiply(inputsl, inputsr) + return outputs + + def _point_double(self, point: jnp.ndarray) -> jnp.ndarray: + original_shape = point.shape + x, y, z, t = point + + et = self.ff_ctx._modular_multiply(x, y) + inputsl = jnp.vstack((x, y, z, et)) + outputs = self.ff_ctx._modular_multiply(inputsl, inputsl) + a, b, ct, et2 = jnp.vsplit(outputs, 4) + + h = self.ff_ctx._modular_negate(self.ff_ctx._modular_add(a, b)) + e = self.ff_ctx._modular_add(et2, h) + g = self.ff_ctx._modular_subtract(b, a) + f = self.ff_ctx._modular_subtract(g, self.ff_ctx._modular_add(ct, ct)) + + inputsl = jnp.vstack((e, g, f, e)) + inputsr = jnp.vstack((f, h, g, h)) + outputs = self.ff_ctx._modular_multiply(inputsl, inputsr) + return outputs.reshape(original_shape) + + def point_add(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + if self.use_compiled_kernels: + kernel_hash = hash_args(a.shape, a.dtype.__str__()) + return self.compiled_kernels[kernel_hash]["point_add"](a, b) + else: + return self._point_add(a, b) + + # TODO: fix the functional bug in point_double + def point_double(self, point: jnp.ndarray) -> jnp.ndarray: + raise ValueError("point_double has logic bug") + shape_dtype_struct = jax.ShapeDtypeStruct(point.shape, point.dtype) + if self.use_compiled_kernels: + return self.compiled_kernels[shape_dtype_struct.__hash__()][ + "point_double" + ](point) + else: + return self._point_double(point) + + def _get_shape_dtype_structs( + self, parameters: dict + ) -> list[jax.ShapeDtypeStruct]: + batch_shape = parameters.get("batch_shape", None) + if batch_shape is None: + batch_shape = (parameters["batch_size"],) + num_moduli = self.jax_parameters.twist_d.shape[0] + point_shape = (4,) + tuple(batch_shape) + (num_moduli,) + if self.use_sharding: + shard_axes = list(range(1, min(3, len(point_shape) - 1))) + if not shard_axes: + shard_axes = [1] + named_sharding, padded_shape = self.create_named_sharding( + shape=point_shape, axes=shard_axes + ) + return [ + jax.ShapeDtypeStruct( + padded_shape, jnp.uint32, sharding=named_sharding + ) + ] + return [jax.ShapeDtypeStruct(point_shape, jnp.uint32)] + + def context_hash(self) -> str: + return hash_args( + self.__class__.__name__, + self.ff_ctx.context_hash(), + self.a, + self.twist_d, + self.alpha, + self.s, + self.A, + self.B, + self.t, + self.use_sharding, + ) + + def serialize(self, parameters: dict): + shape_dtype_structs = self._get_shape_dtype_structs(parameters) + kernel_hash = hash_args(self.context_hash(), parameters) + class_name = self.__class__.__name__ + + store_jax_executable( + self._point_add, + shape_dtype_structs[0], + shape_dtype_structs[0], + name=f"{class_name}_point_add_{kernel_hash}", + ) + store_jax_executable( + self._point_double, + shape_dtype_structs[0], + name=f"{class_name}_point_double_{kernel_hash}", + ) + + def compile(self, parameters: dict): + shape_dtype_structs = self._get_shape_dtype_structs(parameters) + kernel_hash = hash_args(self.context_hash(), parameters) + class_name = self.__class__.__name__ + + point_add_kernel = load_jax_executable( + f"{class_name}_point_add_{kernel_hash}" + ) + point_double_kernel = load_jax_executable( + f"{class_name}_point_double_{kernel_hash}" + ) + + if None in [point_add_kernel, point_double_kernel]: + warnings.warn( + f"Not found stored serialized compiled kernels, compiling...", + UserWarning, + stacklevel=2, + ) + + kernel_hash = hash_args( + shape_dtype_structs[0].shape, shape_dtype_structs[0].dtype.__str__() + ) + self.compiled_kernels[kernel_hash] = { + "point_add": ( + point_add_kernel + if point_add_kernel is not None + else jax_jit_lower_compile( + self._point_add, shape_dtype_structs[0], shape_dtype_structs[0] + ) + ), + "point_double": ( + point_double_kernel + if point_double_kernel is not None + else jax_jit_lower_compile( + self._point_double, shape_dtype_structs[0] + ) + ), + } + self.use_compiled_kernels = True diff --git a/jaxite_ec/elliptic_curve_perf_test.py b/jaxite_ec/elliptic_curve_perf_test.py new file mode 100644 index 0000000..5007608 --- /dev/null +++ b/jaxite_ec/elliptic_curve_perf_test.py @@ -0,0 +1,73 @@ +import os + +import jax +from jaxite.jaxite_ec import elliptic_curve_context as ec_context +from jaxite.jaxite_ec import finite_field_context as ff_context +from jaxite.jaxite_ec import utils +import toml + +from absl.testing import absltest +from absl.testing import parameterized + +jax.config.update("jax_enable_x64", True) + +CONFIG_PATH = os.path.join(os.path.dirname(__file__), "configurations.toml") + +BATCH_SIZE_LIST = [128, 256, 512, 1024, 2048, 4096] + +NUM_MODULI = 32 + +TEST_PARAMS_POINT_ADD = [ + ("point_add", BATCH_SIZE_LIST), +] + +TEST_PARAMS_POINT_DOUBLE = [ + ("point_double", BATCH_SIZE_LIST), +] + + +def _build_ec_context(): + ec_config = toml.load(CONFIG_PATH) + rns_moduli = utils.find_moduli_specified_number(NUM_MODULI, 28) + finite_field_parameters = { + "prime": ec_config["ec_parameters_bls12_377_affine"]["prime"], + "rns_moduli": rns_moduli, + "precision_bits": 28, + "radix_bits": 32, + } + ete_cfg = ec_config["ec_parameters_bls12_377_extended_twisted_edwards"] + ec_parameters = { + "finite_field_context_class": ff_context.DRNSlazyContext, + "finite_field_parameters": finite_field_parameters, + "prime": ete_cfg["prime"], + "order": ete_cfg["order"], + "a": ete_cfg["a"], + "twist_d": ete_cfg["d"], + "alpha": ete_cfg["alpha"], + "b": ete_cfg["b"], + "s": ete_cfg["s"], + "MA": ete_cfg["MA"], + "MB": ete_cfg["MB"], + "t": ete_cfg["t"], + "generator": ete_cfg["generator"], + } + return ec_context.ExtendedTwistedEdwardsContext(ec_parameters) + + +def _point_add_kernel(point_a, point_b, parameters): + return parameters["ctx"]._point_add(point_a, point_b) + + +def _point_double_kernel(point, parameters): + return parameters["ctx"]._point_double(point) + + +class ECPointAddPerformanceTest(parameterized.TestCase): + + @parameterized.named_parameters(*TEST_PARAMS_POINT_ADD) + def test_point_add_performance(self, batch_size_list): + pass + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxite_ec/elliptic_curve_test.py b/jaxite_ec/elliptic_curve_test.py index 4b09116..10a3410 100644 --- a/jaxite_ec/elliptic_curve_test.py +++ b/jaxite_ec/elliptic_curve_test.py @@ -1,939 +1,98 @@ -import functools +import os -import jax -import jax.numpy as jnp -from jaxite.jaxite_ec import util -from jaxite.jaxite_ec.algorithm import config_file -import jaxite.jaxite_ec.algorithm.elliptic_curve as ec -import jaxite.jaxite_ec.elliptic_curve as jec +from jaxite.jaxite_ec import elliptic_curve_context as ec_context +from jaxite.jaxite_ec import finite_field_context as ff_context +from jaxite.jaxite_ec import utils +import numpy as np +import toml from absl.testing import absltest +from absl.testing import parameterized + +# NOTE: Ensure all tests point are on the curve +BLS12_377_TEST_CASES = [ + ( + "0", + [ + [ + 0x01AC3A384FC584EFD3E7F2C5A2927E7D454875C874A051027B9E7363D08942533EDE85DAE295D8CAB2751085206BCA76, + 0x011DB83AEC88460820F4868A73B12309EE2E910526E62DB4ACCB303ABF50F86C3985A072ED07A4B81FFB82D8DD247283, + ], + [ + 0x0164DDDBF27670CE389E2992C0E7DAB7741F1B925EDBDC254D2BC0830BAF8E0B186F80F0DD4DE0F0EA6176E55934D45B, + 0x01908E9D77A0F8AD89AC41441F74248704E756BC59C38920617F51BFCDB738EE5B123876D489D09C9EB904A321A336EC, + ], + ], + [ + [ + 0x01546AF2ABB4E189E9BBC412FDBF2A8E5EC6E4A3B0AF132E21EE9CEC3EF5E226490FB98D662670FA3CFB3948B7E2A48C, + 0x002961A558A885DF227FDB09F8BDF57AF179CB9437FF8828F13E9DF01AE55502F409AAF5058B88F2F7CCC7BC0676A5D4, + ], + [ + 0x00B0630E7F192D20443A93860275447074CE77DF559907FA1900F378D4674649BF25F85C893E2A1916B1DA57594F2E17, + 0x01ACC84F362CF60A265C011F0FE4360A15F51BECF7E2C3923FE07C66D5D113104B56E8486C64204A2A9ECD75BA0C41A7, + ], + ], + ), +] + + +class BLS12_377_Test(parameterized.TestCase): + + def __init__(self, *args, **kwargs): + super(BLS12_377_Test, self).__init__(*args, **kwargs) + + @parameterized.named_parameters(*BLS12_377_TEST_CASES) + def test_ExtendedTwistedEdwards_point_add(self, point_batch_1, point_batch_2): + ec_config = toml.load( + os.path.join(os.path.dirname(__file__), "configurations.toml") + ) + rns_moduli = utils.find_moduli_specified_number(32, 28) + finite_field_parameters = { + "prime": ec_config["ec_parameters_bls12_377_affine"]["prime"], + "rns_moduli": rns_moduli, + "precision_bits": 28, + "radix_bits": 32, + } + ete_cfg = ec_config["ec_parameters_bls12_377_extended_twisted_edwards"] + ec_parameters = { + "finite_field_context_class": ff_context.DRNSlazyContext, + "finite_field_parameters": finite_field_parameters, + "prime": ete_cfg["prime"], + "order": ete_cfg["order"], + "a": ete_cfg["a"], + "twist_d": ete_cfg["d"], + "alpha": ete_cfg["alpha"], + "b": ete_cfg["b"], + "s": ete_cfg["s"], + "MA": ete_cfg["MA"], + "MB": ete_cfg["MB"], + "t": ete_cfg["t"], + "generator": ete_cfg["generator"], + } + affine_cfg = ec_config["ec_parameters_bls12_377_affine"] + ref_ec_parameters = { + "finite_field_parameters": finite_field_parameters, + "finite_field_context_class": ff_context.DRNSlazyContext, + "prime": affine_cfg["prime"], + "order": affine_cfg["order"], + "a": affine_cfg["a"], + "b": affine_cfg["b"], + "generator": affine_cfg["generator"], + } + + ec_ctx = ec_context.ExtendedTwistedEdwardsContext(ec_parameters) + ref_ec_ctx = ec_context.CPUWeierstrassAffineContext(ref_ec_parameters) + + point_batch_1_m = ec_ctx.to_computational_format(point_batch_1) + point_batch_2_m = ec_ctx.to_computational_format(point_batch_2) + result_m = ec_ctx.point_add(point_batch_1_m, point_batch_2_m) + result = ec_ctx.to_original_format(result_m) + + ref_result = ref_ec_ctx._point_add(point_batch_1, point_batch_2) + + np.testing.assert_array_equal(result, ref_result) -class TestEllipticCurve(absltest.TestCase): - - def setUp(self): - super().setUp() - self.coordinate_num = 4 - self.batch_size = 1 - self.x1_int_ = 0x01AC3A384FC584EFD3E7F2C5A2927E7D454875C874A051027B9E7363D08942533EDE85DAE295D8CAB2751085206BCA76 - self.y1_int_ = 0x011DB83AEC88460820F4868A73B12309EE2E910526E62DB4ACCB303ABF50F86C3985A072ED07A4B81FFB82D8DD247283 - self.x2_int_ = 0x01546AF2ABB4E189E9BBC412FDBF2A8E5EC6E4A3B0AF132E21EE9CEC3EF5E226490FB98D662670FA3CFB3948B7E2A48C - self.y2_int_ = 0x002961A558A885DF227FDB09F8BDF57AF179CB9437FF8828F13E9DF01AE55502F409AAF5058B88F2F7CCC7BC0676A5D4 - self.point_a = [self.x1_int_, self.y1_int_] - self.point_b = [self.x2_int_, self.y2_int_] - self.zero_twisted = [0, 1, 1, 0] - self.ec_sys = ec.ECCSWeierstrassXYZZ(config_file.config_BLS12_377) - self.point_a_sys = self.ec_sys.generate_point(self.point_a) - self.point_b_sys = self.ec_sys.generate_point(self.point_b) - assert int(self.point_a_sys.coordinates[0].value) == self.point_a[0] - assert int(self.point_a_sys.coordinates[1].value) == self.point_a[1] - assert int(self.point_b_sys.coordinates[0].value) == self.point_b[0] - assert int(self.point_b_sys.coordinates[1].value) == self.point_b[1] - self.true_result_padd = self.point_a_sys + self.point_b_sys - self.true_result_padd_affine = self.true_result_padd.convert_to_affine() - self.true_result_pdub_a = self.point_a_sys + self.point_a_sys - self.true_result_pdub_a_affine = self.true_result_pdub_a.convert_to_affine() - self.true_result_pdub_b = self.point_b_sys + self.point_b_sys - self.true_result_pdub_b_affine = self.true_result_pdub_b.convert_to_affine() - - def test_padd_barrett_xyzz_pack(self): - point_a_jax = util.int_point_batch_to_jax_point_pack( - [self.point_a + [1] * (self.coordinate_num - len(self.point_a))] - ) - point_b_jax = util.int_point_batch_to_jax_point_pack( - [self.point_b + [1] * (self.coordinate_num - len(self.point_b))] - ) - jit_padd_barrett_xyzz_pack = jax.jit(jec.padd_barrett_xyzz_pack) - - result_jax = jit_padd_barrett_xyzz_pack(point_a_jax, point_b_jax) - result_jax = util.jax_point_pack_to_int_point_batch(result_jax) - - self.assertEqual(result_jax[0][0], self.true_result_padd[0].get_value()) - self.assertEqual(result_jax[0][1], self.true_result_padd[1].get_value()) - self.assertEqual(result_jax[0][2], self.true_result_padd[2].get_value()) - self.assertEqual(result_jax[0][3], self.true_result_padd[3].get_value()) - - # performance measurement - tasks = [ - (jit_padd_barrett_xyzz_pack, (point_a_jax, point_b_jax)), - ] - profile_name = "jit_padd_barrett_xyzz_pack" - # copybara: util.profile_jax_functions(tasks, profile_name) - - def test_pdul_barrett_xyzz(self): - point_a_jax = util.int_point_batch_to_jax_point_pack( - [self.point_a + [1] * (self.coordinate_num - len(self.point_a))] - ) - jit_pdul_barrett_xyzz_pack = jax.jit(jec.pdul_barrett_xyzz_pack) - result_jax = jit_pdul_barrett_xyzz_pack(point_a_jax) - result_jax = util.jax_point_pack_to_int_point_batch(result_jax) - - self.assertEqual(result_jax[0][0], self.true_result_pdub_a[0].get_value()) - self.assertEqual(result_jax[0][1], self.true_result_pdub_a[1].get_value()) - self.assertEqual(result_jax[0][2], self.true_result_pdub_a[2].get_value()) - self.assertEqual(result_jax[0][3], self.true_result_pdub_a[3].get_value()) - - # performance measurement - tasks = [ - (jit_pdul_barrett_xyzz_pack, (point_a_jax,)), - ] - profile_name = "jit_pdul_barrett_xyzz_pack" - # copybara: util.profile_jax_functions(tasks, profile_name) - - def test_jit_pdul_barrett_xyzz_pack_two_no_batch(self): - point_a_jax = util.int_point_batch_to_jax_point_pack( - [self.point_a + [1] * (self.coordinate_num - len(self.point_a))] - ) - point_b_jax = util.int_point_batch_to_jax_point_pack( - [self.point_b + [1] * (self.coordinate_num - len(self.point_b))] - ) - jit_pdul_barrett_xyzz_pack = jax.jit(jec.pdul_barrett_xyzz_pack) - result_a_jax = jit_pdul_barrett_xyzz_pack(point_a_jax) - result_a_int = util.jax_point_pack_to_int_point_batch(result_a_jax) - - result_b_jax = jit_pdul_barrett_xyzz_pack(point_b_jax) - result_b_int = util.jax_point_pack_to_int_point_batch(result_b_jax) - - self.assertEqual(result_a_int[0][0], self.true_result_pdub_a[0].get_value()) - self.assertEqual(result_a_int[0][1], self.true_result_pdub_a[1].get_value()) - self.assertEqual(result_a_int[0][2], self.true_result_pdub_a[2].get_value()) - self.assertEqual(result_a_int[0][3], self.true_result_pdub_a[3].get_value()) - self.assertEqual(result_b_int[0][0], self.true_result_pdub_b[0].get_value()) - self.assertEqual(result_b_int[0][1], self.true_result_pdub_b[1].get_value()) - self.assertEqual(result_b_int[0][2], self.true_result_pdub_b[2].get_value()) - self.assertEqual(result_b_int[0][3], self.true_result_pdub_b[3].get_value()) - - # performance measurement - tasks = [ - (jit_pdul_barrett_xyzz_pack, (point_a_jax,)), - (jit_pdul_barrett_xyzz_pack, (point_b_jax,)), - ] - profile_name = "jit_pdul_barrett_xyzz_pack" - # copybara: util.profile_jax_functions(tasks, profile_name) - - def test_jit_pdul_barrett_xyzz_pack_two_batch(self): - point_a_jax = util.int_point_batch_to_jax_point_pack( - [self.point_a + [1] * (self.coordinate_num - len(self.point_a))] - ) - point_b_jax = util.int_point_batch_to_jax_point_pack( - [self.point_b + [1] * (self.coordinate_num - len(self.point_b))] - ) - batch_point = jnp.concatenate([point_a_jax, point_b_jax], axis=1) - jit_pdul_barrett_xyzz_pack = jax.jit(jec.pdul_barrett_xyzz_pack) - result_jax = jit_pdul_barrett_xyzz_pack(batch_point) - result_int = util.jax_point_pack_to_int_point_batch(result_jax) - - self.assertEqual(result_int[0][0], self.true_result_pdub_a[0].get_value()) - self.assertEqual(result_int[0][1], self.true_result_pdub_a[1].get_value()) - self.assertEqual(result_int[0][2], self.true_result_pdub_a[2].get_value()) - self.assertEqual(result_int[0][3], self.true_result_pdub_a[3].get_value()) - self.assertEqual(result_int[1][0], self.true_result_pdub_b[0].get_value()) - self.assertEqual(result_int[1][1], self.true_result_pdub_b[1].get_value()) - self.assertEqual(result_int[1][2], self.true_result_pdub_b[2].get_value()) - self.assertEqual(result_int[1][3], self.true_result_pdub_b[3].get_value()) - - # performance measurement - tasks = [ - (jit_pdul_barrett_xyzz_pack, (batch_point,)), - ] - profile_name = "jit_pdul_barrett_xyzz_pack" - # copybara: util.profile_jax_functions(tasks, profile_name) - - def test_padd_lazy_xyzz_pack(self): - point_a_jax = util.int_point_batch_to_jax_point_pack( - [self.point_a + [1] * (self.coordinate_num - len(self.point_a))], - chunk_num=util.U16_EXT_CHUNK_NUM, - ) - point_b_jax = util.int_point_batch_to_jax_point_pack( - [self.point_b + [1] * (self.coordinate_num - len(self.point_b))], - chunk_num=util.U16_EXT_CHUNK_NUM, - ) - # lazy_mat = util.construct_lazy_matrix(util.MODULUS_377_INT) - jit_padd_lazy_xyzz_pack = jax.jit(jec.padd_lazy_xyzz_pack) - result_jax = jit_padd_lazy_xyzz_pack(point_a_jax, point_b_jax) - result_jax = util.jax_point_pack_to_int_point_batch(result_jax) - - self.assertEqual( - result_jax[0][0] % util.MODULUS_377_INT, - self.true_result_padd[0].get_value(), - ) - self.assertEqual( - result_jax[0][1] % util.MODULUS_377_INT, - self.true_result_padd[1].get_value(), - ) - self.assertEqual( - result_jax[0][2] % util.MODULUS_377_INT, - self.true_result_padd[2].get_value(), - ) - self.assertEqual( - result_jax[0][3] % util.MODULUS_377_INT, - self.true_result_padd[3].get_value(), - ) - - # performance measurement - tasks = [ - (jit_padd_lazy_xyzz_pack, (point_a_jax, point_b_jax)), - ] - profile_name = "jit_padd_lazy_xyzz_pack" - # copybara: util.profile_jax_functions(tasks, profile_name) - - def test_pdul_lazy_xyzz_pack(self): - point_a_jax = util.int_point_batch_to_jax_point_pack( - [self.point_a + [1] * (self.coordinate_num - len(self.point_a))], - chunk_num=util.U16_EXT_CHUNK_NUM, - ) - - # lazy_mat = util.construct_lazy_matrix(util.MODULUS_377_INT) - jit_pdul_lazy_xyzz_pack = jax.jit(jec.pdul_lazy_xyzz_pack) - result_jax = jit_pdul_lazy_xyzz_pack(point_a_jax) - result_jax = util.jax_point_pack_to_int_point_batch(result_jax) - - self.assertEqual( - result_jax[0][0] % util.MODULUS_377_INT, - self.true_result_pdub_a[0].get_value(), - ) - self.assertEqual( - result_jax[0][1] % util.MODULUS_377_INT, - self.true_result_pdub_a[1].get_value(), - ) - self.assertEqual( - result_jax[0][2] % util.MODULUS_377_INT, - self.true_result_pdub_a[2].get_value(), - ) - self.assertEqual( - result_jax[0][3] % util.MODULUS_377_INT, - self.true_result_pdub_a[3].get_value(), - ) - - # performance measurement - tasks = [ - (jit_pdul_lazy_xyzz_pack, (point_a_jax,)), - ] - profile_name = "jit_pdul_lazy_xyzz_pack" - # copybara: util.profile_jax_functions(tasks, profile_name) - - def test_padd_lazy_twisted_pack(self): - twisted_ec_sys = ec.ECCSTwistedEdwardsExtended( - config_file.config_BLS12_377_t - ) - twist_a = twisted_ec_sys.twist_int_coordinates(self.point_a) - twist_b = twisted_ec_sys.twist_int_coordinates(self.point_b) - - point_a_jax = util.int_point_batch_to_jax_point_pack( - [twist_a], chunk_num=util.U16_EXT_CHUNK_NUM - ) - point_b_jax = util.int_point_batch_to_jax_point_pack( - [twist_b], chunk_num=util.U16_EXT_CHUNK_NUM - ) - - jit_padd_lazy_twisted_pack = jax.jit(jec.padd_lazy_twisted_pack) - result_jax = jit_padd_lazy_twisted_pack(point_a_jax, point_b_jax) - result_int = util.jax_point_pack_to_int_point_batch(result_jax) - - result_affine_point = twisted_ec_sys.generate_point( - result_int[0], twist=False - ).convert_to_affine() - - self.assertEqual( - result_affine_point[0].get_value(), - self.true_result_padd_affine[0].get_value(), - ) - self.assertEqual( - result_affine_point[1].get_value(), - self.true_result_padd_affine[1].get_value(), - ) - - def test_padd_lazy_twisted_pack_batch(self): - for batch_size in [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]: - twisted_ec_sys = ec.ECCSTwistedEdwardsExtended( - config_file.config_BLS12_377_t - ) - twist_a = twisted_ec_sys.twist_int_coordinates(self.point_a) - twist_b = twisted_ec_sys.twist_int_coordinates(self.point_b) - - point_a_jax = util.int_point_batch_to_jax_point_pack( - [twist_a], chunk_num=util.U16_EXT_CHUNK_NUM - ) - point_b_jax = util.int_point_batch_to_jax_point_pack( - [twist_b], chunk_num=util.U16_EXT_CHUNK_NUM - ) - point_a_jax = jnp.broadcast_to( - point_a_jax, (point_a_jax.shape[0], batch_size, point_a_jax.shape[-1]) - ) - point_b_jax = jnp.broadcast_to( - point_b_jax, (point_b_jax.shape[0], batch_size, point_b_jax.shape[-1]) - ) - - jit_padd_lazy_twisted_pack_batch = jax.jit( - jax.named_call( - functools.partial(jec.padd_lazy_twisted_pack), - name=f"jit_padd_lazy_twisted_pack_batch_{batch_size}", - ), - ) - result_jax = jit_padd_lazy_twisted_pack_batch(point_a_jax, point_b_jax) - result_int = util.jax_point_pack_to_int_point_batch(result_jax) - result_affine_point = twisted_ec_sys.generate_point( - result_int[0], twist=False - ).convert_to_affine() - - self.assertEqual( - result_affine_point.coordinates[0].value % util.MODULUS_377_INT, - self.true_result_padd_affine[0].get_value(), - ) - self.assertEqual( - result_affine_point.coordinates[1].value % util.MODULUS_377_INT, - self.true_result_padd_affine[1].get_value(), - ) - self.assertEqual( - result_affine_point.coordinates[2].value % util.MODULUS_377_INT, - self.true_result_padd_affine[2].get_value(), - ) - self.assertEqual( - result_affine_point.coordinates[3].value % util.MODULUS_377_INT, - self.true_result_padd_affine[3].get_value(), - ) - - # performance measurement - tasks = [ - (jit_padd_lazy_twisted_pack_batch, (point_a_jax, point_b_jax)), - ] - profile_name = f"jit_padd_lazy_twisted_pack_batch_{batch_size}" - # copybara: util.profile_jax_functions(tasks, profile_name) - - def test_padd_same_lazy_twisted_pack(self): - twisted_ec_sys = ec.ECCSTwistedEdwardsExtended( - config_file.config_BLS12_377_t - ) - twist_a1 = twisted_ec_sys.twist_int_coordinates(self.point_a) - twist_a2 = twisted_ec_sys.twist_int_coordinates(self.point_a) - - point_a1_jax = util.int_point_batch_to_jax_point_pack( - [twist_a1], chunk_num=util.U16_EXT_CHUNK_NUM - ) - point_a2_jax = util.int_point_batch_to_jax_point_pack( - [twist_a2], chunk_num=util.U16_EXT_CHUNK_NUM - ) - - jit_padd_lazy_twisted_pack = jax.jit(jec.padd_lazy_twisted_pack) - result_jax = jit_padd_lazy_twisted_pack(point_a1_jax, point_a2_jax) - result_int = util.jax_point_pack_to_int_point_batch(result_jax) - - result_affine_point = twisted_ec_sys.generate_point( - result_int[0], twist=False - ).convert_to_affine() - - self.assertEqual( - result_affine_point[0].get_value(), - self.true_result_pdub_a_affine[0].get_value(), - ) - self.assertEqual( - result_affine_point[1].get_value(), - self.true_result_pdub_a_affine[1].get_value(), - ) - - # performance measurement - tasks = [ - (jit_padd_lazy_twisted_pack, (point_a1_jax, point_a2_jax)), - ] - profile_name = "jit_padd_lazy_twisted_pack" - # copybara: util.profile_jax_functions(tasks, profile_name) - - def test_pdul_lazy_twisted_pack(self): - twisted_ec_sys = ec.ECCSTwistedEdwardsExtended( - config_file.config_BLS12_377_t - ) - twist_a = twisted_ec_sys.twist_int_coordinates(self.point_a) - point_a_jax = util.int_point_batch_to_jax_point_pack( - [twist_a], chunk_num=util.U16_EXT_CHUNK_NUM - ) - - jit_pdul_lazy_twisted_pack = jax.jit(jec.pdul_lazy_twisted_pack) - result_jax = jit_pdul_lazy_twisted_pack(point_a_jax) - result_int = util.jax_point_pack_to_int_point_batch(result_jax) - - result_affine_point = twisted_ec_sys.generate_point( - result_int[0], twist=False - ).convert_to_affine() - self.assertEqual( - result_affine_point[0].get_value(), - self.true_result_pdub_a_affine[0].get_value(), - ) - self.assertEqual( - result_affine_point[1].get_value(), - self.true_result_pdub_a_affine[1].get_value(), - ) - - # performance measurement - tasks = [ - (jit_pdul_lazy_twisted_pack, (point_a_jax,)), - ] - profile_name = "jit_pdul_lazy_twisted_pack" - # copybara: util.profile_jax_functions(tasks, profile_name) - - def test_jit_pneg_lazy_twisted_pack(self): - twisted_ec_sys = ec.ECCSTwistedEdwardsExtended( - config_file.config_BLS12_377_t - ) - twist_a = twisted_ec_sys.twist_int_coordinates(self.point_a) - twist_b = twisted_ec_sys.twist_int_coordinates(self.point_b) - - point_a_jax = util.int_point_batch_to_jax_point_pack( - [twist_a], chunk_num=util.U16_EXT_CHUNK_NUM - ) - point_b_jax = util.int_point_batch_to_jax_point_pack( - [twist_b], chunk_num=util.U16_EXT_CHUNK_NUM - ) - - jit_padd_lazy_twisted_pack = jax.jit(jec.padd_lazy_twisted_pack) - jit_pneg_lazy_twisted_pack = jax.jit(jec.pneg_lazy_twisted_pack) - a_plus_b = jit_padd_lazy_twisted_pack(point_a_jax, point_b_jax) - neg_b = jit_pneg_lazy_twisted_pack(point_b_jax) - result_jax = jit_padd_lazy_twisted_pack(a_plus_b, neg_b) - result_int = util.jax_point_pack_to_int_point_batch(result_jax) - - result_affine_point = twisted_ec_sys.generate_point( - result_int[0], twist=False - ).convert_to_affine() - self.assertEqual( - result_affine_point[0].get_value(), self.point_a_sys[0].get_value() - ) - self.assertEqual( - result_affine_point[1].get_value(), self.point_a_sys[1].get_value() - ) - - # performance measurement - tasks = [ - (jit_pneg_lazy_twisted_pack, (point_b_jax,)), - ] - profile_name = "jit_pneg_lazy_twisted_pack" - # copybara: util.profile_jax_functions(tasks, profile_name) - - def test_padd_rns_xyzz(self): - point_a_jax = util.int_point_batch_to_jax_rns_point_pack( - [self.point_a + [1] * (self.coordinate_num - len(self.point_a))] - ) - point_b_jax = util.int_point_batch_to_jax_rns_point_pack( - [self.point_b + [1] * (self.coordinate_num - len(self.point_b))] - ) - rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT) - - jit_padd_rns_xyzz_pack = jax.jit( - jax.named_call( - functools.partial(jec.padd_rns_xyzz_pack, rns_mat=rns_mat), - name="jit_padd_rns_xyzz_pack", - ), - ) - result_jax = jit_padd_rns_xyzz_pack(point_a_jax, point_b_jax) - result_jax = util.jax_rns_point_pack_to_int_point_batch(result_jax) - - self.assertEqual( - result_jax[0][0] % util.MODULUS_377_INT, - self.true_result_padd[0].get_value(), - ) - self.assertEqual( - result_jax[0][1] % util.MODULUS_377_INT, - self.true_result_padd[1].get_value(), - ) - self.assertEqual( - result_jax[0][2] % util.MODULUS_377_INT, - self.true_result_padd[2].get_value(), - ) - self.assertEqual( - result_jax[0][3] % util.MODULUS_377_INT, - self.true_result_padd[3].get_value(), - ) - - # performance measurement - tasks = [ - (jit_padd_rns_xyzz_pack, (point_a_jax, point_b_jax)), - ] - profile_name = "jit_padd_rns_xyzz_pack" - # copybara: util.profile_jax_functions(tasks, profile_name) - - def test_padd_rns_xyzz_batch(self): - for batch_size in [1, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096]: - point_a_jax = util.int_point_batch_to_jax_rns_point_pack( - [self.point_a + [1] * (self.coordinate_num - len(self.point_a))] - ) - point_b_jax = util.int_point_batch_to_jax_rns_point_pack( - [self.point_b + [1] * (self.coordinate_num - len(self.point_b))] - ) - point_a_jax = jnp.broadcast_to( - point_a_jax, (point_a_jax.shape[0], batch_size, point_a_jax.shape[-1]) - ) - point_b_jax = jnp.broadcast_to( - point_b_jax, (point_b_jax.shape[0], batch_size, point_b_jax.shape[-1]) - ) - rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT) - - jit_padd_rns_xyzz_pack_batch = jax.jit( - jax.named_call( - functools.partial(jec.padd_rns_xyzz_pack, rns_mat=rns_mat), - name="jit_padd_rns_xyzz_pack_batch", - ), - ) - result_jax = jit_padd_rns_xyzz_pack_batch(point_a_jax, point_b_jax) - result_jax = util.jax_rns_point_pack_to_int_point_batch(result_jax) - - self.assertEqual( - result_jax[0][0] % util.MODULUS_377_INT, - self.true_result_padd[0].get_value(), - ) - self.assertEqual( - result_jax[0][1] % util.MODULUS_377_INT, - self.true_result_padd[1].get_value(), - ) - self.assertEqual( - result_jax[0][2] % util.MODULUS_377_INT, - self.true_result_padd[2].get_value(), - ) - self.assertEqual( - result_jax[0][3] % util.MODULUS_377_INT, - self.true_result_padd[3].get_value(), - ) - - # performance measurement - tasks = [ - (jit_padd_rns_xyzz_pack_batch, (point_a_jax, point_b_jax)), - ] - profile_name = f"jit_padd_rns_xyzz_pack_batch_{batch_size}" - # copybara: util.profile_jax_functions(tasks, profile_name) - - def test_pdul_rns_xyzz_pack(self): - point_a_jax = util.int_point_batch_to_jax_rns_point_pack( - [self.point_a + [1] * (self.coordinate_num - len(self.point_a))] - ) - rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT) - - jit_pdul_rns_xyzz_pack = jax.jit( - jax.named_call( - functools.partial(jec.pdul_rns_xyzz_pack, rns_mat=rns_mat), - name="jit_pdul_rns_xyzz_pack", - ), - ) - result_jax = jit_pdul_rns_xyzz_pack(point_a_jax) - result_jax = util.jax_rns_point_pack_to_int_point_batch(result_jax) - - self.assertEqual( - result_jax[0][0] % util.MODULUS_377_INT, - self.true_result_pdub_a[0].get_value(), - ) - self.assertEqual( - result_jax[0][1] % util.MODULUS_377_INT, - self.true_result_pdub_a[1].get_value(), - ) - self.assertEqual( - result_jax[0][2] % util.MODULUS_377_INT, - self.true_result_pdub_a[2].get_value(), - ) - self.assertEqual( - result_jax[0][3] % util.MODULUS_377_INT, - self.true_result_pdub_a[3].get_value(), - ) - - # performance measurement - tasks = [ - (jit_pdul_rns_xyzz_pack, (point_a_jax,)), - ] - profile_name = "jit_pdul_rns_xyzz_pack" - # copybara: util.profile_jax_functions(tasks, profile_name) - - def test_padd_rns_twisted_pack(self): - twisted_ec_sys = ec.ECCSTwistedEdwardsExtended( - config_file.config_BLS12_377_t - ) - a = self.point_a_sys.coordinates[:2] - b = self.point_b_sys.coordinates[:2] - project_twist_a = twisted_ec_sys.generate_point(a, twist=True) - project_twist_b = twisted_ec_sys.generate_point(b, twist=True) - rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT) - point_a_jax = util.int_point_batch_to_jax_rns_point_pack( - [[c.get_value() for c in project_twist_a]] - ) - point_b_jax = util.int_point_batch_to_jax_rns_point_pack( - [[c.get_value() for c in project_twist_b]] - ) - jit_padd_rns_twisted_pack = jax.jit( - jax.named_call( - functools.partial(jec.padd_rns_twisted_pack, rns_mat=rns_mat), - name="jit_padd_rns_twisted_pack", - ), - ) - point_c_jax = jit_padd_rns_twisted_pack(point_a_jax, point_b_jax) - project_twist_sum = util.jax_rns_point_pack_to_int_point_batch(point_c_jax)[ - 0 - ] - project_twist_sum_point = twisted_ec_sys.generate_point( - project_twist_sum, twist=False - ).convert_to_affine() - s = project_twist_sum_point.coordinates[:2] - correct_s = self.true_result_padd_affine.coordinates[:2] - self.assertEqual( - s[0].get_value() % util.MODULUS_377_INT, correct_s[0].get_value() - ) - self.assertEqual( - s[1].get_value() % util.MODULUS_377_INT, correct_s[1].get_value() - ) - - # performance measurement - tasks = [ - (jit_padd_rns_twisted_pack, (point_a_jax, point_b_jax)), - ] - profile_name = "jit_padd_rns_twisted_pack" - # copybara: util.profile_jax_functions(tasks, profile_name) - - def test_pdul_rns_twisted_pack(self): - twisted_ec_sys = ec.ECCSTwistedEdwardsExtended( - config_file.config_BLS12_377_t - ) - rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT) - pg = twisted_ec_sys.generate_point(self.point_a, twist=False) - g_correct_ff = pg << 1 - gcf_af = g_correct_ff.convert_to_affine() - point_g_jax = util.int_point_batch_to_jax_rns_point_pack( - [[c.get_value() for c in pg]] - ) - jit_pdul_rns_twisted_pack = jax.jit( - jax.named_call( - functools.partial(jec.pdul_rns_twisted_pack, rns_mat=rns_mat), - name="jit_pdul_rns_twisted_pack", - ), - ) - point_2g_jax = jit_pdul_rns_twisted_pack(point_g_jax) - g_test = util.jax_rns_point_pack_to_int_point_batch(point_2g_jax)[0] - gtf_af = twisted_ec_sys.generate_point( - g_test, twist=False - ).convert_to_affine() - self.assertEqual( - gtf_af.coordinates[0].get_value() % util.MODULUS_377_INT, - gcf_af.coordinates[0].get_value(), - ) - self.assertEqual( - gtf_af.coordinates[1].get_value() % util.MODULUS_377_INT, - gcf_af.coordinates[1].get_value(), - ) - - # performance measurement - tasks = [ - (jit_pdul_rns_twisted_pack, (point_g_jax,)), - ] - profile_name = "jit_pdul_rns_twisted_pack" - # copybara: util.profile_jax_functions(tasks, profile_name) - - def test_padd_rns_twisted_pack_new_twisted(self): - twisted_ec_sys = ec.ECCSTwistedEdwardsExtended( - config_file.config_BLS12_377_t - ) - twist_a = twisted_ec_sys.twist_int_coordinates(self.point_a) - twist_b = twisted_ec_sys.twist_int_coordinates(self.point_b) - - point_a_jax = util.int_point_to_jax_rns_point_pack(twist_a) - point_b_jax = util.int_point_to_jax_rns_point_pack(twist_b) - - rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT) - jit_padd_rns_twisted_pack = jax.jit( - jax.named_call( - functools.partial(jec.padd_rns_twisted_pack, rns_mat=rns_mat), - name="jit_padd_rns_twisted_pack", - ), - ) - point_c_jax = jit_padd_rns_twisted_pack(point_a_jax, point_b_jax) - project_twist_sum = util.jax_rns_point_pack_to_int_point_batch(point_c_jax)[ - 0 - ] - project_twist_sum_point = twisted_ec_sys.generate_point( - project_twist_sum, twist=False - ).convert_to_affine() - s = project_twist_sum_point.coordinates[:2] - correct_s = self.true_result_padd_affine.coordinates[:2] - self.assertEqual( - s[0].get_value() % util.MODULUS_377_INT, correct_s[0].get_value() - ) - self.assertEqual( - s[1].get_value() % util.MODULUS_377_INT, correct_s[1].get_value() - ) - - # performance measurement - tasks = [ - (jit_padd_rns_twisted_pack, (point_a_jax, point_b_jax)), - ] - profile_name = "jit_padd_rns_twisted_pack_new_twist" - # copybara: util.profile_jax_functions(tasks, profile_name) - - def test_pdul_rns_twisted_pack_new_twisted(self): - twisted_ec_sys = ec.ECCSTwistedEdwardsExtended( - config_file.config_BLS12_377_t - ) - rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT) - twist_a = twisted_ec_sys.twist_int_coordinates(self.point_a) - point_a_jax = util.int_point_to_jax_rns_point_pack(twist_a) - jit_pdul_rns_twisted_pack = jax.jit( - jax.named_call( - functools.partial(jec.pdul_rns_twisted_pack, rns_mat=rns_mat), - name="jit_pdul_rns_twisted_pack", - ), - ) - point_2a_jax = jit_pdul_rns_twisted_pack(point_a_jax) - g_test = util.jax_rns_point_pack_to_int_point_batch(point_2a_jax)[0] - gtf_af = twisted_ec_sys.generate_point( - g_test, twist=False - ).convert_to_affine() - self.assertEqual( - gtf_af.coordinates[0].get_value() % util.MODULUS_377_INT, - self.true_result_pdub_a_affine[0].get_value(), - ) - self.assertEqual( - gtf_af.coordinates[1].get_value() % util.MODULUS_377_INT, - self.true_result_pdub_a_affine[1].get_value(), - ) - - # performance measurement - tasks = [ - (jit_pdul_rns_twisted_pack, (point_a_jax,)), - ] - profile_name = "jit_pdul_rns_twisted_pack_new_twist" - # copybara: util.profile_jax_functions(tasks, profile_name) - - def test_padd_rns_twisted_pack_new_twist_two_batch(self): - twisted_ec_sys = ec.ECCSTwistedEdwardsExtended( - config_file.config_BLS12_377_t - ) - twist_a = twisted_ec_sys.twist_int_coordinates(self.point_a) - point_a_jax = util.int_point_to_jax_rns_point_pack(twist_a).reshape( - util.COORDINATE_NUM, 1, util.NUM_MODULI - ) - - twist_b = twisted_ec_sys.twist_int_coordinates(self.point_b) - point_b_jax = util.int_point_to_jax_rns_point_pack(twist_b).reshape( - util.COORDINATE_NUM, 1, util.NUM_MODULI - ) - - batch_point = jnp.concatenate([point_a_jax, point_b_jax], axis=1) - - rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT) - jit_padd_rns_twisted_pack_two_batch = jax.jit( - jax.named_call( - functools.partial(jec.padd_rns_twisted_pack, rns_mat=rns_mat), - name="jit_padd_rns_twisted_pack_two_batch", - ), - ) - result_batch = jit_padd_rns_twisted_pack_two_batch(batch_point, batch_point) - project_twist_sum = util.jax_rns_point_pack_to_int_point_batch(result_batch) - point_2a_jax = twisted_ec_sys.generate_point( - project_twist_sum[0], twist=False - ).convert_to_affine() - point_2b_jax = twisted_ec_sys.generate_point( - project_twist_sum[1], twist=False - ).convert_to_affine() - self.assertEqual( - point_2a_jax[0].get_value() % util.MODULUS_377_INT, - self.true_result_pdub_a_affine[0].get_value(), - ) - self.assertEqual( - point_2a_jax[1].get_value() % util.MODULUS_377_INT, - self.true_result_pdub_a_affine[1].get_value(), - ) - self.assertEqual( - point_2b_jax[0].get_value() % util.MODULUS_377_INT, - self.true_result_pdub_b_affine[0].get_value(), - ) - self.assertEqual( - point_2b_jax[1].get_value() % util.MODULUS_377_INT, - self.true_result_pdub_b_affine[1].get_value(), - ) - - # performance measurement - tasks = [ - (jit_padd_rns_twisted_pack_two_batch, (batch_point, batch_point)), - ] - profile_name = "jit_padd_rns_twisted_pack_two_batch" - # copybara: util.profile_jax_functions(tasks, profile_name) - - def test_pdul_rns_twisted_pack_new_twist_two_batch(self): - twisted_ec_sys = ec.ECCSTwistedEdwardsExtended( - config_file.config_BLS12_377_t - ) - twist_a = twisted_ec_sys.twist_int_coordinates(self.point_a) - point_a_jax = util.int_point_to_jax_rns_point_pack(twist_a).reshape( - util.COORDINATE_NUM, 1, util.NUM_MODULI - ) - - twist_b = twisted_ec_sys.twist_int_coordinates(self.point_b) - point_b_jax = util.int_point_to_jax_rns_point_pack(twist_b).reshape( - util.COORDINATE_NUM, 1, util.NUM_MODULI - ) - - batch_point = jnp.concatenate([point_a_jax, point_b_jax], axis=1) - - rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT) - jit_pdul_rns_twisted_pack_two_batch = jax.jit( - jax.named_call( - functools.partial(jec.pdul_rns_twisted_pack, rns_mat=rns_mat), - name="jit_pdul_rns_twisted_pack_two_batch", - ), - ) - result_batch = jit_pdul_rns_twisted_pack_two_batch(batch_point) - project_twist_sum = util.jax_rns_point_pack_to_int_point_batch(result_batch) - point_2a_jax = twisted_ec_sys.generate_point( - project_twist_sum[0], twist=False - ).convert_to_affine() - point_2b_jax = twisted_ec_sys.generate_point( - project_twist_sum[1], twist=False - ).convert_to_affine() - self.assertEqual( - point_2a_jax[0].get_value() % util.MODULUS_377_INT, - self.true_result_pdub_a_affine[0].get_value(), - ) - self.assertEqual( - point_2a_jax[1].get_value() % util.MODULUS_377_INT, - self.true_result_pdub_a_affine[1].get_value(), - ) - self.assertEqual( - point_2b_jax[0].get_value() % util.MODULUS_377_INT, - self.true_result_pdub_b_affine[0].get_value(), - ) - self.assertEqual( - point_2b_jax[1].get_value() % util.MODULUS_377_INT, - self.true_result_pdub_b_affine[1].get_value(), - ) - - # performance measurement - tasks = [ - (jit_pdul_rns_twisted_pack_two_batch, (batch_point,)), - ] - profile_name = "jit_pdul_rns_twisted_pack_two_batch" - # copybara: util.profile_jax_functions(tasks, profile_name) - - def test_padd_zero_twisted_pack_new_twisted(self): - twisted_ec_sys = ec.ECCSTwistedEdwardsExtended( - config_file.config_BLS12_377_t - ) - twist_a = twisted_ec_sys.twist_int_coordinates(self.point_a) - point_a_jax = util.int_point_batch_to_jax_point_pack( - [twist_a], chunk_num=util.U16_EXT_CHUNK_NUM - ) - point_zero_jax = util.int_point_batch_to_jax_point_pack( - [self.zero_twisted], chunk_num=util.U16_EXT_CHUNK_NUM - ) - - jit_padd_lazy_twisted_pack = jax.jit( - jax.named_call( - functools.partial(jec.padd_lazy_twisted_pack), - name="jit_padd_lazy_twisted_pack", - ), - ) - point_c_jax = jit_padd_lazy_twisted_pack(point_a_jax, point_zero_jax) - # point_c_jax = jec.padd_lazy_twisted_pack(point_a_jax, point_zero_jax) - project_twist_sum = util.jax_point_pack_to_int_point_batch(point_c_jax)[0] - project_twist_sum_point = twisted_ec_sys.generate_point( - project_twist_sum, twist=False - ).convert_to_affine() - self.assertEqual( - project_twist_sum_point[0].get_value() % util.MODULUS_377_INT, - self.point_a[0], - ) - self.assertEqual( - project_twist_sum_point[1].get_value() % util.MODULUS_377_INT, - self.point_a[1], - ) - - @absltest.skip("This is the known issue, which does not affect XYZZ.") - def test_padd_zero_rns_twisted_pack_new_twisted(self): - twisted_ec_sys = ec.ECCSTwistedEdwardsExtended( - config_file.config_BLS12_377_t - ) - twist_a = twisted_ec_sys.twist_int_coordinates(self.point_a) - - point_a_jax = util.int_point_batch_to_jax_rns_point_pack([twist_a]) - point_zero_jax = util.int_point_batch_to_jax_rns_point_pack( - [self.zero_twisted] - ) - - rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT) - jit_padd_rns_twisted_pack = jax.jit( - jax.named_call( - functools.partial(jec.padd_rns_twisted_pack, rns_mat=rns_mat), - name="jit_padd_rns_twisted_pack", - ), - ) - point_c_jax = jit_padd_rns_twisted_pack(point_a_jax, point_zero_jax) - project_twist_sum = util.jax_rns_point_pack_to_int_point_batch(point_c_jax)[ - 0 - ] - project_twist_sum_point = twisted_ec_sys.generate_point( - project_twist_sum, twist=False - ).convert_to_affine() - self.assertEqual( - project_twist_sum_point[0].get_value() % util.MODULUS_377_INT, - self.point_a[0], - ) - self.assertEqual( - project_twist_sum_point[1].get_value() % util.MODULUS_377_INT, - self.point_a[1], - ) - - @absltest.skip("This is the known issue, which does not affect XYZZ.") - def test_padd_rns_a_point_add_zero_correctness(self): - twisted_ec_sys = ec.ECCSTwistedEdwardsExtended( - config_file.config_BLS12_377_t - ) - test_in_point = [ - 184360877740379501345167057231241011318955851506892921023614488028185166370541128591905842464011651119609504970811, - 47698235458971847835762299820400550031713475079888046003406323907410999702258242394959839249289205517205485978635, - ] - twist_a = twisted_ec_sys.twist_int_coordinates(test_in_point) - twist_b = [0, 1, 1, 0] - point_a_jax = util.int_point_batch_to_jax_point_pack( - [twist_a], chunk_num=util.U16_EXT_CHUNK_NUM - ) - point_b_jax = util.int_point_batch_to_jax_point_pack( - [twist_b], chunk_num=util.U16_EXT_CHUNK_NUM - ) - point_a_jax_rns = util.int_point_batch_to_jax_rns_point_pack([twist_a]) - point_b_jax_rns = util.int_point_batch_to_jax_rns_point_pack([twist_b]) - - point_c_jax_rns = jec.padd_rns_twisted_pack( - point_a_jax_rns, point_b_jax_rns - ) - project_twist_sum_rns = util.jax_rns_point_pack_to_int_point_batch( - point_c_jax_rns - )[0] - affine_sum_point_rns = twisted_ec_sys.generate_point( - project_twist_sum_rns, twist=False - ).convert_to_affine() - point_c_jax = jec.padd_lazy_twisted_pack(point_a_jax, point_b_jax) - project_twist_sum = util.jax_point_pack_to_int_point_batch(point_c_jax)[0] - affine_sum_point = twisted_ec_sys.generate_point( - project_twist_sum, twist=False - ).convert_to_affine() - # In Twisted Edward Representation Verification - self.assertEqual( - project_twist_sum[0] % util.MODULUS_377_INT, - project_twist_sum_rns[0] % util.MODULUS_377_INT, - ) - self.assertEqual( - project_twist_sum[1] % util.MODULUS_377_INT, - project_twist_sum_rns[1] % util.MODULUS_377_INT, - ) - self.assertEqual( - project_twist_sum[0] % util.MODULUS_377_INT, - project_twist_sum_rns[0] % util.MODULUS_377_INT, - ) - self.assertEqual( - project_twist_sum[1] % util.MODULUS_377_INT, - project_twist_sum_rns[1] % util.MODULUS_377_INT, - ) - - # Verification in affine - self.assertEqual( - affine_sum_point[0].get_value() % util.MODULUS_377_INT, - affine_sum_point_rns[0].get_value() % util.MODULUS_377_INT, - ) - self.assertEqual( - affine_sum_point[1].get_value() % util.MODULUS_377_INT, - affine_sum_point_rns[1].get_value() % util.MODULUS_377_INT, - ) - if __name__ == "__main__": absltest.main() diff --git a/jaxite_ec/finite_field.py b/jaxite_ec/finite_field.py deleted file mode 100644 index 863cd21..0000000 --- a/jaxite_ec/finite_field.py +++ /dev/null @@ -1,976 +0,0 @@ -"""library of finite field operations. - -This library is used to implement the finite field operations for the -high-precision -elliptic curve. - -- Data Representation. -The high-precision elliptic curve coordinate is represented as a vector of -uint16 24-bit integer. -The actual base is increased with the index. -E.g. - index [ 0, 1, 2, ...] -bit precision [0~7, 8~15, 16~23, ...] - -It includes Barrett based modular multiplication, *barrett_reduction_u16x2* - - -# Function Name Terminology -## _ indicates that the function only works for a single -## precision. -## indicates that the function works for general bit precision. - -# Terminology -## Chunk Reduction: e.g. u8-chunk -> u16-chunk or u32-chunk -## Chunk Decomposition <-> Chunk Merge: -### Chunk Decomposition: break int into multiple low-precision chunks. -### Chunk Merge: Merge multiple low-precision chunks into an int. -""" - -import functools - -import jax -import jax.numpy as jnp -from jaxite.jaxite_ec import util -import numpy as np - - -total_modulus = util.total_modulus -to_rns = util.to_rns - - -jax.config.update("jax_enable_x64", True) - - -@jax.named_call -@functools.partial( - jax.jit, static_argnames=("iter_num", "mask", "chunk_shift_bits") -) -def carry_add( - value_c: jax.Array, - iter_num=util.U16_CHUNK_NUM, - mask=util.U16_MASK, - chunk_shift_bits=util.U16_CHUNK_SHIFT_BITS, -): - """The purpose of this API is to enable general-purposed carry add, where the following knobs are known before runtime. - - iter_num = math.ceil(total_input_bitwidth / chunk_bitwidth), - mask = 2**chunk_bitwidth - 1, - chunk_shift_bits = chunk_bitwidth - - Args: - value_c: The value to carry add. - iter_num: The number of iterations to perform. - mask: The mask to apply to the value. - chunk_shift_bits: The number of bits to shift the value. - - Returns: - value_c: The value after carry adding. - """ - for _ in range(iter_num): - low = jnp.bitwise_and(value_c, mask) - high = jnp.right_shift(value_c, chunk_shift_bits) - high = jnp.roll(high, 1) - value_c = jnp.add(low, high) - return value_c - - -@jax.named_call -@functools.partial(jax.jit, static_argnames="chunk_shift_bits") -def check_any_chunk_with_carry( - value_c: jax.Array, - chunk_shift_bits=util.U16_CHUNK_SHIFT_BITS, -) -> jax.Array: - """This function check whether any chunk of input vector 'value_c' has carry. - - Args: - value_c: The value to carry add. - chunk_shift_bits: ideal bit precision of any given chunk. Note that: actual - bit precision of any given chunk might be higher than chunk_shift_bits - because it needs to hold the overflow. - - Returns: - cond: A boolean value indicating whether any chunk of input vector 'value_c' - has carry. - """ - high = jnp.right_shift(value_c, chunk_shift_bits) - cond = jnp.any(jnp.not_equal(high, 0)) - return cond - - -@jax.named_call -@functools.partial(jax.jit, static_argnames=("mask", "chunk_shift_bits")) -def carry_propagation( - value_c: jax.Array, - mask=util.U16_MASK, - chunk_shift_bits=util.U16_CHUNK_SHIFT_BITS, -): - """The purpose of this API is to enable carry propagation. - - Args: - value_c: The value to carry propagate. - mask: 2**chunk_bitwidth - 1, - chunk_shift_bits: chunk_bitwidth - - This function split each chunk into high and low parts, and high part is left - roll by 1 to carry the overflowed bits to the next chunk. - Note that: in a given jax.array, bit range of the chunk within the original - high precision value is increased from left to the right. - - Returns: - value_c: The value after carry adding. - """ - precision_dim = value_c.shape[-1] - roll_mat = jnp.array( - [0, 1] - + ([0] * (precision_dim) + [1]) * (precision_dim - 2) - + [1] - + [0] * (precision_dim - 1), - dtype=jnp.uint16, - ).reshape(precision_dim, precision_dim) - low = jnp.bitwise_and(value_c, mask) - high = jnp.right_shift(value_c, chunk_shift_bits).astype(jnp.uint16) - high = jnp.matmul(high, roll_mat, preferred_element_type=jnp.uint32).astype( - jnp.uint16 - ) - value_c = jnp.add(low, high) - return value_c - - -def conv_1d_2u16xn(value_a: jax.Array, value_b: jax.Array): - """This function performs a 1D convolution of two u16 arrays. - - Args: - value_a: The chunk-decomposition representation of the high-precision int. - value_b: The chunk-decomposition representation of the high-precision int. - - Returns: - conv: The convolution results of two input arrays being casted to uint8. - """ - value_a = jax.lax.bitcast_convert_type(value_a, jnp.uint8).reshape(-1) - value_b = jax.lax.bitcast_convert_type(value_b, jnp.uint8).reshape(-1) - conv = jnp.convolve( - value_a, - value_b, - preferred_element_type=jnp.uint32, - ) - return conv - - -@jax.named_call -@functools.partial(jax.jit, static_argnames=("chunk_num_u16", "chunk_num_u32")) -def rechunkify(mul_result: jax.Array, chunk_num_u16, chunk_num_u32): - """Given the carry add takes O(C) algorithm complexity, where C is the number of chunks. - - This function performs chunk reduction for ther results of the convolution, - i.e. merge two consecutive chunks into one chunk with double precision. - E.g. u8[0, 8, 8, 0] -> u16[8, 2048] 0-> u32[526336] - - Args: - mul_result: The chunk-wise multiplication (using convolution) result. - chunk_num_u16: The number of bits in each chunk. - chunk_num_u32: The number of bits in the second chunk. - - Returns: - value_c: The result of the chunk reduction. - """ - shift_0_8_u16x4 = jnp.array( - [[0, 8] for _ in range(chunk_num_u16 * 4)], dtype=jnp.uint8 - ) - shift_0_16_u32x4 = jnp.array( - [[0, 16] for _ in range(chunk_num_u32 * 4)], dtype=jnp.uint8 - ) - new_shape = ( - mul_result.shape[:-1] + (-1, 2) if mul_result.ndim == 2 else (-1, 2) - ) - value_c = mul_result.reshape(new_shape) - value_c = jnp.left_shift(value_c, shift_0_8_u16x4[:chunk_num_u16]) - value_c = jnp.sum(value_c, axis=-1) - value_c = value_c.reshape(new_shape).astype(jnp.uint64) - value_c = jnp.left_shift(value_c, shift_0_16_u32x4[:chunk_num_u32]) - value_c = jnp.sum(value_c, axis=-1) - return value_c - - -@jax.named_call -@functools.partial(jax.jit, static_argnames="chunk_num_u16") -def compare_u16( - value_a: jax.Array, value_b: jax.Array, chunk_num_u16=util.U16_CHUNK_NUM -): - """Compare two u16 values. - - Args: - value_a: The first u16 value. - value_b: The second u16 value. - chunk_num_u16: The number of chunks in the u16 value. - - Returns: - cond > 0 -> value_a > value_b - cond = 0 -> value_a = value_b - cond < 0 -> value_a < value_b - """ - sign = jnp.sign( - jnp.subtract(value_a.astype(jnp.int32), value_b.astype(jnp.int32)) - ) - comp_check_vec_weights = jnp.array( - [2**i for i in range(chunk_num_u16)], dtype=jnp.int32 - ) - weight = jnp.multiply(sign, comp_check_vec_weights) - cond = weight.sum(axis=-1) - return cond - - -def add_2u16(value_a: jax.Array, value_b: jax.Array): - """Add two u16 values. - - Args: - value_a: The first u16 value. - value_b: The second u16 value. - - Returns: - value_c: The result of the addition. - """ - value_c = jax.numpy.add( - value_a.astype(jnp.uint32), value_b.astype(jnp.uint32) - ) - value_c = jax.lax.while_loop( - check_any_chunk_with_carry, carry_propagation, value_c - ) - - return value_c.astype(jnp.uint16) - - -def add_3u16(value_a: jax.Array, value_b: jax.Array, value_d: jax.Array): - value_c = jax.numpy.add( - value_a.astype(jnp.uint32), value_b.astype(jnp.uint32) - ) - value_c = jax.numpy.add( - value_c.astype(jnp.uint32), value_d.astype(jnp.uint32) - ) - value_c = jax.lax.while_loop( - check_any_chunk_with_carry, carry_propagation, value_c - ) - - return value_c.astype(jnp.uint16) - - -@jax.named_call -@functools.partial(jax.jit, static_argnames=("mask", "chunk_num_u16")) -def sub_2u16( - value_a: jax.Array, - value_b: jax.Array, - mask=util.U16_MASK, - chunk_num_u16=util.U16_CHUNK_NUM, -): - """Subtract two u16 values. - - Args: - value_a: The first u16 value. - value_b: The second u16 value. - mask: The mask to apply to the value. - chunk_num_u16: The number of chunks in the u16 value (default: 24). - - Returns: - value_c: The result of the subtraction. - """ - borrow_high_u16_pad_zero_array = jnp.array( - [0] + [1] * (chunk_num_u16 - 2) + [0], dtype=jnp.uint32 - ) - borrow_low_u16_array = jnp.array( - [mask + 1] * (chunk_num_u16 - 1) + [0], dtype=jnp.uint32 - ) - value_a = jnp.add(value_a.astype(jnp.uint32), borrow_low_u16_array) - value_c = jnp.subtract(value_a, value_b) - value_c = jnp.subtract(value_c, borrow_high_u16_pad_zero_array) - - value_c = jax.lax.while_loop( - check_any_chunk_with_carry, carry_propagation, value_c - ) - if value_c.ndim == 1: - value_c = value_c.at[chunk_num_u16 - 1].set(value_c[chunk_num_u16 - 1] - 1) - else: - value_c = value_c.at[:, chunk_num_u16 - 1].set( - value_c[:, chunk_num_u16 - 1] - 1 - ) - - value_c = value_c.astype(jnp.uint16) - return value_c - - -@jax.named_call -@functools.partial( - jax.jit, static_argnames=("modulus_377_int_chunk", "chunk_num_u16") -) -def cond_sub_mod_u16( - value_a: jax.Array, - modulus_377_int_chunk=util.MODULUS_377_INT_CHUNK, - chunk_num_u16=util.U16_CHUNK_NUM, -): - """Perform conditional subtraction: value_a - modulus_377_int. - - Args: - value_a: The minuend. - modulus_377_int_chunk: The modulus 377. - chunk_num_u16: The number of chunks in the u16 value (default: 24). - - Returns: - value_c: The result of the conditional subtraction. - """ - compare_u16_local = functools.partial( - compare_u16, chunk_num_u16=chunk_num_u16 - ) - sub_2u16_local = functools.partial(sub_2u16, chunk_num_u16=chunk_num_u16) - modulus_377_int_array = jnp.asarray(modulus_377_int_chunk, jnp.uint16) - - cond = compare_u16_local(value_a, modulus_377_int_array) - value_b = sub_2u16_local(value_a, modulus_377_int_array) - cond = jnp.greater_equal(cond, 0).reshape((cond.shape[0], 1)) - value_c = jnp.where(cond, value_b, value_a) - return value_c - - -@jax.named_call -@functools.partial( - jax.jit, static_argnames=("modulus_377_int_chunk", "chunk_num_u16") -) -def cond_sub_2u16( - value_a: jax.Array, - value_b: jax.Array, - modulus_377_int_chunk=util.MODULUS_377_INT_CHUNK, - chunk_num_u16=util.U16_CHUNK_NUM, -): - """Perform conditional subtraction: value_a - value_b. - - Args: - value_a: The minuend. - value_b: The subtrahend. - modulus_377_int_chunk: The modulus 377. - chunk_num_u16: The number of chunks in the u16 value (default: 24). - - Returns: - value_c: The result of the conditional subtraction. - """ - modulus_377_int_array = jnp.asarray(modulus_377_int_chunk, jnp.uint16) - compare_u16_local = functools.partial( - compare_u16, chunk_num_u16=chunk_num_u16 - ) - sub_2u16_local = functools.partial(sub_2u16, chunk_num_u16=chunk_num_u16) - - cond = compare_u16_local(value_a, value_b) - cond = jnp.greater_equal(cond, 0).reshape((cond.shape[0], 1)) - - value_ap = jnp.add( - value_a.astype(jnp.uint32), modulus_377_int_array.astype(jnp.uint32) - ) - - value_a = jnp.where(cond, value_a.astype(jnp.uint32), value_ap) - value_c = sub_2u16_local(value_a, value_b) - return value_c - - -@jax.named_call -@functools.partial( - jax.jit, - static_argnames=( - "mask", - "chunk_num_u16", - "chunk_shift_bits", - "output_dtype", - "vmap_axes", - ), -) -def mul_2u16( - value_a: jax.Array, - value_b: jax.Array, - mask=util.U32_MASK, - chunk_num_u16=util.U16_CHUNK_NUM, - chunk_shift_bits=util.U32_CHUNK_SHIFT_BITS, - output_dtype=jnp.uint16, - vmap_axes=(0, 0), -): - """Multiply two u16 values. - - Args: - value_a: The first u16 value. - value_b: The second u16 value. - mask: The mask to apply to the value. - chunk_num_u16: The number of chunks in the u16 value (default: 24). - chunk_shift_bits: The number of bits to shift the value. - output_dtype: The desired output data type. - vmap_axes: The axes to use for vmap. - - Returns: - - cond > 0 -> value_a > value_b - cond = 0 -> value_a = value_b - cond < 0 -> value_a < value_b - """ - batch_dim = value_a.shape[0] - mul_result = jax.vmap(conv_1d_2u16xn, in_axes=vmap_axes)(value_a, value_b) - mul_result = jnp.pad(mul_result, ((0, 0), (0, 1))) - value_c = rechunkify(mul_result, 2 * chunk_num_u16, chunk_num_u16) - - value_c = jax.lax.while_loop( - functools.partial( - check_any_chunk_with_carry, chunk_shift_bits=chunk_shift_bits - ), - functools.partial( - carry_propagation, - mask=mask, - chunk_shift_bits=chunk_shift_bits, - ), - value_c, - ) - ratio = 4 if output_dtype == jnp.uint8 else 2 - value_c = jax.lax.bitcast_convert_type( - value_c.astype(jnp.uint32), output_dtype - ).reshape(batch_dim, -1)[:, : ratio * chunk_num_u16] - return value_c - - -@jax.named_call -@functools.partial( - jax.jit, - static_argnames=( - "barrett_shift_u8", - "chunk_num_u16", - "chunk_num_u32", - "vmap_axes", - ), -) -def mul_shift_2u16x2x1( - value_a: jax.Array, - value_b: jax.Array, - mask=util.U32_MASK, - barrett_shift_u8=util.BARRETT_SHIFT_U8, - chunk_num_u16=util.U16_CHUNK_NUM, - chunk_num_u32=util.U32_CHUNK_NUM, - chunk_shift_bits=util.U32_CHUNK_SHIFT_BITS, - vmap_axes=(0, None), -): - """Multiply and shift two u16 values. - - Args: - value_a: The first u16 value. - value_b: The second u16 value. - mask: The mask to apply to the value. - barrett_shift_u8: The number of bits to shift the value. - chunk_num_u16: The number of chunks in the u16 value. - chunk_num_u32: The number of chunks in the u32 value. - chunk_shift_bits: The number of bits to shift the value. - vmap_axes: (0, None) means axis 0 is the mapped access, and The rest is not. - - Returns: - - cond > 0 -> value_a > value_b - cond = 0 -> value_a = value_b - cond < 0 -> value_a < value_b - """ - batch_dim = value_a.shape[0] - conv = jax.vmap(conv_1d_2u16xn, in_axes=vmap_axes)(value_a, value_b) - conv = jnp.pad(conv, ((0, 0), (0, 1))) - value_c = rechunkify(conv, chunk_num_u16 * 3, chunk_num_u32 * 3) - value_c = jax.lax.while_loop( - functools.partial( - check_any_chunk_with_carry, chunk_shift_bits=chunk_shift_bits - ), - functools.partial( - carry_propagation, mask=mask, chunk_shift_bits=chunk_shift_bits - ), - value_c, - ) - - value_c = jax.lax.bitcast_convert_type( - value_c.astype(jnp.uint32), jnp.uint8 - ).reshape(batch_dim, -1)[:, barrett_shift_u8:] - value_c = jax.lax.bitcast_convert_type( - jnp.pad(value_c, ((0, 0), (0, 1))).reshape(batch_dim, -1, 2), jnp.uint16 - )[:, :chunk_num_u16] - return value_c - - -@jax.named_call -@functools.partial( - jax.jit, - static_argnames=( - "mask", - "modulus_377_int_chunk", - "mu_377_int_chunk", - "chunk_num_u16", - "vmap_axes", - ), -) -def mod_mul_barrett_2u16( - value_a: jax.Array, - value_b: jax.Array, - mask=util.U16_MASK, - modulus_377_int_chunk=util.MODULUS_377_INT_CHUNK, - mu_377_int_chunk=util.MU_377_INT_CHUNK, - chunk_num_u16=util.U16_CHUNK_NUM, - vmap_axes=(0, None), -): - """Multiply two u16 values with Barrett reduction. - - Args: - value_a: The first u16 value. - value_b: The second u16 value. - mask: The mask to apply to the value. - modulus_377_int_chunk: The modulus 377. - mu_377_int_chunk: The Barrett reduction coefficient. - chunk_num_u16: The number of chunks in the u16 value (default: 24). - vmap_axes: The axes to use for vmap. - - Returns: - value_c: The result of the multiplication. - """ - modulus_377_int_array = jnp.asarray(modulus_377_int_chunk, jnp.uint16) - mu_377_int_array = jnp.asarray(mu_377_int_chunk, jnp.uint16) - - mul_2u16_const = functools.partial(mul_2u16, vmap_axes=vmap_axes) - sub_2u16_const = functools.partial( - sub_2u16, mask=mask, chunk_num_u16=chunk_num_u16 * 2 - ) - value_x = mul_2u16(value_a, value_b) - value_d = mul_shift_2u16x2x1(value_x, mu_377_int_array) - value_e = mul_2u16_const(value_d, modulus_377_int_array) - value_t = sub_2u16_const(value_x, value_e) - value_c = cond_sub_mod_u16(value_t[:, :chunk_num_u16]) - return value_c - - -@jax.named_call -@functools.partial( - jax.jit, - static_argnames=( - "mask", - "modulus_377_int_chunk", - "mu_377_int_chunk", - "chunk_num_u16", - "vmap_axes", - ), -) -def barrett_reduction_u16x2( - value_x: jax.Array, - mask=util.U16_MASK, - modulus_377_int_chunk=util.MODULUS_377_INT_CHUNK, - mu_377_int_chunk=util.MU_377_INT_CHUNK, - chunk_num_u16=util.U16_CHUNK_NUM, - vmap_axes=(0, None), -): - """Performs Barrett reduction on a u16x2 value. - - Args: - value_x: The u16x2 value to perform Barrett reduction on. - mask: The mask to apply to the value. - modulus_377_int_chunk: The modulus 377. - mu_377_int_chunk: The Barrett reduction coefficient. - chunk_num_u16: The number of chunks in the u16 value (default: 24). - vmap_axes: The axes to use for vmap. - - Returns: - value_c: The result of the Barrett reduction. - """ - modulus_377_int_array = jnp.asarray(modulus_377_int_chunk, jnp.uint16) - mu_377_int_array = jnp.asarray(mu_377_int_chunk, jnp.uint16) - - mul_2u16_const = functools.partial(mul_2u16, vmap_axes=vmap_axes) - value_d = mul_shift_2u16x2x1(value_x, mu_377_int_array) - value_e = mul_2u16_const(value_d, modulus_377_int_array) - value_t = sub_2u16(value_x, value_e, mask, chunk_num_u16 * 2) - value_c = cond_sub_mod_u16(value_t[:, :chunk_num_u16]) - return value_c - - -@jax.named_call -@functools.partial( - jax.jit, - static_argnames=( - "modulus_lazy_mat", - "mask", - "chunk_num_u8", - "chunk_shift_bits", - ), -) -def mod_mul_lazy_2u16( - value_a, - value_b, - modulus_lazy_mat=util.MODULUS_377_LAZY_MAT, - mask=util.U32_MASK, - chunk_num_u8=util.U8_CHUNK_NUM, - chunk_shift_bits=util.U32_CHUNK_SHIFT_BITS, -): - """Multiply two u16 values with lazy matrix reduction. - - Args: - value_a: The first u16 value. - value_b: The second u16 value. - modulus_lazy_mat: The lazy matrix. - mask: The mask to apply to the value. - chunk_num_u8: The number of chunks in the u8 value. - chunk_shift_bits: The number of bits to shift the value. - - Returns: - value_c: The result of the multiplication. - """ - batch_dim = value_a.shape[0] - modulus_lazy_mat = jnp.asarray(modulus_lazy_mat, dtype=jnp.uint16) - mul_2u8 = functools.partial( - mul_2u16, - mask=util.U32_MASK, - chunk_num_u16=util.U16_EXT_CHUNK_NUM, - chunk_shift_bits=util.U32_CHUNK_SHIFT_BITS, - output_dtype=jnp.uint8, - ) - value_c = mul_2u8(value_a, value_b) - standard_product_low = value_c[:, :chunk_num_u8] - standard_product_high = value_c[:, chunk_num_u8:] - - reduced = jnp.matmul( - standard_product_high.astype(jnp.uint16), - modulus_lazy_mat.astype(jnp.uint16), - preferred_element_type=jnp.uint32, - ) - value_c_reduced = jnp.add( - standard_product_low.astype(jnp.uint32), reduced.astype(jnp.uint32) - ) - value_c_reduced_u32 = rechunkify( - value_c_reduced, chunk_num_u8 // 2, chunk_num_u8 // 4 - ) - value_c_reduced_u32 = jnp.pad(value_c_reduced_u32, ((0, 0), (0, 1))) - - value_c_carried = jax.lax.while_loop( - functools.partial( - check_any_chunk_with_carry, chunk_shift_bits=chunk_shift_bits - ), - functools.partial( - carry_propagation, mask=mask, chunk_shift_bits=chunk_shift_bits - ), - value_c_reduced_u32, - ) - - value_c_u16 = jax.lax.bitcast_convert_type( - value_c_carried.astype(jnp.uint32), jnp.uint16 - ).reshape(batch_dim, -1)[:, : util.U16_EXT_CHUNK_NUM] - return value_c_u16 - - -def split_view_32_to_16(a: jnp.ndarray): - # Interpret each 32-bit element as two 16-bit numbers - # and reshape to add an extra dimension of size 2. - v = a.view(jnp.uint16).reshape(a.shape + (2,)) - # Assuming little-endian storage, the lower 16 bits are at index 0 - # and the upper 16 bits are at index 1. - lower = v[..., 0] - upper = v[..., 1] - return upper, lower - - -def split_view_32_to_16_8(a: jnp.ndarray): - # First, reshape the 32-bit integers as groups of 4 bytes. - v8 = a.view(jnp.uint8).reshape(a.shape + (4,)) - # Also, reshape as 16-bit integers (2 per 32-bit element) - v16 = a.view(jnp.uint16).reshape(a.shape + (2,)) - # For each 32-bit integer: - # v16[..., 0] gives the lower 16 bits. - # v8[..., 2] gives the third byte (i.e. the lower 8 bits of the upper 16 bits) - lower = v16[..., 0] - upper8 = v8[..., 2] - return upper8, lower - - -# Reduce via RNS modulus -@jax.named_call -@functools.partial( - jax.jit, - static_argnames="moduli_t", -) -def moduli_rns_red_internal_2u16(vals, moduli_t=util.RNS_MODULI_T): - """Reduce via RNS modulus. - - Args: - vals: The values to reduce. - moduli_t: The moduli for the target. - - Returns: - The reduced values. - """ - # See jaxite_ec/advanced_algorithm/rns_red.py for description - moduli_t = jnp.array(moduli_t, dtype=jnp.uint8) - u1, l1 = split_view_32_to_16(vals) - i1 = jnp.add( - l1.astype(jnp.uint32), - jnp.multiply(u1.astype(jnp.uint32), moduli_t), - ) - u2, l2 = split_view_32_to_16_8(i1) - i2 = jnp.add( - l2.astype(jnp.uint32), - jnp.multiply(u2.astype(jnp.uint16), moduli_t).astype( - jnp.uint32 - ), - ) - u3, l3 = split_view_32_to_16_8(i2) - out = jnp.add(l3, jnp.multiply(u3, moduli_t).astype(jnp.uint16)) - return out - - -# Reduce via prime modulus -@jax.named_call -@functools.partial( - jax.jit, - static_argnames=("rns_mat", "moduli_t", "num_moduli", "precision"), -) -def mod_red_rns_2u16( - c_rns_reduced, - rns_mat=util.RNS_MAT, - moduli_t=util.RNS_MODULI_T, - num_moduli=util.NUM_MODULI, - precision=util.RNS_PRECISION, -): - """Reduce via RNS modulus. - - Args: - c_rns_reduced: The values to reduce. - rns_mat: The RNS precompute. - moduli_t: The moduli for the target. - num_moduli: The number of moduli. - precision: The precision. - - Returns: - The reduced values. - """ - rns_stacked_mat = jnp.array(rns_mat[0], jnp.uint8) - cor_mat = jnp.array(rns_mat[1], jnp.uint16) - - c_target = jnp.matmul( - c_rns_reduced.view(jnp.uint8), - rns_stacked_mat, - preferred_element_type=jnp.uint32, - ) - - mul_res_glb_red_u32 = c_target.reshape(*c_target.shape[:-1], -1, 2) - mul_res_glb_red_u32 = mul_res_glb_red_u32[..., 0] + ( - mul_res_glb_red_u32[..., 1] << 8 - ) - rns_reduce_u32, qe_u32 = jnp.split( - mul_res_glb_red_u32, [num_moduli], axis=1 - ) - - # obtain the high 32 bits from the quotient estimation results qe_u32 - k = (qe_u32 >> precision).astype(jnp.uint16) - c_corrected = rns_reduce_u32 + jnp.matmul( - k, cor_mat, preferred_element_type=jnp.uint32 - ) - - return moduli_rns_red_internal_2u16(c_corrected, moduli_t) - - -# Multiply, without reducing -@jax.named_call -@functools.partial( - jax.jit, - static_argnames="moduli_t", -) -def mul_unreduced_rns_2u16( - value_a, - value_b, - moduli_t=util.RNS_MODULI_T, -): - ab = jnp.multiply(value_a.astype(jnp.uint32), value_b.astype(jnp.uint32)) - return moduli_rns_red_internal_2u16(ab, moduli_t) - - -# Multiply and reduce -@jax.named_call -@functools.partial( - jax.jit, - static_argnames=("rns_mat", "moduli_t"), -) -def mod_mul_rns_2u16( - value_a, - value_b, - rns_mat=util.RNS_MAT, - moduli_t=util.RNS_MODULI_T, -): - """Multiply two u16 values with RNS reduction. - - Args: - value_a: The first u16 value. - value_b: The second u16 value. - rns_mat: The RNS precompute. - moduli_t: The moduli for the target. - - Returns: - The product of the two u16 values. - """ - ab = mul_unreduced_rns_2u16(value_a, value_b, moduli_t) - return mod_red_rns_2u16(ab, rns_mat, moduli_t) - - -@jax.named_call -@functools.partial( - jax.jit, - static_argnames="moduli_t", -) -def add_rns_2u16( - value_a: jax.Array, - value_b: jax.Array, - moduli_t=util.RNS_MODULI_T, -): - """Add two u16 values with RNS reduction. - - Args: - value_a: The first u16 value. - value_b: The second u16 value. - moduli_t: The moduli for the target. - - Returns: - The sum of the two u16 values. - """ - return add_sub_rns_var(value_a, value_b, moduli_t=moduli_t) - - -@jax.named_call -@functools.partial( - jax.jit, - static_argnames="moduli_t", -) -def add_rns_3u16( - value_a: jax.Array, - value_b: jax.Array, - value_c: jax.Array, - moduli_t=util.RNS_MODULI_T, -): - """Add three u16 values with RNS reduction. - - Args: - value_a: The first u16 value. - value_b: The second u16 value. - value_c: The third u16 value. - moduli_t: The moduli for the target. - - Returns: - The sum of the three u16 values. - """ - return add_sub_rns_var(value_a, value_b, value_c, moduli_t=moduli_t) - - -@jax.named_call -@functools.partial( - jax.jit, - static_argnames="moduli_sub", -) -def negate_rns_for_var_add( - value_a: jax.Array, - moduli_sub=util.MODULI_SUB, -): - """Negate a value for use in subtraction. - - Do not use the output in any function but add_sub_rns_var -- may break - correctness. - - Args: - value_a: RNS array to negate - moduli_sub: Precomputed constants for performing negation, that depend on - the target modulus - - Returns: - An intermediate representing the negation of values_a in the target modulus - in RNS form. - - Note: original data precision is 16 bit, using uint32 to avoid overflow - """ - moduli_sub = jnp.array(moduli_sub, dtype=jnp.uint32) - - return jnp.add( - jnp.negative(value_a.astype(jnp.uint16)).astype(jnp.uint32), - moduli_sub, - ) - - -@jax.named_call -@functools.partial( - jax.jit, - static_argnames="moduli_sub", -) -def negate_rns_for_var_add_zero_check( - value_a: jax.Array, - moduli_sub=util.MODULI_SUB, -): - """Negate a value for use in subtraction. - - Do not use the output in any function but add_sub_rns_var -- may break - correctness. - - Args: - value_a: RNS array to negate - moduli_sub: Precomputed constants for performing negation, that depend on - the target modulus - - Returns: - An intermediate representing the negation of values_a in the target modulus - in RNS form. - - Note: original data precision is 16 bit, using uint32 to avoid overflow - """ - - moduli_sub = jnp.array(moduli_sub, dtype=jnp.uint32) - a = value_a.astype(jnp.uint16) - - # Compute two's complement negation: for nonzero a, jnp.negative(a) computes - # (2^16 - a). - neg = jnp.negative(a).astype(jnp.uint32) - - # Build a branchless mask: 0 if a==0, 1 otherwise. - mask = (a != 0).astype(jnp.uint32) - - # For nonzero a: (2^16 - a) + moduli_sub; for zero: m + 0 multiplied by 0 - # gives 0. - return (neg + moduli_sub) * mask - - -@jax.named_call -@functools.partial( - jax.jit, - static_argnames="moduli_t", -) -def add_sub_rns_var(*values, moduli_t=util.RNS_MODULI_T): - """Evaluate an static set of additions and subtractions. - - Subtractions are implemented by calling negate_rns_for_var_add on inputs to - this function. Inputs should be "fresh" or multiplication outputs, and the - output should be used as a multiplication input. Any other usage produces - undefined behavior and may break correctness. - - Args: - *values: A list of RNS values to accumulate - moduli_t: The moduli for the RNS form. - - Returns: - The RNS form of the evaluation of the expession. - """ - assert len(values) > 0 - acc = None - for v in values: - if acc != None: - acc = jnp.add(v.astype(jnp.uint32), acc) - else: - acc = v.astype(jnp.uint32) - assert len(values) < 256 - moduli_t = jnp.array(moduli_t, dtype=jnp.uint8) - # u1 < 254 - u1, l1 = split_view_32_to_16_8(acc) - # i1 < 2**16 - 1 + 255t < 2**17 - t for 8 bit t - i1 = jnp.add( - jnp.multiply(u1.astype(np.uint16), moduli_t).astype(jnp.uint32), - l1.astype(jnp.uint32), - ) - # u2 = 0 or 1, but if u2 = 1 then l < 2**16 - t, so 2**16 - t + t < 2**16 - u2, l2 = split_view_32_to_16_8(i1) - return jnp.add(jnp.multiply(u2, moduli_t).astype(jnp.uint16), l2) - - -@functools.partial(jax.jit, static_argnames=("c", "num_moduli")) -def rns_constant(c, num_moduli=util.NUM_MODULI): - assert c >= 0 - assert c < 2**14 # small constants only please - return jnp.repeat(jnp.array([c], dtype=jnp.uint16), num_moduli) diff --git a/jaxite_ec/finite_field_context.py b/jaxite_ec/finite_field_context.py new file mode 100644 index 0000000..0a690ce --- /dev/null +++ b/jaxite_ec/finite_field_context.py @@ -0,0 +1,1204 @@ +from abc import ABC, abstractmethod +import math +from typing import Any, List, Union +import warnings +import jax +import jax.numpy as jnp +from jaxite.jaxite_ec import utils +import numpy as np + +JaxKernelContextBase = utils.JaxKernelContextBase +JaxParameters = utils.JaxParameters +hash_args = utils.hash_args +jax_jit_lower_compile = utils.jax_jit_lower_compile +load_jax_executable = utils.load_jax_executable +pad_jax_array = utils.pad_jax_array +store_jax_executable = utils.store_jax_executable + +jax.config.update("jax_enable_x64", True) + + +class FiniteFieldContextBase(ABC): + """Abstract base class defining the interface for finite field operations. + + Subclasses must implement all abstract methods to provide concrete + finite field arithmetic operations. + """ + + prime: Any = None + parameters: Any = None + rns_moduli: Any = None + radix_bits: Any = None + + @abstractmethod + def __init__(self, parameters: dict): + """Initialize the finite field context. + + Args: + parameters: Configuration dictionary containing field parameters. + """ + self.prime = parameters.get("prime", None) + assert self.prime is not None, "prime must be provided" + self.parameters = parameters + + @abstractmethod + def to_computational_format(self, a) -> jnp.ndarray: + """Convert input to the internal computational representation. + + Args: + a: Input value in standard format. + + Returns: + Value converted to computational format (e.g., Montgomery form). + """ + pass + + def _modular_multiply(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + return a + + def _modular_add(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + return a + + def _modular_subtract(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + return a + + def _modular_negate(self, a: jnp.ndarray) -> jnp.ndarray: + return a + + def _modular_reduce(self, a: jnp.ndarray) -> jnp.ndarray: + return a + + @abstractmethod + def to_original_format(self, a) -> Union[int, List[Any]]: + """Convert from computational format back to standard representation. + + Args: + a: Value in computational format. + + Returns: + Value in standard integer representation. + """ + pass + + @abstractmethod + def context_hash(self) -> str: + pass + + @abstractmethod + def modular_multiply(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + """Perform modular multiplication: (a * b) mod p. + + Args: + a: First operand in computational format. + b: Second operand in computational format. + + Returns: + Product in computational format. + """ + pass + + @abstractmethod + def modular_add(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + """Perform modular addition: (a + b) mod p. + + Args: + a: First operand. + b: Second operand. + + Returns: + Sum reduced modulo p. + """ + pass + + @abstractmethod + def modular_subtract(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + """Perform modular subtraction: (a - b) mod p. + + Args: + a: First operand. + b: Second operand. + + Returns: + Difference reduced modulo p. + """ + pass + + @abstractmethod + def modular_negate(self, a: jnp.ndarray) -> jnp.ndarray: + """Perform modular negation: -a mod p. + + Args: + a: Operand. + + Returns: + Negation reduced modulo p. + """ + pass + + @abstractmethod + def modular_reduce(self, a: jnp.ndarray) -> jnp.ndarray: + """Reduce value modulo the field prime. + + Args: + a: Value to reduce (may be larger than prime). + + Returns: + Value reduced to [0, p). + """ + pass + + +class RNSContextBase(FiniteFieldContextBase): + + def __init__(self, parameters: dict): + super().__init__(parameters) + self.rns_moduli = parameters.get("rns_moduli", []) + assert len(self.rns_moduli) != 0, "rns_moduli must be non-empty" + self.total_modulus = math.prod(self.rns_moduli) + assert ( + self.total_modulus > self.prime + ), "RNS total modulus must be greater than prime" + self.precision_bits = parameters.get( + "precision_bits", 0 + ) # Default precision bits + assert self.precision_bits != 0, "precision bits must be non-zero" + + self.crt_factors = self._compute_crt_factors(self.rns_moduli) + # RNSContextBase._check_parameters(self) + + def _check_parameters(self): + pass # disable the warning for now + sum_moduli_bits = math.ceil(math.log2(sum(self.rns_moduli))) + if self.precision_bits < sum_moduli_bits: + warnings.warn( + "precision_bits is less than sum of moduli_bits, precision_bits:" + f" {self.precision_bits}, sum of moduli_bits: {sum_moduli_bits}." + + "This may cause precision loss.", + UserWarning, + stacklevel=2, + ) + + max_modulus_bits = max(math.ceil(math.log2(m)) for m in self.rns_moduli) + required_precision = max_modulus_bits + math.log2(len(self.rns_moduli)) + 1 + if self.precision_bits < required_precision: + warnings.warn( + "precision_bits is less than required precision, precision_bits:" + f" {self.precision_bits}, required precision: {required_precision}." + + "This may cause precision loss.", + UserWarning, + stacklevel=2, + ) + + def _compute_crt_factors(self, moduli: list[int]): + ms = [self.total_modulus // m for m in moduli] + ms_inv = [ + utils.modular_inverse(ms[i], moduli[i]) for i in range(len(moduli)) + ] + return [ + (ms[i] * ms_inv[i]) % self.total_modulus for i in range(len(moduli)) + ] + + def _elementwise_add(self, a: list[int], b: list[int]): + assert len(a) == len(b), "a and b must have the same length" + return [a[i] + b[i] for i in range(len(a))] + + def _elementwise_subtract(self, a: list[int], b: list[int]): + assert len(a) == len(b), "a and b must have the same length" + return [a[i] - b[i] for i in range(len(a))] + + def _elementwise_multiply(self, a: list[int], b: list[int]): + assert len(a) == len(b), "a and b must have the same length" + return [a[i] * b[i] for i in range(len(a))] + + def _elementwise_reduce(self, z: list[int], m: list[int]): + assert len(z) == len(m), "z and m must have the same length" + return [z[i] % m[i] for i in range(len(m))] + + def _elementwise_left_shift(self, a: list[int], b: Union[list[int], int]): + if isinstance(b, list): + assert len(a) == len(b), "a and b must have the same length" + return [a[i] << b[i] for i in range(len(a))] + else: + return [a[i] << b for i in range(len(a))] + + def _elementwise_right_shift(self, a: list[int], b: Union[list[int], int]): + if isinstance(b, list): + assert len(a) == len(b), "a and b must have the same length" + return [a[i] >> b[i] for i in range(len(a))] + else: + return [a[i] >> b for i in range(len(a))] + + def _elementwise_and(self, a: list[int], b: Union[list[int], int]): + if isinstance(b, list): + assert len(a) == len(b), "a and b must have the same length" + return [a[i] & b[i] for i in range(len(a))] + else: + return [a[i] & b for i in range(len(a))] + + def _rns_decompose(self, x: int, moduli: list[int]): + return [(x % m) for m in moduli] + + def _get_crns_vector_I_before_reducing(self, moduli_m: list[int], y): + vector_i = [] + modular_M = math.prod(moduli_m) + for i, m_i in enumerate(moduli_m): + # M_y_over_m_i = (modular_M * y) // m_i + M_y_over_m_i = (modular_M) // m_i + # logging.info(f"M_y_over_m_i: {M_y_over_m_i}") + inv_M_y_over_m_i = utils.modular_inverse(M_y_over_m_i % m_i, m_i) + # logging.info(f"inv_M_y_over_m_i: {inv_M_y_over_m_i}") + I_i = inv_M_y_over_m_i * M_y_over_m_i * y + vector_i.append(I_i) + return vector_i + + def get_moduli_num(self) -> int: + return len(self.rns_moduli) + + def _crns( + self, + x: list[int], + matrix_E: list[list[int]], + vector_f_T: list[int], + vector_g: list[int], + moduli: list[int], + u: int, + ): + """CRNS Computation based on Algorithm steps 9-12 + + Args: + x: Input vector x_M in RNS representation + matrix_E: Precomputed matrix E from crns_precompute + vector_f_T: Precomputed vector f^T from crns_precompute + vector_g: Precomputed vector g from crns_precompute + moduli: RNS moduli (should be self.rns_moduli_n) + + Returns: + x_N: Result vector in RNS representation modulo N + """ + # Step 9: E, f^T, g = CRNSPRECOMPUTATION(M, y, N, z, u) - already done + # (matrix_E, vector_f_T, vector_g are passed as parameters) + + # Step 10: v = x_M · f^T (Dot product; can be parallelized with x_M · E) + v = sum(x[i] * vector_f_T[i] for i in range(len(x))) + + # Step 11: k = ⌊v/2^u⌋ (Bitshifting computes fixed-point quotient) + k = v >> u # Right shift by u bits is equivalent to floor division by 2^u + + # Step 12: return x_N = x_M · E + k·g (vector-matrix multiply, scalar-vector multiply, compute ICRT) + x_N = [] + for j in range(len(moduli)): + # Compute x_M · E for column j + dot_product = sum(x[i] * matrix_E[i][j] for i in range(len(x))) + dot_product = dot_product % moduli[j] + x_N.append(dot_product) + + for j in range(len(moduli)): + # Add k * g_j and take modulo n_j + # result = (x_N[j] + k * vector_g[j]) % moduli[j] + result = x_N[j] + k * vector_g[j] + x_N[j] = result + return x_N + + +class DRNSlazyContextBase(RNSContextBase): + + def __init__(self, parameters: dict): + super().__init__(parameters) + self.radix_bits = parameters.get("radix_bits", 0) + assert self.radix_bits != 0, " radix bits must be non-zero" + self.moduli_inv_on_radix = [ + utils.modular_inverse(m, 1 << self.radix_bits) for m in self.rns_moduli + ] + self.crns_y, self.crns_z = self._precompute_crns_parameters() + self.matrix_E, self.vector_f_T, self.vector_g = self._crns_precompute( + self.rns_moduli, + self.prime, + self.crns_y, + self.crns_z, + self.precision_bits, + ) + + def _check_parameters(self): + pass # disable the warning for now + if (self.prime << self.w) > self.total_modulus: + warnings.warn( + "Total modulus is not enough to hold prime in DRNSlazy, total" + f" modulus: {self.total_modulus}, prime: {self.prime}." + + "This may cause finite field overflow.", + UserWarning, + stacklevel=2, + ) + + def _precompute_crns_parameters(self): + radix_inv = utils.modular_inverse(2**self.radix_bits, self.total_modulus) + return radix_inv, 1 << (2 * self.radix_bits) + + def _crns_precompute( + self, + moduli: list[int], + prime: int, + y: int, + z: int, + u: int, + ): + modular = math.prod(moduli) + vector_i = [] + for i, m_i in enumerate(moduli): + M_over_m_i = modular // m_i + inv_M_over_m_i = utils.modular_inverse(M_over_m_i % m_i, m_i) + I_i = (inv_M_over_m_i * M_over_m_i * y) % modular + I_i = I_i + vector_i.append(I_i) + + matrix_E = [] + for i, I_i in enumerate(vector_i): + E_row = [] + for j, n_j in enumerate(moduli): + E_ij = (z * (I_i % prime)) % n_j + E_row.append(E_ij) + matrix_E.append(E_row) + + vector_f_T = [] + for I_i in vector_i: + f_i_T = math.ceil((I_i * (1 << u)) / modular) + vector_f_T.append(f_i_T) + + vector_g = [] + for n_j in moduli: + g_j = (-z * (modular % prime)) % n_j + vector_g.append(g_j) + + return matrix_E, vector_f_T, vector_g + + def _elementwise_montgomery_reduce( + self, z: list[int], m: list[int], m_inv: list[int] + ): + assert len(z) == len(m), "z and m must have the same length" + assert len(z) == len(m_inv), "z and m_inv must have the same length" + mask = (1 << self.radix_bits) - 1 + z_low = self._elementwise_and(z, mask) + z_high = self._elementwise_right_shift(z, self.radix_bits) + q = self._elementwise_and(self._elementwise_multiply(z_low, m_inv), mask) + h = self._elementwise_right_shift( + self._elementwise_multiply(q, m), self.radix_bits + ) + t = self._elementwise_subtract(z_high, h) + t = self._elementwise_add(t, m) + return t + + +class DRNSlazyContext(DRNSlazyContextBase, JaxKernelContextBase): + + def __init__(self, parameters: dict): + super().__init__(parameters) + JaxKernelContextBase.__init__(self) + self.jax_parameters = JaxParameters() + self._init_jax_parameters() + + def to_computational_format(self, a: Union[int, list]) -> jnp.ndarray: + moduli = self.rns_moduli + radix_bits = self.radix_bits + + def individual_convert(a: int) -> list: + return [(((a % m) << radix_bits) % m) for m in moduli] + + def recursive_convert(a): + if isinstance(a, int): + return individual_convert(a) + return [recursive_convert(a_i) for a_i in a] + + converted_list = recursive_convert(a) + converted_a = jnp.array( + np.array(converted_list, dtype=np.uint32), dtype=jnp.uint32 + ) + if self.use_sharding: + named_sharding, padded_shape = self.create_named_sharding( + shape=converted_a.shape, axes=[0] + ) + converted_a = pad_jax_array(converted_a, padded_shape) + return converted_a.to_device(named_sharding) + else: + return converted_a.to_device(jax.devices()[0]) + + def to_original_format(self, a: jnp.ndarray) -> Union[int, list]: # type: ignore + def individual_convert(a: list[int]) -> int: + a = self._elementwise_montgomery_reduce( + a, self.rns_moduli, self.moduli_inv_on_radix + ) + r = 0 + for i in range(len(a)): + r = (r + a[i] * self.crt_factors[i]) % self.total_modulus + return r % self.prime + + def recursive_convert(a): + if a.ndim == 1: + return individual_convert(a.tolist()) + return [recursive_convert(a_i) for a_i in a] + + return recursive_convert(a) + + def _get_shape_dtype_structs( + self, parameters: dict + ) -> list[jax.ShapeDtypeStruct]: + batch_shape = parameters["batch_shape"] + oprand_shape = batch_shape + (len(self.rns_moduli),) + if self.use_sharding: + named_sharding, padded_shape = self.create_named_sharding( + shape=oprand_shape, axes=[0] + ) + return [ + jax.ShapeDtypeStruct( + padded_shape, jnp.uint32, sharding=named_sharding + ) + ] + return [jax.ShapeDtypeStruct(oprand_shape, jnp.uint32)] + + def context_hash(self) -> str: + return hash_args( + self.__class__.__name__, + self.prime, + self.rns_moduli, + self.precision_bits, + self.radix_bits, + self.use_sharding, + ) + + def serialize(self, parameters: dict): + shape_dtype_structs = self._get_shape_dtype_structs(parameters) + kernel_hash = hash_args(self.context_hash(), parameters) + class_name = self.__class__.__name__ + + store_jax_executable( + self._modular_multiply, + shape_dtype_structs[0], + shape_dtype_structs[0], + name=f"{class_name}_modular_multiply_{kernel_hash}", + ) + store_jax_executable( + self._modular_add, + shape_dtype_structs[0], + shape_dtype_structs[0], + name=f"{class_name}_modular_add_{kernel_hash}", + ) + store_jax_executable( + self._modular_subtract, + shape_dtype_structs[0], + shape_dtype_structs[0], + name=f"{class_name}_modular_subtract_{kernel_hash}", + ) + store_jax_executable( + self._modular_reduce, + shape_dtype_structs[0], + name=f"{class_name}_modular_reduce_{kernel_hash}", + ) + store_jax_executable( + self._modular_negate, + shape_dtype_structs[0], + name=f"{class_name}_modular_negate_{kernel_hash}", + ) + + def compile(self, parameters: dict): + shape_dtype_structs = self._get_shape_dtype_structs(parameters) + kernel_hash = hash_args(self.context_hash(), parameters) + class_name = self.__class__.__name__ + + modular_multiply_kernel = load_jax_executable( + f"{class_name}_modular_multiply_{kernel_hash}" + ) + modular_add_kernel = load_jax_executable( + f"{class_name}_modular_add_{kernel_hash}" + ) + modular_subtract_kernel = load_jax_executable( + f"{class_name}_modular_subtract_{kernel_hash}" + ) + modular_reduce_kernel = load_jax_executable( + f"{class_name}_modular_reduce_{kernel_hash}" + ) + modular_negate_kernel = load_jax_executable( + f"{class_name}_modular_negate_{kernel_hash}" + ) + + if None in [ + modular_multiply_kernel, + modular_add_kernel, + modular_subtract_kernel, + modular_reduce_kernel, + modular_negate_kernel, + ]: + # if not self.use_sharding: + warnings.warn( + f"Not found stored serialized compiled kernels, compiling...", + UserWarning, + stacklevel=2, + ) + + kernel_hash = hash_args( + shape_dtype_structs[0].shape, shape_dtype_structs[0].dtype.__str__() + ) + # print(f"kernel_hash: {kernel_hash}") + + self.compiled_kernels[kernel_hash] = { + "modular_multiply": ( + modular_multiply_kernel + if modular_multiply_kernel is not None + else jax_jit_lower_compile( + self._modular_multiply, + shape_dtype_structs[0], + shape_dtype_structs[0], + ) + ), + "modular_add": ( + modular_add_kernel + if modular_add_kernel is not None + else jax_jit_lower_compile( + self._modular_add, + shape_dtype_structs[0], + shape_dtype_structs[0], + ) + ), + "modular_subtract": ( + modular_subtract_kernel + if modular_subtract_kernel is not None + else jax_jit_lower_compile( + self._modular_subtract, + shape_dtype_structs[0], + shape_dtype_structs[0], + ) + ), + "modular_reduce": ( + modular_reduce_kernel + if modular_reduce_kernel is not None + else jax_jit_lower_compile( + self._modular_reduce, shape_dtype_structs[0] + ) + ), + "modular_negate": ( + modular_negate_kernel + if modular_negate_kernel is not None + else jax_jit_lower_compile( + self._modular_negate, shape_dtype_structs[0] + ) + ), + } + self.use_compiled_kernels = True + + def _jax_crns_precompute( + self, moduli: list[int], prime: int, y: int, z: int, u: int + ): + rns_moduli_bytes = 4 + vector_I = self._get_crns_vector_I_before_reducing(moduli, y) + modular = math.prod(moduli) + + vector_I_byteshifted = [] + for value_i in vector_I: + value_i_byteshifted = [ + (value_i << (8 * byte_idx)) % modular + for byte_idx in range(rns_moduli_bytes) + ] + vector_I_byteshifted = vector_I_byteshifted + value_i_byteshifted + + matrix_E = [] + for value_i_byteshifted in vector_I_byteshifted: + value_i_byteshifted = z * (value_i_byteshifted % prime) + matrix_E.append(self._rns_decompose(value_i_byteshifted, moduli)) + + matrix_E_np = np.array(matrix_E, dtype=np.uint32).reshape(-1, len(moduli)) + matrix_E_np_u8 = matrix_E_np.view(np.uint8) + + vector_f = [] + for value_i_byteshifted in vector_I_byteshifted: + value_i_byteshifted_f = math.ceil( + (value_i_byteshifted * (1 << u)) / modular + ) + vector_f.append(value_i_byteshifted_f) + vector_f_T_np = np.array(vector_f, dtype=np.uint32).reshape(-1, 1) + vector_f_T_np_u8 = vector_f_T_np.view(np.uint8) + + vector_g = self._rns_decompose(-z * (modular % prime), moduli) + vector_g_np = np.array(vector_g, dtype=np.uint32).reshape(1, -1) + + matrix_E_with_f_T_np = np.hstack((matrix_E_np_u8, vector_f_T_np_u8)) + matrix_E_with_f_T_np = matrix_E_with_f_T_np.reshape( + len(moduli), 4, len(moduli) + 1, 4 + ) # NOTE: reshape is special part for the optimization + return matrix_E_with_f_T_np.tolist(), vector_g_np.tolist() + + def _init_jax_parameters(self): + half_word_mask = 0xFFFF + half_word_bits = 16 + word_bits = 32 + word_mask = (1 << word_bits) - 1 + rns_moduli_low = [m & half_word_mask for m in self.rns_moduli] + rns_moduli_high = [m >> half_word_bits for m in self.rns_moduli] + rns_moduli_inv_word = [ + utils.modular_inverse(m, 2**word_bits) for m in self.rns_moduli + ] + crns_precision = self.precision_bits + crns_matrix_E_with_f_T, crns_vector_g = self._jax_crns_precompute( + self.rns_moduli, + self.prime, + self.crns_y, + self.crns_z, + self.precision_bits, + ) + num_moduli = len(self.rns_moduli) + moduli_sub = self.to_computational_format( + 256 * num_moduli * 4 * 2 * self.prime + ) + + self.jax_parameters.set_parameter( + word_mask=word_mask, + half_word_mask=half_word_mask, + half_word_bits=half_word_bits, + word_bits=word_bits, + rns_moduli=jnp.array(self.rns_moduli, dtype=jnp.uint64), + rns_moduli_low=jnp.array(rns_moduli_low, dtype=jnp.uint16), + rns_moduli_high=jnp.array(rns_moduli_high, dtype=jnp.uint16), + rns_moduli_inv_word=jnp.array(rns_moduli_inv_word, dtype=jnp.uint32), + crns_precision=jnp.array(crns_precision, dtype=jnp.uint32), + crns_stacked_mat_E_with_f_T=jnp.array( + crns_matrix_E_with_f_T, dtype=jnp.uint8 + ), + crns_vector_g=jnp.array(crns_vector_g, dtype=jnp.uint32), + rns_moduli_sub=jnp.array(moduli_sub, dtype=jnp.uint32), + rns_moduli_negate=jnp.array(self.rns_moduli, dtype=jnp.uint32), + ) + + def _jax_montgomery_reduce(self, z: jax.Array) -> jax.Array: + + # Computation + z_low = z.astype(jnp.uint32) + z_high = (z >> self.jax_parameters.word_bits).astype(jnp.uint32) + t = ( + z_low * self.jax_parameters.rns_moduli_inv_word + ) & self.jax_parameters.word_mask + t_low = t & self.jax_parameters.half_word_mask + t_high = ( + t >> self.jax_parameters.half_word_bits + ) & self.jax_parameters.half_word_mask + + prod_high = ( + t_high * self.jax_parameters.rns_moduli_high + ) # This contributes directly to upper 32 bits + prod_mid_high = ( + t_high * self.jax_parameters.rns_moduli_low + ) # Upper 16 bits go to upper 32 bits + prod_mid_low = ( + t_low * self.jax_parameters.rns_moduli_high + ) # Upper 16 bits go to upper 32 bits + prod_low = ( + t_low * self.jax_parameters.rns_moduli_low + ) # Upper 16 bits contribute to middle part + mid_low = ( + (prod_mid_high & self.jax_parameters.half_word_mask) + + (prod_mid_low & self.jax_parameters.half_word_mask) + + (prod_low >> self.jax_parameters.half_word_bits) + ) + mid_high = ( + (prod_mid_high >> self.jax_parameters.half_word_bits) + + (prod_mid_low >> self.jax_parameters.half_word_bits) + + (mid_low >> self.jax_parameters.half_word_bits) + ) + + # Final upper 32 bits + t_final = prod_high + mid_high + b = z_high + self.jax_parameters.rns_moduli - t_final + return b.astype(jnp.uint32) + + def _jax_crns(self, z: jax.Array) -> jax.Array: + shift_factors = jnp.array([0, 8, 16, 24], dtype=jnp.uint64) + precision_mask = jnp.array( + (1 << self.jax_parameters.crns_precision) - 1, dtype=jnp.uint32 + ) + num_moduli_n = self.jax_parameters.crns_vector_g.shape[1] + + einsum_subscripts = "...kq,kqnp->...np" + # computation + x_u8 = jax.lax.bitcast_convert_type(z, jnp.uint8) + + # x_reconstruction_with_v = jnp.matmul(x_u8, stacked_mat_E_with_f_T, preferred_element_type=jnp.uint32) + x_reconstruction_with_v = jnp.einsum( + einsum_subscripts, + x_u8, + self.jax_parameters.crns_stacked_mat_E_with_f_T, + preferred_element_type=jnp.uint32, + ) + x_reconstruction_with_v_u64 = jnp.sum( + x_reconstruction_with_v.astype(jnp.uint64) << shift_factors, axis=(-1,) + ) + + x_n, vector_v = jnp.split( + x_reconstruction_with_v_u64, [num_moduli_n], axis=-1 + ) + vector_v = (vector_v >> self.jax_parameters.crns_precision) & precision_mask + + x_n = x_n + jnp.multiply(vector_v, self.jax_parameters.crns_vector_g) + return x_n + + def _modular_multiply(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + z = jnp.multiply(a.astype(jnp.uint64), b.astype(jnp.uint64)) + z_reduced = self._jax_montgomery_reduce(z) + z_rns_reduced = self._jax_crns( + z_reduced + ) # could be skipped for small prime + z_reduced = self._jax_montgomery_reduce( + z_rns_reduced + ) # could be skipped for small prime (paired with the above) + return z_reduced + + def _modular_add(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + return a + b + + def _modular_subtract(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + return self._modular_negate(b) + a + + def _modular_reduce(self, a: jnp.ndarray) -> jnp.ndarray: + z_rns_reduced = self._jax_crns(a) + z_reduced = self._jax_montgomery_reduce(z_rns_reduced) + return z_reduced + + def _modular_negate(self, a: jax.Array) -> jax.Array: + return jnp.add( + jnp.subtract(self.jax_parameters.rns_moduli_negate, a), + self.jax_parameters.rns_moduli_sub, + ) + + def modular_multiply(self, a: jax.Array, b: jax.Array) -> jax.Array: + kernel_hash = hash_args(a.shape, a.dtype.__str__()) + if self.use_compiled_kernels: + print(f"using compiled kernel for modular_multiply: {kernel_hash}") + return self.compiled_kernels[kernel_hash]["modular_multiply"](a, b) + else: + return self._modular_multiply(a, b) + + def modular_add(self, a: jax.Array, b: jax.Array) -> jax.Array: + kernel_hash = hash_args(a.shape, a.dtype.__str__()) + if self.use_compiled_kernels: + return self.compiled_kernels[kernel_hash]["modular_add"](a, b) + else: + return self._modular_add(a, b) + + def modular_subtract(self, a: jax.Array, b: jax.Array) -> jax.Array: + kernel_hash = hash_args(a.shape, a.dtype.__str__()) + if self.use_compiled_kernels: + return self.compiled_kernels[kernel_hash]["modular_subtract"](a, b) + else: + return self._modular_subtract(a, b) + + def modular_reduce(self, a: jax.Array) -> jax.Array: + kernel_hash = hash_args(a.shape, a.dtype.__str__()) + if self.use_compiled_kernels: + return self.compiled_kernels[kernel_hash]["modular_reduce"](a) + else: + return self._modular_reduce(a) + + def modular_negate(self, a: jax.Array) -> jax.Array: + kernel_hash = hash_args(a.shape, a.dtype.__str__()) + if self.use_compiled_kernels: + return self.compiled_kernels[kernel_hash]["modular_negate"](a) + else: + return self._modular_negate(a) + + +# ============================================================================= +# Lazy matrix reduction context +# ============================================================================= + + +def _lazy_check_carry(value_c: jax.Array) -> jax.Array: + """Check whether any 32-bit chunk holds a value exceeding 32 bits.""" + return jnp.any(jnp.not_equal(jnp.right_shift(value_c, jnp.uint64(32)), 0)) + + +def _lazy_carry_propagate(value_c: jax.Array) -> jax.Array: + """Propagate carries between adjacent 32-bit chunks.""" + n = value_c.shape[-1] + roll_mat = jnp.array( + [0, 1] + ([0] * n + [1]) * (n - 2) + [1] + [0] * (n - 1), + dtype=jnp.uint16, + ).reshape(n, n) + low = jnp.bitwise_and(value_c, jnp.uint64(0xFFFFFFFF)) + high = jnp.right_shift(value_c, jnp.uint64(32)).astype(jnp.uint16) + high = jnp.matmul(high, roll_mat, preferred_element_type=jnp.uint32).astype( + jnp.uint16 + ) + return jnp.add(low, high.astype(jnp.uint64)) + + +class LazyContextBase(FiniteFieldContextBase): + """Base class for lazy matrix modular reduction. + + Represents field elements as little-endian arrays of uint32 chunks + (base 2^32). The number of chunks is chunk_num_u8 // 4 + 1 to + provide one extra chunk of headroom for intermediate overflow. + """ + + def __init__(self, parameters: dict): + super().__init__(parameters) + raw_chunk_num_u8 = parameters.get( + "chunk_num_u8", math.ceil(int(self.prime).bit_length() / 8) + ) + # The byte pipeline in ``_mul_to_u8`` emits a ``4 * 2 * chunk_num_u32``-byte + # buffer, and ``_modular_multiply`` slices ``high = vc[:, n8:2*n8+4]`` + # against a ``(n8+4, n8)`` lazy matrix. The slice only fits when + # ``chunk_num_u8`` is a multiple of 4; round up so the invariant holds + # for primes with arbitrary bit-length. + self.chunk_num_u8 = ((raw_chunk_num_u8 + 3) // 4) * 4 + self.chunk_num_u32 = self.chunk_num_u8 // 4 + 1 + self.rns_moduli = self.chunk_num_u32 + self.word_mask = (1 << 32) - 1 + self.modulus_lazy_mat = self._construct_lazy_matrix() + self.prime_chunk = tuple(self._int_to_array(self.prime, self.chunk_num_u32)) + + @staticmethod + def _int_to_array(x: int, size: int, base: int = 32) -> list: + mask = (1 << base) - 1 + elems = [] + while x > 0: + elems.append(int(x & mask)) + x >>= base + return elems[:size] + [0] * max(0, size - len(elems)) + + @staticmethod + def _array_to_int(arr, base: int = 32) -> int: + result = 0 + for i, v in enumerate(arr): + result |= int(v) << (i * base) + return result + + def _construct_lazy_matrix(self): + """Build the lazy reduction matrix. + + Row i = (256^(chunk_num_u8 + i)) % prime, expressed as chunk_num_u8 + little-endian uint8 chunks. Shape: (chunk_num_u8 + 4, chunk_num_u8). + """ + n = self.chunk_num_u8 + return tuple( + tuple(self._int_to_array(pow(256, n + i, self.prime), n, base=8)) + for i in range(n + 4) + ) + + +class CROSSLazyContext(LazyContextBase, JaxKernelContextBase): + """Lazy matrix modular multiplication context for JAX.""" + + def __init__(self, parameters: dict): + super().__init__(parameters) + JaxKernelContextBase.__init__(self) + self._lazy_mat_jnp = jnp.array(self.modulus_lazy_mat, dtype=jnp.uint16) + self._prime_jnp = jnp.array(self.prime_chunk, dtype=jnp.uint32) + + # ---------- format conversion ---------- + + def to_computational_format(self, a) -> jnp.ndarray: + def _convert(x: int) -> jnp.ndarray: + return jnp.array( + self._int_to_array(x % self.prime, self.chunk_num_u32), + dtype=jnp.uint32, + ) + + def _recurse(x): + if isinstance(x, int): + return _convert(x) + return jnp.array([_recurse(xi) for xi in x], dtype=jnp.uint32) + + converted_a = _recurse(a) + if self.use_sharding: + named_sharding, padded_shape = self.create_named_sharding( + shape=converted_a.shape, axes=[0] + ) + converted_a = pad_jax_array(converted_a, padded_shape) + return converted_a.to_device(named_sharding) + else: + return converted_a.to_device(jax.devices()[0]) + + def to_original_format(self, a: jnp.ndarray): + def _convert(arr) -> int: + return self._array_to_int(arr.tolist()) % self.prime + + def _recurse(x): + if x.ndim == 1: + return _convert(x) + return [_recurse(xi) for xi in x] + + return _recurse(a) + + # ---------- internal JAX helpers ---------- + + @staticmethod + def _conv_1d(a: jax.Array, b: jax.Array) -> jax.Array: + a = jax.lax.bitcast_convert_type(a, jnp.uint8).reshape(-1) + b = jax.lax.bitcast_convert_type(b, jnp.uint8).reshape(-1) + if jax.default_backend() == "gpu": + # cuDNN's integer conv only supports s8, not u8, so route through + # float32 on GPU. u8*u8 sums fit exactly in float32's 24-bit mantissa. + res = jnp.convolve(a.astype(jnp.float32), b.astype(jnp.float32)) + return res.astype(jnp.uint32) + return jnp.convolve(a, b, preferred_element_type=jnp.uint32) + + def _rechunkify(self, x: jax.Array, n_u16: int, n_u32: int) -> jax.Array: + """Merge adjacent uint8 coefficients into uint16, then uint32 chunks.""" + shift_u16 = jnp.array([[0, 8]] * n_u16, dtype=jnp.uint8) + shift_u32 = jnp.array([[0, 16]] * n_u32, dtype=jnp.uint8) + shape = x.shape[:-1] + (-1, 2) if x.ndim == 2 else (-1, 2) + x = jnp.sum(jnp.left_shift(x.reshape(shape), shift_u16), axis=-1) + x = jnp.sum( + jnp.left_shift(x.reshape(shape).astype(jnp.uint64), shift_u32), axis=-1 + ) + return x + + def _mul_to_u8(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + """Multiply two chunk arrays; return product as a flat uint8 array.""" + n = self.chunk_num_u32 + batch = a.shape[0] + res = jax.vmap(self._conv_1d)(a, b) + res = jnp.pad(res, ((0, 0), (0, 1))) + res = self._rechunkify(res, 4 * n, 2 * n) + res = jax.lax.while_loop(_lazy_check_carry, _lazy_carry_propagate, res) + return jax.lax.bitcast_convert_type( + res.astype(jnp.uint32), jnp.uint8 + ).reshape(batch, -1) + + def _sub_raw(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + """Compute a - b assuming a >= b (batched, shape [batch, chunk_num_u32]).""" + n = self.chunk_num_u32 + borrow_low = jnp.array( + [self.word_mask + 1] * (n - 1) + [0], dtype=jnp.uint64 + ) + borrow_high = jnp.array([0] + [1] * (n - 2) + [0], dtype=jnp.uint64) + c = jnp.subtract( + jnp.add(a.astype(jnp.uint64), borrow_low), b.astype(jnp.uint64) + ) + c = jnp.subtract(c, borrow_high) + c = jax.lax.while_loop(_lazy_check_carry, _lazy_carry_propagate, c) + c = c.at[:, n - 1].set(c[:, n - 1] - 1) + return c.astype(jnp.uint32) + + def _compare(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + """Return per-batch sign: >=0 means a>=b, <0 means a jnp.ndarray: + n8 = self.chunk_num_u8 + batch = a.shape[0] + vc = self._mul_to_u8(a, b) + low = vc[:, :n8] + high = vc[:, n8 : n8 * 2 + 4] + reduced = jnp.matmul( + high.astype(jnp.uint16), + self._lazy_mat_jnp, + preferred_element_type=jnp.uint32, + ) + vc2 = jnp.add(low.astype(jnp.uint32), reduced) + vc2 = self._rechunkify(vc2, n8 // 2, n8 // 4) + vc2 = jnp.pad(vc2, ((0, 0), (0, 1))) + vc2 = jax.lax.while_loop(_lazy_check_carry, _lazy_carry_propagate, vc2) + return vc2.astype(jnp.uint32)[:, : self.chunk_num_u32] + + def _modular_add(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + c = jax.lax.while_loop( + _lazy_check_carry, + _lazy_carry_propagate, + jnp.add(a.astype(jnp.uint64), b.astype(jnp.uint64)), + ) + return c.astype(jnp.uint32) + + def _modular_negate(self, a: jnp.ndarray) -> jnp.ndarray: + p = jnp.broadcast_to(self._prime_jnp, a.shape) + neg = self._sub_raw(p, a) + is_zero = jnp.all(a == 0, axis=-1, keepdims=True) + return jnp.where(is_zero, a, neg) + + def _modular_subtract(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + return self._modular_add(a, self._modular_negate(b)) + + def _modular_reduce(self, a: jnp.ndarray) -> jnp.ndarray: + p = jnp.broadcast_to(self._prime_jnp, a.shape) + cond = jnp.greater_equal(self._compare(a, p), 0).reshape(-1, 1) + return jnp.where(cond, self._sub_raw(a, p), a) + + # ---------- public interface ---------- + + def modular_multiply(self, a: jax.Array, b: jax.Array) -> jax.Array: + kernel_hash = hash_args(a.shape, a.dtype.__str__()) + if self.use_compiled_kernels: + print(f"using compiled kernel for modular_multiply: {kernel_hash}") + return self.compiled_kernels[kernel_hash]["modular_multiply"](a, b) + else: + return self._modular_multiply(a, b) + + def modular_add(self, a: jax.Array, b: jax.Array) -> jax.Array: + kernel_hash = hash_args(a.shape, a.dtype.__str__()) + if self.use_compiled_kernels: + return self.compiled_kernels[kernel_hash]["modular_add"](a, b) + else: + return self._modular_add(a, b) + + def modular_subtract(self, a: jax.Array, b: jax.Array) -> jax.Array: + kernel_hash = hash_args(a.shape, a.dtype.__str__()) + if self.use_compiled_kernels: + return self.compiled_kernels[kernel_hash]["modular_subtract"](a, b) + else: + return self._modular_subtract(a, b) + + def modular_reduce(self, a: jax.Array) -> jax.Array: + kernel_hash = hash_args(a.shape, a.dtype.__str__()) + if self.use_compiled_kernels: + return self.compiled_kernels[kernel_hash]["modular_reduce"](a) + else: + return self._modular_reduce(a) + + def modular_negate(self, a: jax.Array) -> jax.Array: + kernel_hash = hash_args(a.shape, a.dtype.__str__()) + if self.use_compiled_kernels: + return self.compiled_kernels[kernel_hash]["modular_negate"](a) + else: + return self._modular_negate(a) + + def context_hash(self) -> str: + return hash_args( + self.__class__.__name__, + self.prime, + self.chunk_num_u8, + self.use_sharding, + ) + + def _get_shape_dtype_structs( + self, parameters: dict + ) -> list[jax.ShapeDtypeStruct]: + batch_shape = parameters["batch_shape"] + operand_shape = batch_shape + (self.chunk_num_u32,) + if self.use_sharding: + named_sharding, padded_shape = self.create_named_sharding( + shape=operand_shape, axes=[0] + ) + return [ + jax.ShapeDtypeStruct( + padded_shape, jnp.uint32, sharding=named_sharding + ) + ] + return [jax.ShapeDtypeStruct(operand_shape, jnp.uint32)] + + def serialize(self, parameters): # pytype: disable=signature-mismatch + shape_dtype_structs = self._get_shape_dtype_structs(parameters) + kernel_hash = hash_args(self.context_hash(), parameters) + class_name = self.__class__.__name__ + + store_jax_executable( + self._modular_multiply, + shape_dtype_structs[0], + shape_dtype_structs[0], + name=f"{class_name}_modular_multiply_{kernel_hash}", + ) + store_jax_executable( + self._modular_add, + shape_dtype_structs[0], + shape_dtype_structs[0], + name=f"{class_name}_modular_add_{kernel_hash}", + ) + store_jax_executable( + self._modular_subtract, + shape_dtype_structs[0], + shape_dtype_structs[0], + name=f"{class_name}_modular_subtract_{kernel_hash}", + ) + store_jax_executable( + self._modular_reduce, + shape_dtype_structs[0], + name=f"{class_name}_modular_reduce_{kernel_hash}", + ) + store_jax_executable( + self._modular_negate, + shape_dtype_structs[0], + name=f"{class_name}_modular_negate_{kernel_hash}", + ) + + def compile(self, parameters): # pytype: disable=signature-mismatch + shape_dtype_structs = self._get_shape_dtype_structs(parameters) + kernel_hash = hash_args(self.context_hash(), parameters) + class_name = self.__class__.__name__ + + modular_multiply_kernel = load_jax_executable( + f"{class_name}_modular_multiply_{kernel_hash}" + ) + modular_add_kernel = load_jax_executable( + f"{class_name}_modular_add_{kernel_hash}" + ) + modular_subtract_kernel = load_jax_executable( + f"{class_name}_modular_subtract_{kernel_hash}" + ) + modular_reduce_kernel = load_jax_executable( + f"{class_name}_modular_reduce_{kernel_hash}" + ) + modular_negate_kernel = load_jax_executable( + f"{class_name}_modular_negate_{kernel_hash}" + ) + + if None in [ + modular_multiply_kernel, + modular_add_kernel, + modular_subtract_kernel, + modular_reduce_kernel, + modular_negate_kernel, + ]: + # if not self.use_sharding: + warnings.warn( + f"Not found stored serialized compiled kernels, compiling...", + UserWarning, + stacklevel=2, + ) + + kernel_hash = hash_args( + shape_dtype_structs[0].shape, shape_dtype_structs[0].dtype.__str__() + ) + + self.compiled_kernels[kernel_hash] = { + "modular_multiply": ( + modular_multiply_kernel + if modular_multiply_kernel is not None + else jax_jit_lower_compile( + self._modular_multiply, + shape_dtype_structs[0], + shape_dtype_structs[0], + ) + ), + "modular_add": ( + modular_add_kernel + if modular_add_kernel is not None + else jax_jit_lower_compile( + self._modular_add, + shape_dtype_structs[0], + shape_dtype_structs[0], + ) + ), + "modular_subtract": ( + modular_subtract_kernel + if modular_subtract_kernel is not None + else jax_jit_lower_compile( + self._modular_subtract, + shape_dtype_structs[0], + shape_dtype_structs[0], + ) + ), + "modular_reduce": ( + modular_reduce_kernel + if modular_reduce_kernel is not None + else jax_jit_lower_compile( + self._modular_reduce, shape_dtype_structs[0] + ) + ), + "modular_negate": ( + modular_negate_kernel + if modular_negate_kernel is not None + else jax_jit_lower_compile( + self._modular_negate, shape_dtype_structs[0] + ) + ), + } + self.use_compiled_kernels = True diff --git a/jaxite_ec/finite_field_perf_test.py b/jaxite_ec/finite_field_perf_test.py new file mode 100644 index 0000000..0069470 --- /dev/null +++ b/jaxite_ec/finite_field_perf_test.py @@ -0,0 +1,108 @@ +import os + +import jax +import jax.numpy as jnp +from jaxite.jaxite_ec import finite_field_context as ff_context +from jaxite.jaxite_ec import profiler +from jaxite.jaxite_ec import utils + +from absl.testing import absltest +from absl.testing import parameterized + +jax.config.update("jax_enable_x64", True) + +PRIME = 0x01AE3A4617C510EAC63B05C06CA1493B1A22D9F300F5138F1EF3622FBA094800170B5D44300000008508C00000000001 + +BATCH_SIZE_LIST = [4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048, 4096] + +NUM_MODULI_LIST = [32] # 21 for 256-bit, 32 for 384-bit, 56 for 753-bit + +TEST_PARAMS = [(f"moduli_{n}", n, BATCH_SIZE_LIST) for n in NUM_MODULI_LIST] + + +def _modular_multiply_kernel(a, b, parameters): + return parameters["ctx"]._modular_multiply(a, b) + + +class FiniteFieldModularMultiplyPerformanceTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + outputs_dir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR") + if outputs_dir: + self.output_trace_root = os.path.join(outputs_dir, "log") + else: + self.output_trace_root = os.path.join(os.path.dirname(__file__), "log") + self.profiler_config = { + "iterations": 1, + "save_to_file": True, + } + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + root_dir = os.path.dirname(os.path.abspath(__file__)) + print(f"Collecting logs from: {root_dir}") + profiler.collect_logs(root_dir) + + def _create_kernel_wrapper(self, kernel_name, ctx, batch, num_moduli): + input_shape = (batch, num_moduli) + return profiler.KernelWrapper( + kernel_name=kernel_name, + function_to_wrap=_modular_multiply_kernel, + input_structs=[ + (input_shape, jnp.uint32), + (input_shape, jnp.uint32), + ], + parameters={"ctx": ctx}, + ) + + def _profile_modular_multiply(self, num_moduli, batch_size_list): + rns_moduli = utils.find_moduli_specified_number(num_moduli, 28) + + ctx = ff_context.DRNSlazyContext({ + "prime": PRIME, + "rns_moduli": rns_moduli, + "precision_bits": 28, + "radix_bits": 32, + }) + + profiler_instance = profiler.Profiler( + output_trace_path=self.output_trace_root, + profile_naming=f"ff_modular_multiply_moduli_{num_moduli}", + configuration=self.profiler_config, + ) + + for batch in batch_size_list: + kernel_name = f"ff_mod_mul_m{num_moduli}_b{batch}" + kernel_wrapper = self._create_kernel_wrapper( + kernel_name=kernel_name, + ctx=ctx, + batch=batch, + num_moduli=num_moduli, + ) + + profiler_instance.add_profile( + name=kernel_name, + kernel_wrapper=kernel_wrapper, + kernel_setting_cols={ + "num_moduli": num_moduli, + "batch": batch, + }, + ) + + profiler_instance.profile_all_profilers() + profiler_instance.post_process_all_profilers() + + @parameterized.named_parameters(*TEST_PARAMS) + def test_DRNSlazy_modular_multiply_performance( + self, num_moduli, batch_size_list + ): + self._profile_modular_multiply( + num_moduli=num_moduli, + batch_size_list=batch_size_list, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxite_ec/finite_field_test.py b/jaxite_ec/finite_field_test.py index 208eda0..7d8f9c3 100644 --- a/jaxite_ec/finite_field_test.py +++ b/jaxite_ec/finite_field_test.py @@ -1,254 +1,68 @@ -import random - import jax -import jax.numpy as jnp -from jaxite.jaxite_ec import util -import jaxite.jaxite_ec.algorithm.finite_field as pyff -import jaxite.jaxite_ec.finite_field as ff +from jaxite.jaxite_ec import finite_field_context as ff_context +from jaxite.jaxite_ec import utils import numpy as np from absl.testing import absltest +from absl.testing import parameterized jax.config.update("jax_enable_x64", True) -randint = random.randint - - -def list_operation(a, b, func): - return [func(ai, bi) for ai, bi in zip(a, b)] - - -def list_operation_three(a, b, c, func): - return [func(ai, bi, ci) for ai, bi, ci in zip(a, b, c)] - - -class FiniteFieldTest(absltest.TestCase): - - def setUp(self): - super().setUp() - self.value_a = [ - 0xBE4FBE5D03CE926E40E058BBDC3269C78CFAFED39796CD13EC8E9B0072DB2538DFFBCA05804574D9E2FF7EEB1DE219, - 0x008848DEFE740A67C8FC6225BF87FF5485951E2CAA9D41BB188282C8BD37CB5CD5481512FFCD394EEAB9B16EB21BE9EF, - ] - self.value_b = [ - 0x82A0ED372BFAB8198D0667A1DC5E299C1F6C8FEB0ACD4D05A228325117BE63EAE5BABE6807F41C6C8016BDAC251CFE, - 0x01914A69C5102EFF1F674F5D30AFEEC4BD7FB348CA3E52D96D182AD44FB82305C2FE3D3634A9591AFD82DE55559C8EA6, - ] - self.value_c = [ - 0x125E69CE765D167C0B19F8D6D6708D39C7782F33B6D320802E2FFA92BBB12DBB3897EAF9CC4CF67E487478F3C3FAD16, - 0x01AC3A384FC584EFD3E7F2C5A2927E7D454875C874A051027B9E7363D08942533EDE85DAE295D8CAB2751085206BCA76, - ] - self.value_a_jax = util.int_list_to_array( - self.value_a, base=util.BASE, array_size=util.U16_CHUNK_NUM - ) - self.value_b_jax = util.int_list_to_array( - self.value_b, base=util.BASE, array_size=util.U16_CHUNK_NUM - ) - self.value_c_jax = util.int_list_to_array( - self.value_c, base=util.BASE, array_size=util.U16_CHUNK_NUM - ) - - @absltest.skip("This test is only needed if u need to generate RNS Matrix.") - def test_generate_rns_precompute_matrix(self): - rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT) - rns_stack_mat = rns_mat[0] - cor_mat = rns_mat[1] - print("Printing out RNS matrix") - print("(") - for i in range(len(rns_stack_mat)): - print("(", end="") - for j in range(len(rns_stack_mat[0])): - if j == len(rns_stack_mat[0]) - 1: - print(rns_stack_mat[i][j], end="), ") - else: - print(rns_stack_mat[i][j], end=", ") - print("") - print(")") - - print("Printing out Correction matrix") - print("(") - for i in range(len(cor_mat)): - print("( ", end="") - for j in range(len(cor_mat[0])): - if j == len(cor_mat[0]) - 1: - print(cor_mat[i][j], end="), ") - else: - print(cor_mat[i][j], end=", ") - print("") - print(")") - print() - - def test_add_two(self): - result_jax = ff.add_2u16(self.value_a_jax, self.value_b_jax) - result = util.array_to_int_list(result_jax, util.BASE) - - result_ref = list_operation(self.value_a, self.value_b, lambda a, b: a + b) - - self.assertEqual(result, result_ref) - - def test_add_three(self): - result_jax = ff.add_3u16( - self.value_a_jax, self.value_b_jax, self.value_c_jax - ) - result = util.array_to_int_list(result_jax, util.BASE) - result_ref = list_operation( - list_operation(self.value_a, self.value_b, lambda a, b: a + b), - self.value_c, - lambda a, b: a + b, - ) - - self.assertEqual(result, result_ref) - - def test_cond_sub_1(self): - result_jax = ff.cond_sub_2u16(self.value_a_jax, self.value_b_jax) - result = util.array_to_int_list(result_jax, util.BASE) - - def cond_sub(a, b): - if a < b: - return a + util.MODULUS_377_INT - b - else: - return a - b - - result_ref = list_operation(self.value_a, self.value_b, cond_sub) - self.assertEqual(result, result_ref) - - def test_cond_sub_2(self): - result_jax = ff.cond_sub_2u16(self.value_b_jax, self.value_a_jax) - result = util.array_to_int_list(result_jax, util.BASE) - - def cond_sub(a, b): - if a < b: - return a + util.MODULUS_377_INT - b - else: - return a - b - - result_ref = list_operation(self.value_b, self.value_a, cond_sub) - - self.assertEqual(result, result_ref) - - def test_cond_sub_mod_1(self): - value_list = [util.MODULUS_377_INT + 123, util.MODULUS_377_INT - 5432] - value_jax = util.int_list_to_array( - value_list, base=util.BASE, array_size=util.U16_CHUNK_NUM - ) - result_jax = ff.cond_sub_mod_u16(value_jax) - result = util.array_to_int_list(result_jax, util.BASE) - - def cond_sub_mod(a): - if a < util.MODULUS_377_INT: - return a - else: - return a - util.MODULUS_377_INT - - result_ref = [cond_sub_mod(a) for a in value_list] - - self.assertEqual(result, result_ref) - - def test_mul_1(self): - result_jax = ff.mul_2u16(self.value_a_jax, self.value_b_jax) - result = util.array_to_int_list(result_jax, util.BASE) - result_ref = list_operation(self.value_a, self.value_b, lambda a, b: a * b) - self.assertEqual(result, result_ref) - - def test_mod_mul_barrett_1(self): - result_jax = ff.mod_mul_barrett_2u16(self.value_a_jax, self.value_b_jax) - result = util.array_to_int_list(result_jax, util.BASE) - - def mod_mul_barrett(a, b): - value_a_barrett = pyff.FiniteFieldElementBarrett(a, util.MODULUS_377_INT) - value_b_barrett = pyff.FiniteFieldElementBarrett(b, util.MODULUS_377_INT) - return (value_a_barrett * value_b_barrett).get_value() - - result_ref = list_operation(self.value_a, self.value_b, mod_mul_barrett) - - self.assertEqual(result, result_ref) - - def test_jax_mod_mul_lazy_reduction(self): - """This test case check the jax version (TPU deployment) of the lazy reduction based modular multiplication algorithm.""" - batch_size = 16 - a_list = [randint(0, util.MODULUS_377_INT) for _ in range(batch_size)] - b_list = [randint(0, util.MODULUS_377_INT) for _ in range(batch_size)] - - a_batch = util.int_list_to_array( - a_list, base=util.BASE, array_size=util.U16_EXT_CHUNK_NUM - ) - b_batch = util.int_list_to_array( - b_list, base=util.BASE, array_size=util.U16_EXT_CHUNK_NUM - ) - c_batch = ff.mod_mul_lazy_2u16(a_batch, b_batch) - c_list = util.array_to_int_list(c_batch, util.BASE) - for i in range(len(a_list)): - np.testing.assert_equal( - c_list[i] % util.MODULUS_377_INT, - (a_list[i] * b_list[i]) % util.MODULUS_377_INT, - ) - - def test_jax_mod_mul_rns_reduction(self): - """This test case check the jax version (TPU deployment) of the rns reduction based modular multiplication algorithm.""" - batch_size = 16 - a_list = [randint(0, util.MODULUS_377_INT) for _ in range(batch_size)] - b_list = [randint(0, util.MODULUS_377_INT) for _ in range(batch_size)] - - modulus_rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT) - a_batch = util.int_list_to_array_rns(a_list) - b_batch = util.int_list_to_array_rns(b_list) - c_batch = ff.mod_mul_rns_2u16(a_batch, b_batch, modulus_rns_mat) - c_list = util.array_rns_to_int_list(c_batch) - for i in range(len(a_list)): - np.testing.assert_equal( - c_list[i] % util.MODULUS_377_INT, - (a_list[i] * b_list[i]) % util.MODULUS_377_INT, - ) - - def test_jax_add_rns(self): - max_val = [2**16 - 1 for _ in range(util.NUM_MODULI)] - max_normal_val = [m - 1 for m in util.MODULI] - zero = [0 for _ in range(util.NUM_MODULI)] - values = [zero, max_val, max_normal_val] - for a in values: - for b in values: - jax_a = jnp.array(a, dtype=jnp.uint16).reshape((1, util.NUM_MODULI)) - jax_b = jnp.array(b, dtype=jnp.uint16).reshape((1, util.NUM_MODULI)) - jax_sum = ff.add_rns_2u16(jax_a, jax_b, tuple(util.RNS_MODULI_T)) - jax_3sum = ff.add_rns_3u16( - jax_a, jax_b, jax_a, tuple(util.RNS_MODULI_T) - ) - for i in range(util.NUM_MODULI): - np.testing.assert_equal( - int(jax_sum[0, i]) % util.MODULI[i], - (a[i] + b[i]) % util.MODULI[i], - ) - np.testing.assert_equal( - int(jax_3sum[0, i]) % util.MODULI[i], - (a[i] + b[i] + a[i]) % util.MODULI[i], - ) - - def test_jax_sub_rns(self): - batch_size = 16 - bound = 256 * util.NUM_MODULI * util.MODULUS_377_INT - a_list = [randint(0, bound) for _ in range(batch_size)] - b_list = [randint(0, bound) for _ in range(batch_size)] - b_list[0] = bound - 1 - a_batch = util.int_list_to_array_rns(a_list) - b_batch = util.int_list_to_array_rns(b_list) - diff = ff.add_sub_rns_var(a_batch, ff.negate_rns_for_var_add(b_batch)) - diff_int = util.array_rns_to_int_list(diff) - for i in range(batch_size): - np.testing.assert_equal( - diff_int[i] % util.MODULUS_377_INT, - (a_list[i] - b_list[i]) % util.MODULUS_377_INT, - ) - - def test_jax_add_rns_specific_case(self): - e = 149025596882241982990837486539530757729373308235078472950338530041138824139683871423246406325832428746678481926350 - h = 64504914146370186321601383206234327860947337385636434408981741341000348311615332297641759363669577422255115417749 - q = util.MODULUS_377_INT - jax_a = util.int_list_to_array_rns([e]) - jax_b = util.int_list_to_array_rns([h]) - jax_sum = ff.mod_mul_rns_2u16(jax_a, jax_b) - val_sum = util.array_rns_to_int(jax_sum[0]) - np.testing.assert_equal(val_sum % util.MODULUS_377_INT, (e * h % q)) +FF = [ + ( + "0", + [ + 0xBE4FBE5D03CE926E40E058BBDC3269C78CFAFED39796CD13EC8E9B0072DB2538DFFBCA05804574D9E2FF7EEB1DE219, + 0x008848DEFE740A67C8FC6225BF87FF5485951E2CAA9D41BB188282C8BD37CB5CD5481512FFCD394EEAB9B16EB21BE9EF, + ], + [ + 0x82A0ED372BFAB8198D0667A1DC5E299C1F6C8FEB0ACD4D05A228325117BE63EAE5BABE6807F41C6C8016BDAC251CFE, + 0x01914A69C5102EFF1F674F5D30AFEEC4BD7FB348CA3E52D96D182AD44FB82305C2FE3D3634A9591AFD82DE55559C8EA6, + ], + ), +] + + +class FiniteFieldTest(parameterized.TestCase): + + def __init__(self, *args, **kwargs): + super(FiniteFieldTest, self).__init__(*args, **kwargs) + + @parameterized.named_parameters(*FF) + def test_DRNSlazy_modular_multiply(self, value_a, value_b): + prime = 0x01AE3A4617C510EAC63B05C06CA1493B1A22D9F300F5138F1EF3622FBA094800170B5D44300000008508C00000000001 + rns_moduli = utils.find_moduli_specified_number(32, 28) + ref_value_c = [(a * b) % prime for a, b in zip(value_a, value_b)] + + ctx = ff_context.DRNSlazyContext({ + "prime": prime, + "rns_moduli": rns_moduli, + "precision_bits": 28, + "radix_bits": 32, + }) + + value_a_m = ctx.to_computational_format(value_a) + value_b_m = ctx.to_computational_format(value_b) + value_c_m = ctx.modular_multiply(value_a_m, value_b_m) + value_c = ctx.to_original_format(value_c_m) + + np.testing.assert_array_equal(value_c, ref_value_c) + + @parameterized.named_parameters(*FF) + def test_lazy_modular_multiply(self, value_a, value_b): + prime = 0x01AE3A4617C510EAC63B05C06CA1493B1A22D9F300F5138F1EF3622FBA094800170B5D44300000008508C00000000001 + + ref_value_c = [(a * b) % prime for a, b in zip(value_a, value_b)] + + ctx = ff_context.CROSSLazyContext({"prime": prime, "chunk_num_u8": 48}) + + value_a_m = ctx.to_computational_format(value_a) + value_b_m = ctx.to_computational_format(value_b) + value_c_m = ctx.modular_multiply(value_a_m, value_b_m) + value_c = ctx.to_original_format(value_c_m) + + np.testing.assert_array_equal(value_c, ref_value_c) if __name__ == "__main__": diff --git a/jaxite_ec/msm_test.py b/jaxite_ec/msm_test.py deleted file mode 100644 index cd78971..0000000 --- a/jaxite_ec/msm_test.py +++ /dev/null @@ -1,730 +0,0 @@ -import csv -import os -import sys - -import jax -import jax.numpy as jnp -from jaxite.jaxite_ec import pippenger -from jaxite.jaxite_ec import pippenger_rns -from jaxite.jaxite_ec import util -from jaxite.jaxite_ec.algorithm import config_file -import jaxite.jaxite_ec.algorithm.elliptic_curve as ec - -# copybara: from google3.pyglib import resources -from absl.testing import absltest -from absl.testing import parameterized - - -script_path = os.path.abspath(sys.argv[0]) -script_dir = os.path.dirname(script_path) - -jax.config.update("jax_traceback_filtering", "off") - -# Only needed when tryingt to understand the HLO dump. -os.environ["XLA_FLAGS"] = ( - "--xla_dump_to=sponge --xla_backend_optimization_level=4" -) - -TEST_PARAMS = [ - ( - "test_4_degree", - os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), - f"{script_dir}/jaxite_ec/test_case/t4/zprize_msm_curve_377_scalars_dim_4_seed_0.csv" - ), - os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), - f"{script_dir}/jaxite_ec/test_case/t4/zprize_msm_curve_377_bases_dim_4_seed_0.csv" - ), - os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), - f"{script_dir}/jaxite_ec/test_case/t4/zprize_msm_curve_377_res_dim_4_seed_0.csv" - ), - ), - ( - "test_1024_degree", - os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), - f"{script_dir}/jaxite_ec/test_case/t1024/zprize_msm_curve_377_scalars_dim_1024_seed_0.csv" - ), - os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), - f"{script_dir}/jaxite_ec/test_case/t1024/zprize_msm_curve_377_bases_dim_1024_seed_0.csv" - ), - os.path.join(os.path.dirname(os.path.abspath(sys.argv[0])), - f"{script_dir}/jaxite_ec/test_case/t1024/zprize_msm_curve_377_res_dim_1024_seed_0.csv" - ), - ), -] - - -def twist_coordinates_list(ec_config, coordinates_list): - twisted_ec_sys = ec.ECCSTwistedEdwardsExtended(ec_config) - twisted_coordinates_list = [] - untwisted_coordinates_indeices = [] - for i, coordinates in enumerate(coordinates_list): - twisted_coordinates = twisted_ec_sys.twist_int_coordinates(coordinates) - twisted_coordinates_list.append(twisted_coordinates) - if twisted_coordinates == [0, 1, 1, 0]: - untwisted_coordinates_indeices.append(i) - return twisted_coordinates_list, untwisted_coordinates_indeices - - -class MSMTest(parameterized.TestCase): - def read_external_file(self, scalar_path, base_path, result_path): - scalars = [] - with open( - scalar_path, "r", newline="", encoding="utf-8" - ) as csvfile: # Handle potential encoding issues - csv_reader = csv.reader(csvfile) - for row in csv_reader: - scalars.append(int(row[-1][13:-2], 16)) - - points = [] - with open( - base_path, "r", newline="", encoding="utf-8" - ) as csvfile: # Handle potential encoding issues - csv_reader = csv.reader(csvfile) - for row in csv_reader: - points.append([int(row[8][13:-2], 16), int(row[-1][13:-2], 16)]) - - result_ref = [] - with open( - result_path, "r", newline="", encoding="utf-8" - ) as csvfile: # Handle potential encoding issues - csv_reader = csv.reader(csvfile) - for row in csv_reader: - result_ref.append(int(row[7][13:-2], 16)) - result_ref.append(int(row[-1][13:-2], 16)) - return scalars, points, result_ref - - @parameterized.named_parameters(*TEST_PARAMS) - def test_pippenger_index_selection(self, scalar_path, base_path, result_path): - """Normal version Pippenger.""" - scalars, points, result_ref = self.read_external_file( - scalar_path, base_path, result_path - ) - slice_length = 4 - msm_algo = pippenger.MSMPippenger(slice_length) - msm_algo.initialize(scalars, points) - - window_num = msm_algo.window_num - bucket_num_per_window = msm_algo.bucket_num_per_window - msm_length = msm_algo.msm_length - coordinate_num = msm_algo.coordinate_num - chunk_num = util.U16_EXT_CHUNK_NUM - - bucket_accumulation_index_scan_jit = ( - jax.jit( - pippenger.bucket_accumulation_index_scan_algorithm, - static_argnames="msm_length", - ) - .lower( - jax.ShapeDtypeStruct( - (coordinate_num, window_num, bucket_num_per_window, chunk_num), - dtype=jnp.uint16, - ), - jax.ShapeDtypeStruct( - (msm_length, coordinate_num, chunk_num), dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct((msm_length, window_num), dtype=jnp.uint32), - jax.ShapeDtypeStruct( - (msm_length, window_num, bucket_num_per_window), dtype=jnp.uint8 - ), - msm_length, - ) - .compile() - ) - - bucket_reduction_scan_jit = ( - jax.jit( - pippenger.bucket_reduction_scan_algorithm, - static_argnames="bucket_num_in_window", - ) - .lower( - jax.ShapeDtypeStruct( - (coordinate_num, window_num, bucket_num_per_window, chunk_num), - dtype=jnp.uint16, - ), - jax.ShapeDtypeStruct( - (coordinate_num, window_num, chunk_num), dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (coordinate_num, window_num, chunk_num), dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (bucket_num_per_window, window_num), dtype=jnp.uint8 - ), - jax.ShapeDtypeStruct( - (bucket_num_per_window, window_num), dtype=jnp.uint8 - ), - jax.ShapeDtypeStruct( - ( - bucket_num_per_window, - window_num, - ), - dtype=jnp.uint8, - ), - bucket_num_per_window, - ) - .compile() - ) - - window_merge_scan_jit = ( - jax.jit( - pippenger.window_merge_scan_algorithm, - static_argnames="slice_length", - ) - .lower( - jax.ShapeDtypeStruct( - (coordinate_num, window_num, chunk_num), dtype=jnp.uint16 - ), - slice_length, - ) - .compile() - ) - - msm_algo.bucket_accumulation(bucket_accumulation_index_scan_jit) - msm_algo.bucket_reduction(bucket_reduction_scan_jit) - result = msm_algo.window_merge(window_merge_scan_jit) - result = util.jax_point_pack_to_int_point(result) - ec_sys = ec.ECCSWeierstrassXYZZ(config_file.config_BLS12_377) - result_affine_point = ec_sys.generate_point(result).convert_to_affine() - coordinates = ( - result_affine_point[0].get_value(), - result_affine_point[1].get_value(), - ) - self.assertEqual(coordinates[0], result_ref[0]) - self.assertEqual(coordinates[1], result_ref[1]) - - # performance measurement - tasks = [ - (msm_algo.bucket_accumulation, (bucket_accumulation_index_scan_jit,)), - (msm_algo.bucket_reduction, (bucket_reduction_scan_jit,)), - (msm_algo.window_merge, (window_merge_scan_jit,)), - ] - profile_name = "normal_pippenger_index_selection" - # copybara: util.profile_jax_functions(tasks, profile_name) - - @parameterized.named_parameters(*TEST_PARAMS) - def test_pippenger_index_selection_twisted_edwards( - self, scalar_path, base_path, result_path - ): - scalars, points, result_ref = self.read_external_file( - scalar_path, base_path, result_path - ) - twisted_points, untwisted_coordinates_indeices = twist_coordinates_list( - config_file.config_BLS12_377_t, points - ) - assert not untwisted_coordinates_indeices - slice_length = 4 - parallel_num = 4 - msm_algo = pippenger.MSMPippengerTwisted(slice_length, parallel_num) - msm_algo.initialize(scalars, twisted_points) - - window_num = msm_algo.window_num - bucket_num_per_window = msm_algo.bucket_num_per_window - msm_length = msm_algo.msm_length - coordinate_num = msm_algo.coordinate_num - chunk_num = util.U16_EXT_CHUNK_NUM - - batch_window_num = window_num * parallel_num - batch_mem_length = msm_length // parallel_num - - bucket_accumulation_index_scan_jit = ( - jax.jit( - pippenger.bucket_accumulation_index_scan_parallel_algorithm_twisted, - static_argnames="msm_length", - ) - .lower( - jax.ShapeDtypeStruct( - ( - coordinate_num, - batch_window_num, - bucket_num_per_window, - chunk_num, - ), - dtype=jnp.uint16, - ), - jax.ShapeDtypeStruct( - (batch_mem_length, coordinate_num, parallel_num, chunk_num), - dtype=jnp.uint16, - ), - jax.ShapeDtypeStruct( - (batch_mem_length, batch_window_num), dtype=jnp.uint32 - ), - batch_mem_length, - ) - .compile() - ) - - bucket_reduction_scan_jit = ( - jax.jit( - pippenger.bucket_reduction_scan_algorithm_twisted, - static_argnames="bucket_num_in_window", - ) - .lower( - jax.ShapeDtypeStruct( - ( - coordinate_num, - batch_window_num, - bucket_num_per_window, - chunk_num, - ), - dtype=jnp.uint16, - ), - jax.ShapeDtypeStruct( - (coordinate_num, batch_window_num, chunk_num), dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (coordinate_num, batch_window_num, chunk_num), dtype=jnp.uint16 - ), - bucket_num_per_window, - ) - .compile() - ) - - batch_window_summation_jit = ( - jax.jit( - pippenger.batch_window_summation_algorithm_twisted, - static_argnames="point_parallel", - ) - .lower( - jax.ShapeDtypeStruct( - (coordinate_num, window_num, chunk_num), dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (coordinate_num, batch_window_num, chunk_num), dtype=jnp.uint16 - ), - parallel_num, - ) - .compile() - ) - - window_merge_scan_jit = ( - jax.jit( - pippenger.window_merge_scan_algorithm_twisted, - static_argnames="slice_length", - ) - .lower( - jax.ShapeDtypeStruct( - (coordinate_num, window_num, chunk_num), dtype=jnp.uint16 - ), - slice_length, - ) - .compile() - ) - - # HERE - msm_algo.bucket_accumulation(bucket_accumulation_index_scan_jit) - msm_algo.bucket_reduction(bucket_reduction_scan_jit) - msm_algo.batch_window_summation(batch_window_summation_jit) - result = msm_algo.window_merge(window_merge_scan_jit) - result = util.jax_point_pack_to_int_point(result) - # TO HERE - ec_sys = ec.ECCSTwistedEdwardsExtended(config_file.config_BLS12_377_t) - result_affine_point = ec_sys.generate_point( - result, twist=False - ).convert_to_affine() - coordinates = ( - result_affine_point[0].get_value(), - result_affine_point[1].get_value(), - ) - self.assertEqual(coordinates[0], result_ref[0]) - self.assertEqual(coordinates[1], result_ref[1]) - - # performance measurement - tasks = [ - (msm_algo.bucket_accumulation, (bucket_accumulation_index_scan_jit,)), - (msm_algo.bucket_reduction, (bucket_reduction_scan_jit,)), - (msm_algo.batch_window_summation, (batch_window_summation_jit,)), - (msm_algo.window_merge, (window_merge_scan_jit,)), - ] - profile_name = "pippenger_index_selection_twisted_edwards" - # copybara: util.profile_jax_functions(tasks, profile_name) - - @parameterized.named_parameters(*TEST_PARAMS) - def test_pippenger_signed_index_selection_twisted_edwards( - self, scalar_path, base_path, result_path - ): - scalars, points, result_ref = self.read_external_file( - scalar_path, base_path, result_path - ) - twisted_points, untwisted_coordinates_indeices = twist_coordinates_list( - config_file.config_BLS12_377_t, points - ) - assert not untwisted_coordinates_indeices - slice_length = 4 - parallel_num = 4 - msm_algo = pippenger.MSMPippengerTwistedSigned(slice_length, parallel_num) - msm_algo.initialize(scalars, twisted_points) - - window_num = msm_algo.window_num - bucket_num_per_window = msm_algo.bucket_num_per_window - msm_length = msm_algo.msm_length - coordinate_num = msm_algo.coordinate_num - chunk_num = util.U16_EXT_CHUNK_NUM - - batch_window_num = window_num * parallel_num - batch_mem_length = msm_length // parallel_num - - bucket_accumulation_index_scan_jit = ( - jax.jit( - pippenger.bucket_accumulation_signed_index_scan_parallel_algorithm_twisted, - static_argnames="msm_length", - ) - .lower( - jax.ShapeDtypeStruct( - ( - coordinate_num, - batch_window_num, - bucket_num_per_window, - chunk_num, - ), - dtype=jnp.uint16, - ), - jax.ShapeDtypeStruct( - (batch_mem_length, coordinate_num, parallel_num, chunk_num), - dtype=jnp.uint16, - ), - jax.ShapeDtypeStruct( - (batch_mem_length, batch_window_num), dtype=jnp.uint32 - ), - jax.ShapeDtypeStruct( - (batch_mem_length, batch_window_num), dtype=jnp.uint8 - ), - batch_mem_length, - ) - .compile() - ) - - bucket_reduction_scan_jit = ( - jax.jit( - pippenger.bucket_reduction_scan_algorithm_twisted, - static_argnames="bucket_num_in_window", - ) - .lower( - jax.ShapeDtypeStruct( - ( - coordinate_num, - batch_window_num, - bucket_num_per_window, - chunk_num, - ), - dtype=jnp.uint16, - ), - jax.ShapeDtypeStruct( - (coordinate_num, batch_window_num, chunk_num), dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (coordinate_num, batch_window_num, chunk_num), dtype=jnp.uint16 - ), - bucket_num_per_window, - ) - .compile() - ) - - batch_window_summation_jit = ( - jax.jit( - pippenger.batch_window_summation_algorithm_twisted, - static_argnames="point_parallel", - ) - .lower( - jax.ShapeDtypeStruct( - (coordinate_num, window_num, chunk_num), dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (coordinate_num, batch_window_num, chunk_num), dtype=jnp.uint16 - ), - parallel_num, - ) - .compile() - ) - - window_merge_scan_jit = ( - jax.jit( - pippenger.window_merge_scan_algorithm_twisted, - static_argnames="slice_length", - ) - .lower( - jax.ShapeDtypeStruct( - (coordinate_num, window_num, chunk_num), dtype=jnp.uint16 - ), - slice_length, - ) - .compile() - ) - - # HERE - msm_algo.bucket_accumulation(bucket_accumulation_index_scan_jit) - msm_algo.bucket_reduction(bucket_reduction_scan_jit) - msm_algo.batch_window_summation(batch_window_summation_jit) - result = msm_algo.window_merge(window_merge_scan_jit) - result = util.jax_point_pack_to_int_point(result) - # TO HERE - ec_sys = ec.ECCSTwistedEdwardsExtended(config_file.config_BLS12_377_t) - result_affine_point = ec_sys.generate_point( - result, twist=False - ).convert_to_affine() - coordinates = ( - result_affine_point[0].get_value(), - result_affine_point[1].get_value(), - ) - self.assertEqual(coordinates[0], result_ref[0]) - self.assertEqual(coordinates[1], result_ref[1]) - - # performance measurement - tasks = [ - (msm_algo.bucket_accumulation, (bucket_accumulation_index_scan_jit,)), - (msm_algo.bucket_reduction, (bucket_reduction_scan_jit,)), - (msm_algo.batch_window_summation, (batch_window_summation_jit,)), - (msm_algo.window_merge, (window_merge_scan_jit,)), - ] - profile_name = "pippenger_signed_index_selection_twisted_edwards" - # copybara: util.profile_jax_functions(tasks, profile_name) - - # @absltest.skip("test pass") - @parameterized.named_parameters(*TEST_PARAMS) - def test_pippenger_index_rns_selection( - self, scalar_path, base_path, result_path - ): - """RNS version Pippenger - XYZZ.""" - scalars, points, result_ref = self.read_external_file( - scalar_path, base_path, result_path - ) - slice_length = 4 - msm_algo = pippenger_rns.MSMPippenger(slice_length) - msm_algo.initialize(scalars, points) - - window_num = msm_algo.window_num - bucket_num_per_window = msm_algo.bucket_num_per_window - msm_length = msm_algo.msm_length - coordinate_num = msm_algo.coordinate_num - chunk_num = util.NUM_MODULI - - bucket_accumulation_index_scan_jit = ( - jax.jit( - pippenger_rns.bucket_accumulation_index_scan_algorithm, - static_argnames="msm_length", - ) - .lower( - jax.ShapeDtypeStruct( - (coordinate_num, window_num, bucket_num_per_window, chunk_num), - dtype=jnp.uint16, - ), - jax.ShapeDtypeStruct( - (msm_length, coordinate_num, chunk_num), dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct((msm_length, window_num), dtype=jnp.uint32), - jax.ShapeDtypeStruct( - (msm_length, window_num, bucket_num_per_window), dtype=jnp.uint8 - ), - msm_length, - ) - .compile() - ) - - bucket_reduction_scan_jit = ( - jax.jit( - pippenger_rns.bucket_reduction_scan_algorithm, - static_argnames="bucket_num_in_window", - ) - .lower( - jax.ShapeDtypeStruct( - (coordinate_num, window_num, bucket_num_per_window, chunk_num), - dtype=jnp.uint16, - ), - jax.ShapeDtypeStruct( - (coordinate_num, window_num, chunk_num), dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (coordinate_num, window_num, chunk_num), dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (bucket_num_per_window, window_num), dtype=jnp.uint8 - ), - jax.ShapeDtypeStruct( - (bucket_num_per_window, window_num), dtype=jnp.uint8 - ), - jax.ShapeDtypeStruct( - (bucket_num_per_window, window_num), dtype=jnp.uint8 - ), - bucket_num_per_window, - ) - .compile() - ) - - window_merge_scan_jit = ( - jax.jit( - pippenger_rns.window_merge_scan_algorithm, - static_argnames="slice_length", - ) - .lower( - jax.ShapeDtypeStruct( - (coordinate_num, window_num, chunk_num), dtype=jnp.uint16 - ), - slice_length, - ) - .compile() - ) - - # HERE - msm_algo.bucket_accumulation(bucket_accumulation_index_scan_jit) - msm_algo.bucket_reduction(bucket_reduction_scan_jit) - result = msm_algo.window_merge(window_merge_scan_jit) - result = util.jax_rns_point_pack_to_int_point(result) - # TO HERE - ec_sys = ec.ECCSWeierstrassXYZZ(config_file.config_BLS12_377) - result_affine_point = ec_sys.generate_point(result).convert_to_affine() - coordinates = ( - result_affine_point[0].get_value(), - result_affine_point[1].get_value(), - ) - self.assertEqual(coordinates[0], result_ref[0]) - self.assertEqual(coordinates[1], result_ref[1]) - - # performance measurement - tasks = [ - (msm_algo.bucket_accumulation, (bucket_accumulation_index_scan_jit,)), - (msm_algo.bucket_reduction, (bucket_reduction_scan_jit,)), - (msm_algo.window_merge, (window_merge_scan_jit,)), - ] - profile_name = "pippenger_index_rns_selection" - # copybara: util.profile_jax_functions(tasks, profile_name) - - # @absltest.skip("has some bug in result") - @parameterized.named_parameters(*TEST_PARAMS) - def test_pippenger_index_selection_rns_twisted_edwards( - self, scalar_path, base_path, result_path - ): - scalars, points, result_ref = self.read_external_file( - scalar_path, base_path, result_path - ) - twisted_points, untwisted_coordinates_indeices = twist_coordinates_list( - config_file.config_BLS12_377_t, points - ) - assert not untwisted_coordinates_indeices - slice_length = 4 - parallel_num = 4 - msm_algo = pippenger_rns.MSMPippengerTwisted(slice_length, parallel_num) - msm_algo.initialize(scalars, twisted_points) - - window_num = msm_algo.window_num - bucket_num_per_window = msm_algo.bucket_num_per_window - msm_length = msm_algo.msm_length - coordinate_num = msm_algo.coordinate_num - chunk_num = util.NUM_MODULI - - batch_window_num = window_num * parallel_num - batch_mem_length = msm_length // parallel_num - - bucket_accumulation_index_scan_jit = ( - jax.jit( - pippenger_rns.bucket_accumulation_index_scan_parallel_algorithm_twisted, - static_argnames="msm_length", - ) - .lower( - jax.ShapeDtypeStruct( - ( - coordinate_num, - batch_window_num, - bucket_num_per_window, - chunk_num, - ), - dtype=jnp.uint16, - ), - jax.ShapeDtypeStruct( - (batch_mem_length, coordinate_num, parallel_num, chunk_num), - dtype=jnp.uint16, - ), - jax.ShapeDtypeStruct( - (batch_mem_length, batch_window_num), dtype=jnp.uint32 - ), - batch_mem_length, - ) - .compile() - ) - - bucket_reduction_scan_jit = ( - jax.jit( - pippenger_rns.bucket_reduction_scan_algorithm_twisted, - static_argnames="bucket_num_in_window", - ) - .lower( - jax.ShapeDtypeStruct( - ( - coordinate_num, - batch_window_num, - bucket_num_per_window, - chunk_num, - ), - dtype=jnp.uint16, - ), - jax.ShapeDtypeStruct( - (coordinate_num, batch_window_num, chunk_num), dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (coordinate_num, batch_window_num, chunk_num), dtype=jnp.uint16 - ), - bucket_num_per_window, - ) - .compile() - ) - - batch_window_summation_jit = ( - jax.jit( - pippenger_rns.batch_window_summation_algorithm_twisted, - static_argnames="point_parallel", - ) - .lower( - jax.ShapeDtypeStruct( - (coordinate_num, window_num, chunk_num), dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (coordinate_num, batch_window_num, chunk_num), dtype=jnp.uint16 - ), - parallel_num, - ) - .compile() - ) - - window_merge_scan_jit = ( - jax.jit( - pippenger_rns.window_merge_scan_algorithm_twisted, - static_argnames="slice_length", - ) - .lower( - jax.ShapeDtypeStruct( - (coordinate_num, window_num, chunk_num), dtype=jnp.uint16 - ), - slice_length, - ) - .compile() - ) - - # HERE - msm_algo.bucket_accumulation(bucket_accumulation_index_scan_jit) - msm_algo.bucket_reduction(bucket_reduction_scan_jit) - msm_algo.batch_window_summation(batch_window_summation_jit) - result = msm_algo.window_merge(window_merge_scan_jit) - result = util.jax_rns_point_pack_to_int_point(result) - # TO HERE - ec_sys = ec.ECCSTwistedEdwardsExtended(config_file.config_BLS12_377_t) - result_affine_point = ec_sys.generate_point( - result, twist=False - ).convert_to_affine() - coordinates = ( - result_affine_point[0].get_value() % util.MODULUS_377_INT, - result_affine_point[1].get_value() % util.MODULUS_377_INT, - ) - self.assertEqual(coordinates[0], result_ref[0]) - self.assertEqual(coordinates[1], result_ref[1]) - - # performance measurement - tasks = [ - (msm_algo.bucket_accumulation, (bucket_accumulation_index_scan_jit,)), - (msm_algo.bucket_reduction, (bucket_reduction_scan_jit,)), - (msm_algo.batch_window_summation, (batch_window_summation_jit,)), - (msm_algo.window_merge, (window_merge_scan_jit,)), - ] - profile_name = "pippenger_index_selection_rns_twisted_edwards" - # copybara: util.profile_jax_functions(tasks, profile_name) - - -if __name__ == "__main__": - absltest.main() diff --git a/jaxite_ec/multiscalar_multiplication_context.py b/jaxite_ec/multiscalar_multiplication_context.py new file mode 100644 index 0000000..a78252e --- /dev/null +++ b/jaxite_ec/multiscalar_multiplication_context.py @@ -0,0 +1,2772 @@ +from abc import ABC, abstractmethod +import ctypes +import math +import random +from typing import Any, Optional +import warnings +import jax +import jax.ffi as ffi +import jax.numpy as jnp +from jaxite.jaxite_ec.c_kernels.build import ensure_distribution_kernel +from jaxite.jaxite_ec.elliptic_curve_context import EllipticCurveContextBase +import jaxite.jaxite_ec.utils as utils +from jaxite.jaxite_ec.utils import JaxKernelContextBase, hash_args, jax_jit_lower_compile, load_jax_executable, store_jax_executable +import numpy as np + +jax.config.update("jax_enable_x64", True) + + +class MultiscalarMultiplicationContextBase(ABC): + + def __init__(self, parameters: dict): + self.parameters = parameters + self.ec_ctx_class = parameters.get("elliptic_curve_context_class", None) + assert ( + self.ec_ctx_class is not None + ), "elliptic_curve_context_class must be provided" + self.ec_ctx: EllipticCurveContextBase = self.ec_ctx_class( + parameters.get("elliptic_curve_parameters", {}) + ) + + def _padd(self, a: jnp.ndarray, b: jnp.ndarray) -> jnp.ndarray: + return self.ec_ctx.point_add(a, b) + + @abstractmethod + def to_computational_format(self, a: Any) -> Any: + pass + + @abstractmethod + def to_original_format(self, a: Any) -> Any: + pass + + @abstractmethod + def multiscalar_multiply(self, tiled_slices: jnp.ndarray): + pass + + +class CPUDistributionMSMContextBase(MultiscalarMultiplicationContextBase): + + def __init__(self, parameters: dict): + super().__init__(parameters) + self.scalar_bits = 0 + self.slice_bits = 0 + self.tile_length = 0 + self._init_config_parameters() + self._init_jax_data() + self._init_cpu_kernels() + self._init_point_parameters() + + def _init_config_parameters(self): + pass + + def _init_jax_data(self): + pass + + def _init_cpu_kernels(self): + pass + + def _init_point_parameters(self): + raw_points = utils.read_external_msm_file( + self.parameters.get("points_path"), "points" + ) + with jax.default_device(jax.devices("cpu")[0]): + points = self.ec_ctx.to_computational_format(raw_points).to_device( + jax.devices("cpu")[0] + ) + self.points = points.transpose(1, 0, 2) + + def _preprocess_scalars(self, scalars: list): + tiled_scalar_list = utils.split_list(scalars, self.tile_length) + tiled_slices_list = [] + for tiled_scalars in tiled_scalar_list: + sliced_scalars = utils.slice_scalars( + tiled_scalars, self.scalar_bits, self.slice_bits + ) + tiled_slices_list.append(sliced_scalars) + + with jax.default_device(jax.devices("cpu")[0]): + tiled_slices_list = jnp.array(tiled_slices_list, dtype=jnp.int32) + return tiled_slices_list + + +class OldCPUDistributionMSMContext(CPUDistributionMSMContextBase): + + def __init__(self, parameters: dict): + CPUDistributionMSMContextBase.__init__(self, parameters) + + def _init_config_parameters(self): + self.coordinate_dim = self.parameters.get("coordinate_dim", 4) + self.msm_length = self.parameters.get("msm_length") + self.tile_length = self.parameters.get("tile_length") + assert ( + self.msm_length % self.tile_length == 0 + ), "msm_length must be divisible by tile_length" + self.tile_num = self.msm_length // self.tile_length + self.slice_bits = self.parameters.get("slice_bits") + self.scalar_bits = self.parameters.get("scalar_bits") + self.order = self.parameters.get("order") + self.window_num = int(math.ceil(self.scalar_bits / self.slice_bits)) # + self.batch_window_num = self.window_num + self.bucket_num_per_window = ( + 2**self.slice_bits - 1 + ) # Note: here remove the bucket_0 + self.bucket_num_last_window = self.order >> ( + (self.window_num - 1) * self.slice_bits + ) + self.moduli_num = self.ec_ctx.get_finite_field_context().get_moduli_num() + + # Special bucket optimization + self.log_special_duplication_ratio = math.ceil( + math.log2(self.bucket_num_per_window / self.bucket_num_last_window) + ) + self.special_duplication_ratio = 2**self.log_special_duplication_ratio + self.bucket_num_duplication = ( + self.bucket_num_last_window * self.special_duplication_ratio + ) + + def _init_jax_data(self): + self.zero_point = ( + self.ec_ctx.get_finite_field_context().to_computational_format( + self.ec_ctx.zero_point + ) + ) + self.all_buckets = jnp.broadcast_to( + self.zero_point.reshape(1, self.coordinate_dim, 1, self.moduli_num), + ( + self.batch_window_num, + self.coordinate_dim, + self.bucket_num_per_window, + self.moduli_num, + ), + ) + self.temp_sum = jnp.array( + [self.zero_point for _ in range(self.batch_window_num)] + ).transpose(1, 0, 2) + self.window_sum = jnp.array( + [self.zero_point for _ in range(self.batch_window_num)] + ).transpose(1, 0, 2) + + def _init_cpu_kernels(self): + expected_regular_bucket_size = self.tile_length / ( + self.bucket_num_per_window + 1 + ) + expected_special_bucket_size = math.ceil( + self.tile_length / self.bucket_num_duplication + ) + # print("expected_regular_bucket_size", expected_regular_bucket_size) + # print("expected_special_bucket_size", expected_special_bucket_size) + self.expend_ratio = self.parameters["c_kernel_ret_space_ratio"] + self.c_kernel_regular_bucket_size = int( + expected_regular_bucket_size * self.expend_ratio + ) + self.c_kernel_special_bucket_size = int( + expected_special_bucket_size * self.expend_ratio + ) + + regular_shape = ( + self.window_num - 1, + self.bucket_num_per_window, + self.c_kernel_regular_bucket_size, + self.coordinate_dim, + self.moduli_num, + ) + + special_shape = ( + self.bucket_num_last_window, + self.special_duplication_ratio, + self.c_kernel_special_bucket_size, + self.coordinate_dim, + self.moduli_num, + ) + self.ba_input_regular_shape = regular_shape + self.ba_input_special_shape = special_shape + + lib = ctypes.cdll.LoadLibrary(ensure_distribution_kernel()) + jax.ffi.register_ffi_target( + "distribute_buf", jax.ffi.pycapsule(lib.DistributeBuf), platform="cpu" + ) + self.distribution_buf_c_kernel_call = ffi.ffi_call( + "distribute_buf", + ( + jax.ShapeDtypeStruct(regular_shape, jnp.uint32), + jax.ShapeDtypeStruct(special_shape, jnp.uint32), + jax.ShapeDtypeStruct((2,), jnp.uint32), + ), + ) + + def _bucket_accumulation_per_window( + self, buckets: jnp.ndarray, window_points: jnp.ndarray + ) -> jnp.ndarray: + """Accumulate points within buckets for a single window. + + Args: + buckets: Initial bucket values (coordinate_dim, bucket_dim, + precision_dim). + window_points: Points to accumulate (bucket_size_dim, coordinate_dim, + bucket_dim, precision_dim). + parameters: Computation parameters. + + Returns: + Accumulated bucket values. + """ + bucket_size_dim = window_points.shape[0] + + def scan_body(buckets, points): + buckets = self._padd(buckets, points) + return buckets, None + + buckets, _ = jax.lax.scan( + scan_body, buckets, window_points, length=bucket_size_dim + ) + return buckets + + def _bucket_accumulation_regular_windows( + self, + regular_buckets: jnp.ndarray, + all_points: jnp.ndarray, + ) -> jnp.ndarray: + """Accumulate points for all regular windows. + + Args: + regular_buckets: Initial bucket values (window_dim, coordinate_dim, + bucket_dim, precision_dim). + all_points: Points for all windows (window_dim, bucket_size_dim, + coordinate_dim, bucket_dim, precision_dim). + + Returns: + Accumulated bucket values for all regular windows. + """ + window_dim = regular_buckets.shape[0] + + def scan_body(empty, window_bucket_point_pack): + window_buckets, points = window_bucket_point_pack + buckets = self._bucket_accumulation_per_window(window_buckets, points) + return None, buckets + + _, buckets = jax.lax.scan( + scan_body, None, (regular_buckets, all_points), length=window_dim + ) + return buckets + + def _bucket_accumulation_last_window( + self, + buckets_in: jnp.ndarray, + window_points: jnp.ndarray, + ) -> jnp.ndarray: + """Accumulate points for the last (special) window with duplication handling. + + Args: + buckets_in: Initial bucket values (coordinate_dim, bucket_dim, + precision_dim). + window_points: Points with duplication (bucket_size_dim, coordinate_dim, + bucket_dup_dim, bucket_dim, precision_dim). + parameters: Computation parameters. + + Returns: + Accumulated bucket values. + """ + coordinate_dim, buckets_dim, precision_dim = buckets_in.shape + bucket_size_dim, _, bucket_dup_dim, _, _ = window_points.shape + + # Reshape for processing + window_points = window_points.reshape( + bucket_size_dim, coordinate_dim, -1, precision_dim + ) + base_dup_buckets = window_points[0] + + def scan_body(buckets, points): + buckets = self._padd(buckets, points) + return buckets, None + + dup_buckets, _ = jax.lax.scan( + scan_body, + base_dup_buckets, + window_points[1:], + length=bucket_size_dim - 1, + ) + + # Reduce duplicated buckets using tree reduction + log_bucket_dup_dim = int(math.log2(bucket_dup_dim)) + for _ in range(log_bucket_dup_dim): + buckets_split = jnp.split(dup_buckets, 2, axis=1) + dup_buckets = self._padd(buckets_split[0], buckets_split[1]) + + # Add to input buckets + buckets_in = self._padd(buckets_in, dup_buckets) + return buckets_in + + def _bucket_accumulation_all_windows( + self, + all_buckets: jnp.ndarray, + regular_points: jnp.ndarray, + last_window_points: jnp.ndarray, + ) -> jnp.ndarray: + """Accumulate points for all windows with distributed optimization. + + This is the main bucket accumulation kernel that handles both regular + windows and the special last window with duplication optimization. + + Args: + all_buckets: Initial bucket values (window_dim, coordinate_dim, + bucket_dim, precision_dim). + regular_points: Points for regular windows (window_dim-1, bucket_num, + bucket_size, coord_dim, prec_dim). + last_window_points: Points for last window (bucket_num, dup_ratio, + bucket_size, coord_dim, prec_dim). + parameters: Computation parameters. + + Returns: + Accumulated bucket values for all windows. + """ + # Transpose for computation + regular_points = regular_points.transpose(0, 2, 3, 1, 4) + last_window_points = last_window_points.transpose(2, 3, 1, 0, 4) + + window_dim, coordinate_dim, buckets_dim, precision_dim = all_buckets.shape + _, _, last_window_bucket_dup, last_window_bucket_dim, _ = ( + last_window_points.shape + ) + + # Process regular windows + regular_buckets = all_buckets[: window_dim - 1] + regular_buckets = self._bucket_accumulation_regular_windows( + regular_buckets, regular_points + ) + + # Process last window + last_point_buckets = all_buckets[ + window_dim - 1, :, :last_window_bucket_dim, : + ] + last_blank_buckets = all_buckets[ + window_dim - 1, :, last_window_bucket_dim:, : + ] + last_point_buckets = self._bucket_accumulation_last_window( + last_point_buckets, last_window_points + ) + + # Combine last window buckets + last_buckets = jax.lax.broadcast( + jnp.concatenate((last_point_buckets, last_blank_buckets), axis=1), (1,) + ) + + # Combine all buckets + all_buckets = jnp.concatenate((regular_buckets, last_buckets), axis=0) + return all_buckets + + def _bucket_reduction( + self, + all_buckets: jnp.ndarray, + temp_sum: jnp.ndarray, + window_sum: jnp.ndarray, + ) -> jnp.ndarray: + """Reduce buckets to window sums using scan algorithm. + + Implements the bucket reduction phase of Pippenger's algorithm using + a scan-based approach for efficiency. + + Args: + all_buckets: Bucket values (window_dim, coordinate_dim, bucket_dim, + precision_dim). + temp_sum: Temporary sum array (coordinate_dim, window_dim, + precision_dim). + window_sum: Initial window sum (coordinate_dim, window_dim, + precision_dim). + bucket_num_in_window: Number of buckets per window. + parameters: Computation parameters. + + Returns: + Window sums (coordinate_dim, window_dim, precision_dim). + """ + # Transpose for scan + bucket_num_in_window = all_buckets.shape[2] + all_buckets = all_buckets.transpose(2, 1, 0, 3) + + def scan_body(temp_and_window_sum_pack, buckets): + temp_sum, window_sum = temp_and_window_sum_pack + temp_sum = self._padd(temp_sum, buckets) + window_sum = self._padd(window_sum, temp_sum) + return (temp_sum, window_sum), None + + (_, window_sum), _ = jax.lax.scan( + scan_body, + (temp_sum, window_sum), + all_buckets[:bucket_num_in_window], + length=bucket_num_in_window, + reverse=True, + ) + return window_sum + + def _window_merge(self, window_sum: jnp.ndarray) -> jnp.ndarray: + """Merge window results into final MSM result using scan algorithm. + + Implements the window merging phase of Pippenger's algorithm. + + Args: + window_sum: Window sums (coordinate_dim, window_dim, precision_dim). + slice_length: Bit width of each window. + parameters: Computation parameters. + + Returns: + Final MSM result (coordinate_dim, precision_dim). + """ + coordinate_dim, window_dim, precision_dim = window_sum.shape + window_sum = window_sum.transpose(1, 0, 2).reshape( + (window_dim, coordinate_dim, 1, precision_dim) + ) + result = window_sum[window_dim - 1] + + def fori_loop_body(i, result): + result = self._padd(result, result) + return result + + def scan_body(result, window_sum): + result = jax.lax.fori_loop(0, self.slice_bits, fori_loop_body, result) + result = self._padd(result, window_sum) + return result, None + + result, _ = jax.lax.scan( + scan_body, + result, + window_sum[: window_dim - 1], + reverse=True, + length=window_dim - 1, + ) + result = result.reshape((coordinate_dim, precision_dim)) + return result + + def distribute_buckets( + self, tiled_slices: jnp.ndarray, tiled_points: jnp.ndarray + ) -> tuple[jnp.ndarray, jnp.ndarray]: + with jax.default_device(jax.devices("cpu")[0]): + tiled_slices = jax.device_put(tiled_slices, jax.devices("cpu")[0]) + tiled_points = jax.device_put(tiled_points, jax.devices("cpu")[0]) + zero_point = jax.device_put(self.zero_point, jax.devices("cpu")[0]) + regular_buckets, last_window_buckets, metadata = ( + self.distribution_buf_c_kernel_call( + tiled_slices, + tiled_points, + zero_point, + window_num=np.uint32(self.window_num), + regular_bucket_num=np.uint32(self.bucket_num_per_window), + special_bucket_num=np.uint32(self.bucket_num_last_window), + msm_length=np.uint32(self.tile_length), + fixed_regular_padding_size=np.uint32( + self.c_kernel_regular_bucket_size + ), + fixed_special_padding_size=np.uint32( + self.c_kernel_special_bucket_size + * self.special_duplication_ratio + ), + ) + ) + return regular_buckets, last_window_buckets + + def multiscalar_multiply(self, tiled_slices: jnp.ndarray): + + for tile_index in range(self.tile_num): + idx_start = tile_index * self.tile_length + idx_end = idx_start + self.tile_length + tiled_slices_tile = tiled_slices[tile_index] + tiled_points_tile = self.points[idx_start:idx_end] + regular_buckets, last_window_buckets = self.distribute_buckets( + tiled_slices_tile, tiled_points_tile + ) + regular_buckets = jax.device_put(regular_buckets, jax.devices()[0]) + last_window_buckets = jax.device_put( + last_window_buckets, jax.devices()[0] + ) + self.all_buckets = self._bucket_accumulation_all_windows( + self.all_buckets, regular_buckets, last_window_buckets + ) + + window_sum = self._bucket_reduction( + self.all_buckets, self.temp_sum, self.window_sum + ) + result = self._window_merge(window_sum) + return result + + def to_original_format(self, a: jnp.ndarray) -> list: + return self.ec_ctx.to_original_format(a) + + def to_computational_format(self, scalars: list) -> jnp.ndarray: + tiled_slices = self._preprocess_scalars(scalars) + return tiled_slices + + +class CPUDistributionMSMContext( + OldCPUDistributionMSMContext, JaxKernelContextBase +): + + def __init__(self, parameters: dict): + super().__init__(parameters) + JaxKernelContextBase.__init__(self) + + def _init_config_parameters(self): + self.coordinate_dim = self.parameters.get("coordinate_dim", 4) + self.msm_length = self.parameters.get("msm_length") + self.tile_length = self.parameters.get("tile_length") + assert ( + self.msm_length % self.tile_length == 0 + ), "msm_length must be divisible by tile_length" + self.tile_num = self.msm_length // self.tile_length + self.slice_bits = self.parameters.get("slice_bits") + self.scalar_bits = self.parameters.get("scalar_bits") + self.order = self.parameters.get("order") + self.window_num = int(math.ceil(self.scalar_bits / self.slice_bits)) # + self.batch_window_num = self.window_num + self.bucket_num_per_window = 2**self.slice_bits # Note: Include Bucket 0 + orig_bucket_num_last_window = ( + self.order >> ((self.window_num - 1) * self.slice_bits) + ) + 1 # Note: Include Bucket 0 + # Now pad to nearest value so bucket_num_last_window % 8 == 0 + added_padding = (8 - orig_bucket_num_last_window % 8) % 8 + if added_padding > 0: + if added_padding / orig_bucket_num_last_window > 0.1: + print( + f"[bucket_num_last_window] Added {added_padding} to make it" + " divisible by 8, but it is too large for" + f" {orig_bucket_num_last_window}. Setting to 0." + ) + added_padding = 0 + else: + print( + f"[bucket_num_last_window] Was {orig_bucket_num_last_window}, " + f"added {added_padding} to make it divisible by 8" + ) + self.bucket_num_last_window = orig_bucket_num_last_window + added_padding + self.moduli_num = self.ec_ctx.get_finite_field_context().get_moduli_num() + + # Special bucket optimization + self.log_special_duplication_ratio = math.ceil( + math.log2(self.bucket_num_per_window / self.bucket_num_last_window) + ) + self.special_duplication_ratio = 2**self.log_special_duplication_ratio + self.bucket_num_duplication = ( + self.bucket_num_last_window * self.special_duplication_ratio + ) + + def _init_jax_data(self): + self.zero_point = ( + self.ec_ctx.get_finite_field_context().to_computational_format( + self.ec_ctx.zero_point + ) + ) + self.all_buckets = jnp.broadcast_to( + self.zero_point.reshape(1, self.coordinate_dim, 1, self.moduli_num), + ( + self.batch_window_num, + self.coordinate_dim, + self.bucket_num_per_window, + self.moduli_num, + ), + ) + self.temp_sum = jnp.array( + [self.zero_point for _ in range(self.batch_window_num)] + ).transpose(1, 0, 2) + self.window_sum = jnp.array( + [self.zero_point for _ in range(self.batch_window_num)] + ).transpose(1, 0, 2) + self.window_sum = jnp.broadcast_to( + self.window_sum.reshape( + self.coordinate_dim, 1, self.batch_window_num, self.moduli_num + ), + ( + self.coordinate_dim, + self.bucket_num_per_window, + self.batch_window_num, + self.moduli_num, + ), + ) + + # Pre-store shardings for bucket accumulation inputs + self.ba_all_buckets_sharding = None + self.ba_regular_points_sharding = None + self.ba_special_points_sharding = None + # Pre-store shardings for bucket reduction inputs + self.br_all_buckets_sharding = None + self.br_window_sum_sharding = None + # Pre-store shardings for window merge input + self.wm_window_sum_sharding = None + + def _init_shardings(self): + """Pre-compute and store all shardings for compiled/sharded execution.""" + all_buckets_shape = ( + self.batch_window_num, + self.coordinate_dim, + self.bucket_num_per_window, + self.moduli_num, + ) + regular_points_shape = self.ba_input_regular_shape + special_points_shape = self.ba_input_special_shape + window_sum_shape = ( + self.coordinate_dim, + self.bucket_num_per_window, + self.batch_window_num, + self.moduli_num, + ) + wm_window_sum_shape = ( + self.coordinate_dim, + self.batch_window_num, + self.moduli_num, + ) + + # Bucket accumulation shardings + self.ba_all_buckets_sharding, _ = self.create_named_sharding( + shape=all_buckets_shape, axes=[2] + ) + self.ba_regular_points_sharding, _ = self.create_named_sharding( + shape=regular_points_shape, axes=[1] + ) + self.ba_special_points_sharding, _ = self.create_named_sharding( + shape=special_points_shape, axes=[0, 1] + ) + # Bucket reduction shardings + self.br_all_buckets_sharding, _ = self.create_named_sharding( + shape=all_buckets_shape, axes=[2] + ) + self.br_window_sum_sharding, _ = self.create_named_sharding( + shape=window_sum_shape, axes=[1] + ) + # Window merge shardings + self.wm_window_sum_sharding, _ = self.create_named_sharding( + shape=wm_window_sum_shape, axes=[] + ) + + def set_use_sharding(self, use_sharding: bool): + super().set_use_sharding(use_sharding) + if use_sharding: + self._init_shardings() + + def _init_cpu_kernels(self): + expected_regular_bucket_size = self.tile_length / ( + self.bucket_num_per_window + ) + expected_special_bucket_size = math.ceil( + self.tile_length / self.bucket_num_duplication + ) + self.expend_ratio = self.parameters["c_kernel_ret_space_ratio"] + self.c_kernel_regular_bucket_size = int( + expected_regular_bucket_size * self.expend_ratio + ) + self.c_kernel_special_bucket_size = int( + expected_special_bucket_size * self.expend_ratio + ) + + regular_shape = ( + self.window_num - 1, + self.bucket_num_per_window, + self.c_kernel_regular_bucket_size, + self.coordinate_dim, + self.moduli_num, + ) + + special_shape = ( + self.bucket_num_last_window, + self.special_duplication_ratio, + self.c_kernel_special_bucket_size, + self.coordinate_dim, + self.moduli_num, + ) + self.ba_input_regular_shape = regular_shape + self.ba_input_special_shape = special_shape + + lib = ctypes.cdll.LoadLibrary(ensure_distribution_kernel()) + jax.ffi.register_ffi_target( + "distribute_buf", + jax.ffi.pycapsule(lib.DistributeBufZero), + platform="cpu", + ) + self.distribution_buf_c_kernel_call = ffi.ffi_call( + "distribute_buf", + ( + jax.ShapeDtypeStruct(regular_shape, jnp.uint32), + jax.ShapeDtypeStruct(special_shape, jnp.uint32), + jax.ShapeDtypeStruct((2,), jnp.uint32), + ), + ) + + def _bucket_accumulation_regular_windows_opt( + self, + regular_buckets: jnp.ndarray, + all_points: jnp.ndarray, + ) -> jnp.ndarray: + """Accumulate points for all regular windows. + + Args: + regular_buckets: Initial bucket values (window_dim, coordinate_dim, + bucket_dim, precision_dim). + all_points: Points for all windows (window_dim, bucket_size_dim, + coordinate_dim, bucket_dim, precision_dim). + + Returns: + Accumulated bucket values for all regular windows. + """ + + bucket_size_dim = all_points.shape[1] + regular_buckets = regular_buckets.transpose(1, 0, 2, 3) + all_points = all_points.transpose(1, 2, 0, 3, 4) + + def scan_body(buckets, points): + buckets = self._padd(buckets, points) + return buckets, None + + buckets, _ = jax.lax.scan( + scan_body, regular_buckets, all_points, length=bucket_size_dim + ) + buckets = buckets.transpose(1, 0, 2, 3) + return buckets + + def _bucket_accumulation_all_windows_opt( + self, + all_buckets: jnp.ndarray, + regular_points: jnp.ndarray, + last_window_points: jnp.ndarray, + ) -> jnp.ndarray: + """Accumulate points for all windows with distributed optimization. + + This is the main bucket accumulation kernel that handles both regular + windows and the special last window with duplication optimization. + + Args: + all_buckets: Initial bucket values (window_dim, coordinate_dim, + bucket_dim, precision_dim). + regular_points: Points for regular windows (window_dim-1, bucket_num, + bucket_size, coord_dim, prec_dim). + last_window_points: Points for last window (bucket_num, dup_ratio, + bucket_size, coord_dim, prec_dim). + parameters: Computation parameters. + + Returns: + Accumulated bucket values for all windows. + """ + # Transpose for computation + regular_points = regular_points.transpose(0, 2, 3, 1, 4) + last_window_points = last_window_points.transpose(2, 3, 1, 0, 4) + + window_dim, coordinate_dim, buckets_dim, precision_dim = all_buckets.shape + _, _, last_window_bucket_dup, last_window_bucket_dim, _ = ( + last_window_points.shape + ) + + # Process regular windows + regular_buckets = all_buckets[: window_dim - 1] + regular_buckets = self._bucket_accumulation_regular_windows_opt( + regular_buckets, regular_points + ) + + # Process last window + last_point_buckets = all_buckets[ + window_dim - 1, :, :last_window_bucket_dim, : + ] + last_blank_buckets = all_buckets[ + window_dim - 1, :, last_window_bucket_dim:, : + ] + last_point_buckets = self._bucket_accumulation_last_window( + last_point_buckets, last_window_points + ) + + # Combine last window buckets + last_buckets = jax.lax.broadcast( + jnp.concatenate((last_point_buckets, last_blank_buckets), axis=1), (1,) + ) + + # Combine all buckets + all_buckets = jnp.concatenate((regular_buckets, last_buckets), axis=0) + return all_buckets + + def _bucket_reduction( + self, + all_buckets: jnp.ndarray, + temp_sum: jnp.ndarray, + window_sum: jnp.ndarray, + ) -> jnp.ndarray: + """Reduce buckets to window sums using tree-based parallel algorithm. + + Computes S = sum_{i=0}^{n-1} i * B[i] per window via adjacent-pair + tree reduction in O(log n) parallel steps. + + Tracks H = 2^k * bucket_sums (doubling each level) so that H[1::2] + provides the correction weight directly. No separate bucket array needed. + + Args: + all_buckets: Bucket values (window_dim, coordinate_dim, bucket_dim, + precision_dim). + temp_sum: Unused (kept for interface compatibility). + window_sum: Initial window sum (coordinate_dim, bucket_dim, window_dim, + precision_dim). + + Returns: + Window sums (coordinate_dim, window_dim, precision_dim). + """ + bucket_num_in_window = all_buckets.shape[2] + # (window, coord, bucket, prec) → (coord, bucket, window, prec) + all_buckets = all_buckets.transpose(1, 2, 0, 3) + iter_num = int(math.log2(bucket_num_in_window)) + + cd, m, wd, pd = all_buckets.shape + for _ in range(iter_num): + m = all_buckets.shape[1] + half = m // 2 + + all_buckets = all_buckets.reshape(cd, half, 2, wd, pd) + window_sum = window_sum.reshape(cd, half, 2, wd, pd) + + all_buckets_left, all_buckets_right = ( + all_buckets[:, :, 0], + all_buckets[:, :, 1], + ) + window_sum_left, window_sum_right = ( + window_sum[:, :, 0], + window_sum[:, :, 1], + ) + + window_sum = self._padd( + self._padd(window_sum_left, window_sum_right), all_buckets_right + ) + bucket_sum = self._padd(all_buckets_left, all_buckets_right) + all_buckets = self._padd(bucket_sum, bucket_sum) + + return window_sum[:, 0] + + def _get_ba_shape_dtype_structs(self): + """Get ShapeDtypeStructs for bucket_accumulation_all_windows inputs.""" + all_buckets_shape = ( + self.batch_window_num, + self.coordinate_dim, + self.bucket_num_per_window, + self.moduli_num, + ) + regular_points_shape = self.ba_input_regular_shape + special_points_shape = self.ba_input_special_shape + + if self.use_sharding: + return [ + jax.ShapeDtypeStruct( + all_buckets_shape, + jnp.uint32, + sharding=self.ba_all_buckets_sharding, + ), + jax.ShapeDtypeStruct( + regular_points_shape, + jnp.uint32, + sharding=self.ba_regular_points_sharding, + ), + jax.ShapeDtypeStruct( + special_points_shape, + jnp.uint32, + sharding=self.ba_special_points_sharding, + ), + ] + return [ + jax.ShapeDtypeStruct(all_buckets_shape, jnp.uint32), + jax.ShapeDtypeStruct(regular_points_shape, jnp.uint32), + jax.ShapeDtypeStruct(special_points_shape, jnp.uint32), + ] + + def _get_br_shape_dtype_structs(self): + """Get ShapeDtypeStructs for bucket_reduction inputs.""" + all_buckets_shape = ( + self.batch_window_num, + self.coordinate_dim, + self.bucket_num_per_window, + self.moduli_num, + ) + temp_sum_shape = ( + self.coordinate_dim, + self.batch_window_num, + self.moduli_num, + ) + window_sum_shape = ( + self.coordinate_dim, + self.bucket_num_per_window, + self.batch_window_num, + self.moduli_num, + ) + + if self.use_sharding: + return [ + jax.ShapeDtypeStruct( + all_buckets_shape, + jnp.uint32, + sharding=self.br_all_buckets_sharding, + ), + jax.ShapeDtypeStruct(temp_sum_shape, jnp.uint32), + jax.ShapeDtypeStruct( + window_sum_shape, jnp.uint32, sharding=self.br_window_sum_sharding + ), + ] + return [ + jax.ShapeDtypeStruct(all_buckets_shape, jnp.uint32), + jax.ShapeDtypeStruct(temp_sum_shape, jnp.uint32), + jax.ShapeDtypeStruct(window_sum_shape, jnp.uint32), + ] + + def _get_wm_shape_dtype_structs(self): + """Get ShapeDtypeStructs for window_merge input.""" + window_sum_shape = ( + self.coordinate_dim, + self.batch_window_num, + self.moduli_num, + ) + + if self.use_sharding: + return [ + jax.ShapeDtypeStruct( + window_sum_shape, jnp.uint32, sharding=self.wm_window_sum_sharding + ) + ] + return [jax.ShapeDtypeStruct(window_sum_shape, jnp.uint32)] + + def context_hash(self) -> str: + return hash_args( + self.__class__.__name__, + self.ec_ctx.context_hash() + if hasattr(self.ec_ctx, "context_hash") + else str(self.ec_ctx.__class__.__name__), + self.slice_bits, + self.scalar_bits, + self.msm_length, + self.tile_length, + self.bucket_num_per_window, + self.bucket_num_last_window, + self.use_sharding, + ) + + def serialize(self, parameters: Optional[dict] = None): + ba_structs = self._get_ba_shape_dtype_structs() + br_structs = self._get_br_shape_dtype_structs() + wm_structs = self._get_wm_shape_dtype_structs() + kernel_hash = hash_args( + self.context_hash(), parameters if parameters else {} + ) + class_name = self.__class__.__name__ + + store_jax_executable( + self._bucket_accumulation_all_windows, + ba_structs[0], + ba_structs[1], + ba_structs[2], + name=f"{class_name}_bucket_accumulation_all_windows_{kernel_hash}", + ) + store_jax_executable( + self._bucket_reduction, + br_structs[0], + br_structs[1], + br_structs[2], + name=f"{class_name}_bucket_reduction_{kernel_hash}", + ) + store_jax_executable( + self._window_merge, + wm_structs[0], + name=f"{class_name}_window_merge_{kernel_hash}", + ) + + def compile(self, parameters: Optional[dict] = None): + ba_structs = self._get_ba_shape_dtype_structs() + br_structs = self._get_br_shape_dtype_structs() + wm_structs = self._get_wm_shape_dtype_structs() + kernel_hash = hash_args( + self.context_hash(), parameters if parameters else {} + ) + class_name = self.__class__.__name__ + + ba_kernel = load_jax_executable( + f"{class_name}_bucket_accumulation_all_windows_{kernel_hash}" + ) + br_kernel = load_jax_executable( + f"{class_name}_bucket_reduction_{kernel_hash}" + ) + wm_kernel = load_jax_executable(f"{class_name}_window_merge_{kernel_hash}") + + if None in [ba_kernel, br_kernel, wm_kernel]: + warnings.warn( + "Not found stored serialized compiled kernels for MSM, compiling...", + UserWarning, + stacklevel=2, + ) + + ba_hash = hash_args( + ba_structs[0].shape, + ba_structs[0].dtype.__str__(), + ba_structs[0].shape, + ba_structs[0].dtype.__str__(), + ) + br_hash = hash_args( + br_structs[0].shape, + br_structs[0].dtype.__str__(), + br_structs[0].shape, + br_structs[0].dtype.__str__(), + ) + wm_hash = hash_args( + wm_structs[0].shape, + wm_structs[0].dtype.__str__(), + wm_structs[0].shape, + wm_structs[0].dtype.__str__(), + ) + self.compiled_kernels.setdefault(ba_hash, {})[ + "bucket_accumulation_all_windows" + ] = ( + ba_kernel + if ba_kernel is not None + else jax_jit_lower_compile( + self._bucket_accumulation_all_windows, + ba_structs[0], + ba_structs[1], + ba_structs[2], + ) + ) + self.compiled_kernels.setdefault(br_hash, {})["bucket_reduction"] = ( + br_kernel + if br_kernel is not None + else jax_jit_lower_compile( + self._bucket_reduction, br_structs[0], br_structs[1], br_structs[2] + ) + ) + self.compiled_kernels.setdefault(wm_hash, {})["window_merge"] = ( + wm_kernel + if wm_kernel is not None + else jax_jit_lower_compile(self._window_merge, wm_structs[0]) + ) + self.use_compiled_kernels = True + + def bucket_accumulation_all_windows( + self, all_buckets, regular_points, last_window_points + ): + if self.use_compiled_kernels: + kernel_hash = hash_args( + all_buckets.shape, + all_buckets.dtype.__str__(), + all_buckets.shape, + all_buckets.dtype.__str__(), + ) + return self.compiled_kernels[kernel_hash][ + "bucket_accumulation_all_windows" + ](all_buckets, regular_points, last_window_points) + else: + return self._bucket_accumulation_all_windows_opt( + all_buckets, regular_points, last_window_points + ) + + def bucket_reduction(self, all_buckets, temp_sum, window_sum): + if self.use_compiled_kernels: + kernel_hash = hash_args( + all_buckets.shape, + all_buckets.dtype.__str__(), + all_buckets.shape, + all_buckets.dtype.__str__(), + ) + return self.compiled_kernels[kernel_hash]["bucket_reduction"]( + all_buckets, temp_sum, window_sum + ) + else: + return self._bucket_reduction(all_buckets, temp_sum, window_sum) + + def window_merge(self, window_sum): + if self.use_compiled_kernels: + kernel_hash = hash_args( + window_sum.shape, + window_sum.dtype.__str__(), + window_sum.shape, + window_sum.dtype.__str__(), + ) + return self.compiled_kernels[kernel_hash]["window_merge"](window_sum) + else: + return self._window_merge(window_sum) + + def multiscalar_multiply(self, tiled_slices: jnp.ndarray): + + for tile_index in range(self.tile_num): + idx_start = tile_index * self.tile_length + idx_end = idx_start + self.tile_length + tiled_slices_tile = tiled_slices[tile_index] + # tiled_points_tile = self.points[:, idx_start:idx_end].transpose(1, 0, 2) + tiled_points_tile = self.points[idx_start:idx_end] + regular_buckets, last_window_buckets = self.distribute_buckets( + tiled_slices_tile, tiled_points_tile + ) + if self.use_sharding: + regular_buckets = regular_buckets.to_device( + self.ba_regular_points_sharding + ) + last_window_buckets = last_window_buckets.to_device( + self.ba_special_points_sharding + ) + else: + regular_buckets = jax.device_put(regular_buckets, jax.devices()[0]) + last_window_buckets = jax.device_put( + last_window_buckets, jax.devices()[0] + ) + self.all_buckets = self.bucket_accumulation_all_windows( + self.all_buckets, regular_buckets, last_window_buckets + ) + + window_sum = self.bucket_reduction( + self.all_buckets, self.temp_sum, self.window_sum + ) + result = self.window_merge(window_sum) + return result + + +class TPUDistributionMSMContext(CPUDistributionMSMContext): + """MSM context where bucket distribution runs on TPU using a sort-based + + bucketize algorithm instead of a CPU C kernel. + + Distribution becomes a major TPU kernel alongside bucket_accumulation, + bucket_reduction, and window_merge — all four are compiled by `compile()` + and dispatched the same way. + """ + + def __init__(self, parameters: dict): + MultiscalarMultiplicationContextBase.__init__(self, parameters) + self._init_config_parameters() + self._init_jax_data() + self._init_kernels() + self._init_point_parameters() + JaxKernelContextBase.__init__(self) + + def _init_kernels(self): + expected_regular_bucket_size = self.tile_length / self.bucket_num_per_window + expected_special_bucket_size = math.ceil( + self.tile_length / self.bucket_num_duplication + ) + self.expend_ratio = self.parameters["c_kernel_ret_space_ratio"] + self.regular_bucket_size = int( + expected_regular_bucket_size * self.expend_ratio + ) + self.special_bucket_size = int( + expected_special_bucket_size * self.expend_ratio + ) + + self.ba_input_regular_shape = ( + self.window_num - 1, + self.bucket_num_per_window, + self.regular_bucket_size, + self.coordinate_dim, + self.moduli_num, + ) + self.ba_input_special_shape = ( + self.bucket_num_last_window, + self.special_duplication_ratio, + self.special_bucket_size, + self.coordinate_dim, + self.moduli_num, + ) + self.dist_input_slices_shape = (self.window_num, self.tile_length) + self.dist_input_points_shape = ( + self.tile_length, + self.coordinate_dim, + self.moduli_num, + ) + + def _init_point_parameters(self): + raw_points = utils.read_external_msm_file( + self.parameters.get("points_path"), "points" + ) + points = self.ec_ctx.to_computational_format(raw_points) + points = points.transpose(1, 0, 2) # (N, coordinate_dim, moduli_num) + self.points = jax.device_put(points, jax.devices()[0]) + + def _bucketize_regular_windows( + self, items: jnp.ndarray, bucket_ids: jnp.ndarray + ) -> jnp.ndarray: + """Sort-based bucketization for regular windows. + + Args: + items: Points to distribute (tile_length, coordinate_dim, moduli_num). + bucket_ids: Slice values for the regular windows (window_num - 1, + tile_length). + + Returns: + Bucketed points (window_num - 1, bucket_num_per_window, + regular_bucket_size, + coordinate_dim, moduli_num). + """ + n = self.tile_length + b = self.window_num - 1 + + bucket_ids_i32 = bucket_ids.astype(jnp.int32) + sorted_indices = jnp.argsort(bucket_ids_i32, axis=1, stable=True) + sorted_buckets = jnp.take_along_axis(bucket_ids_i32, sorted_indices, axis=1) + + is_boundary = jnp.concatenate( + [ + jnp.ones_like(sorted_buckets[:, :1], dtype=jnp.bool_), + sorted_buckets[:, 1:] != sorted_buckets[:, :-1], + ], + axis=1, + ) + positions = jnp.broadcast_to(jnp.arange(n, dtype=jnp.int32), (b, n)) + boundary_positions = jnp.where(is_boundary, positions, jnp.int32(0)) + bucket_starts = jax.lax.associative_scan( + jnp.maximum, boundary_positions, axis=1 + ) + within_rank = positions - bucket_starts + + batch_idx = jnp.broadcast_to( + jnp.arange(b, dtype=jnp.int32)[:, None], (b, n) + ) + zeros_bn = jnp.zeros((b, n), dtype=jnp.int32) + inv_rank = zeros_bn.at[batch_idx, sorted_indices].set(within_rank) + items_b = jnp.broadcast_to(items[None, :, :, :], (b, n, 4, 32)) + + # items_sorted = items[sorted_indices] + + output = jnp.broadcast_to( + self.zero_point, + ( + b, + self.bucket_num_per_window, + self.regular_bucket_size, + self.coordinate_dim, + self.moduli_num, + ), + ).copy() + # output = output.at[batch_idx, sorted_buckets, within_rank].set(items_sorted) + output = output.at[batch_idx, bucket_ids_i32, inv_rank].set(items_b) + return output + + def _bucketize_last_window( + self, items: jnp.ndarray, last_slices: jnp.ndarray + ) -> jnp.ndarray: + """Sort-based bucketization for the last (special) window with duplication. + + Each logical bucket s in [0, bucket_num_last_window) is replicated + special_duplication_ratio times. Item i is mapped to duplicate (i mod sdup), + spreading items across the duplicates so the per-slot capacity stays at + special_bucket_size. Reduction back to logical buckets is handled later by + _bucket_accumulation_last_window. + + Args: + items: Points to distribute (tile_length, coordinate_dim, moduli_num). + last_slices: Slice values for the last window (tile_length,). + + Returns: + Bucketed points (bucket_num_last_window, special_duplication_ratio, + special_bucket_size, coordinate_dim, moduli_num). + """ + n = self.tile_length + sdup = self.special_duplication_ratio + + dup_ids = jnp.arange(n, dtype=jnp.int32) % sdup + bucket_ids_i32 = last_slices.astype(jnp.int32) * sdup + dup_ids + + sorted_indices = jnp.argsort(bucket_ids_i32, stable=True) + sorted_buckets = bucket_ids_i32[sorted_indices] + + is_boundary = jnp.concatenate([ + jnp.ones((1,), dtype=jnp.bool_), + sorted_buckets[1:] != sorted_buckets[:-1], + ]) + positions = jnp.arange(n, dtype=jnp.int32) + boundary_positions = jnp.where(is_boundary, positions, jnp.int32(0)) + bucket_starts = jax.lax.associative_scan(jnp.maximum, boundary_positions) + within_rank = positions - bucket_starts + + zeros_n = jnp.zeros((n,), dtype=jnp.int32) + inv_rank = zeros_n.at[sorted_indices].set(within_rank) + + output = jnp.broadcast_to( + self.zero_point, + ( + self.bucket_num_duplication, + self.special_bucket_size, + self.coordinate_dim, + self.moduli_num, + ), + ).copy() + output = output.at[bucket_ids_i32, inv_rank].set(items) + return output.reshape( + self.bucket_num_last_window, + self.special_duplication_ratio, + self.special_bucket_size, + self.coordinate_dim, + self.moduli_num, + ) + + def _distribute_buckets( + self, + regular_slices: jnp.ndarray, + last_window_slice: jnp.ndarray, + tiled_points: jnp.ndarray, + ): + """Major TPU kernel: distribute one tile of points into buckets. + + Args: + regular_slices: Slice values for regular windows (window_num - 1, + tile_length). + last_window_slice: Slice values for the last window (tile_length,). + tiled_points: Points for one tile (tile_length, coordinate_dim, + moduli_num). + + Returns: + regular_buckets: (window_num - 1, bucket_num_per_window, + regular_bucket_size, + coordinate_dim, moduli_num). + last_window_buckets: (bucket_num_last_window, special_duplication_ratio, + special_bucket_size, coordinate_dim, moduli_num). + """ + regular_buckets = self._bucketize_regular_windows( + tiled_points, regular_slices + ) + last_window_buckets = self._bucketize_last_window( + tiled_points, last_window_slice + ) + return regular_buckets, last_window_buckets + + def _get_dist_shape_dtype_structs(self): + """Get ShapeDtypeStructs for _distribute_buckets inputs.""" + regular_slices_shape = (self.window_num - 1, self.tile_length) + last_window_slice_shape = (self.tile_length,) + return [ + jax.ShapeDtypeStruct(regular_slices_shape, jnp.int32), + jax.ShapeDtypeStruct(last_window_slice_shape, jnp.int32), + jax.ShapeDtypeStruct(self.dist_input_points_shape, jnp.uint32), + ] + + def compile(self, parameters: Optional[dict] = None): + dist_structs = self._get_dist_shape_dtype_structs() + ba_structs = self._get_ba_shape_dtype_structs() + br_structs = self._get_br_shape_dtype_structs() + wm_structs = self._get_wm_shape_dtype_structs() + + dist_hash = hash_args( + dist_structs[0].shape, + dist_structs[0].dtype.__str__(), + dist_structs[1].shape, + dist_structs[1].dtype.__str__(), + dist_structs[2].shape, + dist_structs[2].dtype.__str__(), + ) + ba_hash = hash_args( + ba_structs[0].shape, + ba_structs[0].dtype.__str__(), + ba_structs[0].shape, + ba_structs[0].dtype.__str__(), + ) + br_hash = hash_args( + br_structs[0].shape, + br_structs[0].dtype.__str__(), + br_structs[0].shape, + br_structs[0].dtype.__str__(), + ) + wm_hash = hash_args( + wm_structs[0].shape, + wm_structs[0].dtype.__str__(), + wm_structs[0].shape, + wm_structs[0].dtype.__str__(), + ) + + self.compiled_kernels.setdefault(dist_hash, {})["distribute_buckets"] = ( + jax_jit_lower_compile( + self._distribute_buckets, + dist_structs[0], + dist_structs[1], + dist_structs[2], + ) + ) + self.compiled_kernels.setdefault(ba_hash, {})[ + "bucket_accumulation_all_windows" + ] = jax_jit_lower_compile( + self._bucket_accumulation_all_windows_opt, + ba_structs[0], + ba_structs[1], + ba_structs[2], + ) + self.compiled_kernels.setdefault(br_hash, {})["bucket_reduction"] = ( + jax_jit_lower_compile( + self._bucket_reduction, br_structs[0], br_structs[1], br_structs[2] + ) + ) + self.compiled_kernels.setdefault(wm_hash, {})["window_merge"] = ( + jax_jit_lower_compile(self._window_merge, wm_structs[0]) + ) + self.use_compiled_kernels = True + + def distribute_buckets(self, regular_slices, last_window_slice, tiled_points): + if self.use_compiled_kernels: + kernel_hash = hash_args( + regular_slices.shape, + regular_slices.dtype.__str__(), + last_window_slice.shape, + last_window_slice.dtype.__str__(), + tiled_points.shape, + tiled_points.dtype.__str__(), + ) + return self.compiled_kernels[kernel_hash]["distribute_buckets"]( + regular_slices, last_window_slice, tiled_points + ) + else: + return self._distribute_buckets( + regular_slices, last_window_slice, tiled_points + ) + + def multiscalar_multiply(self, tiled_slices: jnp.ndarray): + tpu = jax.devices()[0] + tiled_slices = jax.device_put(tiled_slices, tpu) + + for tile_index in range(self.tile_num): + idx_start = tile_index * self.tile_length + idx_end = idx_start + self.tile_length + tiled_slices_tile = tiled_slices[tile_index] + regular_slices_tile = tiled_slices_tile[: self.window_num - 1] + last_window_slice_tile = tiled_slices_tile[self.window_num - 1] + tiled_points_tile = self.points[idx_start:idx_end] + regular_buckets, last_window_buckets = self.distribute_buckets( + regular_slices_tile, last_window_slice_tile, tiled_points_tile + ) + self.all_buckets = self.bucket_accumulation_all_windows( + self.all_buckets, regular_buckets, last_window_buckets + ) + + window_sum = self.bucket_reduction( + self.all_buckets, self.temp_sum, self.window_sum + ) + result = self.window_merge(window_sum) + return result + + +class FusionMSMContext( + MultiscalarMultiplicationContextBase, JaxKernelContextBase +): + """MSM context that fuses bucket distribution and accumulation on TPU.""" + + def __init__(self, parameters: dict): + super().__init__(parameters) + MultiscalarMultiplicationContextBase.__init__(self, parameters) + self._init_config_parameters() + self._init_jax_data() + self._init_point_parameters() + JaxKernelContextBase.__init__(self) + self.use_fused = False + + def _init_config_parameters(self): + self.coordinate_dim = self.parameters.get("coordinate_dim", 4) + self.msm_length = self.parameters.get("msm_length") + self.tile_length = self.parameters.get("tile_length") + assert self.msm_length is not None + assert self.tile_length is not None + assert ( + self.msm_length % self.tile_length == 0 + and self.msm_length >= self.tile_length + ), ( + "msm_length must be divisible by tile_length and greater than or equal" + " to tile_length" + ) + + self.tile_num = self.msm_length // self.tile_length + self.slice_bits = self.parameters.get("slice_bits") + if self.slice_bits != 15: + warnings.warn( + f"Slice bits {self.slice_bits} may cause performance issue, using 15" + " instead." + ) + if 2**self.slice_bits > self.tile_length: + warnings.warn( + f"2**{self.slice_bits} is greater than tile_length, which may cause" + " performance issue." + ) + self.scalar_bits = self.parameters.get("scalar_bits") + self.order = self.parameters.get("order") + self.window_num = int(math.ceil(self.scalar_bits / self.slice_bits)) # + self.batch_window_num = self.window_num + self.bucket_num_per_window = 2**self.slice_bits # Note: Include Bucket 0 + orig_bucket_num_last_window = ( + self.order >> ((self.window_num - 1) * self.slice_bits) + ) + 1 # Note: Include Bucket 0 + # Now pad to nearest value so bucket_num_last_window % 8 == 0 + added_padding = (8 - orig_bucket_num_last_window % 8) % 8 + if added_padding > 0: + if added_padding / orig_bucket_num_last_window > 0.1: + print( + f"[bucket_num_last_window] Added {added_padding} to make it" + " divisible by 8, but it is too large for" + f" {orig_bucket_num_last_window}. Setting to 0." + ) + added_padding = 0 + else: + print( + f"[bucket_num_last_window] Was {orig_bucket_num_last_window}, " + f"added {added_padding} to make it divisible by 8" + ) + self.bucket_num_last_window = orig_bucket_num_last_window + added_padding + self.moduli_num = self.ec_ctx.get_finite_field_context().get_moduli_num() + + # Special bucket optimization + self.log_special_duplication_ratio = math.ceil( + math.log2(self.bucket_num_per_window / self.bucket_num_last_window) + ) + self.special_duplication_ratio = 2**self.log_special_duplication_ratio + self.bucket_num_duplication = ( + self.bucket_num_last_window * self.special_duplication_ratio + ) + expected_regular_bucket_size = self.tile_length / self.bucket_num_per_window + expected_special_bucket_size = math.ceil( + self.tile_length / self.bucket_num_duplication + ) + self.expend_ratio = self.parameters["c_kernel_ret_space_ratio"] + self.regular_bucket_size = int( + expected_regular_bucket_size * self.expend_ratio + ) + self.special_bucket_size = int( + expected_special_bucket_size * self.expend_ratio + ) + + def _init_point_parameters(self): + if self.parameters.get("points_path") is not None: + raw_points = utils.read_external_msm_file( + self.parameters.get("points_path"), "points" + ) + points = self.ec_ctx.to_computational_format(raw_points) + points = points.transpose(1, 0, 2) # (N, coordinate_dim, moduli_num) + else: + points = jax.random.randint( + jax.random.PRNGKey(0), + (self.msm_length, self.coordinate_dim, self.moduli_num), + 0, + 2**16, + dtype=jnp.uint32, + ) + # Reshape once here into per-tile layout so both execution paths index + # tiles with a single leading-axis lookup (self.points[tile_index]) and + # the fused path can pass self.points directly as tiled_points. + self.points = points.reshape( + self.tile_num, + self.tile_length, + self.coordinate_dim, + self.moduli_num, + ) + + def _init_jax_data(self): + self.zero_point = ( + self.ec_ctx.get_finite_field_context().to_computational_format( + self.ec_ctx.zero_point + ) + ) + all_buckets = jnp.broadcast_to( + self.zero_point.reshape(self.coordinate_dim, 1, 1, self.moduli_num), + ( + self.coordinate_dim, + self.batch_window_num, + self.bucket_num_per_window, + self.moduli_num, + ), + ) + self.regular_window_buckets = all_buckets[:, : self.window_num - 1] + self.last_window_buckets = all_buckets[:, self.window_num - 1] + self.window_sum = jnp.broadcast_to( + self.zero_point.reshape(self.coordinate_dim, 1, 1, self.moduli_num), + ( + self.coordinate_dim, + self.bucket_num_per_window, + self.batch_window_num, + self.moduli_num, + ), + ) + self.fused_reg_slices_padded = None + self.fused_reg_slices_sharding = None + self.fused_last_slices_padded = None + self.fused_last_slices_sharding = None + self.fused_tiled_points_padded = None + self.fused_tiled_points_sharding = None + self.bba_reg_slices_padded = None + self.bba_reg_slices_sharding = None + self.bba_last_slice_padded = None + self.bba_last_slice_sharding = None + self.bba_points_padded = None + self.bba_points_sharding = None + self.bba_reg_buckets_padded = None + self.bba_reg_buckets_sharding = None + self.bba_last_buckets_padded = None + self.bba_last_buckets_sharding = None + self.bws_reg_buckets_padded = None + self.bws_reg_buckets_sharding = None + self.bws_last_buckets_padded = None + self.bws_last_buckets_sharding = None + self.bws_window_sum_padded = None + self.bws_window_sum_sharding = None + + def _preprocess_scalars(self, scalars: list): + tiled_scalar_list = utils.split_list(scalars, self.tile_length) + tiled_slices_list = [] + for tiled_scalars in tiled_scalar_list: + sliced_scalars = utils.slice_scalars( + tiled_scalars, self.scalar_bits, self.slice_bits + ) + tiled_slices_list.append(sliced_scalars) + tiled_slices_list = jnp.array(tiled_slices_list, dtype=jnp.int32) + return tiled_slices_list + + def to_original_format(self, a: jnp.ndarray) -> list: + return self.ec_ctx.to_original_format(a) + + def to_computational_format( + self, scalars: Optional[list] = None + ) -> jnp.ndarray: + if scalars is None: + warnings.warn("No scalars provided, generating random scalars.") + scalars = [ + random.randint(0, self.order - 1) for _ in range(self.msm_length) + ] + tiled_slices = self._preprocess_scalars(scalars) + + # When sharding is enabled, pad tiled_slices along the tile_length axis + # once here so every per-tile slice (regular + last) already matches the + # padded shape expected by the compiled BBA kernel. The hot loop can then + # just `to_device(...)` without any shape-time work. + if self.use_sharding: + padded_tile_length = self.bba_reg_slices_padded[1] + if tiled_slices.shape[-1] != padded_tile_length: + target_shape = ( + tiled_slices.shape[0], + tiled_slices.shape[1], + padded_tile_length, + ) + warnings.warn( + f"Tiled slices shape {tiled_slices.shape} does not match padded" + f" tile length {padded_tile_length}, padding to {target_shape}." + ) + tiled_slices = utils.pad_jax_array(tiled_slices, target_shape) + return tiled_slices + + def _bucketize_regular_windows( + self, items: jnp.ndarray, bucket_ids: jnp.ndarray + ) -> jnp.ndarray: + """Sort-based bucketization for regular windows. + + Args: + items: Points to distribute (tile_length, coordinate_dim, moduli_num). + bucket_ids: Slice values for the regular windows (window_num - 1, + tile_length). + + Returns: + Bucketed points (regular_bucket_size, coordinate_dim, + window_num - 1, bucket_num_per_window, moduli_num). + """ + n = self.tile_length + b = self.window_num - 1 + + bucket_ids_i32 = bucket_ids.astype(jnp.int32) + sorted_indices = jnp.argsort(bucket_ids_i32, axis=1, stable=True) + sorted_buckets = jnp.take_along_axis(bucket_ids_i32, sorted_indices, axis=1) + + is_boundary = jnp.concatenate( + [ + jnp.ones_like(sorted_buckets[:, :1], dtype=jnp.bool_), + sorted_buckets[:, 1:] != sorted_buckets[:, :-1], + ], + axis=1, + ) + positions = jnp.broadcast_to(jnp.arange(n, dtype=jnp.int32), (b, n)) + boundary_positions = jnp.where(is_boundary, positions, jnp.int32(0)) + bucket_starts = jax.lax.associative_scan( + jnp.maximum, boundary_positions, axis=1 + ) + within_rank = positions - bucket_starts + + batch_idx = jnp.broadcast_to( + jnp.arange(b, dtype=jnp.int32)[:, None], (b, n) + ) + zeros_bn = jnp.zeros((b, n), dtype=jnp.int32) + sharding = jax.typeof(sorted_indices).sharding + zeros_bn = jax.device_put(zeros_bn, sharding) + batch_idx = jax.device_put(batch_idx, sharding) + inv_rank = zeros_bn.at[batch_idx, sorted_indices].set( + within_rank, out_sharding=sharding + ) + items_b = jnp.broadcast_to( + items[None, :, :, :], (b, n, self.coordinate_dim, self.moduli_num) + ) + + output = jnp.broadcast_to( + self.zero_point.reshape(1, self.coordinate_dim, 1, 1, self.moduli_num), + ( + self.regular_bucket_size, + self.coordinate_dim, + b, + self.bucket_num_per_window, + self.moduli_num, + ), + ).copy() + sharding_out = jax.typeof(output).sharding + output = output.at[inv_rank, :, batch_idx, bucket_ids_i32].set( + items_b, out_sharding=sharding_out + ) + return output + + def _bucketize_last_window( + self, items: jnp.ndarray, last_slices: jnp.ndarray + ) -> jnp.ndarray: + """Sort-based bucketization for the last (special) window with duplication. + + Each logical bucket s in [0, bucket_num_last_window) is replicated + special_duplication_ratio times. Item i is mapped to duplicate (i mod sdup), + spreading items across the duplicates so the per-slot capacity stays at + special_bucket_size. Reduction back to logical buckets is handled later by + _bucket_accumulation_last_window. + + Args: + items: Points to distribute (tile_length, coordinate_dim, moduli_num). + last_slices: Slice values for the last window (tile_length,). + + Returns: + Bucketed points (special_bucket_size, coordinate_dim, + special_duplication_ratio, bucket_num_last_window, + moduli_num). + """ + n = self.tile_length + sdup = self.special_duplication_ratio + + dup_ids = jnp.arange(n, dtype=jnp.int32) % sdup + bucket_ids_i32 = dup_ids * self.bucket_num_last_window + last_slices.astype( + jnp.int32 + ) + + sorted_indices = jnp.argsort(bucket_ids_i32, stable=True) + sorted_buckets = bucket_ids_i32[sorted_indices] + + is_boundary = jnp.concatenate([ + jnp.ones((1,), dtype=jnp.bool_), + sorted_buckets[1:] != sorted_buckets[:-1], + ]) + positions = jnp.arange(n, dtype=jnp.int32) + boundary_positions = jnp.where(is_boundary, positions, jnp.int32(0)) + bucket_starts = jax.lax.associative_scan(jnp.maximum, boundary_positions) + within_rank = positions - bucket_starts + + zeros_n = jnp.zeros((n,), dtype=jnp.int32) + inv_rank = zeros_n.at[sorted_indices].set(within_rank) + + output = jnp.broadcast_to( + self.zero_point.reshape(1, self.coordinate_dim, 1, self.moduli_num), + ( + self.special_bucket_size, + self.coordinate_dim, + self.bucket_num_duplication, + self.moduli_num, + ), + ).copy() + output = output.at[inv_rank, :, bucket_ids_i32].set(items) + return output.reshape( + self.special_bucket_size, + self.coordinate_dim, + self.special_duplication_ratio, + self.bucket_num_last_window, + self.moduli_num, + ) + + def _distribute_buckets( + self, + regular_slices: jnp.ndarray, + last_window_slice: jnp.ndarray, + tiled_points: jnp.ndarray, + ): + """Major TPU kernel: distribute one tile of points into buckets. + + Args: + regular_slices: Slice values for regular windows (window_num - 1, + tile_length). + last_window_slice: Slice values for the last window (tile_length,). + tiled_points: Points for one tile (tile_length, coordinate_dim, + moduli_num). + + Returns: + regular_buckets: (regular_bucket_size, coordinate_dim, + window_num - 1, bucket_num_per_window, moduli_num). + last_window_buckets: (special_bucket_size, coordinate_dim, + special_duplication_ratio, bucket_num_last_window, + moduli_num). + """ + regular_buckets = self._bucketize_regular_windows( + tiled_points, regular_slices + ) + last_window_buckets = self._bucketize_last_window( + tiled_points, last_window_slice + ) + return regular_buckets, last_window_buckets + + def _bucket_accumulation_regular_windows_2d_parallel( + self, + regular_buckets: jnp.ndarray, + all_points: jnp.ndarray, + ) -> jnp.ndarray: + """Accumulate points for all regular windows. + + Args: + regular_buckets: Initial bucket values (coordinate_dim, window_dim, + bucket_dim, precision_dim). + all_points: Points for all windows (bucket_size_dim, coordinate_dim, + window_dim, bucket_dim, precision_dim). + + Returns: + Accumulated bucket values for all regular windows. + """ + bucket_size_dim = all_points.shape[0] + + def scan_body(buckets, points): + buckets = self._padd(buckets, points) + return buckets, None + + buckets, _ = jax.lax.scan( + scan_body, regular_buckets, all_points, length=bucket_size_dim + ) + return buckets + + def _bucket_accumulation_last_window( + self, + buckets_in: jnp.ndarray, + window_points: jnp.ndarray, + ) -> jnp.ndarray: + """Accumulate points for the last (special) window with duplication handling. + + Args: + buckets_in: Initial bucket values (coordinate_dim, bucket_dim, + precision_dim). + window_points: Points with duplication (bucket_size_dim, coordinate_dim, + bucket_dup_dim, bucket_dim, precision_dim). + parameters: Computation parameters. + + Returns: + Accumulated bucket values. + """ + coordinate_dim, buckets_dim, precision_dim = buckets_in.shape + bucket_size_dim, _, bucket_dup_dim, _, _ = window_points.shape + + # Reshape for processing + window_points = window_points.reshape( + bucket_size_dim, coordinate_dim, -1, precision_dim + ) + base_dup_buckets = window_points[0] + + def scan_body(buckets, points): + buckets = self._padd(buckets, points) + return buckets, None + + dup_buckets, _ = jax.lax.scan( + scan_body, + base_dup_buckets, + window_points[1:], + length=bucket_size_dim - 1, + ) + + # Reduce duplicated buckets using tree reduction + log_bucket_dup_dim = int(math.log2(bucket_dup_dim)) + for _ in range(log_bucket_dup_dim): + buckets_split = jnp.split(dup_buckets, 2, axis=1) + dup_buckets = self._padd(buckets_split[0], buckets_split[1]) + + # Add to input buckets + buckets_in = self._padd(buckets_in, dup_buckets) + return buckets_in + + def _bucket_accumulation_all_windows( + self, + regular_window_buckets: jnp.ndarray, + last_window_buckets: jnp.ndarray, + regular_window_points: jnp.ndarray, + last_window_points: jnp.ndarray, + ) -> tuple[jnp.ndarray, jnp.ndarray]: + """Accumulate points for all windows with optimized regular window processing. + + Uses the opt variant for regular windows, which processes all windows + simultaneously in a single scan instead of scanning window-by-window. + + Args: + regular_buckets: Initial bucket values for regular windows (coord_dim, + window_dim-1, bucket_dim, prec_dim). + last_window_buckets: Initial bucket values for the last window + (coord_dim, bucket_dim, prec_dim). + regular_window_points: Points for regular windows (bucket_size, + coord_dim, window_dim-1, bucket_num, prec_dim). + last_window_points: Points for last window (bucket_size, coord_dim, + dup_ratio, bucket_num, prec_dim). + + Returns: + (regular_buckets, last_window_buckets): Accumulated bucket values, + same shapes as the inputs. + """ + _, _, last_window_bucket_dup, last_window_bucket_dim, _ = ( + last_window_points.shape + ) + + regular_window_buckets = ( + self._bucket_accumulation_regular_windows_2d_parallel( + regular_window_buckets, regular_window_points + ) + ) + + last_point_buckets = last_window_buckets[:, :last_window_bucket_dim, :] + last_blank_buckets = last_window_buckets[:, last_window_bucket_dim:, :] + last_point_buckets = self._bucket_accumulation_last_window( + last_point_buckets, last_window_points + ) + + last_window_buckets = jnp.concatenate( + (last_point_buckets, last_blank_buckets), axis=1 + ) + + # Combine all buckets + return regular_window_buckets, last_window_buckets + + def _bucketize_and_bucket_accumulation( + self, + regular_window_slices: jnp.ndarray, + last_window_slices: jnp.ndarray, + tiled_points: jnp.ndarray, + regular_window_buckets: jnp.ndarray, + last_window_buckets: jnp.ndarray, + ) -> tuple[jnp.ndarray, jnp.ndarray]: + regular_window_points = self._bucketize_regular_windows( + tiled_points, regular_window_slices + ) + last_window_points = self._bucketize_last_window( + tiled_points, last_window_slices + ) + regular_window_buckets, last_window_buckets = ( + self._bucket_accumulation_all_windows( + regular_window_buckets, + last_window_buckets, + regular_window_points, + last_window_points, + ) + ) + + return regular_window_buckets, last_window_buckets + + def _bucket_reduction( + self, + regular_window_buckets: jnp.ndarray, + last_window_buckets: jnp.ndarray, + window_sum: jnp.ndarray, + ) -> jnp.ndarray: + """Reduce buckets to window sums using tree-based parallel algorithm. + + Computes S = sum_{i=0}^{n-1} i * B[i] per window via adjacent-pair + tree reduction in O(log n) parallel steps. + + Tracks H = 2^k * bucket_sums (doubling each level) so that H[1::2] + provides the correction weight directly. No separate bucket array needed. + + Args: + all_buckets: Bucket values (window_dim, coordinate_dim, bucket_dim, + precision_dim). + window_sum: Initial window sum (coordinate_dim, bucket_dim, window_dim, + precision_dim). + + Returns: + Window sums (coordinate_dim, window_dim, precision_dim). + """ + all_buckets = jnp.concatenate( + (regular_window_buckets, last_window_buckets[:, jnp.newaxis]), axis=1 + ) + bucket_num_in_window = all_buckets.shape[2] + # (coord, window, bucket, prec) → (coord, bucket, window, prec) + all_buckets = all_buckets.transpose(0, 2, 1, 3) + iter_num = int(math.log2(bucket_num_in_window)) + + cd, m, wd, pd = all_buckets.shape + for _ in range(iter_num): + m = all_buckets.shape[1] + half = m // 2 + + all_buckets = all_buckets.reshape(cd, half, 2, wd, pd) + window_sum = window_sum.reshape(cd, half, 2, wd, pd) + + all_buckets_left, all_buckets_right = ( + all_buckets[:, :, 0], + all_buckets[:, :, 1], + ) + window_sum_left, window_sum_right = ( + window_sum[:, :, 0], + window_sum[:, :, 1], + ) + + window_sum = self._padd( + self._padd(window_sum_left, window_sum_right), all_buckets_right + ) + bucket_sum = self._padd(all_buckets_left, all_buckets_right) + all_buckets = self._padd(bucket_sum, bucket_sum) + + return window_sum[:, 0] + + def _window_merge(self, window_sum: jnp.ndarray) -> jnp.ndarray: + """Merge window results into final MSM result using scan algorithm. + + Implements the window merging phase of Pippenger's algorithm. + + Args: + window_sum: Window sums (coordinate_dim, window_dim, precision_dim). + slice_length: Bit width of each window. + parameters: Computation parameters. + + Returns: + Final MSM result (coordinate_dim, precision_dim). + """ + coordinate_dim, window_dim, precision_dim = window_sum.shape + window_sum = window_sum.transpose(1, 0, 2).reshape( + (window_dim, coordinate_dim, 1, precision_dim) + ) + result = window_sum[window_dim - 1] + + def fori_loop_body(i, result): + result = self._padd(result, result) + return result + + def scan_body(result, window_sum): + result = jax.lax.fori_loop(0, self.slice_bits, fori_loop_body, result) + result = self._padd(result, window_sum) + return result, None + + result, _ = jax.lax.scan( + scan_body, + result, + window_sum[: window_dim - 1], + reverse=True, + length=window_dim - 1, + ) + result = result.reshape((coordinate_dim, precision_dim)) + return result + + def _bba_and_bws( + self, + regular_window_slices: jnp.ndarray, + last_window_slices: jnp.ndarray, + tiled_points: jnp.ndarray, + regular_window_buckets: jnp.ndarray, + last_window_buckets: jnp.ndarray, + window_sum: jnp.ndarray, + ) -> jnp.ndarray: + """Fused kernel: run the last-tile BBA and then BWS in a single jit. + + Saves a kernel launch + the BBA->BWS round-trip via the host/HBM: XLA + sees the accumulated buckets immediately feed into bucket reduction and + can overlap the window-axis -> bucket-axis reshard with the reduction + tree. + """ + regular_window_buckets, last_window_buckets = ( + self._bucketize_and_bucket_accumulation( + regular_window_slices, + last_window_slices, + tiled_points, + regular_window_buckets, + last_window_buckets, + ) + ) + # Sharding conversion at the BBA -> BWS boundary. + # BBA output: regular_window_buckets on axis[1] (window), + # last_window_buckets on axis[1] (bucket). + # BWS input : regular_window_buckets on axis[2] (bucket), + # last_window_buckets on axis[1] (bucket, unchanged). + # The last_window_buckets layout already matches BWS; only the + # regular buckets need a window-axis -> bucket-axis reshard. + if self.use_sharding: + P = jax.sharding.PartitionSpec + sharding = jax.sharding.NamedSharding( + self.sharding_mesh, P(None, None, self.mesh_axes, None) + ) + regular_window_buckets = jax.device_put(regular_window_buckets, sharding) + return self._bucket_and_window_sum( + regular_window_buckets, last_window_buckets, window_sum + ) + + def _bucket_and_window_sum( + self, + regular_window_buckets: jnp.ndarray, + last_window_buckets: jnp.ndarray, + window_sum: jnp.ndarray, + ) -> jnp.ndarray: + bucket_summations = self._bucket_reduction( + regular_window_buckets, last_window_buckets, window_sum + ) + # bucket_summations: (C, W, M). Window merge has no natural batch axis + # to shard along, so force replication here. Without this constraint, + # XLA may propagate BR's bucket-axis sharding forward and emit awkward + # collectives inside the window-merge scan. + if self.use_sharding: + P = jax.sharding.PartitionSpec + bucket_summations = self.shard_constraint( + bucket_summations, P(None, None, None) + ) + window_summations = self._window_merge(bucket_summations) + return window_summations + + # ----- Sharding setup ----- + + def _init_shardings(self): + """Pre-compute named shardings and padded shapes for BBA/BWS kernels. + + Sharding plan (mirrors scratch_profile_msm_sharding.py references): + + BBA (bucketize + bucket accumulation) — analogous to DA fused: + regular_slices (W-1, tile_length) -> axis[0] (window) + last_window_slice (tile_length,) -> replicated + tiled_points (tile_length, C, M) -> replicated + regular_window_buckets (C, W-1, B, M) -> axis[1] (window) + last_window_buckets (C, B, M) -> axis[1] (bucket) + + BWS (bucket reduction + window merge) — analogous to BR_new_opt + WM: + regular_window_buckets (C, W-1, B, M) -> axis[2] (bucket) + last_window_buckets (C, B, M) -> axis[1] (bucket) + window_sum (C, B, W, M) -> axis[1] (bucket) + + Note: regular_window_buckets is bucket-axis sharded at BWS entry but + window-axis sharded during BBA. XLA inserts the reshard at the BWS + boundary. + + Any axis whose extent is not divisible by the mesh size triggers a + warning (from create_named_sharding) and is padded up. + """ + C = self.coordinate_dim + B = self.bucket_num_per_window + M = self.moduli_num + W = self.window_num + T = self.tile_length + + # Shapes + reg_slices_shape = (W - 1, T) + last_slice_shape = (T,) + points_shape = (T, C, M) + reg_buckets_shape = (C, W - 1, B, M) + last_buckets_shape = (C, B, M) + window_sum_shape = (C, B, W, M) + + def _check(name, shape, padded): + if tuple(shape) != tuple(padded): + warnings.warn( + f"[FusionMSMContext sharding] '{name}' shape {shape} not " + f"divisible by mesh; padded to {padded}.", + stacklevel=2, + ) + + # ----- BBA kernel shardings ----- + self.bba_reg_slices_sharding, self.bba_reg_slices_padded = ( + self.create_named_sharding(shape=reg_slices_shape, axes=[0]) + ) + _check("bba_regular_slices", reg_slices_shape, self.bba_reg_slices_padded) + + self.bba_last_slice_sharding, self.bba_last_slice_padded = ( + self.create_named_sharding(shape=last_slice_shape, axes=[]) + ) + _check( + "bba_last_window_slice", last_slice_shape, self.bba_last_slice_padded + ) + + self.bba_points_sharding, self.bba_points_padded = ( + self.create_named_sharding(shape=points_shape, axes=[]) + ) + _check("bba_tiled_points", points_shape, self.bba_points_padded) + + self.bba_reg_buckets_sharding, self.bba_reg_buckets_padded = ( + self.create_named_sharding(shape=reg_buckets_shape, axes=[1]) + ) + _check( + "bba_regular_window_buckets", + reg_buckets_shape, + self.bba_reg_buckets_padded, + ) + + self.bba_last_buckets_sharding, self.bba_last_buckets_padded = ( + self.create_named_sharding(shape=last_buckets_shape, axes=[1]) + ) + _check( + "bba_last_window_buckets", + last_buckets_shape, + self.bba_last_buckets_padded, + ) + + # ----- BWS kernel shardings ----- + self.bws_reg_buckets_sharding, self.bws_reg_buckets_padded = ( + self.create_named_sharding(shape=reg_buckets_shape, axes=[2]) + ) + _check( + "bws_regular_window_buckets", + reg_buckets_shape, + self.bws_reg_buckets_padded, + ) + + self.bws_last_buckets_sharding, self.bws_last_buckets_padded = ( + self.create_named_sharding(shape=last_buckets_shape, axes=[1]) + ) + _check( + "bws_last_window_buckets", + last_buckets_shape, + self.bws_last_buckets_padded, + ) + + self.bws_window_sum_sharding, self.bws_window_sum_padded = ( + self.create_named_sharding(shape=window_sum_shape, axes=[1]) + ) + _check("bws_window_sum", window_sum_shape, self.bws_window_sum_padded) + + # ----- Whole-MSM fused kernel shardings ----- + # Slices get a tile axis prepended, so the regular window axis shifts + # from axis[0] to axis[1]. Last slices + points are replicated. + fused_reg_slices_shape = (self.tile_num, W - 1, T) + fused_last_slices_shape = (self.tile_num, T) + fused_tiled_points_shape = (self.tile_num, T, C, M) + self.fused_reg_slices_sharding, self.fused_reg_slices_padded = ( + self.create_named_sharding(shape=fused_reg_slices_shape, axes=[1]) + ) + _check( + "fused_regular_tiled_slices", + fused_reg_slices_shape, + self.fused_reg_slices_padded, + ) + self.fused_last_slices_sharding, self.fused_last_slices_padded = ( + self.create_named_sharding(shape=fused_last_slices_shape, axes=[]) + ) + _check( + "fused_last_tiled_slices", + fused_last_slices_shape, + self.fused_last_slices_padded, + ) + self.fused_tiled_points_sharding, self.fused_tiled_points_padded = ( + self.create_named_sharding(shape=fused_tiled_points_shape, axes=[]) + ) + _check( + "fused_tiled_points", + fused_tiled_points_shape, + self.fused_tiled_points_padded, + ) + + def set_use_sharding(self, use_sharding: bool): + super().set_use_sharding(use_sharding) + if use_sharding: + self._init_shardings() + self._place_sharded_state() + + def _place_sharded_state(self): + """Pad + place persistent state (accumulators, points) onto the correct + + shardings so compiled kernels receive inputs whose layout matches what + they were compiled for. Must be called after _init_shardings(). + """ + # Accumulators — BBA input layouts. + self.regular_window_buckets = utils.pad_jax_array( + jnp.asarray(self.regular_window_buckets), self.bba_reg_buckets_padded + ).to_device(self.bba_reg_buckets_sharding) + self.last_window_buckets = utils.pad_jax_array( + jnp.asarray(self.last_window_buckets), self.bba_last_buckets_padded + ).to_device(self.bba_last_buckets_sharding) + # window_sum — BWS input layout. + self.window_sum = utils.pad_jax_array( + jnp.asarray(self.window_sum), self.bws_window_sum_padded + ).to_device(self.bws_window_sum_sharding) + # Points — replicated on the mesh. Use the 4D fused sharding so that + # the runtime placement matches what the compiled fused kernel was + # lowered with (both paths see the same PartitionSpec). + self.points = jax.device_put( + jnp.asarray(self.points), self.fused_tiled_points_sharding + ) + + # ----- Compile system (mirrors TPUDistributionMSMContext) ----- + + def _get_bba_shape_dtype_structs(self): + """ShapeDtypeStructs for _bucketize_and_bucket_accumulation inputs. + + When sharding is enabled, inputs carry the shardings/padded shapes + established in _init_shardings(). + """ + if self.use_sharding: + return [ + jax.ShapeDtypeStruct( + self.bba_reg_slices_padded, + jnp.int32, + sharding=self.bba_reg_slices_sharding, + ), + jax.ShapeDtypeStruct( + self.bba_last_slice_padded, + jnp.int32, + sharding=self.bba_last_slice_sharding, + ), + jax.ShapeDtypeStruct( + self.bba_points_padded, + jnp.uint32, + sharding=self.bba_points_sharding, + ), + jax.ShapeDtypeStruct( + self.bba_reg_buckets_padded, + jnp.uint32, + sharding=self.bba_reg_buckets_sharding, + ), + jax.ShapeDtypeStruct( + self.bba_last_buckets_padded, + jnp.uint32, + sharding=self.bba_last_buckets_sharding, + ), + ] + regular_slices_shape = (self.window_num - 1, self.tile_length) + last_window_slice_shape = (self.tile_length,) + tiled_points_shape = ( + self.tile_length, + self.coordinate_dim, + self.moduli_num, + ) + regular_window_buckets_shape = ( + self.coordinate_dim, + self.window_num - 1, + self.bucket_num_per_window, + self.moduli_num, + ) + last_window_buckets_shape = ( + self.coordinate_dim, + self.bucket_num_per_window, + self.moduli_num, + ) + return [ + jax.ShapeDtypeStruct(regular_slices_shape, jnp.int32), + jax.ShapeDtypeStruct(last_window_slice_shape, jnp.int32), + jax.ShapeDtypeStruct(tiled_points_shape, jnp.uint32), + jax.ShapeDtypeStruct(regular_window_buckets_shape, jnp.uint32), + jax.ShapeDtypeStruct(last_window_buckets_shape, jnp.uint32), + ] + + def _get_bws_shape_dtype_structs(self): + """ShapeDtypeStructs for _bucket_and_window_sum inputs.""" + if self.use_sharding: + return [ + jax.ShapeDtypeStruct( + self.bws_reg_buckets_padded, + jnp.uint32, + sharding=self.bws_reg_buckets_sharding, + ), + jax.ShapeDtypeStruct( + self.bws_last_buckets_padded, + jnp.uint32, + sharding=self.bws_last_buckets_sharding, + ), + jax.ShapeDtypeStruct( + self.bws_window_sum_padded, + jnp.uint32, + sharding=self.bws_window_sum_sharding, + ), + ] + regular_window_buckets_shape = ( + self.coordinate_dim, + self.window_num - 1, + self.bucket_num_per_window, + self.moduli_num, + ) + last_window_buckets_shape = ( + self.coordinate_dim, + self.bucket_num_per_window, + self.moduli_num, + ) + window_sum_shape = ( + self.coordinate_dim, + self.bucket_num_per_window, + self.batch_window_num, + self.moduli_num, + ) + return [ + jax.ShapeDtypeStruct(regular_window_buckets_shape, jnp.uint32), + jax.ShapeDtypeStruct(last_window_buckets_shape, jnp.uint32), + jax.ShapeDtypeStruct(window_sum_shape, jnp.uint32), + ] + + def _get_bba_bws_shape_dtype_structs(self): + """ShapeDtypeStructs for the fused _bba_and_bws inputs. + + Slices/points + BBA-layout accumulators + BWS-layout window_sum. The + reshard from BBA's window-axis bucket layout to BWS's bucket-axis + layout happens inside the kernel via with_sharding_constraint. + """ + bba = self._get_bba_shape_dtype_structs() + bws = self._get_bws_shape_dtype_structs() + # 3 slice/point inputs + 2 BBA accumulators + BWS window_sum. + return bba + [bws[2]] + + def compile(self, parameters: Optional[dict] = None): + """Compile the kernels needed for the chosen execution path. + + Args: + parameters: Compile-time options. Recognised keys: - ``use_fused`` (bool, + default False): when True, compile the whole-MSM fused kernel + (``_multiscalar_multiply_fused``) and only the subkernels it uses + internally (``_bba_and_bws`` for the last tile, plus + ``_bucketize_and_bucket_accumulation`` for the scan body when ``tile_num + > 1``). When False, compile the per-stage path: BBA + BWS, and + additionally ``_bba_and_bws`` for the last-tile fusion that + ``_multiscalar_multiply`` uses. Also skips compiling the standalone BBA + when ``tile_num == 1``. + """ + parameters = parameters or {} + use_fused = parameters.get("use_fused", False) + self.use_fused = use_fused + + bba_structs = self._get_bba_shape_dtype_structs() + bws_structs = self._get_bws_shape_dtype_structs() + bba_bws_structs = self._get_bba_bws_shape_dtype_structs() + + bba_hash = hash_args( + *(v for s in bba_structs for v in (s.shape, s.dtype.__str__())) + ) + bws_hash = hash_args( + *(v for s in bws_structs for v in (s.shape, s.dtype.__str__())) + ) + bba_bws_hash = hash_args( + *(v for s in bba_bws_structs for v in (s.shape, s.dtype.__str__())) + ) + + # Always compile the fused BBA+BWS kernel: it handles the last tile + # in both execution paths, and is the single kernel used by + # tile_num == 1 in the non-fused path. + self.compiled_kernels.setdefault(bba_bws_hash, {})["bba_and_bws"] = ( + jax_jit_lower_compile( + self._bba_and_bws, + bba_bws_structs[0], + bba_bws_structs[1], + bba_bws_structs[2], + bba_bws_structs[3], + bba_bws_structs[4], + bba_bws_structs[5], + ) + ) + + if use_fused: + fused_structs = self._get_fused_shape_dtype_structs() + fused_hash = hash_args( + *(v for s in fused_structs for v in (s.shape, s.dtype.__str__())) + ) + self.compiled_kernels.setdefault(fused_hash, {})[ + "multiscalar_multiply_fused" + ] = jax_jit_lower_compile( + self._multiscalar_multiply_fused, + *fused_structs, + ) + # The whole-MSM fused kernel absorbs BBA (via scan) and BWS + # internally, so the standalone per-stage compiles aren't needed. + else: + # Non-fused path uses standalone BBA for tiles 0..N-2, and + # _bba_and_bws for the last tile (already compiled above). When + # tile_num == 1 there are no non-last tiles, so skip the + # standalone BBA compile entirely. + if self.tile_num > 1: + self.compiled_kernels.setdefault(bba_hash, {})[ + "bucketize_and_bucket_accumulation" + ] = jax_jit_lower_compile( + self._bucketize_and_bucket_accumulation, + bba_structs[0], + bba_structs[1], + bba_structs[2], + bba_structs[3], + bba_structs[4], + ) + # BWS standalone is still useful if any external caller invokes + # bucket_and_window_sum directly. Keep compiling it. + self.compiled_kernels.setdefault(bws_hash, {})[ + "bucket_and_window_sum" + ] = jax_jit_lower_compile( + self._bucket_and_window_sum, + bws_structs[0], + bws_structs[1], + bws_structs[2], + ) + + self.use_compiled_kernels = True + + def bba_and_bws( + self, + regular_slices, + last_window_slice, + tiled_points, + regular_window_buckets, + last_window_buckets, + window_sum, + ): + if self.use_compiled_kernels: + kernel_hash = hash_args( + regular_slices.shape, + regular_slices.dtype.__str__(), + last_window_slice.shape, + last_window_slice.dtype.__str__(), + tiled_points.shape, + tiled_points.dtype.__str__(), + regular_window_buckets.shape, + regular_window_buckets.dtype.__str__(), + last_window_buckets.shape, + last_window_buckets.dtype.__str__(), + window_sum.shape, + window_sum.dtype.__str__(), + ) + return self.compiled_kernels[kernel_hash]["bba_and_bws"]( + regular_slices, + last_window_slice, + tiled_points, + regular_window_buckets, + last_window_buckets, + window_sum, + ) + return self._bba_and_bws( + regular_slices, + last_window_slice, + tiled_points, + regular_window_buckets, + last_window_buckets, + window_sum, + ) + + def bucketize_and_bucket_accumulation( + self, + regular_slices, + last_window_slice, + tiled_points, + regular_window_buckets, + last_window_buckets, + ): + if self.use_compiled_kernels: + kernel_hash = hash_args( + regular_slices.shape, + regular_slices.dtype.__str__(), + last_window_slice.shape, + last_window_slice.dtype.__str__(), + tiled_points.shape, + tiled_points.dtype.__str__(), + regular_window_buckets.shape, + regular_window_buckets.dtype.__str__(), + last_window_buckets.shape, + last_window_buckets.dtype.__str__(), + ) + return self.compiled_kernels[kernel_hash][ + "bucketize_and_bucket_accumulation" + ]( + regular_slices, + last_window_slice, + tiled_points, + regular_window_buckets, + last_window_buckets, + ) + return self._bucketize_and_bucket_accumulation( + regular_slices, + last_window_slice, + tiled_points, + regular_window_buckets, + last_window_buckets, + ) + + def bucket_and_window_sum( + self, regular_window_buckets, last_window_buckets, window_sum + ): + if self.use_compiled_kernels: + kernel_hash = hash_args( + regular_window_buckets.shape, + regular_window_buckets.dtype.__str__(), + last_window_buckets.shape, + last_window_buckets.dtype.__str__(), + window_sum.shape, + window_sum.dtype.__str__(), + ) + return self.compiled_kernels[kernel_hash]["bucket_and_window_sum"]( + regular_window_buckets, + last_window_buckets, + window_sum, + ) + return self._bucket_and_window_sum( + regular_window_buckets, last_window_buckets, window_sum + ) + + def context_hash(self) -> str: + return hash_args( + self.__class__.__name__, + self.ec_ctx.context_hash() + if hasattr(self.ec_ctx, "context_hash") + else str(self.ec_ctx.__class__.__name__), + self.slice_bits, + self.scalar_bits, + self.msm_length, + self.tile_length, + self.bucket_num_per_window, + self.bucket_num_last_window, + self.use_sharding, + ) + + def _multiscalar_multiply(self, tiled_slices: jnp.ndarray): + # Padding (if any) was applied in to_computational_format. Here we only + # place the per-tile slices onto the right sharding — points are already + # placed on the replicated BBA points sharding via _place_sharded_state. + last_tile = self.tile_num - 1 + for tile_index in range(self.tile_num): + tiled_slices_tile = tiled_slices[tile_index] + regular_slices_tile = tiled_slices_tile[: self.window_num - 1] + last_window_slice_tile = tiled_slices_tile[self.window_num - 1] + tiled_points_tile = self.points[tile_index] + + if self.use_sharding: + regular_slices_tile = regular_slices_tile.to_device( + self.bba_reg_slices_sharding + ) + last_window_slice_tile = last_window_slice_tile.to_device( + self.bba_last_slice_sharding + ) + + if tile_index == last_tile: + # Fused BBA + BWS: reshard (BBA -> BWS) stays inside one compiled + # kernel, so XLA can overlap it with the reduction tree. + return self.bba_and_bws( + regular_slices_tile, + last_window_slice_tile, + tiled_points_tile, + self.regular_window_buckets, + self.last_window_buckets, + self.window_sum, + ) + + self.regular_window_buckets, self.last_window_buckets = ( + self.bucketize_and_bucket_accumulation( + regular_slices_tile, + last_window_slice_tile, + tiled_points_tile, + self.regular_window_buckets, + self.last_window_buckets, + ) + ) + + def _multiscalar_multiply_fused( + self, + regular_tiled_slices: jnp.ndarray, + last_tiled_slices: jnp.ndarray, + tiled_points: jnp.ndarray, + regular_window_buckets: jnp.ndarray, + last_window_buckets: jnp.ndarray, + window_sum: jnp.ndarray, + ) -> jnp.ndarray: + """Whole-MSM fused kernel — one jit across all tiles + BWS. + + Regular and last-window slices are passed as separate tensors because + they require different input shardings: regular is sharded along its + window axis, while the last-window slice has no batch-like axis to + shard over and is replicated. Packing them together would force a + single sharding on the combined tensor, losing this distinction. + + Args: + regular_tiled_slices: (tile_num, W-1, tile_length) — shard axis[1] + last_tiled_slices: (tile_num, tile_length) — replicated + tiled_points: (tile_num, tile_length, C, M) — replicated + regular_window_buckets: (C, W-1, B, M) — BBA axis[1] + last_window_buckets: (C, B, M) — BBA axis[1] + window_sum: (C, B, W, M) — BWS axis[1] + + Hybrid structure: + - Tiles ``0..N-2`` run through ``jax.lax.scan``. + - The last tile goes through ``_bba_and_bws`` (reshard + BR->WM + constraint live inside that kernel). + - When ``tile_num == 1``, the scan is skipped. + """ + + def scan_body(carry, inputs): + reg_buckets, last_buckets = carry + regular_slices, last_window_slice, points_tile = inputs + reg_buckets, last_buckets = self._bucketize_and_bucket_accumulation( + regular_slices, + last_window_slice, + points_tile, + reg_buckets, + last_buckets, + ) + return (reg_buckets, last_buckets), None + + if self.tile_num > 1: + (regular_window_buckets, last_window_buckets), _ = jax.lax.scan( + scan_body, + (regular_window_buckets, last_window_buckets), + ( + regular_tiled_slices[:-1], + last_tiled_slices[:-1], + tiled_points[:-1], + ), + length=self.tile_num - 1, + ) + + return self._bba_and_bws( + regular_tiled_slices[-1], + last_tiled_slices[-1], + tiled_points[-1], + regular_window_buckets, + last_window_buckets, + window_sum, + ) + + def _get_fused_shape_dtype_structs(self): + """ShapeDtypeStructs for the whole-MSM fused kernel. + + Regular / last slices are produced as separate structs so each can + carry its own sharding (regular: window-axis sharded; last: + replicated). + """ + bba = self._get_bba_shape_dtype_structs() + bws = self._get_bws_shape_dtype_structs() + per_tile_length = bba[0].shape[1] + + regular_tiled_slices_shape = ( + self.tile_num, + self.window_num - 1, + per_tile_length, + ) + last_tiled_slices_shape = (self.tile_num, per_tile_length) + tiled_points_shape = (self.tile_num,) + tuple(bba[2].shape) + + if self.use_sharding: + # Use the shardings/padded shapes pre-computed in _init_shardings so + # the compile-time and run-time placements agree. + return [ + jax.ShapeDtypeStruct( + self.fused_reg_slices_padded, + jnp.int32, + sharding=self.fused_reg_slices_sharding, + ), + jax.ShapeDtypeStruct( + self.fused_last_slices_padded, + jnp.int32, + sharding=self.fused_last_slices_sharding, + ), + jax.ShapeDtypeStruct( + self.fused_tiled_points_padded, + jnp.uint32, + sharding=self.fused_tiled_points_sharding, + ), + bba[3], + bba[4], + bws[2], + ] + return [ + jax.ShapeDtypeStruct(regular_tiled_slices_shape, jnp.int32), + jax.ShapeDtypeStruct(last_tiled_slices_shape, jnp.int32), + jax.ShapeDtypeStruct(tiled_points_shape, jnp.uint32), + bba[3], + bba[4], + bws[2], + ] + + def multiscalar_multiply(self, tiled_slices: jnp.ndarray): + if not self.use_fused: + return self._multiscalar_multiply(tiled_slices) + + # Fused path: split regular / last slices so they can carry independent + # shardings. self.points is already placed in per-tile layout. + regular_tiled_slices = tiled_slices[:, : self.window_num - 1] + last_tiled_slices = tiled_slices[:, self.window_num - 1] + + if self.use_sharding: + regular_tiled_slices = regular_tiled_slices.to_device( + self.fused_reg_slices_sharding + ) + last_tiled_slices = last_tiled_slices.to_device( + self.fused_last_slices_sharding + ) + + args = ( + regular_tiled_slices, + last_tiled_slices, + self.points, + self.regular_window_buckets, + self.last_window_buckets, + self.window_sum, + ) + if self.use_compiled_kernels: + kernel_hash = hash_args( + *(v for a in args for v in (a.shape, a.dtype.__str__())) + ) + return self.compiled_kernels[kernel_hash]["multiscalar_multiply_fused"]( + *args + ) + return self._multiscalar_multiply_fused(*args) diff --git a/jaxite_ec/multiscalar_multiplication_perf_test.py b/jaxite_ec/multiscalar_multiplication_perf_test.py new file mode 100644 index 0000000..c7eaf2d --- /dev/null +++ b/jaxite_ec/multiscalar_multiplication_perf_test.py @@ -0,0 +1,211 @@ +import os + +import jax +from jaxite.jaxite_ec import utils +import jaxite.jaxite_ec.elliptic_curve_context as ec_context +import jaxite.jaxite_ec.finite_field_context as ff_context +import jaxite.jaxite_ec.multiscalar_multiplication_context as msm_context +from jaxite.jaxite_ec.profiler import ( + PrecompiledKernelWrapper, + Profiler, + collect_logs, +) +from jaxite.jaxite_ec.utils import hash_args + +from absl.testing import absltest +from absl.testing import parameterized + +jax.config.update("jax_enable_x64", True) + +MODULUS_377_INT = 0x01AE3A4617C510EAC63B05C06CA1493B1A22D9F300F5138F1EF3622FBA094800170B5D44300000008508C00000000001 +NUM_MODULI = 32 +RNS_MODULI = utils.find_moduli_specified_number(NUM_MODULI, 28) + +MSM_LENGTH_LIST = [2**10] +SLICE_BITS = 10 +SCALAR_BITS = 253 + +TEST_PARAMS_MSM_FUSION = [ + ("msm_fusion", MSM_LENGTH_LIST), +] + + +def _build_msm_parameters(msm_length: int): + ff_parameters = { + "prime": MODULUS_377_INT, + "rns_moduli": RNS_MODULI, + "precision_bits": 28, + "radix_bits": 32, + } + ec_parameters = { + "finite_field_context_class": ff_context.DRNSlazyContext, + "finite_field_parameters": ff_parameters, + "prime": ( + 258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177 + ), + "order": ( + 8444461749428370424248824938781546531375899335154063827935233455917409239041 + ), + "a": -1, + "twist_d": ( + 122268283598675559488486339158635529096981886914877139579534153582033676785385790730042363341236035746924960903179 + ), + "alpha": -1, + "b": 1, + "s": ( + 10189023633222963290707194929886294091415157242906428298294512798502806398782149227503530278436336312243746741931 + ), + "MA": ( + 228097355113300204138531148905234651262148041026195375645000724271212049151994375092458297304264351187709081232384 + ), + "MB": ( + 10189023633222963290707194929886294091415157242906428298294512798502806398782149227503530278436336312243746741931 + ), + "t": ( + 23560188534917577818843641916571445935985386319233886518929971599490231428764380923487987729215299304184915158756 + ), + "generator": [ + 71222569531709137229370268896323705690285216175189308202338047559628438110820800641278662592954630774340654489393, + 6177051365529633638563236407038680211609544222665285371549726196884440490905471891908272386851767077598415378235, + ], + } + return { + "elliptic_curve_context_class": ( + ec_context.ExtendedTwistedEdwardsNDContext + ), + "elliptic_curve_parameters": ec_parameters, + "coordinate_dim": 4, + "msm_length": msm_length, + "tile_length": msm_length, + "slice_bits": SLICE_BITS, + "scalar_bits": SCALAR_BITS, + "order": ec_parameters["order"], + "c_kernel_ret_space_ratio": 2, + } + + +def _build_fusion_ctx(msm_length: int): + """Mirrors scratch_profile_msm_fused.py: sharded fused MSM with + + pre-compiled kernels. The context's multiscalar_multiply internally + dispatches to these already-compiled, already-sharded kernels — so it + must NOT be re-traced under jax.jit. + """ + params = _build_msm_parameters(msm_length) + ctx = msm_context.FusionMSMContext(params) + ctx.set_use_compiled_kernels(True) + ctx.set_use_sharding(True) + ctx.compile(parameters={"use_fused": True}) + return ctx + + +class FusionMSMPerformanceTest(parameterized.TestCase): + + def setUp(self): + super().setUp() + outputs_dir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR") + if outputs_dir: + self.output_trace_root = os.path.join(outputs_dir, "log") + else: + self.output_trace_root = os.path.join(os.path.dirname(__file__), "log") + self.profiler_config = { + "iterations": 1, + "save_to_file": True, + "enable_sharding": True, + } + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + outputs_dir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR") + root_dir = ( + outputs_dir + if outputs_dir + else os.path.dirname(os.path.abspath(__file__)) + ) + print(f"Collecting logs from: {root_dir}") + collect_logs(root_dir) + + def _create_fusion_msm_wrapper(self, kernel_name, ctx, tiled_slices): + """Profile ``_multiscalar_multiply_fused`` — the individual jit-able kernel. + + ``ctx.multiscalar_multiply`` is a Python dispatcher that slices inputs, + places them on shardings, and hashes into ``compiled_kernels``; it isn't + itself a kernel. We reproduce its setup once here, then hand the + already-compiled executable + concrete inputs to the profiler. + """ + # Replicate the slicing ctx.multiscalar_multiply does for the fused path. + regular_tiled_slices = tiled_slices[:, : ctx.window_num - 1] + last_tiled_slices = tiled_slices[:, ctx.window_num - 1] + if ctx.use_sharding: + regular_tiled_slices = regular_tiled_slices.to_device( + ctx.fused_reg_slices_sharding + ) + last_tiled_slices = last_tiled_slices.to_device( + ctx.fused_last_slices_sharding + ) + + input_arrays = [ + regular_tiled_slices, + last_tiled_slices, + ctx.points, + ctx.regular_window_buckets, + ctx.last_window_buckets, + ctx.window_sum, + ] + + # Pull the already-compiled kernel out of ctx — same hash scheme as + # ctx.multiscalar_multiply uses at dispatch time. + kernel_hash = hash_args( + *(v for a in input_arrays for v in (a.shape, a.dtype.__str__())) + ) + compiled_fn = ctx.compiled_kernels[kernel_hash][ + "multiscalar_multiply_fused" + ] + # compiled_fn = None + + return PrecompiledKernelWrapper( + kernel_name=kernel_name, + callable_function=compiled_fn, + input_arrays=input_arrays, + enable_sharding=True, + callable_function_name="_multiscalar_multiply_fused", # it is important for collecting the correct trace events + ) + + @parameterized.named_parameters(*TEST_PARAMS_MSM_FUSION) + def test_fusion_msm_performance(self, msm_length_list): + profiler_instance = Profiler( + output_trace_path=self.output_trace_root, + profile_naming="msm_fusion", + configuration=self.profiler_config, + ) + + for msm_length in msm_length_list: + ctx = _build_fusion_ctx(msm_length) + tiled_slices = ctx.to_computational_format(None) + + kernel_name = f"msm_fusion_n{msm_length}" + kernel_wrapper = self._create_fusion_msm_wrapper( + kernel_name=kernel_name, + ctx=ctx, + tiled_slices=tiled_slices, + ) + profiler_instance.add_profile( + name=kernel_name, + kernel_wrapper=kernel_wrapper, + kernel_setting_cols={ + "msm_length": msm_length, + "slice_bits": SLICE_BITS, + "scalar_bits": SCALAR_BITS, + "num_moduli": NUM_MODULI, + "use_sharding": True, + "use_fused": True, + }, + ) + + profiler_instance.profile_all_profilers() + profiler_instance.post_process_all_profilers() + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxite_ec/multiscalar_multiplication_test.py b/jaxite_ec/multiscalar_multiplication_test.py new file mode 100644 index 0000000..ab4b935 --- /dev/null +++ b/jaxite_ec/multiscalar_multiplication_test.py @@ -0,0 +1,126 @@ +import os + +import jax +import jaxite.jaxite_ec.elliptic_curve_context as ec_context +import jaxite.jaxite_ec.finite_field_context as ff_context +import jaxite.jaxite_ec.multiscalar_multiplication_context as msm_context +import jaxite.jaxite_ec.utils as utils +import numpy as np + +from absl.testing import absltest +from absl.testing import parameterized + +jax.config.update("jax_enable_x64", True) + +MODULUS_377_INT = 0x01AE3A4617C510EAC63B05C06CA1493B1A22D9F300F5138F1EF3622FBA094800170B5D44300000008508C00000000001 +NUM_MODULI = 32 +RNS_MODULI = utils.find_moduli_specified_number(NUM_MODULI, 28) + +MSM_DIM = 2**10 +SEED = 0 +MSM_TEST_DIR = os.path.join(os.path.dirname(__file__), "data", f"t{MSM_DIM}") +POINTS_PATH = os.path.join( + MSM_TEST_DIR, f"zprize_msm_curve_377_bases_dim_{MSM_DIM}_seed_{SEED}.csv" +) +SCALARS_PATH = os.path.join( + MSM_TEST_DIR, f"zprize_msm_curve_377_scalars_dim_{MSM_DIM}_seed_{SEED}.csv" +) +REF_RESULT_PATH = os.path.join( + MSM_TEST_DIR, f"zprize_msm_curve_377_res_dim_{MSM_DIM}_seed_{SEED}.csv" +) + + +def _build_msm_parameters(): + ff_parameters = { + "prime": MODULUS_377_INT, + "rns_moduli": RNS_MODULI, + "precision_bits": 28, + "radix_bits": 32, + } + ec_parameters = { + "finite_field_context_class": ff_context.DRNSlazyContext, + "finite_field_parameters": ff_parameters, + "prime": ( + 258664426012969094010652733694893533536393512754914660539884262666720468348340822774968888139573360124440321458177 + ), + "order": ( + 8444461749428370424248824938781546531375899335154063827935233455917409239041 + ), + "a": -1, + "twist_d": ( + 122268283598675559488486339158635529096981886914877139579534153582033676785385790730042363341236035746924960903179 + ), + "alpha": -1, + "b": 1, + "s": ( + 10189023633222963290707194929886294091415157242906428298294512798502806398782149227503530278436336312243746741931 + ), + "MA": ( + 228097355113300204138531148905234651262148041026195375645000724271212049151994375092458297304264351187709081232384 + ), + "MB": ( + 10189023633222963290707194929886294091415157242906428298294512798502806398782149227503530278436336312243746741931 + ), + "t": ( + 23560188534917577818843641916571445935985386319233886518929971599490231428764380923487987729215299304184915158756 + ), + "generator": [ + 71222569531709137229370268896323705690285216175189308202338047559628438110820800641278662592954630774340654489393, + 6177051365529633638563236407038680211609544222665285371549726196884440490905471891908272386851767077598415378235, + ], + } + return { + "elliptic_curve_context_class": ( + ec_context.ExtendedTwistedEdwardsNDContext + ), + "elliptic_curve_parameters": ec_parameters, + "coordinate_dim": 4, + "msm_length": MSM_DIM, + "tile_length": MSM_DIM, + "slice_bits": 6, + "scalar_bits": 253, + "order": ec_parameters["order"], + "points_path": POINTS_PATH, + "c_kernel_ret_space_ratio": 2, + } + + +MSM_CONTEXT_CASES = [ + ("cpu_distribution", msm_context.CPUDistributionMSMContext, False), + ("tpu_distribution", msm_context.TPUDistributionMSMContext, False), + ("fusion", msm_context.FusionMSMContext, True), +] + + +class MsmBls12377Test(parameterized.TestCase): + + def setUp(self): + super().setUp() + self.msm_parameters = _build_msm_parameters() + self.scalars = utils.read_external_msm_file(SCALARS_PATH, "scalars") + self.ref_result = utils.read_external_msm_file( + REF_RESULT_PATH, "result_ref" + ) + + def _run_msm(self, context_class, use_fused): + ctx = context_class(self.msm_parameters) + ctx.set_use_compiled_kernels(True) + compile_kwargs = {"use_fused": True} if use_fused else {} + ctx.compile(parameters=compile_kwargs) + + tiled_slices = ctx.to_computational_format(self.scalars) + result_m = ctx.multiscalar_multiply(tiled_slices) + return ctx.to_original_format(result_m) + + @parameterized.named_parameters(*MSM_CONTEXT_CASES) + def test_multiscalar_multiply_matches_reference( + self, context_class, use_fused + ): + result = self._run_msm(context_class, use_fused) + np.testing.assert_array_equal( + np.asarray(result), np.asarray(self.ref_result) + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxite_ec/number_theory_transform_context.py b/jaxite_ec/number_theory_transform_context.py new file mode 100644 index 0000000..14f7134 --- /dev/null +++ b/jaxite_ec/number_theory_transform_context.py @@ -0,0 +1,834 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict +import jax +import jax.numpy as jnp +from jaxite.jaxite_ec import utils +from jaxite.jaxite_ec.finite_field_context import ( + CROSSLazyContext, + DRNSlazyContext, +) +import numpy as np + +jax.config.update("jax_enable_x64", True) + + +######################## +# Helper Functions +######################## +def gen_twiddle_matrix(rows, cols, q, omega): + """Precompute twiddle matrix T where T[r, c] = omega^(r*c) mod q. + + Stored as ``dtype=object`` so the cells hold arbitrary-precision Python + ints; needed because ``q`` can exceed 64 bits (e.g. the 753-bit ZKP + prime used in the perf test). + """ + twiddle_matrix = np.empty((rows, cols), dtype=object) + for r in range(rows): + for c in range(cols): + twiddle_matrix[r, c] = pow(int(omega), int(r * c), int(q)) + return twiddle_matrix + + +def gen_twiddle_matrix_inv(rows, cols, q, omega): + """Precompute inverse twiddle matrix T_inv where T_inv[r, c] = omega^{-(r*c)} mod q.""" + twiddle_matrix_inv = np.empty((rows, cols), dtype=object) + for r in range(rows): + for c in range(cols): + twiddle_matrix_inv[r, c] = pow(int(omega), int(-r * c), int(q)) + return twiddle_matrix_inv + + +def get_bit_reverse_perm(n): + """Generates bit-reversal permutation indices of size n.""" + if n <= 0: + return [] + bits = n.bit_length() - 1 + perm = [0] * n + for i in range(n): + r = 0 + temp = i + for _ in range(bits): + r = (r << 1) | (temp & 1) + temp >>= 1 + perm[i] = r + return perm + + +######################## +# Abstract Base Class +######################## +class NumberTheoryTransformBase(ABC): + """Abstract base class for all NTT context implementations.""" + + @abstractmethod + def ntt(self, v): + """Forward Number Theory Transform.""" + pass + + @abstractmethod + def intt(self, v): + """Inverse Number Theory Transform.""" + pass + + @abstractmethod + def to_computational_format(self, a): + """Convert from plain integers to the internal computational representation.""" + pass + + @abstractmethod + def to_original_format(self, a): + """Convert from internal representation back to a flat list of integers.""" + pass + + +######################## +# BAT (Basis Aligned Transformation) helpers +######################## +def basis_aligned_transformation( + matrix_drns: np.ndarray, rns_moduli +) -> jnp.ndarray: + """Convert a 2-D DRNS twiddle matrix to BAT format for 8-bit matmul. + + Each uint32 value is byte-shifted by 0/8/16/24 bits, reduced mod each + RNS modulus, then bitcast to uint8. + + Args: + matrix_drns: DRNS twiddle, shape (rows, cols, num_moduli) uint32. + rns_moduli: sequence of RNS moduli. + + Returns: + BAT matrix of shape (rows, 4, cols, 4, num_moduli) uint8. + """ + rows, cols, M = np.array(matrix_drns).shape + matrix_u64 = np.array(matrix_drns, dtype=np.uint64) + moduli = np.array(rns_moduli, dtype=np.uint64) + + # (4, rows, cols, M) — byte-shifted and reduced per channel + shifted = np.empty((4, rows, cols, M), dtype=np.uint32) + for s in range(4): + shifted[s] = ((matrix_u64 << (8 * s)) % moduli).astype(np.uint32) + + # Bitcast uint32 → uint8: (4, rows, cols, M) → (4, rows, cols, M*4) → (4, rows, cols, M, 4) + shifted_u8 = shifted.view(np.uint8).reshape(4, rows, cols, M, 4) + + # Rearrange to (rows, 4_shift, cols, 4_bytes, M) + shifted_u8 = shifted_u8.transpose(1, 0, 2, 4, 3) + return jnp.array(shifted_u8) + + +def matmul_bat_einsum( + v: jax.Array, bat_twiddle: jax.Array, subscripts: str +) -> jax.Array: + """BAT-based 8-bit matrix multiplication. + + Bitcasts the input from uint32 to uint8, performs an 8-bit einsum with + the pre-processed BAT twiddle, and reconstructs the uint64 result via + byte-shifting. + + Args: + v: Input array (uint32). The trailing dimension (num_moduli) is expanded + to (num_moduli, 4) by bitcast. + bat_twiddle: BAT twiddle (uint8), pre-computed offline. + subscripts: Einsum subscript string including the byte dimensions. + + Returns: + uint64 result array (trailing moduli dimension, byte dim summed out). + """ + v_u8 = jax.lax.bitcast_convert_type(v, jnp.uint8) + shift_factors = jnp.array([0, 8, 16, 24], dtype=jnp.uint32) + products = jnp.einsum( + subscripts, v_u8, bat_twiddle, preferred_element_type=jnp.uint32 + ) + return jnp.sum(products.astype(jnp.uint64) << shift_factors, axis=-1) + + +######################## +# BAT subscript generator (right-convention only) +######################## +_BAT_RESERVED_LETTERS = frozenset("mqpj") + + +def _bat_subscripts_right(v_ndim: int, contract_axis: int) -> str: + """Einsum subscripts for BAT-based right-convention matmul. + + Operand ``v`` has shape ``(*leading, M)`` with ``v_ndim == len(leading) + 1``; + after bitcast-to-u8 it grows a trailing size-4 byte axis. The twiddle + handle has shape ``(K, 4, J, 4, M)``: axis 0 is the contracted dim, axis + 2 is the produced dim. The axis at ``contract_axis`` of ``v`` (size K) + is replaced on the output by an axis of size J. + """ + if contract_axis < 0: + contract_axis += v_ndim + if not (0 <= contract_axis < v_ndim - 1): + raise ValueError( + "contract_axis must be a leading axis of v " + f"(got {contract_axis} for v.ndim={v_ndim})" + ) + letters = [] + c = ord("a") + while len(letters) < v_ndim - 1: + ch = chr(c) + if ch not in _BAT_RESERVED_LETTERS: + letters.append(ch) + c += 1 + k = letters[contract_axis] + v_sub = "".join(letters) + "mq" + out_letters = letters.copy() + out_letters[contract_axis] = "j" + t_sub = f"{k}qjpm" + out_sub = "".join(out_letters) + "mp" + return f"{v_sub},{t_sub}->{out_sub}" + + +######################## +# Extension contexts — add NTT-specific modular_matmul / twiddle helpers +# on top of the finite-field backends. The NTT layer only ever consumes +# the abstract API exposed here, not the backend internals. +######################## +class DRNSLazyExtensionContext(DRNSlazyContext): + """``DRNSlazyContext`` + NTT-specific modular matrix multiplication. + + Exposes: + * :py:meth:`preprocess_matmul` — 2-D integer matrix → BAT + uint8 handle of shape ``(K, 4, J, 4, M)``. + * :py:meth:`preprocess_elementwise` — N-D integer tensor → + DRNS computational-format uint32 tensor of shape ``(..., M)``. + * :py:meth:`modular_matmul` — right-convention modular matmul + (BAT einsum + Montgomery → CRNS → Montgomery). Only supports the + matmul-shape used by the NTT stack (contract axis 0 of the handle). + * :py:meth:`modular_multiply_broadcast` — broadcasting element-wise + modular multiply for twiddle steps. + """ + + def __init__(self, parameters: Dict[str, Any]): + super().__init__(parameters) + self.for_ntt = True + + def preprocess_matmul(self, mat_2d) -> jnp.ndarray: + """Encode a 2-D ``(K, J)`` integer matrix as a BAT uint8 handle.""" + arr = np.asarray(mat_2d) + if arr.ndim != 2: + raise ValueError(f"matmul twiddle must be 2-D, got shape {arr.shape}") + # Reduce mod each RNS modulus in Python-int space so this stays correct + # for primes wider than 64 bits. Output fits in uint32 (moduli < 2^28). + if arr.dtype != object: + arr = arr.astype(object) + moduli_obj = np.array(self.rns_moduli, dtype=object) + drns_obj = ( + (arr[..., np.newaxis] % moduli_obj) << self.radix_bits + ) % moduli_obj + return basis_aligned_transformation( + drns_obj.astype(np.uint32), self.rns_moduli + ) + + def preprocess_elementwise(self, mat_nd) -> jnp.ndarray: + """Encode an arbitrary-rank integer tensor in DRNS computational form.""" + arr = np.asarray(mat_nd) + if arr.dtype != object: + arr = arr.astype(object) + moduli_obj = np.array(self.rns_moduli, dtype=object) + drns_obj = ( + (arr[..., np.newaxis] % moduli_obj) << self.radix_bits + ) % moduli_obj + return jnp.array(drns_obj.astype(np.uint32)) + + def modular_matmul( + self, operand: jax.Array, handle: jax.Array, contract_axis: int + ) -> jax.Array: + """Modular matmul contracting ``operand[contract_axis]`` against + + ``handle[0]``. ``handle`` must come from + :py:meth:`preprocess_matmul`. + """ + operand = jnp.asarray(operand) + subs = _bat_subscripts_right(operand.ndim, contract_axis) + z = matmul_bat_einsum(operand, handle, subs) + z = self._jax_montgomery_reduce(z) + z = self._jax_crns(z) + z = self._jax_montgomery_reduce(z) + return z.astype(jnp.uint32) + + def modular_multiply_broadcast(self, a: jax.Array, b: jax.Array) -> jax.Array: + """Element-wise modular multiply with numpy-style broadcasting.""" + return self._modular_multiply(a, b) + + +class CROSSLazyExtensionContext(CROSSLazyContext): + """``CROSSLazyContext`` + NTT-specific modular matrix multiplication. + + Same four-method API as :py:class:`DRNSLazyExtensionContext`. The + matmul path here is a ``fori_loop`` that accumulates per-``k`` + broadcast products via ``_modular_multiply`` + ``_modular_add`` + (there is no dense-MXU-friendly representation for multi-limb CROSS + elements). + """ + + def __init__(self, parameters: Dict[str, Any]): + super().__init__(parameters) + self.for_ntt = True + + def _encode_chunks(self, mat) -> jnp.ndarray: + """Encode an integer tensor as little-endian uint32 chunks.""" + arr = np.asarray(mat) + flat = [int(x) % self.prime for x in arr.reshape(-1).tolist()] + n = self.chunk_num_u32 + chunks = np.zeros((len(flat), n), dtype=np.uint32) + for i, v in enumerate(flat): + x = v + for j in range(n): + chunks[i, j] = x & 0xFFFFFFFF + x >>= 32 + return jnp.asarray(chunks.reshape(arr.shape + (n,))) + + def preprocess_matmul(self, mat_2d) -> jnp.ndarray: + arr = np.asarray(mat_2d) + if arr.ndim != 2: + raise ValueError(f"matmul twiddle must be 2-D, got shape {arr.shape}") + return self._encode_chunks(arr) + + def preprocess_elementwise(self, mat_nd) -> jnp.ndarray: + return self._encode_chunks(mat_nd) + + def modular_multiply_broadcast(self, a: jax.Array, b: jax.Array) -> jax.Array: + """Element-wise modular multiply with broadcasting. + + Uses nested ``vmap`` so sharded callers (running under ``shard_map``) + see consistent axis sharding between the two operands. + """ + shape = jnp.broadcast_shapes(a.shape, b.shape) + a_b = jnp.broadcast_to(a, shape) + b_b = jnp.broadcast_to(b, shape) + fn = self._modular_multiply + for _ in range(len(shape) - 2): + fn = jax.vmap(fn) + return fn(a_b, b_b) + + def modular_matmul( + self, operand: jax.Array, handle: jax.Array, contract_axis: int + ) -> jax.Array: + """Modular matmul contracting ``operand[contract_axis]`` against + + ``handle[0]``. ``handle`` must come from + :py:meth:`preprocess_matmul` and is ``(K, J, chunks)``. + """ + operand = jnp.asarray(operand) + handle = jnp.asarray(handle) + if contract_axis < 0: + contract_axis += operand.ndim + if not (0 <= contract_axis < operand.ndim - 1): + raise ValueError( + "contract_axis must be a leading axis of operand " + f"(got {contract_axis} for operand.ndim={operand.ndim})" + ) + nc = self.chunk_num_u32 + v_moved = jnp.moveaxis(operand, contract_axis, 0) # (K, *rest, chunks) + K = v_moved.shape[0] + J = handle.shape[1] + rest = v_moved.shape[1:-1] + out_shape = (J,) + rest + (nc,) + t_reshape = (J,) + (1,) * len(rest) + (nc,) + + def body(k, acc): + vk = v_moved[k] + tk = handle[k, :, :].reshape(t_reshape) + prod = self.modular_multiply_broadcast(vk[None], tk) + return self._modular_add(acc, prod) + + init = jnp.zeros(out_shape, dtype=jnp.uint32) + out = jax.lax.fori_loop(0, K, body, init) + return jnp.moveaxis(out, 0, contract_axis) + + +######################## +# Unified NTT stack — one class per step count, works with either +# extension context above. The extension context is what lets the NTT +# layer stay backend-agnostic. +######################## +class NumpyCPUContext: + """Pure-numpy CPU reference context compatible with the unified NTT stack. + + Exposes the same ``preprocess_matmul`` / ``preprocess_elementwise`` + / ``modular_matmul`` / ``modular_multiply_broadcast`` / + ``to_computational_format`` / ``to_original_format`` API as + :class:`DRNSLazyExtensionContext` and :class:`CROSSLazyExtensionContext`, + so that :class:`NTT3Step` / :class:`NTT5Step` / :class:`NTT7Step` + reproduce the same outputs as the CPU reference classes + (:class:`CPUCROSSContext`, :class:`CPUCROSS5StepContext`, + :class:`CPUCROSS7StepContext`). + + Each field element is stored as a single ``uint64`` with a trailing + size-1 "chunks" axis (matching the layout the NTT layer expects). + All arithmetic is a plain ``numpy`` multiply / tensordot followed by + ``% prime`` — correctness only, no JAX / no TPU. + """ + + def __init__(self, parameters: Dict[str, Any]): + self.prime = int(parameters["prime"]) + self.for_ntt = True + + # ---------- format conversion ---------- + + def to_computational_format(self, a) -> np.ndarray: + """Wrap ``a`` as an ``(..., 1)`` uint64 ndarray.""" + arr = np.asarray(a, dtype=np.uint64) + return arr[..., np.newaxis] + + def to_original_format(self, a): + """Reduce, flatten, return a flat list of Python ints.""" + arr = np.asarray(a, dtype=np.uint64) + return (arr.reshape(-1) % self.prime).tolist() + + # ---------- twiddle preprocessing ---------- + + def preprocess_matmul(self, mat_2d) -> np.ndarray: + arr = np.asarray(mat_2d) + if arr.ndim != 2: + raise ValueError(f"matmul twiddle must be 2-D, got shape {arr.shape}") + if arr.dtype != object: + arr = arr.astype(object) + return (arr % self.prime).astype(np.uint64)[..., np.newaxis] + + def preprocess_elementwise(self, mat_nd) -> np.ndarray: + arr = np.asarray(mat_nd) + if arr.dtype != object: + arr = arr.astype(object) + return (arr % self.prime).astype(np.uint64)[..., np.newaxis] + + # ---------- modular arithmetic ---------- + + def modular_matmul( + self, operand: np.ndarray, handle: np.ndarray, contract_axis: int + ) -> np.ndarray: + """Right-convention modular matmul along ``contract_axis``. + + ``operand`` has trailing size-1 axis (the "chunks" axis). + ``handle`` has shape ``(K, J, 1)`` — axis 0 is the contracted dim. + """ + op = np.asarray(operand, dtype=np.uint64)[..., 0] + h = np.asarray(handle, dtype=np.uint64)[..., 0] + if contract_axis < 0: + contract_axis += op.ndim + 1 # +1 because we squeezed trailing chunks + # tensordot contracts op.axes[contract_axis] with h.axis[0]; the new + # J axis from h lands at the end of the result, so move it back to + # ``contract_axis``. + out = np.tensordot(op, h, axes=([contract_axis], [0])) + out = np.moveaxis(out, -1, contract_axis) + out = out % self.prime + return out[..., np.newaxis] + + def modular_multiply_broadcast( + self, a: np.ndarray, b: np.ndarray + ) -> np.ndarray: + """Broadcasting element-wise modular multiply.""" + a_arr = np.asarray(a, dtype=np.uint64) + b_arr = np.asarray(b, dtype=np.uint64) + return (a_arr * b_arr) % self.prime + + +class NTTBase(NumberTheoryTransformBase): + """Shared scaffolding for the unified NTT classes. + + All NTT classes below follow the same algorithm skeleton (precompute + twiddles via :py:func:`gen_twiddle_matrix`, apply bit-reversal + permutations, then alternate matmul / element-wise twiddle steps) + and dispatch all arithmetic through the extension context passed in + as ``finite_field_context``. + + The ``finite_field_context`` parameter accepts either: + + * An already-constructed extension context instance + (``DRNSLazyExtensionContext``, ``CROSSLazyExtensionContext``, or + ``NumpyCPUContext``). + * A backend string (``"drns"``, ``"cross"``, or ``"cpu"``), in which + case the extension context is auto-constructed from the remaining + parameters: + + ============ ======================================================= + ``"drns"`` ``num_moduli`` (default 21), ``precision_bits`` (28), + ``radix_bits`` (32). + ``"cross"`` ``chunk_num_u8`` (default derived from prime bit-length). + ``"cpu"`` No extra parameters. + ============ ======================================================= + """ + + @staticmethod + def _build_ff_ctx(parameters: dict): + """Auto-construct an extension context from top-level parameters. + + If ``finite_field_context`` is already an instance, return it as-is. + If it's a backend string, build the matching extension context using + ``prime`` and the optional sizing parameters from ``parameters``. + """ + ff_ctx = parameters["finite_field_context"] + if isinstance(ff_ctx, str): + prime = parameters["prime"] + backend = ff_ctx.lower() + if backend == "drns": + num_moduli = parameters.get("num_moduli", 21) + precision_bits = parameters.get("precision_bits", 28) + radix_bits = parameters.get("radix_bits", 32) + rns_moduli = utils.find_moduli_specified_number( + num_moduli, precision_bits + ) + return DRNSLazyExtensionContext({ + "prime": prime, + "rns_moduli": rns_moduli, + "precision_bits": precision_bits, + "radix_bits": radix_bits, + }) + elif backend == "cross": + ctx_params = {"prime": prime} + if "chunk_num_u8" in parameters: + ctx_params["chunk_num_u8"] = parameters["chunk_num_u8"] + return CROSSLazyExtensionContext(ctx_params) + elif backend == "cpu": + return NumpyCPUContext({"prime": prime}) + else: + raise ValueError( + f"Unknown backend {ff_ctx!r}; use 'drns', 'cross', or 'cpu'" + ) + return ff_ctx + + def __init__(self, ff_ctx, spatial_shape: tuple): + if not getattr(ff_ctx, "for_ntt", False): + raise TypeError( + "finite_field_context does not declare NTT support; " + f"{type(ff_ctx).__name__} must set ``self.for_ntt = True``" + ) + self.ff_ctx = ff_ctx + self._spatial_shape = spatial_shape + + def to_computational_format(self, a): + return self.ff_ctx.to_computational_format(a) + + def to_original_format(self, a): + a = jnp.asarray(a) + trailing = a.shape[-1] + return self.ff_ctx.to_original_format(a.reshape(-1, trailing)) + + def _ensure_ntt_shape(self, v: jnp.ndarray) -> jnp.ndarray: + v = jnp.asarray(v) + expected_ndim = 1 + len(self._spatial_shape) + 1 + if v.ndim < expected_ndim: + v = v.reshape(-1, *self._spatial_shape, v.shape[-1]) + return v + + +class NTT3Step(NTTBase): + """Unified 3-step NTT. Input/output shape: ``(B, R, C, trailing)``.""" + + def __init__(self, parameters: Dict[str, Any]): + self.prime = parameters["prime"] + self.r = parameters["r"] + self.c = parameters["c"] + ff_ctx = self._build_ff_ctx(parameters) + super().__init__(ff_ctx, spatial_shape=(self.r, self.c)) + + self.transform_length = self.r * self.c + psi = parameters.get("psi") + self.psi = ( + int(psi) + if psi is not None + else utils.root_of_unity(2 * self.transform_length, self.prime) + ) + self.omega = (self.psi**2) % self.prime + + # --- twiddle matrices (plain integer form) --- + omega_col = pow(self.omega, self.c, self.prime) + omega_row = pow(self.omega, self.r, self.prime) + ntt_tf1 = gen_twiddle_matrix(self.r, self.r, self.prime, omega_col) + ntt_tf2 = gen_twiddle_matrix(self.r, self.c, self.prime, self.omega) + ntt_tf3 = gen_twiddle_matrix(self.c, self.c, self.prime, omega_row) + + inv_omega_col = pow(omega_col, -1, self.prime) + inv_omega_row = pow(omega_row, -1, self.prime) + intt_tf1 = gen_twiddle_matrix(self.c, self.c, self.prime, inv_omega_row) + intt_tf2 = gen_twiddle_matrix_inv(self.r, self.c, self.prime, self.omega) + col_inv = pow(self.c, -1, self.prime) + row_inv = pow(self.r, -1, self.prime) + intt_tf2 = (intt_tf2 * col_inv) % self.prime + intt_tf3 = gen_twiddle_matrix(self.r, self.r, self.prime, inv_omega_col) + intt_tf3 = (intt_tf3 * row_inv) % self.prime + + # --- bit-reversal permutations --- + perm_r = get_bit_reverse_perm(self.r) + perm_c = get_bit_reverse_perm(self.c) + ntt_tf1 = ntt_tf1[perm_r, :] + ntt_tf2 = ntt_tf2[perm_r, :] + ntt_tf3 = ntt_tf3[:, perm_c] + intt_tf1 = intt_tf1[perm_c, :] + intt_tf2 = intt_tf2[perm_r, :] + intt_tf3 = intt_tf3[:, perm_r] + + # Right-convention handles: the contracted dim is at axis 0 of the + # handle. Step 1 of NTT is logically "T1 @ v along R", which becomes + # right-matmul of v against T1.T; step 3 is already "v @ T3". + self.ntt_t1 = ff_ctx.preprocess_matmul(ntt_tf1.T) + self.ntt_t2 = ff_ctx.preprocess_elementwise(ntt_tf2) + self.ntt_t3 = ff_ctx.preprocess_matmul(ntt_tf3) + self.intt_t1 = ff_ctx.preprocess_matmul(intt_tf1) + self.intt_t2 = ff_ctx.preprocess_elementwise(intt_tf2) + self.intt_t3 = ff_ctx.preprocess_matmul(intt_tf3.T) + + def ntt(self, v: jnp.ndarray): + v = self._ensure_ntt_shape(v) # (B, R, C, trailing) + v = self.ff_ctx.modular_matmul(v, self.ntt_t1, contract_axis=1) + v = self.ff_ctx.modular_multiply_broadcast(v, self.ntt_t2) + v = self.ff_ctx.modular_matmul(v, self.ntt_t3, contract_axis=2) + return v + + def intt(self, v: jnp.ndarray): + v = self._ensure_ntt_shape(v) + v = self.ff_ctx.modular_matmul(v, self.intt_t1, contract_axis=2) + v = self.ff_ctx.modular_multiply_broadcast(v, self.intt_t2) + v = self.ff_ctx.modular_matmul(v, self.intt_t3, contract_axis=1) + return v + + +class NTT5Step(NTTBase): + """Unified 5-step NTT. Shape: ``(B, RR, RC, C, trailing)``.""" + + def __init__(self, parameters: Dict[str, Any]): + self.prime = parameters["prime"] + self.rr = parameters["rr"] + self.rc = parameters["rc"] + self.c = parameters["c"] + ff_ctx = self._build_ff_ctx(parameters) + super().__init__(ff_ctx, spatial_shape=(self.rr, self.rc, self.c)) + + self.transform_length = self.rr * self.rc * self.c + R = self.rr * self.rc + psi = parameters.get("psi") + self.psi = ( + int(psi) + if psi is not None + else utils.root_of_unity(2 * self.transform_length, self.prime) + ) + self.omega = (self.psi**2) % self.prime + + omega_R = pow(self.omega, self.c, self.prime) + omega_RR = pow(omega_R, self.rc, self.prime) + omega_RC = pow(omega_R, self.rr, self.prime) + omega_C = pow(self.omega, R, self.prime) + + ntt_T1 = gen_twiddle_matrix(self.rr, self.rr, self.prime, omega_RR) + ntt_T2 = gen_twiddle_matrix(self.rr, self.rc, self.prime, omega_R) + ntt_T3 = gen_twiddle_matrix(self.rc, self.rc, self.prime, omega_RC) + ntt_T4 = gen_twiddle_matrix(R, self.c, self.prime, self.omega).reshape( + self.rr, self.rc, self.c + ) + ntt_T5 = gen_twiddle_matrix(self.c, self.c, self.prime, omega_C) + + inv_omega_RR = pow(omega_RR, -1, self.prime) + inv_omega_RC = pow(omega_RC, -1, self.prime) + inv_omega_C = pow(omega_C, -1, self.prime) + rr_inv = pow(self.rr, -1, self.prime) + rc_inv = pow(self.rc, -1, self.prime) + c_inv = pow(self.c, -1, self.prime) + + intt_T5 = gen_twiddle_matrix(self.c, self.c, self.prime, inv_omega_C) + intt_T4 = gen_twiddle_matrix_inv(R, self.c, self.prime, self.omega).reshape( + self.rr, self.rc, self.c + ) + intt_T4 = (intt_T4 * c_inv) % self.prime + intt_t3 = gen_twiddle_matrix(self.rc, self.rc, self.prime, inv_omega_RC) + intt_T2 = gen_twiddle_matrix_inv(self.rr, self.rc, self.prime, omega_R) + intt_T2 = (intt_T2 * rc_inv) % self.prime + intt_T1 = gen_twiddle_matrix(self.rr, self.rr, self.prime, inv_omega_RR) + intt_T1 = (intt_T1 * rr_inv) % self.prime + + perm_rr = get_bit_reverse_perm(self.rr) + perm_rc = get_bit_reverse_perm(self.rc) + perm_c = get_bit_reverse_perm(self.c) + + ntt_T1 = ntt_T1[perm_rr, :] + ntt_T2 = ntt_T2[perm_rr, :] + ntt_T3 = ntt_T3[:, perm_rc] + perm_R = get_bit_reverse_perm(R) + ntt_T4 = ntt_T4.reshape(R, self.c)[perm_R, :].reshape( + self.rr, self.rc, self.c + ) + ntt_T5 = ntt_T5[:, perm_c] + + intt_T5 = intt_T5[perm_c, :] + intt_T4 = intt_T4.reshape(R, self.c)[perm_R, :].reshape( + self.rr, self.rc, self.c + ) + intt_t3 = intt_t3[perm_rc, :] + intt_T2 = intt_T2[perm_rr, :] + intt_T1 = intt_T1[:, perm_rr] + + # Matmul twiddles: transpose when the original step was left-matmul. + self.ntt_T1 = ff_ctx.preprocess_matmul( + ntt_T1.T + ) # left-mat → right via transpose + self.ntt_T2 = ff_ctx.preprocess_elementwise(ntt_T2) # (RR, RC, trailing) + self.ntt_T3 = ff_ctx.preprocess_matmul(ntt_T3) # right-mat + self.ntt_T4 = ff_ctx.preprocess_elementwise(ntt_T4) # (RR, RC, C, trailing) + self.ntt_T5 = ff_ctx.preprocess_matmul(ntt_T5) # right-mat + + self.intt_T5 = ff_ctx.preprocess_matmul(intt_T5) + self.intt_T4 = ff_ctx.preprocess_elementwise(intt_T4) + self.intt_T3 = ff_ctx.preprocess_matmul(intt_t3) + self.intt_T2 = ff_ctx.preprocess_elementwise(intt_T2) + self.intt_T1 = ff_ctx.preprocess_matmul( + intt_T1.T + ) # undo the step-1 left-mat + + def ntt(self, v: jnp.ndarray): + v = self._ensure_ntt_shape(v) # (B, RR, RC, C, trailing) + v = self.ff_ctx.modular_matmul(v, self.ntt_T1, contract_axis=1) + v = self.ff_ctx.modular_multiply_broadcast(v, self.ntt_T2[:, :, None, :]) + v = self.ff_ctx.modular_matmul(v, self.ntt_T3, contract_axis=2) + v = self.ff_ctx.modular_multiply_broadcast(v, self.ntt_T4[None]) + v = self.ff_ctx.modular_matmul(v, self.ntt_T5, contract_axis=3) + return v + + def intt(self, v: jnp.ndarray): + v = self._ensure_ntt_shape(v) + v = self.ff_ctx.modular_matmul(v, self.intt_T5, contract_axis=3) + v = self.ff_ctx.modular_multiply_broadcast(v, self.intt_T4[None]) + v = self.ff_ctx.modular_matmul(v, self.intt_T3, contract_axis=2) + v = self.ff_ctx.modular_multiply_broadcast(v, self.intt_T2[:, :, None, :]) + v = self.ff_ctx.modular_matmul(v, self.intt_T1, contract_axis=1) + return v + + +class NTT7Step(NTTBase): + """Unified 7-step NTT. Shape: ``(B, RR, RC, CR, CC, trailing)``.""" + + def __init__(self, parameters: Dict[str, Any]): + self.prime = parameters["prime"] + self.rr = parameters["rr"] + self.rc = parameters["rc"] + self.cr = parameters["cr"] + self.cc = parameters["cc"] + ff_ctx = self._build_ff_ctx(parameters) + super().__init__(ff_ctx, spatial_shape=(self.rr, self.rc, self.cr, self.cc)) + + r_total = self.rr * self.rc + c_total = self.cr * self.cc + self.transform_length = r_total * c_total + psi = parameters.get("psi") + self.psi = ( + int(psi) + if psi is not None + else utils.root_of_unity(2 * self.transform_length, self.prime) + ) + self.omega = (self.psi**2) % self.prime + + omega_r = pow(self.omega, c_total, self.prime) + omega_rr = pow(omega_r, self.rc, self.prime) + omega_rc = pow(omega_r, self.rr, self.prime) + omega_c = pow(self.omega, r_total, self.prime) + omega_cr = pow(omega_c, self.cc, self.prime) + omega_cc = pow(omega_c, self.cr, self.prime) + + ntt_T1 = gen_twiddle_matrix(self.rr, self.rr, self.prime, omega_rr) + ntt_T2 = gen_twiddle_matrix(self.rr, self.rc, self.prime, omega_r) + ntt_T3 = gen_twiddle_matrix(self.rc, self.rc, self.prime, omega_rc) + ntt_T4 = gen_twiddle_matrix( + r_total, c_total, self.prime, self.omega + ).reshape(self.rr, self.rc, self.cr, self.cc) + ntt_T5 = gen_twiddle_matrix(self.cr, self.cr, self.prime, omega_cr) + ntt_T6 = gen_twiddle_matrix(self.cr, self.cc, self.prime, omega_c) + ntt_T7 = gen_twiddle_matrix(self.cc, self.cc, self.prime, omega_cc) + + inv_omega_rr = pow(omega_rr, -1, self.prime) + inv_omega_rc = pow(omega_rc, -1, self.prime) + inv_omega_cr = pow(omega_cr, -1, self.prime) + inv_omega_cc = pow(omega_cc, -1, self.prime) + rr_inv = pow(self.rr, -1, self.prime) + rc_inv = pow(self.rc, -1, self.prime) + cr_inv = pow(self.cr, -1, self.prime) + cc_inv = pow(self.cc, -1, self.prime) + + intt_T7 = gen_twiddle_matrix(self.cc, self.cc, self.prime, inv_omega_cc) + intt_T6 = gen_twiddle_matrix_inv(self.cr, self.cc, self.prime, omega_c) + intt_T6 = (intt_T6 * cc_inv) % self.prime + intt_T5 = gen_twiddle_matrix(self.cr, self.cr, self.prime, inv_omega_cr) + intt_T5 = (intt_T5 * cr_inv) % self.prime + intt_T4 = gen_twiddle_matrix_inv( + r_total, c_total, self.prime, self.omega + ).reshape(self.rr, self.rc, self.cr, self.cc) + intt_T3 = gen_twiddle_matrix(self.rc, self.rc, self.prime, inv_omega_rc) + intt_T2 = gen_twiddle_matrix_inv(self.rr, self.rc, self.prime, omega_r) + intt_T2 = (intt_T2 * rc_inv) % self.prime + intt_T1 = gen_twiddle_matrix(self.rr, self.rr, self.prime, inv_omega_rr) + intt_T1 = (intt_T1 * rr_inv) % self.prime + + perm_rr = get_bit_reverse_perm(self.rr) + perm_rc = get_bit_reverse_perm(self.rc) + perm_cr = get_bit_reverse_perm(self.cr) + perm_cc = get_bit_reverse_perm(self.cc) + + ntt_T1 = ntt_T1[perm_rr, :] + ntt_T2 = ntt_T2[perm_rr, :] + ntt_T3 = ntt_T3[:, perm_rc] + perm_r = get_bit_reverse_perm(r_total) + ntt_T4 = ntt_T4.reshape(r_total, c_total)[perm_r, :].reshape( + self.rr, self.rc, self.cr, self.cc + ) + ntt_T5 = ntt_T5[perm_cr, :] + ntt_T6 = ntt_T6[perm_cr, :] + ntt_T7 = ntt_T7[:, perm_cc] + + intt_T7 = intt_T7[perm_cc, :] + intt_T6 = intt_T6[perm_cr, :] + intt_T5 = intt_T5[:, perm_cr] + intt_T4 = intt_T4.reshape(r_total, c_total)[perm_r, :].reshape( + self.rr, self.rc, self.cr, self.cc + ) + intt_T3 = intt_T3[perm_rc, :] + intt_T2 = intt_T2[perm_rr, :] + intt_T1 = intt_T1[:, perm_rr] + + # Steps 1 and 5 are logically left-matmul; steps 3 and 7 are right. + self.ntt_T1 = ff_ctx.preprocess_matmul(ntt_T1.T) + self.ntt_T2 = ff_ctx.preprocess_elementwise(ntt_T2) + self.ntt_T3 = ff_ctx.preprocess_matmul(ntt_T3) + self.ntt_T4 = ff_ctx.preprocess_elementwise(ntt_T4) + self.ntt_T5 = ff_ctx.preprocess_matmul(ntt_T5.T) + self.ntt_T6 = ff_ctx.preprocess_elementwise(ntt_T6) + self.ntt_T7 = ff_ctx.preprocess_matmul(ntt_T7) + + self.intt_T7 = ff_ctx.preprocess_matmul(intt_T7) + self.intt_T6 = ff_ctx.preprocess_elementwise(intt_T6) + self.intt_T5 = ff_ctx.preprocess_matmul(intt_T5.T) + self.intt_T4 = ff_ctx.preprocess_elementwise(intt_T4) + self.intt_T3 = ff_ctx.preprocess_matmul(intt_T3) + self.intt_T2 = ff_ctx.preprocess_elementwise(intt_T2) + self.intt_T1 = ff_ctx.preprocess_matmul(intt_T1.T) + + def ntt(self, v: jnp.ndarray): + v = self._ensure_ntt_shape(v) # (B, RR, RC, CR, CC, trailing) + # Steps 1-3: inner NTT on R = RR * RC. + v = self.ff_ctx.modular_matmul(v, self.ntt_T1, contract_axis=1) + v = self.ff_ctx.modular_multiply_broadcast( + v, self.ntt_T2[:, :, None, None, :] + ) + v = self.ff_ctx.modular_matmul(v, self.ntt_T3, contract_axis=2) + # Step 4: outer twiddle (broadcast over batch only). + v = self.ff_ctx.modular_multiply_broadcast(v, self.ntt_T4[None]) + # Steps 5-7: inner NTT on C = CR * CC. + v = self.ff_ctx.modular_matmul(v, self.ntt_T5, contract_axis=3) + v = self.ff_ctx.modular_multiply_broadcast( + v, self.ntt_T6[None, None, None, :, :, :] + ) + v = self.ff_ctx.modular_matmul(v, self.ntt_T7, contract_axis=4) + return v + + def intt(self, v: jnp.ndarray): + v = self._ensure_ntt_shape(v) + v = self.ff_ctx.modular_matmul(v, self.intt_T7, contract_axis=4) + v = self.ff_ctx.modular_multiply_broadcast( + v, self.intt_T6[None, None, None, :, :, :] + ) + v = self.ff_ctx.modular_matmul(v, self.intt_T5, contract_axis=3) + v = self.ff_ctx.modular_multiply_broadcast(v, self.intt_T4[None]) + v = self.ff_ctx.modular_matmul(v, self.intt_T3, contract_axis=2) + v = self.ff_ctx.modular_multiply_broadcast( + v, self.intt_T2[:, :, None, None, :] + ) + v = self.ff_ctx.modular_matmul(v, self.intt_T1, contract_axis=1) + return v diff --git a/jaxite_ec/number_theory_transform_perf_test.py b/jaxite_ec/number_theory_transform_perf_test.py new file mode 100644 index 0000000..b990f51 --- /dev/null +++ b/jaxite_ec/number_theory_transform_perf_test.py @@ -0,0 +1,447 @@ +"""Performance tests for batched NTT. + +Profiles the six NTT designs produced by crossing + * step count (:class:`NTT3Step`, :class:`NTT5Step`, :class:`NTT7Step`) + * backend (DRNS via :class:`DRNSLazyExtensionContext`, + CROSS via :class:`CROSSLazyExtensionContext`) + +The batch (leading) dimension is sharded across all available JAX +devices. Performance is measured via ``jax.profiler`` through the +``KernelWrapper`` / ``Profiler`` helpers in ``profiler.py``, NOT +wall-clock time. Sharded-correctness checks live in +``number_theory_transform_test.py``. +""" + +import os + +import jax +from jax.experimental.shard_map import shard_map +import jax.numpy as jnp +import jaxite.jaxite_ec.number_theory_transform_context as ntt_context +from jaxite.jaxite_ec.profiler import KernelWrapper, Profiler, collect_logs +import jaxite.jaxite_ec.utils as utils + +from absl.testing import absltest +from absl.testing import parameterized + +jax.config.update("jax_enable_x64", True) + +# --------------------------------------------------------------------------- +# Trailing-dimension sizing — change these to sweep field-element width. +# NUM_MODULI drives the prime preset below (21 → 256-bit prime, 56 → 753-bit). +# --------------------------------------------------------------------------- +NUM_MODULI = 21 # DRNS: number of RNS moduli (trailing dim size) +PRECISION_BITS = 28 # DRNS: bit-width per modulus +RADIX_BITS = 32 # DRNS: Montgomery radix + +# --------------------------------------------------------------------------- +# Prime & 2N-th-root presets, keyed by NUM_MODULI. Each preset supplies a +# matching Q_PERF and PSI_PERF_BY_DEGREE (primitive 2*2^degree-th roots of +# unity mod Q_PERF). The per-degree psi sidesteps utils.root_of_unity(), +# which trial-divides q-1 and is infeasible for the 753-bit prime. +# --------------------------------------------------------------------------- +_PERF_PRIME_PRESETS: dict[int, dict] = { + # 256-bit prime — fits NUM_MODULI * PRECISION_BITS = 21 * 28 = 588 ≥ 256. + 21: { + "q": 0x8000000000000000000000000000000000000000000000000000000070000001, + "psi_by_degree": { + 14: ( + 0x210D1D264152132AE3E5610B7E230BCD0058FE66FB35C5713527EA1FA40D1845 + ), + 16: ( + 0x40D7C3F33672325E7B65C4A20B0BE07DD32F3EBB05C33DD8675D68EB3A8BDB6B + ), + 18: ( + 0x568FDDCD95737AC264EAADA546D74B051CA1B7FC5B8427DCE706674011E009E0 + ), + 20: ( + 0x3AAE36A59E8E4F95E3118AA64270D0E122E0FC9585380815F737A67D613B5516 + ), + 22: ( + 0x23E461BCC11091F4A355AD034B454991F9CFA113272B8DBDF38E895C68BE3702 + ), + }, + }, + # 753-bit prime — fits NUM_MODULI * PRECISION_BITS = 56 * 28 = 1568 ≥ 753. + 56: { + "q": ( + 0x100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000023C00001 + ), + "psi_by_degree": { + 14: ( + 0x987B6025AB8620C5CB8376257CBC1FED3F6A9CD8003B6CB78C442378C5CB76DF824DDFAD28F53E6EFED050F1193BD6B114DDBAF944860CE6CEEC6A4ECCC92D690996CD24ED61167E18B61F33C7F45CA1231F30751C16AA586DA157BC12DA + ), + 16: ( + 0xC35116503813C2FEA4AA3458B0F4F8B28BBCC8BEE7004D34FB47B57FE855C6D764AF4F7CE54C9334C6C957CDE2D613405AD78F5BF210772600A9E8FBC797E35DDFCB55643E3F9ADE2C9B4AD8D08F00D7224EE0E4E176C98E61232E23114C + ), + 18: ( + 0x80DE16A04909C865C8C3954F81D93780D98B4826B7C8AEC95A3149697831B5CE3BE84CEA28866B38A0458484ACDFBFC0FB26215FFF7083F52721E3BBE094FA42CCAF9D9DF1730211F6A211BED8C8128041609A585C1234113AC33FD11FD2 + ), + 20: ( + 0x4AE6B7005951673890AC5AEA5DFCA80F449C73138FBBA6EF4E26D43DDA812A3BE60C97E60450C3AA294B751F6FCAB9E3736C02191A19E87C08AB00F3A3D2B599CFC3D52A0886F3EDA8941813BCDBCD51930F838B04CFE32A1AEF52510324 + ), + 22: ( + 0x26DFA808FE842C5A3C50EED23D433805D7CA3BC5FE9EDFA38C0A325159D75B8DBF7ED80054D92678DCCC4CC6B9A1D47976DFDA07D39C463F312EC45147860C46F8E4ACE696AB8B7F789EBE170FE615C18902C15F97ABC7300DBD9F1870A + ), + }, + }, +} + +if NUM_MODULI not in _PERF_PRIME_PRESETS: + raise ValueError( + f"No Q_PERF/PSI preset for NUM_MODULI={NUM_MODULI}; " + f"available: {sorted(_PERF_PRIME_PRESETS)}" + ) +Q_PERF: int = _PERF_PRIME_PRESETS[NUM_MODULI]["q"] +PSI_PERF_BY_DEGREE: dict[int, int] = _PERF_PRIME_PRESETS[NUM_MODULI][ + "psi_by_degree" +] + +# DRNS configs (BAT einsum + Montgomery/CRNS — runs fast, can push to deg=22). +PERF_CONFIGS_DRNS_3STEP = [ + {"degree": 14, "r1": 7, "c": 7}, + # {"degree": 16, "r1": 8, "c": 8}, + # {"degree": 18, "r1": 9, "c": 9}, + # {"degree": 20, "r1": 10, "c": 10}, + # {"degree": 22, "r1": 11, "c": 11}, +] + +PERF_CONFIGS_DRNS_5STEP = [ + {"degree": 14, "r1": 5, "r2": 5, "c": 4}, + # {"degree": 16, "r1": 5, "r2": 5, "c": 6}, + # {"degree": 18, "r1": 6, "r2": 6, "c": 6}, + # {"degree": 20, "r1": 6, "r2": 6, "c": 8}, + # {"degree": 22, "r1": 7, "r2": 7, "c": 8}, +] + +PERF_CONFIGS_DRNS_7STEP = [ + {"degree": 16, "r1": 4, "r2": 4, "c1": 4, "c2": 4}, + # {"degree": 20, "r1": 5, "r2": 5, "c1": 5, "c2": 5}, +] + +# CROSS configs — CROSS's fori_loop-based matmul is ~1000× slower than +# DRNS BAT per NTT, so keep sizes modest to finish in reasonable time. +PERF_CONFIGS_CROSS_3STEP = [ + {"degree": 14, "r1": 7, "c": 7}, + # {"degree": 16, "r1": 8, "c": 8}, + # {"degree": 18, "r1": 9, "c": 9}, + # {"degree": 20, "r1": 10, "c": 10}, + # {"degree": 22, "r1": 11, "c": 11}, +] + +PERF_CONFIGS_CROSS_5STEP = [ + {"degree": 14, "r1": 5, "r2": 5, "c": 4}, + # {"degree": 16, "r1": 5, "r2": 5, "c": 6}, + # {"degree": 18, "r1": 6, "r2": 6, "c": 6}, + # {"degree": 20, "r1": 6, "r2": 6, "c": 8}, + # {"degree": 22, "r1": 7, "r2": 7, "c": 8}, +] + +PERF_CONFIGS_CROSS_7STEP = [ + {"degree": 16, "r1": 4, "r2": 4, "c1": 4, "c2": 4}, + # {"degree": 20, "r1": 5, "r2": 5, "c1": 5, "c2": 5}, +] + + +# --------------------------------------------------------------------------- +# Extension-context builders +# --------------------------------------------------------------------------- +def _create_drns_ff_ctx(prime): + rns_moduli = utils.find_moduli_specified_number(NUM_MODULI, PRECISION_BITS) + return ntt_context.DRNSLazyExtensionContext({ + "prime": prime, + "rns_moduli": rns_moduli, + "precision_bits": PRECISION_BITS, + "radix_bits": RADIX_BITS, + }) + + +def _create_cross_ff_ctx(prime): + params = {"prime": prime} + return ntt_context.CROSSLazyExtensionContext(params) + + +# --------------------------------------------------------------------------- +# Sharding helpers +# --------------------------------------------------------------------------- +def _create_sharding(): + """Create default batch sharding for the current device mesh.""" + available_devices = jax.devices() + if not available_devices: + raise RuntimeError("No devices available for sharding test.") + if len(available_devices) == 8: + mesh_shape = (2, 4) + elif len(available_devices) == 4: + mesh_shape = (2, 2) + elif len(available_devices) == 2: + mesh_shape = (2, 1) + else: + mesh_shape = (1, 1) + + mesh = jax.make_mesh(mesh_shape, ("x", "y")) + return mesh, jax.sharding.PartitionSpec + + +def _batch_sharding(mesh, partition_spec, ndim): + """NamedSharding that partitions only the leading (batch) axis.""" + axis_names = mesh.axis_names + batch_partition = axis_names if len(axis_names) > 1 else axis_names[0] + spec = (batch_partition,) + (None,) * (ndim - 1) + return jax.sharding.NamedSharding(mesh, partition_spec(*spec)) + + +# --------------------------------------------------------------------------- +# Kernel-wrapper helpers +# --------------------------------------------------------------------------- +def _ntt_kernel(input_array, parameters): + return parameters["ctx"].ntt(input_array) + + +def _intt_kernel(input_array, parameters): + return parameters["ctx"].intt(input_array) + + +def _shard_mapped_kernel(method_name): + """Build a kernel entry that shard-maps ``ctx.`` over batch. + + CROSS's NTT path uses ``jax.vmap`` internally over every leading axis + of ``_modular_multiply``. Under a plain jit with a batch-sharded + input, the broadcast of replicated twiddles against the sharded input + creates vmap axis-spec mismatches. Running the kernel under + ``shard_map`` gives every device a shard-local (unsharded) view, so + all broadcast / reshape / vmap ops inside the kernel see regular- + strided arrays while the heavy lifting still parallelizes across all + devices at the outer (shard_map) level. + """ + + def kernel(input_array, parameters): + ctx = parameters["ctx"] + mesh = parameters["mesh"] + batch_spec = parameters["batch_spec"] + fn = getattr(ctx, method_name) + mapped = shard_map( + fn, + mesh=mesh, + in_specs=batch_spec, + out_specs=batch_spec, + check_rep=False, + ) + return mapped(input_array) + + return kernel + + +# --------------------------------------------------------------------------- +# Performance profiling (jax.profiler traces via KernelWrapper + Profiler). +# One test method per design: 3 DRNS designs + 3 CROSS designs = 6 tests. +# --------------------------------------------------------------------------- +class NTTShardedPerformanceTest(parameterized.TestCase): + """Profile all six NTT designs (step-count × backend) at sharded batch.""" + + def setUp(self): + super().setUp() + outputs_dir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR") + self.output_trace_root = ( + os.path.join(outputs_dir, "log") + if outputs_dir + else os.path.join(os.path.dirname(os.path.abspath(__file__)), "log") + ) + self.profiler_config = { + "iterations": 1, + "save_to_file": True, + "enable_sharding": True, + } + + @classmethod + def tearDownClass(cls): + super().tearDownClass() + collect_logs(os.path.dirname(os.path.abspath(__file__))) + + # ------------------------------------------------------------------------- + # Profile driver: uniform entry point for DRNS and CROSS. Selects the + # right extension context, trailing-dim size, and per-device kernel + # wrapping (plain jit for DRNS, shard_map-wrapped jit for CROSS). + # ------------------------------------------------------------------------- + def _profile_design( + self, variant_name, configs, ntt_cls, backend, make_params + ): + assert backend in ("drns", "cross") + if backend == "drns": + ff_ctx = _create_drns_ff_ctx(Q_PERF) + trailing = len(ff_ctx.rns_moduli) + trailing_key = "num_moduli" + kernel_factory = lambda direction: ( + _ntt_kernel if direction == "ntt" else _intt_kernel + ) + else: + ff_ctx = _create_cross_ff_ctx(Q_PERF) + trailing = ff_ctx.chunk_num_u32 + trailing_key = "chunk_num_u32" + kernel_factory = lambda direction: _shard_mapped_kernel(direction) + + mesh, partition_spec = _create_sharding() + num_devices = len(jax.devices()) + + profiler = Profiler( + output_trace_path=self.output_trace_root, + profile_naming=f"sharded_{variant_name}_{ntt_cls.__name__}", + configuration=self.profiler_config, + ) + + for cfg in configs: + degree = cfg["degree"] + spatial_params, spatial_shape = make_params(cfg) + params = { + "prime": Q_PERF, + "finite_field_context": ff_ctx, + **spatial_params, + } + psi = PSI_PERF_BY_DEGREE.get(degree) + if psi is not None: + params["psi"] = psi + ntt_ctx = ntt_cls(params) + + ndim = 1 + len(spatial_shape) + 1 + batch_sharding = _batch_sharding(mesh, partition_spec, ndim) + + axis_names = mesh.axis_names + batch_partition = axis_names if len(axis_names) > 1 else axis_names[0] + batch_spec = partition_spec(batch_partition, *([None] * (ndim - 1))) + + setting_base = { + "context": ntt_cls.__name__, + "backend": backend, + "degree": degree, + "spatial_shape": str(spatial_shape), + trailing_key: trailing, + "num_devices": num_devices, + } + + batch = num_devices + input_shape = (batch,) + spatial_shape + (trailing,) + for direction in ("ntt", "intt"): + name = f"{variant_name}_{direction}_deg{degree}_batch{batch}" + wrapper = KernelWrapper( + kernel_name=name, + function_to_wrap=kernel_factory(direction), + input_structs=[(input_shape, jnp.uint32)], + parameters={"ctx": ntt_ctx, "mesh": mesh, "batch_spec": batch_spec}, + mesh=mesh, + input_shardings=(batch_sharding,), + output_sharding=batch_sharding, + enable_sharding=True, + ) + profiler.add_profile( + name=name, + kernel_wrapper=wrapper, + kernel_setting_cols={ + **setting_base, + "direction": direction, + "batch": batch, + }, + ) + + profiler.profile_all_profilers() + profiler.post_process_all_profilers() + + # ------------------------------------------------------------------------- + # 3 DRNS-based designs + # ------------------------------------------------------------------------- + def test_sharded_drns_3step(self): + """DRNS 3-step NTT (``NTT3Step`` + ``DRNSLazyExtensionContext``).""" + + def make_params(cfg): + r, c = 2 ** cfg["r1"], 2 ** cfg["c"] + return {"r": r, "c": c}, (r, c) + + self._profile_design( + "drns_3step", + PERF_CONFIGS_DRNS_3STEP, + ntt_context.NTT3Step, + "drns", + make_params, + ) + + def test_sharded_drns_5step(self): + """DRNS 5-step NTT (``NTT5Step`` + ``DRNSLazyExtensionContext``).""" + + def make_params(cfg): + rr, rc, c = 2 ** cfg["r1"], 2 ** cfg["r2"], 2 ** cfg["c"] + return {"rr": rr, "rc": rc, "c": c}, (rr, rc, c) + + self._profile_design( + "drns_5step", + PERF_CONFIGS_DRNS_5STEP, + ntt_context.NTT5Step, + "drns", + make_params, + ) + + def test_sharded_drns_7step(self): + """DRNS 7-step NTT (``NTT7Step`` + ``DRNSLazyExtensionContext``).""" + + def make_params(cfg): + rr, rc = 2 ** cfg["r1"], 2 ** cfg["r2"] + cr, cc = 2 ** cfg["c1"], 2 ** cfg["c2"] + return {"rr": rr, "rc": rc, "cr": cr, "cc": cc}, (rr, rc, cr, cc) + + self._profile_design( + "drns_7step", + PERF_CONFIGS_DRNS_7STEP, + ntt_context.NTT7Step, + "drns", + make_params, + ) + + # ------------------------------------------------------------------------- + # 3 CROSS-backed designs + # ------------------------------------------------------------------------- + def test_sharded_cross_3step(self): + """CROSS 3-step NTT (``NTT3Step`` + ``CROSSLazyExtensionContext``).""" + + def make_params(cfg): + r, c = 2 ** cfg["r1"], 2 ** cfg["c"] + return {"r": r, "c": c}, (r, c) + + self._profile_design( + "cross_3step", + PERF_CONFIGS_CROSS_3STEP, + ntt_context.NTT3Step, + "cross", + make_params, + ) + + def test_sharded_cross_5step(self): + """CROSS 5-step NTT (``NTT5Step`` + ``CROSSLazyExtensionContext``).""" + + def make_params(cfg): + rr, rc, c = 2 ** cfg["r1"], 2 ** cfg["r2"], 2 ** cfg["c"] + return {"rr": rr, "rc": rc, "c": c}, (rr, rc, c) + + self._profile_design( + "cross_5step", + PERF_CONFIGS_CROSS_5STEP, + ntt_context.NTT5Step, + "cross", + make_params, + ) + + def test_sharded_cross_7step(self): + """CROSS 7-step NTT (``NTT7Step`` + ``CROSSLazyExtensionContext``).""" + + def make_params(cfg): + rr, rc = 2 ** cfg["r1"], 2 ** cfg["r2"] + cr, cc = 2 ** cfg["c1"], 2 ** cfg["c2"] + return {"rr": rr, "rc": rc, "cr": cr, "cc": cc}, (rr, rc, cr, cc) + + self._profile_design( + "cross_7step", + PERF_CONFIGS_CROSS_7STEP, + ntt_context.NTT7Step, + "cross", + make_params, + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxite_ec/number_theory_transform_test.py b/jaxite_ec/number_theory_transform_test.py new file mode 100644 index 0000000..c972a48 --- /dev/null +++ b/jaxite_ec/number_theory_transform_test.py @@ -0,0 +1,428 @@ +import jax +import jax.numpy as jnp +import jaxite.jaxite_ec.number_theory_transform_context as ntt_context +import jaxite.jaxite_ec.utils as utils +import numpy as np + +from absl.testing import absltest +from absl.testing import parameterized + +MODULI = 134219681 + +# --------------------------------------------------------------------------- +# Trailing-dimension sizing — change these to sweep field-element width. +# --------------------------------------------------------------------------- +NUM_MODULI = 21 # DRNS: number of RNS moduli (trailing dim size), 21 for 256-bit, 56 for 753-bit +PRECISION_BITS = 28 # DRNS: bit-width per modulus +RADIX_BITS = 32 # DRNS: Montgomery radix +CHUNK_NUM_U8 = 32 # CROSS: override chunk_num_u8 (None = auto from prime), 32 for 256-bit, 95 for 753-bit + +TEST_VECTOR = { + "coef_in": [ + 105825732, + 68433452, + 36629220, + 126901109, + 89469849, + 106633716, + 15102657, + 108374459, + 68789927, + 23451922, + 93538050, + 20585372, + 30604976, + 37517995, + 65644325, + 102451383, + ], + "eval_in": [ + 26196696, + 45475009, + 10055359, + 23277424, + 69041040, + 71916973, + 73894069, + 3311254, + 44646798, + 49882443, + 28097016, + 70484730, + 10811958, + 11946041, + 61318182, + 19099272, + ], +} + + +def make_ntt_case(name, *layout_dims): + return ( + name, + MODULI, + None, + 1, + *layout_dims, + TEST_VECTOR["coef_in"], + TEST_VECTOR["eval_in"], + ) + + +NTT_3STEP = [make_ntt_case("0", 4, 4)] # Layout: (r, c) +NTT_5STEP = [make_ntt_case("0", 2, 2, 4)] # Layout: (rr, rc, c) +NTT_7STEP = [make_ntt_case("0", 2, 2, 2, 2)] # Layout: (rr, rc, cr, cc) + + +# --------------------------------------------------------------------------- +# Sharded correctness configs: (case_name, spatial_params, ntt_cls, spatial_shape) +# --------------------------------------------------------------------------- +SHARDED_CORRECTNESS_CONFIGS = [ + ("3step", {"r": 4, "c": 4}, ntt_context.NTT3Step, (4, 4)), + ("5step", {"rr": 2, "rc": 2, "c": 4}, ntt_context.NTT5Step, (2, 2, 4)), + ( + "7step", + {"rr": 2, "rc": 2, "cr": 2, "cc": 2}, + ntt_context.NTT7Step, + (2, 2, 2, 2), + ), +] + + +# --------------------------------------------------------------------------- +# Sharding helpers +# --------------------------------------------------------------------------- +def _create_sharding(): + """Create default batch sharding for the current device mesh.""" + available_devices = jax.devices() + if not available_devices: + raise RuntimeError("No devices available for sharding test.") + if len(available_devices) == 8: + mesh_shape = (2, 4) + elif len(available_devices) == 4: + mesh_shape = (2, 2) + elif len(available_devices) == 2: + mesh_shape = (2, 1) + else: + mesh_shape = (1, 1) + + mesh = jax.make_mesh(mesh_shape, ("x", "y")) + return mesh, jax.sharding.PartitionSpec + + +def _batch_sharding(mesh, partition_spec, ndim): + """NamedSharding that partitions only the leading (batch) axis.""" + axis_names = mesh.axis_names + batch_partition = axis_names if len(axis_names) > 1 else axis_names[0] + spec = (batch_partition,) + (None,) * (ndim - 1) + return jax.sharding.NamedSharding(mesh, partition_spec(*spec)) + + +def _tile_to_batch(single, shard_batch): + """Replicate a single ``(1, *spatial, trailing)`` array across batch dim.""" + return jnp.tile(single, (shard_batch,) + (1,) * (single.ndim - 1)) + + +def _create_drns_ff_ctx(prime): + rns_moduli = utils.find_moduli_specified_number(NUM_MODULI, PRECISION_BITS) + return ntt_context.DRNSLazyExtensionContext({ + "prime": prime, + "rns_moduli": rns_moduli, + "precision_bits": PRECISION_BITS, + "radix_bits": RADIX_BITS, + }) + + +class NTTTest(parameterized.TestCase): + + # @absltest.skip("Skip DRNS NTT tests") + @parameterized.named_parameters(*NTT_3STEP) + def test_DRNS_NTT_3step(self, q, psi, batch, r, c, coef_in, eval_in): + """Validate the 3-step NTT with DRNS lazy reduction.""" + rns_moduli = utils.find_moduli_specified_number(NUM_MODULI, PRECISION_BITS) + ff_ctx = ntt_context.DRNSLazyExtensionContext({ + "prime": q, + "rns_moduli": rns_moduli, + "precision_bits": PRECISION_BITS, + "radix_bits": RADIX_BITS, + }) + ntt_ctx = ntt_context.NTT3Step( + {"prime": q, "r": r, "c": c, "finite_field_context": ff_ctx} + ) + + # to_computational_format gives (N, num_moduli); ntt auto-reshapes + coef_drns = ntt_ctx.to_computational_format(coef_in) + + # Forward NTT + ntt_result = ntt_ctx.ntt(coef_drns) + np.testing.assert_array_equal( + eval_in, ntt_ctx.to_original_format(ntt_result) + ) + + # Inverse NTT + intt_result = ntt_ctx.intt(ntt_result) + np.testing.assert_array_equal( + coef_in, ntt_ctx.to_original_format(intt_result) + ) + + # @absltest.skip("Skip DRNS NTT tests") + @parameterized.named_parameters(*NTT_5STEP) + def test_DRNS_NTT_5step(self, q, psi, batch, rr, rc, c, coef_in, eval_in): + """Validate the 5-step NTT with DRNS lazy reduction.""" + rns_moduli = utils.find_moduli_specified_number(NUM_MODULI, PRECISION_BITS) + ff_ctx = ntt_context.DRNSLazyExtensionContext({ + "prime": q, + "rns_moduli": rns_moduli, + "precision_bits": PRECISION_BITS, + "radix_bits": RADIX_BITS, + }) + ntt_ctx = ntt_context.NTT5Step( + {"prime": q, "rr": rr, "rc": rc, "c": c, "finite_field_context": ff_ctx} + ) + + # to_computational_format gives (N, num_moduli); ntt auto-reshapes + coef_drns = ntt_ctx.to_computational_format(coef_in) + + # Forward NTT + ntt_result = ntt_ctx.ntt(coef_drns) + np.testing.assert_array_equal( + eval_in, ntt_ctx.to_original_format(ntt_result) + ) + + # Inverse NTT + intt_result = ntt_ctx.intt(ntt_result) + np.testing.assert_array_equal( + coef_in, ntt_ctx.to_original_format(intt_result) + ) + + # @absltest.skip("Skip DRNS NTT tests") + @parameterized.named_parameters(*NTT_7STEP) + def test_DRNS_NTT_7step( + self, q, psi, batch, rr, rc, cr, cc, coef_in, eval_in + ): + """Validate the 7-step NTT with DRNS lazy reduction.""" + rns_moduli = utils.find_moduli_specified_number(NUM_MODULI, PRECISION_BITS) + ff_ctx = ntt_context.DRNSLazyExtensionContext({ + "prime": q, + "rns_moduli": rns_moduli, + "precision_bits": PRECISION_BITS, + "radix_bits": RADIX_BITS, + }) + ntt_ctx = ntt_context.NTT7Step({ + "prime": q, + "rr": rr, + "rc": rc, + "cr": cr, + "cc": cc, + "finite_field_context": ff_ctx, + }) + + coef_drns = ntt_ctx.to_computational_format(coef_in) + + # Forward NTT + ntt_result = ntt_ctx.ntt(coef_drns) + np.testing.assert_array_equal( + eval_in, ntt_ctx.to_original_format(ntt_result) + ) + + # Inverse NTT + intt_result = ntt_ctx.intt(ntt_result) + np.testing.assert_array_equal( + coef_in, ntt_ctx.to_original_format(intt_result) + ) + + # @absltest.skip("Skip CROSS NTT tests") + @parameterized.named_parameters(*NTT_3STEP) + def test_CROSS_NTT_3step(self, q, psi, batch, r, c, coef_in, eval_in): + """Validate the 3-step NTT with CROSS lazy matrix reduction.""" + cross_params = {"prime": q} + if CHUNK_NUM_U8 is not None: + cross_params["chunk_num_u8"] = CHUNK_NUM_U8 + ff_ctx = ntt_context.CROSSLazyExtensionContext(cross_params) + ntt_ctx = ntt_context.NTT3Step( + {"prime": q, "r": r, "c": c, "finite_field_context": ff_ctx} + ) + + coef_cross = ntt_ctx.to_computational_format(coef_in) + + ntt_result = ntt_ctx.ntt(coef_cross) + np.testing.assert_array_equal( + eval_in, ntt_ctx.to_original_format(ntt_result) + ) + + intt_result = ntt_ctx.intt(ntt_result) + np.testing.assert_array_equal( + coef_in, ntt_ctx.to_original_format(intt_result) + ) + + # @absltest.skip("Skip CROSS NTT tests") + @parameterized.named_parameters(*NTT_5STEP) + def test_CROSS_NTT_5step(self, q, psi, batch, rr, rc, c, coef_in, eval_in): + """Validate the 5-step NTT with CROSS lazy matrix reduction.""" + cross_params = {"prime": q} + if CHUNK_NUM_U8 is not None: + cross_params["chunk_num_u8"] = CHUNK_NUM_U8 + ff_ctx = ntt_context.CROSSLazyExtensionContext(cross_params) + ntt_ctx = ntt_context.NTT5Step( + {"prime": q, "rr": rr, "rc": rc, "c": c, "finite_field_context": ff_ctx} + ) + + coef_cross = ntt_ctx.to_computational_format(coef_in) + + ntt_result = ntt_ctx.ntt(coef_cross) + np.testing.assert_array_equal( + eval_in, ntt_ctx.to_original_format(ntt_result) + ) + + intt_result = ntt_ctx.intt(ntt_result) + np.testing.assert_array_equal( + coef_in, ntt_ctx.to_original_format(intt_result) + ) + + # @absltest.skip("Skip CROSS NTT tests") + @parameterized.named_parameters(*NTT_7STEP) + def test_CROSS_NTT_7step( + self, q, psi, batch, rr, rc, cr, cc, coef_in, eval_in + ): + """Validate the 7-step NTT with CROSS lazy matrix reduction.""" + cross_params = {"prime": q} + if CHUNK_NUM_U8 is not None: + cross_params["chunk_num_u8"] = CHUNK_NUM_U8 + ff_ctx = ntt_context.CROSSLazyExtensionContext(cross_params) + ntt_ctx = ntt_context.NTT7Step({ + "prime": q, + "rr": rr, + "rc": rc, + "cr": cr, + "cc": cc, + "finite_field_context": ff_ctx, + }) + + coef_cross = ntt_ctx.to_computational_format(coef_in) + + ntt_result = ntt_ctx.ntt(coef_cross) + np.testing.assert_array_equal( + eval_in, ntt_ctx.to_original_format(ntt_result) + ) + + intt_result = ntt_ctx.intt(ntt_result) + np.testing.assert_array_equal( + coef_in, ntt_ctx.to_original_format(intt_result) + ) + + # --------------------------------------------------------------------- + # Unified NTT + NumpyCPUContext parity tests: NTT{3,5,7}Step with a + # NumpyCPUContext backend must produce the same output as the + # CPUCROSS{,5Step,7Step}Context legacy reference classes. + # --------------------------------------------------------------------- + # @absltest.skip("Skip NumpyCPU NTT tests") + @parameterized.named_parameters(*NTT_3STEP) + def test_NumpyCPU_NTT_3step(self, q, psi, batch, r, c, coef_in, eval_in): + ff_ctx = ntt_context.NumpyCPUContext({"prime": q}) + ntt_ctx = ntt_context.NTT3Step( + {"prime": q, "r": r, "c": c, "finite_field_context": ff_ctx} + ) + coef_cpu = ntt_ctx.to_computational_format(coef_in) + ntt_result = ntt_ctx.ntt(coef_cpu) + np.testing.assert_array_equal( + eval_in, ntt_ctx.to_original_format(ntt_result) + ) + intt_result = ntt_ctx.intt(ntt_result) + np.testing.assert_array_equal( + coef_in, ntt_ctx.to_original_format(intt_result) + ) + + # @absltest.skip("Skip NumpyCPU NTT tests") + @parameterized.named_parameters(*NTT_5STEP) + def test_NumpyCPU_NTT_5step(self, q, psi, batch, rr, rc, c, coef_in, eval_in): + ff_ctx = ntt_context.NumpyCPUContext({"prime": q}) + ntt_ctx = ntt_context.NTT5Step( + {"prime": q, "rr": rr, "rc": rc, "c": c, "finite_field_context": ff_ctx} + ) + coef_cpu = ntt_ctx.to_computational_format(coef_in) + ntt_result = ntt_ctx.ntt(coef_cpu) + np.testing.assert_array_equal( + eval_in, ntt_ctx.to_original_format(ntt_result) + ) + intt_result = ntt_ctx.intt(ntt_result) + np.testing.assert_array_equal( + coef_in, ntt_ctx.to_original_format(intt_result) + ) + + # @absltest.skip("Skip NumpyCPU NTT tests") + @parameterized.named_parameters(*NTT_7STEP) + def test_NumpyCPU_NTT_7step( + self, q, psi, batch, rr, rc, cr, cc, coef_in, eval_in + ): + ff_ctx = ntt_context.NumpyCPUContext({"prime": q}) + ntt_ctx = ntt_context.NTT7Step({ + "prime": q, + "rr": rr, + "rc": rc, + "cr": cr, + "cc": cc, + "finite_field_context": ff_ctx, + }) + coef_cpu = ntt_ctx.to_computational_format(coef_in) + ntt_result = ntt_ctx.ntt(coef_cpu) + np.testing.assert_array_equal( + eval_in, ntt_ctx.to_original_format(ntt_result) + ) + intt_result = ntt_ctx.intt(ntt_result) + np.testing.assert_array_equal( + coef_in, ntt_ctx.to_original_format(intt_result) + ) + + +# --------------------------------------------------------------------------- +# Sharded correctness (DRNS only — CROSS correctness is covered by the +# per-backend tests above since CROSS's device-0 closure constants conflict +# with shard_map under the small-mesh jit harness). +# --------------------------------------------------------------------------- +class NTTShardedCorrectnessTest(parameterized.TestCase): + """Verifies that batched, device-sharded NTT is element-wise correct.""" + + # @absltest.skip("Skip DRNS NTT correctness tests") + @parameterized.named_parameters(*SHARDED_CORRECTNESS_CONFIGS) + def test_sharded_ntt_correctness( + self, spatial_params, ctx_cls, spatial_shape + ): + coef_in = TEST_VECTOR["coef_in"] + eval_in = TEST_VECTOR["eval_in"] + ff_ctx = _create_drns_ff_ctx(MODULI) + params = {"prime": MODULI, "finite_field_context": ff_ctx, **spatial_params} + ntt_ctx = ctx_cls(params) + + coef_drns = ntt_ctx.to_computational_format(coef_in) + coef_drns = ntt_ctx._ensure_ntt_shape(coef_drns) + + mesh, partition_spec = _create_sharding() + shard_batch = len(jax.devices()) + batched = _tile_to_batch(coef_drns, shard_batch) + sharding = _batch_sharding(mesh, partition_spec, batched.ndim) + batched_sharded = jax.device_put(batched, sharding) + + jit_ntt = jax.jit(ntt_ctx.ntt, out_shardings=sharding) + jit_intt = jax.jit(ntt_ctx.intt, out_shardings=sharding) + + ntt_result = jit_ntt(batched_sharded) + ntt_host = np.asarray(ntt_result) + for i in range(shard_batch): + np.testing.assert_array_equal( + eval_in, + ntt_ctx.to_original_format(ntt_host[i]), + err_msg=f"NTT mismatch at batch index {i}", + ) + + intt_result = jit_intt(ntt_result) + intt_host = np.asarray(intt_result) + for i in range(shard_batch): + np.testing.assert_array_equal( + coef_in, + ntt_ctx.to_original_format(intt_host[i]), + err_msg=f"INTT mismatch at batch index {i}", + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/jaxite_ec/pippenger.py b/jaxite_ec/pippenger.py deleted file mode 100644 index 63d5372..0000000 --- a/jaxite_ec/pippenger.py +++ /dev/null @@ -1,1319 +0,0 @@ -"""Pippenger algorithm for elliptic curves. - -This module implements the Pippenger algorithm for elliptic curves. The -algorithm -is a generalization of the elliptic curve algorithm that can be used to find -elliptic curves of arbitrary order. - -The Pippenger algorithm works by first finding a set of "scalars" and "points" -that lie on the elliptic curve. These scalars and points are then used to -construct -a "bucket" for each window of the elliptic curve. The buckets are then reduced -to a single point for each window. Finally, the points from all of the windows -are merged together to form the final elliptic curve. - -The Pippenger algorithm is a powerful tool for finding elliptic curves, and it -has been used to find elliptic curves of arbitrary order and dimension. - -Example Compiled Function: - -bucket_accumulation_scan_jax_jit = jax.jit( - bucket_accumulation_scan_jax, - static_argnames='msm_length').lower( - jax.ShapeDtypeStruct( - (COORDINATE_NUM, WINDOW_NUM*BUCKET_NUM_PER_WINDOW, CHUNK_NUM), - dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (MSM_LENGTH, COORDINATE_NUM, CHUNK_NUM), - dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (MSM_LENGTH,WINDOW_NUM*BUCKET_NUM_PER_WINDOW), - dtype=jnp.uint8 - ), - jax.ShapeDtypeStruct( - (MSM_LENGTH,WINDOW_NUM*BUCKET_NUM_PER_WINDOW), - dtype=jnp.uint8 - ), - MSM_LENGTH - ).compile() - -bucket_accumulation_index_scan_jax_jit = jax.jit( - bucket_accumulation_index_scan_jax, - static_argnames='msm_length').lower( - jax.ShapeDtypeStruct( - (COORDINATE_NUM, WINDOW_NUM,BUCKET_NUM_PER_WINDOW, CHUNK_NUM), - dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (MSM_LENGTH, COORDINATE_NUM, CHUNK_NUM), - dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (MSM_LENGTH,WINDOW_NUM), - dtype=jnp.uint32 - ), - jax.ShapeDtypeStruct( - (MSM_LENGTH,WINDOW_NUM,BUCKET_NUM_PER_WINDOW), - dtype=jnp.uint8 - ), - MSM_LENGTH - ).compile() - -bucket_reduction_scan_jax_jit = jax.jit( - bucket_reduction_scan_jax, - static_argnames='bucket_num_in_window').lower( - jax.ShapeDtypeStruct( - (COORDINATE_NUM, WINDOW_NUM, BUCKET_NUM_PER_WINDOW, CHUNK_NUM), - dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (COORDINATE_NUM, WINDOW_NUM, CHUNK_NUM), - dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (COORDINATE_NUM, WINDOW_NUM, CHUNK_NUM), - dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (BUCKET_NUM_PER_WINDOW, WINDOW_NUM), - dtype=jnp.uint8 - ), - jax.ShapeDtypeStruct( - (BUCKET_NUM_PER_WINDOW, WINDOW_NUM), - dtype=jnp.uint8 - ), - jax.ShapeDtypeStruct( - (BUCKET_NUM_PER_WINDOW, WINDOW_NUM,), - dtype=jnp.uint8 - ), - BUCKET_NUM_PER_WINDOW - ).compile() - -window_merge_scan_jax_jit = jax.jit( - window_merge_scan_jax, - static_argnames='slice_length').lower( - jax.ShapeDtypeStruct((COORDINATE_NUM, WINDOW_NUM, CHUNK_NUM), - dtype=jnp.uint16), - SLICE_LENGTH - ).compile() - -selective_padd_with_zero_jit = jax.jit( - selective_padd_with_zero).lower( - jax.ShapeDtypeStruct( - (COORDINATE_NUM, WINDOW_NUM*BUCKET_NUM_PER_WINDOW, COORDINATE_NUM), - dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (COORDINATE_NUM, MSM_LENGTH, COORDINATE_NUM), - dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (WINDOW_NUM*BUCKET_NUM_PER_WINDOW,), - dtype=jnp.uint8 - ), - jax.ShapeDtypeStruct( - (WINDOW_NUM*BUCKET_NUM_PER_WINDOW,), - dtype=jnp.uint8 - ), - ).compile() -""" - -import copy -import math -from typing import List - -import jax -import jax.numpy as jnp -from jaxite.jaxite_ec import util -import jaxite.jaxite_ec.elliptic_curve as jec -import jaxite.jaxite_ec.finite_field as ff -import numpy as np - - -deepcopy = copy.deepcopy - -""" -Example Parameters: -SLICE_LENGTH = 4 -WINDOW_NUM = int(math.ceil(255 / SLICE_LENGTH)) -BUCKET_NUM_PER_WINDOW = 2**SLICE_LENGTH -MSM_LENGTH = 1024 -COORDINATE_NUM = 4 -""" - - -def selective_padd_with_zero(partial_sum, single_point, select, is_zero): - """Padd the partial sum with the single point, but only if the selection state is 1. - - Args: - partial_sum: The partial sum. - single_point: The single point. - select: The selection state. - is_zero: The zero states. - - Returns: - The new partial sum. - """ - _, batch_dim, _ = partial_sum.shape - new_partial_sum = jec.padd_lazy_xyzz_pack(partial_sum, single_point) - - cond_select = jnp.equal(select, 1).reshape(1, batch_dim, 1) - sum_result = jnp.where(cond_select, new_partial_sum, partial_sum) - - cond_zero = jnp.equal(is_zero, 1).reshape(1, batch_dim, 1) - cond_select_and_zero = jnp.logical_and(cond_select, cond_zero) - result = jnp.where(cond_select_and_zero, single_point, sum_result) - return result - - -def padd_with_zero(partial_sum, single_point, ps_is_zero, sp_is_zero): - """Padd the partial sum with the single point. - - Check if the partial sum is equal to the single point first. - - Args: - partial_sum: The partial sum. - single_point: The single point. - ps_is_zero: The zero states of the partial sum. - sp_is_zero: The zero states of the single point. - - Returns: - The new partial sum. - """ - _, batch_dim, _ = partial_sum.shape - new_partial_sum = jec.padd_lazy_xyzz_pack(partial_sum, single_point) - cond_sp_zero = jnp.equal(sp_is_zero, 1).reshape(1, batch_dim, 1) - cond_ps_zero = jnp.equal(ps_is_zero, 1).reshape(1, batch_dim, 1) - result_1 = jnp.where(cond_sp_zero, partial_sum, single_point) - result_2 = jnp.where( - jnp.logical_or(cond_sp_zero, cond_ps_zero), result_1, new_partial_sum - ) - return result_2 - - -def padd_with_zero_alter(partial_sum, single_point, ps_is_zero): - - _, batch_dim, _ = partial_sum.shape - new_partial_sum = jec.padd_lazy_xyzz_pack(partial_sum, single_point) - cond_ps_zero = jnp.equal(ps_is_zero, 1).reshape(1, batch_dim, 1) - result_2 = jnp.where(cond_ps_zero, single_point, new_partial_sum) - return result_2 - - -def padd_with_zero_and_pdul_check( - partial_sum, single_point, ps_is_zero, sp_is_zero -): - """Padd the partial sum with the single point. - - Check if the partial sum is equal to the single point first. If they are - equal, then double the partial sum. - - Args: - partial_sum: The partial sum. - single_point: The single point. - ps_is_zero: The zero states of the partial sum. - sp_is_zero: The zero states of the single point. - - Returns: - The new partial sum. - """ - # coordinate_dim, batch_dim, precision_dim = partial_sum.shape - _, batch_dim, _ = partial_sum.shape - new_partial_sum = jec.padd_lazy_xyzz_pack(partial_sum, single_point) - double_partial_sum = jec.pdul_lazy_xyzz_pack(partial_sum) - cond_equal = jnp.all(partial_sum == single_point, axis=(0, 2)).reshape( - 1, batch_dim, 1 - ) - cond_sp_zero = jnp.equal(sp_is_zero, 1).reshape(1, batch_dim, 1) - cond_ps_zero = jnp.equal(ps_is_zero, 1).reshape(1, batch_dim, 1) - result_1 = jnp.where(cond_sp_zero, partial_sum, single_point) - result_2 = jnp.where( - jnp.logical_or(cond_sp_zero, cond_ps_zero), result_1, new_partial_sum - ) - reuslt_3 = jnp.where(cond_equal, double_partial_sum, result_2) - return reuslt_3 - - -def bucket_accumulation_algorithm( - all_buckets: jnp.ndarray, - all_points: jnp.ndarray, - selection_list: jnp.ndarray, - zero_states_list: jnp.ndarray, - msm_length: int, -): - """Non-scan version BA.""" - coordinate_dim, buckets_dim, precision_dim = all_buckets.shape - for i in range(msm_length): - point = jax.lax.broadcast_in_dim( - all_points[i], (coordinate_dim, buckets_dim, precision_dim), (0, 2) - ) - all_buckets = selective_padd_with_zero( - all_buckets, point, selection_list[i], zero_states_list[i] - ) - return all_buckets - - -def bucket_accumulation_scan_algorithm( - all_buckets: jnp.ndarray, - all_points: jnp.ndarray, - selection_list: jnp.ndarray, - zero_states_list: jnp.ndarray, - msm_length: int, -): - """Scan version BA.""" - coordinate_dim, buckets_dim, precision_dim = all_buckets.shape - - def scan_body(buckets, point_with_cond_pack): - point, selection, zero_states = point_with_cond_pack - point = jax.lax.broadcast_in_dim( - point, (coordinate_dim, buckets_dim, precision_dim), (0, 2) - ) - all_buckets = selective_padd_with_zero( - buckets, point, selection, zero_states - ) - return all_buckets, None - - all_buckets, _ = jax.lax.scan( - scan_body, - all_buckets, - (all_points, selection_list, zero_states_list), - length=msm_length, - ) - return all_buckets - - -def bucket_accumulation_index_algorithm( - all_buckets: jnp.ndarray, - all_points: jnp.ndarray, - selection_index_list: jnp.ndarray, - zero_states_list: jnp.ndarray, - msm_length: int, -): - """Non-scan version BA with index selection.""" - # buckets_dim is not used in the algorithm, changed it into _. - coordinate_dim, window_dim, _, precision_dim = all_buckets.shape - for i in range(msm_length): - point = jax.lax.broadcast_in_dim( - all_points[i], (coordinate_dim, window_dim, precision_dim), (0, 2) - ) - selective_buckets = all_buckets[ - :, jnp.arange(window_dim), selection_index_list[i], : - ] - selective_zero_states = zero_states_list[ - i, jnp.arange(window_dim), selection_index_list[i] - ] - selective_update = padd_with_zero_alter( - selective_buckets, point, selective_zero_states - ) - all_buckets = all_buckets.at[ - :, jnp.arange(window_dim), selection_index_list[i], : - ].set(selective_update) - return all_buckets - - -def bucket_accumulation_index_scan_algorithm( - all_buckets: jnp.ndarray, - all_points: jnp.ndarray, - selection_index_list: jnp.ndarray, - zero_states_list: jnp.ndarray, - msm_length: int, -): - """Scan version BA with index selection.""" - # buckets_dim is not used in the algorithm, changed it into _. - coordinate_dim, window_dim, _, precision_dim = all_buckets.shape - - def scan_body(buckets, point_with_cond_pack): - point, selection_index, zero_states = point_with_cond_pack - point = jax.lax.broadcast_in_dim( - point, (coordinate_dim, window_dim, precision_dim), (0, 2) - ) - selective_buckets = buckets[:, jnp.arange(window_dim), selection_index, :] - selective_zero_states = zero_states[jnp.arange(window_dim), selection_index] - selective_update = padd_with_zero_alter( - selective_buckets, point, selective_zero_states - ) - return ( - buckets.at[:, jnp.arange(window_dim), selection_index, :].set( - selective_update - ), - None, - ) - - all_buckets, _ = jax.lax.scan( - scan_body, - all_buckets, - (all_points, selection_index_list, zero_states_list), - length=msm_length, - ) - return all_buckets - - -def bucket_reduction_algorithm( - all_buckets: jnp.ndarray, - temp_sum: jnp.ndarray, - window_sum: jnp.ndarray, - bucket_zero_states_list: jnp.ndarray, - temp_zero_states_list: jnp.ndarray, - window_zero_states_list: jnp.ndarray, - bucket_num_in_window: int, -): - """Non-scan version BR.""" - all_buckets = jnp.flip(all_buckets.transpose(2, 0, 1, 3), axis=0) - for i in range(bucket_num_in_window - 1): - temp_sum = padd_with_zero( - temp_sum, - all_buckets[i], - temp_zero_states_list[i], - bucket_zero_states_list[i], - ) - window_sum = padd_with_zero_and_pdul_check( - window_sum, - temp_sum, - window_zero_states_list[i], - temp_zero_states_list[i + 1], - ) - return window_sum - - -def bucket_reduction_scan_algorithm( - all_buckets: jnp.ndarray, - temp_sum: jnp.ndarray, - window_sum: jnp.ndarray, - bucket_zero_states_list: jnp.ndarray, - temp_zero_states_list: jnp.ndarray, - window_zero_states_list: jnp.ndarray, - bucket_num_in_window: int, -): - """Scan version BR.""" - all_buckets = jnp.flip(all_buckets.transpose(2, 0, 1, 3), axis=0) - - def scan_body(temp_and_window_sum_pack, bucket_with_cond_pack): - temp_sum, window_sum = temp_and_window_sum_pack - ( - bucket, - bucket_zero_states, - temp_zero_states, - temp_zero_states1, - window_zero_states, - ) = bucket_with_cond_pack - temp_sum = padd_with_zero( - temp_sum, bucket, temp_zero_states, bucket_zero_states - ) - window_sum = padd_with_zero_and_pdul_check( - window_sum, temp_sum, window_zero_states, temp_zero_states1 - ) - return (temp_sum, window_sum), None - - (_, window_sum), _ = jax.lax.scan( - scan_body, - (temp_sum, window_sum), - ( - all_buckets[: bucket_num_in_window - 1], - bucket_zero_states_list[: bucket_num_in_window - 1], - temp_zero_states_list[: bucket_num_in_window - 1], - temp_zero_states_list[1:], - window_zero_states_list[: bucket_num_in_window - 1], - ), - length=bucket_num_in_window - 1, - ) - - return window_sum - - -def window_merge_algorithm(window_sum: jnp.ndarray, slice_length: int): - """Non-scan version WM.""" - coordinate_dim, window_dim, precision_dim = window_sum.shape - result = window_sum[:, window_dim - 1, :].reshape( - (coordinate_dim, 1, precision_dim) - ) - for w in range(window_dim - 2, -1, -1): - for _ in range(slice_length): - result = jec.pdul_lazy_xyzz_pack(result) - result = jec.padd_lazy_xyzz_pack( - result, - window_sum[:, w, :].reshape( - (coordinate_dim, 1, util.U16_EXT_CHUNK_NUM) - ), - ) - - result = result.reshape((coordinate_dim, precision_dim)) - return result - - -def window_merge_scan_algorithm(window_sum: jnp.ndarray, slice_length: int): - """Scan version WM.""" - coordinate_dim, window_dim, precision_dim = window_sum.shape - window_sum = window_sum.transpose(1, 0, 2) - result = window_sum[window_dim - 1, :, :].reshape( - (coordinate_dim, 1, precision_dim) - ) - - def fori_loop_body(_, result): - result = jec.pdul_lazy_xyzz_pack(result) - return result - - def scan_body(result, window_sum): - result = jax.lax.fori_loop(0, slice_length, fori_loop_body, result) - result = jec.padd_lazy_xyzz_pack( - result, window_sum.reshape((coordinate_dim, 1, util.U16_EXT_CHUNK_NUM)) - ) - return result, None - - result, _ = jax.lax.scan( - scan_body, - result, - window_sum[: window_dim - 1, :, :], - reverse=True, - length=window_dim - 1, - ) - result = result.reshape((coordinate_dim, precision_dim)) - return result - - -class MSMPippenger: - """Pippenger algorithm for elliptic curves. - - Attributes: - coordinate_num: The number of coordinates in the elliptic curve. - slice_length: The length of each slice in the elliptic curve. - window_num: The number of windows in the elliptic curve. - bucket_num_per_window: The number of buckets in each window. - bucket_num_in_window: The number of buckets in each window. - slice_mask: The mask for the slices in the elliptic curve. - blank_point: A JAX array of zeros, used to initialize the buckets. - all_buckets: A JAX array of all the buckets in the elliptic curve. - points: A list of JAX arrays, where each array represents an Orignal point - from the trace. - scalars: A list of integers, where each integer represents an Orignal scalar - from the trace. - all_points: A JAX array of all the points in the elliptic curve. from the - trace. - window_sum: A JAX array of the window sum. - bucket_zero_states: A JAX array of the zero states for the buckets. - temp_sum_zero_states: A JAX array of the zero states for the temp sum. - window_sum_zero_states: A JAX array of the zero states for the window sum. - zero_states_list: A JAX array of the zero states for the buckets. - selection_list: A JAX array of the selection states for the buckets. - selection_index_list: A JAX array of the selection index for the buckets. - msm_length: The length of the MSM trace. - result: The final elliptic curve. - lazy_mat: The lazy matrix used for padding and doubling. - """ - - def __init__(self, slice_length): - self.coordinate_num = util.COORDINATE_NUM - - self.slice_length = slice_length - self.window_num = int(math.ceil(254 / self.slice_length)) # - self.bucket_num_per_window = 2**self.slice_length - self.slice_mask = self.bucket_num_per_window - 1 - self.blank_point = util.int_list_to_array( - [0, 0, 0, 0], util.BASE, util.U16_EXT_CHUNK_NUM - ).reshape(self.coordinate_num, 1, util.U16_EXT_CHUNK_NUM) - - self.all_buckets = jnp.broadcast_to( - self.blank_point.reshape( - 1, self.coordinate_num, 1, util.U16_EXT_CHUNK_NUM - ).transpose(1, 0, 2, 3), - ( - self.coordinate_num, - self.window_num, - self.bucket_num_per_window, - util.U16_EXT_CHUNK_NUM, - ), - ) - - self.window_sum: jnp.ndarray - - self.msm_length = 0 - self.bucket_zero_states: jnp.ndarray - self.temp_sum_zero_states: jnp.ndarray - self.window_sum_zero_states: jnp.ndarray - self.zero_states_list: jnp.ndarray - self.selection_list: jnp.ndarray - self.selection_index_list: jnp.ndarray - self.all_points: jnp.ndarray - - self.scalars: List[int] = [] # Orignal scalar from the trace - # [Points, Points, ..., Points] - self.points: List[jnp.ndarray] = [] # Orignal points from the trace - self.lazy_mat = util.construct_lazy_matrix(util.MODULUS_377_INT) - - self.result = None - - def initialize(self, scalars, points): - """Initialize the Pippenger algorithm. - - Args: - scalars: A list of integers, where each integer represents an Orignal - scalar from the trace. - points: A list of JAX arrays, where each array represents an Orignal point - from the trace. - """ - # Initial internal selection from the scalar - self.scalars = scalars - self.msm_length = len(scalars) - - # Convert high-precision points into a vector of low-precision chunks - self.points = [ - util.int_list_to_array( - coordinates + [1, 1], util.BASE, util.U16_EXT_CHUNK_NUM - ) - for coordinates in points - ] # pytype: disable=container-type-mismatch - - self.all_points = jnp.array(self.points) - - # For BA - zero_states_pylist, selection_pylist, selection_index_pylist = ( - self.construct_ba_zero_states_and_selection() - ) - self.selection_list = jnp.array(selection_pylist, dtype=jnp.uint8).reshape( - (-1, self.window_num * self.bucket_num_per_window) - ) - # For index selection version BA - self.zero_states_list = jnp.array(zero_states_pylist, dtype=jnp.uint8) - self.selection_index_list = jnp.array( - selection_index_pylist, dtype=jnp.uint32 - ) - - # For BR - ( - bucket_zero_states_py, - temp_sum_zero_states_py, - window_sum_zero_states_py, - ) = self.construct_br_zero_states( - zero_states_pylist[len(zero_states_pylist) - 1] - ) - self.bucket_zero_states = jnp.array(bucket_zero_states_py, dtype=jnp.uint8) - self.temp_sum_zero_states = jnp.array( - temp_sum_zero_states_py, dtype=jnp.uint8 - ) - self.window_sum_zero_states = jnp.array( - window_sum_zero_states_py, dtype=jnp.uint8 - ) - - def bucket_accumulation(self, bucket_accumulation_index_func): - """BA index selection version.""" - self.all_buckets = bucket_accumulation_index_func( - self.all_buckets, - self.all_points, - self.selection_index_list, - self.zero_states_list[: self.msm_length], - ) - - return self.all_buckets - - def bucket_reduction(self, bucket_reduction_func): - """Reduce the buckets to a single point for each window.""" - temp_sum = jnp.broadcast_to( - self.blank_point, - ( - self.coordinate_num, - self.window_num, - util.U16_EXT_CHUNK_NUM, - ), - ) - window_sum = jnp.broadcast_to( - self.blank_point, - ( - self.coordinate_num, - self.window_num, - util.U16_EXT_CHUNK_NUM, - ), - ) - self.window_sum = bucket_reduction_func( - self.all_buckets, - temp_sum, - window_sum, - self.bucket_zero_states[: self.bucket_num_per_window], - self.temp_sum_zero_states[: self.bucket_num_per_window], - self.window_sum_zero_states[: self.bucket_num_per_window], - ) - return self.window_sum - - def window_merge(self, window_merge_func): - """Merge the windows to form the final elliptic curve. - - Args: - window_merge_func: The function to merge the windows. - - Returns: - The final elliptic curve. - """ - self.result = window_merge_func(self.window_sum) - return self.result - - def construct_ba_zero_states_and_selection(self): - """Construct the zero states and selection for the bucket accumulation (BA) step. - - Returns: - A tuple of two lists: the zero states for the bucket accumulation, and the - selection for the bucket accumulation. - """ - zero_states = [ - deepcopy([1] * self.bucket_num_per_window) - for _ in range(self.window_num) - ] - zero_states_list = [] - zero_states_list.append(deepcopy(zero_states)) - selection_list = [] - selection_index_list = [] # Used for index selection - for scalar in self.scalars: - # Compute the zero states for each scalar by time dependence - selection = [ - deepcopy(([0] * self.bucket_num_per_window)) - for _ in range(self.window_num) - ] - selection_index = [] - for w in range(self.window_num): - slice_index = (scalar >> (w * self.slice_length)) & self.slice_mask - zero_states[w][slice_index] = 0 - selection[w][slice_index] = 1 - selection_index.append(slice_index) - - selection_list.append(deepcopy(selection)) - zero_states_list.append(deepcopy(zero_states)) - selection_index_list.append(deepcopy(selection_index)) - return zero_states_list, selection_list, selection_index_list - - def construct_br_zero_states(self, bucket_zero_states): - """Construct the zero states for the bucket reduction (BR) step. - - Args: - bucket_zero_states: The zero states of the buckets. - - Returns: - A tuple of three lists: the zero states for the bucket reduction, the zero - states for the temporary sum, and the zero states for the window sum. - """ - temp_sum_zero_states = np.array([1] * self.window_num) - window_sum_zero_states = np.array([1] * self.window_num) - temp_sum_zero_states_list = [] - window_sum_zero_states_list = [] - temp_sum_zero_states_list.append(temp_sum_zero_states) - window_sum_zero_states_list.append(window_sum_zero_states) - bucket_zero_states_list = np.flip( - np.array(bucket_zero_states).transpose(1, 0), axis=0 - ) - for b in range(self.bucket_num_per_window): - next_temp_sum_zero_states = ( - temp_sum_zero_states_list[b] & bucket_zero_states_list[b] - ) - next_window_sum_zero_states = ( - window_sum_zero_states_list[b] & next_temp_sum_zero_states - ) - temp_sum_zero_states_list.append(next_temp_sum_zero_states) - window_sum_zero_states_list.append(next_window_sum_zero_states) - return ( - bucket_zero_states_list, - temp_sum_zero_states_list, - window_sum_zero_states_list, - ) - - -######################### -# Functions for twisted curve -######################### - - -def padd(partial_sum, single_point): - return jec.padd_lazy_twisted_pack(partial_sum, single_point) - - -def padd_with_pdul_check(partial_sum, single_point): - # coordinate_dim, batch_dim, precision_dim = partial_sum.shape - _, batch_dim, _ = partial_sum.shape - new_partial_sum = jec.padd_lazy_twisted_pack(partial_sum, single_point) - double_partial_sum = jec.pdul_lazy_twisted_pack(partial_sum) - cond_equal = jnp.all(partial_sum == single_point, axis=(0, 2)).reshape( - 1, batch_dim, 1 - ) - return jnp.where(cond_equal, double_partial_sum, new_partial_sum) - - -def bucket_accumulation_index_scan_parallel_algorithm_twisted( - all_buckets: jnp.ndarray, - all_points: jnp.ndarray, - selection_index_list: jnp.ndarray, - msm_length: int, -): - """Scan version BA with index selection.""" - # buckets_dim is not used in the algorithm, changed it into _. - coordinate_dim, batch_window_dim, _, precision_dim = all_buckets.shape - _, _, parallel_dim, _ = ( - all_points.shape - ) # (serial_dim, coordinate_dim, parallel_dim, precision_dim) - single_window_dim = batch_window_dim // parallel_dim - - def scan_body(buckets, point_with_cond_pack): - point, selection_index = point_with_cond_pack - point = jax.lax.broadcast_in_dim( - point, - (coordinate_dim, parallel_dim, single_window_dim, precision_dim), - (0, 1, 3), - ) - point = point.reshape((coordinate_dim, batch_window_dim, precision_dim)) - selective_buckets = buckets[ - :, jnp.arange(batch_window_dim), selection_index, : - ] - selective_update = padd(selective_buckets, point) - return ( - buckets.at[:, jnp.arange(batch_window_dim), selection_index, :].set( - selective_update - ), - None, - ) - - all_buckets, _ = jax.lax.scan( - scan_body, - all_buckets, - (all_points, selection_index_list), - length=msm_length, - ) - return all_buckets - - -def bucket_reduction_scan_algorithm_twisted( - all_buckets: jnp.ndarray, - temp_sum: jnp.ndarray, - window_sum: jnp.ndarray, - bucket_num_in_window: int, -): - """Scan version BR.""" - all_buckets = jnp.flip(all_buckets.transpose(2, 0, 1, 3), axis=0) - - def scan_body(temp_and_window_sum_pack, buckets): - temp_sum, window_sum = temp_and_window_sum_pack - temp_sum = padd(temp_sum, buckets) - window_sum = padd_with_pdul_check(window_sum, temp_sum) - return (temp_sum, window_sum), None - - (_, window_sum), _ = jax.lax.scan( - scan_body, - (temp_sum, window_sum), - all_buckets[:bucket_num_in_window], - length=bucket_num_in_window, - ) - - return window_sum - - -def batch_window_summation_algorithm_twisted( - batch_window_sum: jnp.ndarray, - all_window_sum: jnp.ndarray, - point_parallel: int, -): - """Batch window summation algorithm for twisted curve.""" - # batch_window_dim is not used in the algorithm, changed it into _. - coordinate_dim, _, precision_dim = all_window_sum.shape - all_window_sum = all_window_sum.reshape( - (coordinate_dim, point_parallel, -1, precision_dim) - ).transpose(1, 0, 2, 3) - - def scan_body(batch_window_sum, single_window_sum): - batch_window_sum = padd(batch_window_sum, single_window_sum) - return batch_window_sum, None - - batch_window_sum, _ = jax.lax.scan( - scan_body, batch_window_sum, all_window_sum, length=point_parallel - ) - return batch_window_sum - - -def window_merge_scan_algorithm_twisted( - window_sum: jnp.ndarray, slice_length: int -): - """Scan version WM.""" - coordinate_dim, window_dim, precision_dim = window_sum.shape - window_sum = window_sum.transpose(1, 0, 2) - result = window_sum[window_dim - 1, :, :].reshape( - (coordinate_dim, 1, precision_dim) - ) - - def fori_loop_body(_, result): - result = jec.pdul_lazy_twisted_pack(result) - return result - - def scan_body(result, window_sum): - result = jax.lax.fori_loop(0, slice_length, fori_loop_body, result) - result = jec.padd_lazy_twisted_pack( - result, window_sum.reshape((coordinate_dim, 1, util.U16_EXT_CHUNK_NUM)) - ) - return result, None - - result, _ = jax.lax.scan( - scan_body, - result, - window_sum[: window_dim - 1, :, :], - reverse=True, - length=window_dim - 1, - ) - result = result.reshape((coordinate_dim, precision_dim)) - return result - - -class MSMPippengerTwisted: - """Pippenger algorithm for elliptic curves with twisted points. - - Attributes: - coordinate_num: The number of coordinates in the elliptic curve. - slice_length: The length of each slice in the elliptic curve. - point_parallel: The number of parallel points in the elliptic curve. - window_num: The number of windows in the elliptic curve. - batch_window_num: The number of batch windows in the elliptic curve. - bucket_num_per_window: The number of buckets in each window. - slice_mask: The mask for the slices in the elliptic curve. - blank_point: A JAX array of zeros, used to initialize the buckets. - all_buckets: A JAX array of all the buckets in the elliptic curve. - points: A list of JAX arrays, where each array represents an Orignal point - from the trace. - scalars: A list of integers, where each integer represents an Orignal scalar - from the trace. - all_points: A JAX array of all the points in the elliptic curve. from the - trace. - window_sum: A JAX array of the window sum. - br_temp_sum: A JAX array of the temp sum for bucket reduction. - batch_window_sum: A JAX array of the batch window sum. - selection_index_list: A JAX array of the selection index for the buckets. - msm_length: The length of the MSM trace. - result: The final elliptic curve. - lazy_mat: The lazy matrix used for padding and doubling. - """ - - def __init__(self, slice_length: int, point_parallel: int): - - self.coordinate_num = util.COORDINATE_NUM - - self.slice_length = slice_length - self.point_parallel = point_parallel - self.window_num = int(math.ceil(254 / self.slice_length)) # - self.batch_window_num = self.window_num * self.point_parallel - self.bucket_num_per_window = ( - 2**self.slice_length - 1 - ) # Note: here remove the bucket_0 - self.slice_mask = 2**self.slice_length - 1 - self.blank_point = util.int_list_to_array( - [0, 1, 1, 0], util.BASE, util.U16_EXT_CHUNK_NUM - ).reshape(self.coordinate_num, 1, util.U16_EXT_CHUNK_NUM) - - self.all_buckets = jnp.broadcast_to( - self.blank_point.reshape( - 1, self.coordinate_num, 1, util.U16_EXT_CHUNK_NUM - ).transpose(1, 0, 2, 3), - ( - self.coordinate_num, - self.window_num, - self.bucket_num_per_window, - util.U16_EXT_CHUNK_NUM, - ), - ) - - self.all_buckets = jnp.tile( - self.all_buckets, (1, self.point_parallel, 1, 1) - ) - - self.window_sum: jnp.ndarray - self.br_temp_sum: jnp.ndarray - self.batch_window_sum: jnp.ndarray - - self.msm_length = 0 - - self.selection_index_list: jnp.ndarray - self.all_points: jnp.ndarray - - self.scalars: List[int] = [] # Orignal scalar from the trace - # [Points, Points, ..., Points] - self.points: List[jnp.ndarray] = [] # Orignal points from the trace - self.lazy_mat = util.construct_lazy_matrix(util.MODULUS_377_INT) - - self.result = None - - def initialize(self, scalars, points): - """Initialize the Pippenger algorithm. - - Args: - scalars: A list of integers, where each integer represents an Orignal - scalar from the trace. - points: A list of JAX arrays, where each array represents an Orignal point - from the trace. - """ - # Initial internal selection from the scalar - self.scalars = scalars - self.msm_length = len(scalars) - - # Convert high-precision points into a vector of low-precision chunks - self.points = [ - util.int_list_to_array(coordinates, util.BASE, util.U16_EXT_CHUNK_NUM) - for coordinates in points - ] # pytype: disable=container-type-mismatch - - self.all_points = jnp.array(self.points) - _, coordinate_dim, precision_dim = self.all_points.shape - - # For BA - selection_index_pylist = self.construct_ba_selection() - # Note: it contains uint(-1) for the bucket_0. - # In BA, it may cause some undefined behavior when do bucket selection - # it is correct now, because when setting buckets after the computation, - # jax.numpy will ignore the index with uint(-1) out of index. - self.selection_index_list = jnp.array(selection_index_pylist).astype( - jnp.uint32 - ) - _, window_dim = self.selection_index_list.shape - - # Batch construction - self.all_points = self.all_points.reshape( - (-1, self.point_parallel, coordinate_dim, precision_dim) - ).transpose(0, 2, 1, 3) - self.selection_index_list = self.selection_index_list.reshape( - (-1, window_dim * self.point_parallel) - ) - self.br_temp_sum = jnp.broadcast_to( - self.blank_point, - ( - self.coordinate_num, - self.batch_window_num, - util.U16_EXT_CHUNK_NUM, - ), - ) - self.window_sum = jnp.broadcast_to( - self.blank_point, - ( - self.coordinate_num, - self.batch_window_num, - util.U16_EXT_CHUNK_NUM, - ), - ) - self.batch_window_sum = jnp.broadcast_to( - self.blank_point, - ( - self.coordinate_num, - self.window_num, - util.U16_EXT_CHUNK_NUM, - ), - ) - - def bucket_accumulation(self, bucket_accumulation_index_func): - """BA index selection version.""" - self.all_buckets = bucket_accumulation_index_func( - self.all_buckets, self.all_points, self.selection_index_list - ) - return self.all_buckets - - def bucket_reduction(self, bucket_reduction_func): - """Reduce the buckets to a single point for each window.""" - temp_sum = jnp.broadcast_to( - self.blank_point, - ( - self.coordinate_num, - self.batch_window_num, - util.U16_EXT_CHUNK_NUM, - ), - ) - window_sum = jnp.broadcast_to( - self.blank_point, - ( - self.coordinate_num, - self.batch_window_num, - util.U16_EXT_CHUNK_NUM, - ), - ) - self.window_sum = bucket_reduction_func( - self.all_buckets, temp_sum, window_sum - ) - return self.window_sum - - def batch_window_summation(self, batch_window_summation_func): - """Sum the batch windows to form the final window sum.""" - batch_window_sum = jnp.broadcast_to( - self.blank_point, - ( - self.coordinate_num, - self.window_num, - util.U16_EXT_CHUNK_NUM, - ), - ) - self.window_sum = batch_window_summation_func( - batch_window_sum, self.window_sum - ) - return self.window_sum - - def window_merge(self, window_merge_func): - """Merge the windows to form the final elliptic curve.""" - self.result = window_merge_func(self.window_sum) - return self.result - - def construct_ba_selection(self): - selection_index_list = [] # Used for index selection - for scalar in self.scalars: - # Compute the zero states for each scalar by time dependence - selection_index = [] - for w in range(self.window_num): - slice_index = ( - (scalar >> (w * self.slice_length)) & self.slice_mask - ) - 1 - selection_index.append(slice_index) - selection_index_list.append(deepcopy(selection_index)) - return selection_index_list - - -######################### -# Functions for Signed bucket + twisted curve -######################### -def padd_with_sign(partial_sum, single_point, sign): - neg_single_point = jec.pneg_lazy_twisted_pack(single_point) - _, batch_dim, _ = partial_sum.shape - cond_neg = jnp.equal(sign, 1).reshape(1, batch_dim, 1) - signed_point = jnp.where(cond_neg, neg_single_point, single_point) - result = jec.padd_lazy_twisted_pack(partial_sum, signed_point) - return result - - -def bucket_accumulation_signed_index_scan_parallel_algorithm_twisted( - all_buckets: jnp.ndarray, - all_points: jnp.ndarray, - selection_index_list: jnp.ndarray, - selection_sign_list: jnp.ndarray, - msm_length: int, -): - """Scan version BA with index selection.""" - coordinate_dim, batch_window_dim, _, precision_dim = all_buckets.shape - _, _, parallel_dim, _ = ( - all_points.shape - ) # (serial_dim, coordinate_dim, parallel_dim, precision_dim) - single_window_dim = batch_window_dim // parallel_dim - - def scan_body(buckets, point_with_cond_pack): - point, selection_index, selection_sign = point_with_cond_pack - point = jax.lax.broadcast_in_dim( - point, - (coordinate_dim, parallel_dim, single_window_dim, precision_dim), - (0, 1, 3), - ) - point = point.reshape((coordinate_dim, batch_window_dim, precision_dim)) - selective_buckets = buckets[ - :, jnp.arange(batch_window_dim), selection_index, : - ] - selective_update = padd_with_sign(selective_buckets, point, selection_sign) - return ( - buckets.at[:, jnp.arange(batch_window_dim), selection_index, :].set( - selective_update - ), - None, - ) - - all_buckets, _ = jax.lax.scan( - scan_body, - all_buckets, - (all_points, selection_index_list, selection_sign_list), - length=msm_length, - ) - return all_buckets - - -class MSMPippengerTwistedSigned: - """Pippenger algorithm for elliptic curves with twisted and signed points. - - Attributes: - coordinate_num: The number of coordinates in the elliptic curve. - slice_length: The length of each slice in the elliptic curve. - point_parallel: The number of parallel points in the elliptic curve. - window_num: The number of windows in the elliptic curve. - batch_window_num: The number of batch windows in the elliptic curve. - bucket_num_per_window: The number of buckets in each window. - slice_mask: The mask for the slices in the elliptic curve. - blank_point: A JAX array of zeros, used to initialize the buckets. - all_buckets: A JAX array of all the buckets in the elliptic curve. - points: A list of JAX arrays, where each array represents an Orignal point - from the trace. - scalars: A list of integers, where each integer represents an Orignal scalar - from the trace. - all_points: A JAX array of all the points in the elliptic curve. from the - trace. - window_sum: A JAX array of the window sum. - zero_states_list: A JAX array of the zero states for the buckets. - selection_list: A JAX array of the selection states for the buckets. - selection_index_list: A JAX array of the selection index for the buckets. - selection_sign_list: A JAX array of the selection sign for the buckets. - all_points: A JAX array of all the points in the elliptic curve. from the - trace. - msm_length: The length of the MSM trace. - result: The final elliptic curve. - lazy_mat: The lazy matrix used for padding and doubling. - """ - - def __init__(self, slice_length: int, point_parallel: int): - self.coordinate_num = util.COORDINATE_NUM - - self.slice_length = slice_length - self.point_parallel = point_parallel - self.window_num = int(math.ceil(254 / self.slice_length)) # - self.batch_window_num = self.window_num * self.point_parallel - self.bucket_num_per_window = 2 ** (self.slice_length - 1) - self.slice_mask = 2**self.slice_length - 1 - self.blank_point = util.int_list_to_array( - [0, 1, 1, 0], util.BASE, util.U16_EXT_CHUNK_NUM - ).reshape(self.coordinate_num, 1, util.U16_EXT_CHUNK_NUM) - - self.all_buckets = jnp.broadcast_to( - self.blank_point.reshape( - 1, self.coordinate_num, 1, util.U16_EXT_CHUNK_NUM - ).transpose(1, 0, 2, 3), - ( - self.coordinate_num, - self.window_num, - self.bucket_num_per_window, - util.U16_EXT_CHUNK_NUM, - ), - ) - - self.all_buckets = jnp.tile( - self.all_buckets, (1, self.point_parallel, 1, 1) - ) - - self.window_sum: jnp.ndarray - - self.msm_length = 0 - - self.zero_states_list: jnp.ndarray - self.selection_list: jnp.ndarray - self.selection_index_list: jnp.ndarray - self.selection_sign_list: jnp.ndarray - self.all_points: jnp.ndarray - - self.scalars: List[int] = [] # Orignal scalar from the trace - # [Points, Points, ..., Points] - self.points: List[jnp.ndarray] = [] # Orignal points from the trace - self.lazy_mat = util.construct_lazy_matrix(util.MODULUS_377_INT) - - self.result = None - - def initialize(self, scalars, points): - """Initialize the Pippenger algorithm. - - Args: - scalars: A list of integers, where each integer represents an Orignal - scalar from the trace. - points: A list of JAX arrays, where each array represents an Orignal point - from the trace. - """ - # Initial internal selection from the scalar - self.scalars = scalars - self.msm_length = len(scalars) - - # Convert high-precision points into a vector of low-precision chunks - self.points = [ - util.int_list_to_array(coordinates, util.BASE, util.U16_EXT_CHUNK_NUM) - for coordinates in points - ] - self.all_points = jnp.array(self.points) - _, coordinate_dim, precision_dim = self.all_points.shape - - # For BA - selection_index_pylist, selection_sign_pylist = ( - self.construct_ba_selection_with_sign() - ) - self.selection_index_list = jnp.asarray(selection_index_pylist).astype( - jnp.uint32 - ) - self.selection_sign_list = jnp.array(selection_sign_pylist, dtype=jnp.uint8) - _, window_dim = self.selection_index_list.shape - - # Batch construction - self.all_points = self.all_points.reshape( - (-1, self.point_parallel, coordinate_dim, precision_dim) - ).transpose(0, 2, 1, 3) - self.selection_index_list = self.selection_index_list.reshape( - (-1, window_dim * self.point_parallel) - ) - self.selection_sign_list = self.selection_sign_list.reshape( - (-1, window_dim * self.point_parallel) - ) - - def bucket_accumulation(self, bucket_accumulation_index_algorithm): - """BA index selection version.""" - self.all_buckets = bucket_accumulation_index_algorithm( - self.all_buckets, - self.all_points, - self.selection_index_list, - self.selection_sign_list, - ) - - return self.all_buckets - - def bucket_reduction(self, bucket_reduction_func): - """Reduce the buckets to a single point for each window.""" - temp_sum = jnp.broadcast_to( - self.blank_point, - ( - self.coordinate_num, - self.batch_window_num, - util.U16_EXT_CHUNK_NUM, - ), - ) - window_sum = jnp.broadcast_to( - self.blank_point, - ( - self.coordinate_num, - self.batch_window_num, - util.U16_EXT_CHUNK_NUM, - ), - ) - self.window_sum = bucket_reduction_func( - self.all_buckets, temp_sum, window_sum - ) - return self.window_sum - - def batch_window_summation(self, batch_window_summation_algorithm): - batch_window_sum = jnp.broadcast_to( - self.blank_point, - ( - self.coordinate_num, - self.window_num, - util.U16_EXT_CHUNK_NUM, - ), - ) - self.window_sum = batch_window_summation_algorithm( - batch_window_sum, self.window_sum - ) - return self.window_sum - - def window_merge(self, window_merge_func): - """Merge the windows to form the final elliptic curve.""" - self.result = window_merge_func(self.window_sum) - return self.result - - def construct_ba_selection_with_sign(self): - """Construct the selection index and sign for the bucket accumulation (BA) step. - - Returns: - A tuple of two lists: the selection index for the bucket accumulation, and - the selection sign for the bucket accumulation. - """ - selection_index_list = [] # Used for index selection - selection_sign_list = [] - slice_max = 2**self.slice_length - slice_half = 2 ** (self.slice_length - 1) - for scalar in self.scalars: - # Compute the zero states for each scalar by time dependence - selection_index = [] - selection_sign = [] - carry = 0 - for w in range(self.window_num): - slice_index = (scalar >> (w * self.slice_length)) & self.slice_mask - slice_index = slice_index + carry - if slice_index >= slice_half: - new_slice_index = abs(slice_index - slice_max) - carry = 1 - else: - new_slice_index = slice_index - carry = 0 - selection_index.append(new_slice_index - 1) - selection_sign.append(carry) - assert carry == 0 - selection_index_list.append(deepcopy(selection_index)) - selection_sign_list.append(deepcopy(selection_sign)) - return selection_index_list, selection_sign_list diff --git a/jaxite_ec/pippenger_rns.py b/jaxite_ec/pippenger_rns.py deleted file mode 100644 index fd27be3..0000000 --- a/jaxite_ec/pippenger_rns.py +++ /dev/null @@ -1,1057 +0,0 @@ -"""RNS Pippenger algorithm for elliptic curves. - -This module implements the Pippenger algorithm for elliptic curves. The -algorithm -is a generalization of the elliptic curve algorithm that can be used to find -elliptic curves of arbitrary order. - -The Pippenger algorithm works by first finding a set of "scalars" and "points" -that lie on the elliptic curve. These scalars and points are then used to -construct -a "bucket" for each window of the elliptic curve. The buckets are then reduced -to a single point for each window. Finally, the points from all of the windows -are merged together to form the final elliptic curve. - -The Pippenger algorithm is a powerful tool for finding elliptic curves, and it -has been used to find elliptic curves of arbitrary order and dimension. - -Example Compiled Function: - -bucket_accumulation_scan_jax_jit = jax.jit( - bucket_accumulation_scan_jax, - static_argnames='msm_length').lower( - jax.ShapeDtypeStruct( - (COORDINATE_NUM, WINDOW_NUM*BUCKET_NUM_PER_WINDOW, CHUNK_NUM), - dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (MSM_LENGTH, COORDINATE_NUM, CHUNK_NUM), - dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (MSM_LENGTH,WINDOW_NUM*BUCKET_NUM_PER_WINDOW), - dtype=jnp.uint8 - ), - jax.ShapeDtypeStruct( - (MSM_LENGTH,WINDOW_NUM*BUCKET_NUM_PER_WINDOW), - dtype=jnp.uint8 - ), - MSM_LENGTH - ).compile() - -bucket_accumulation_index_scan_jax_jit = jax.jit( - bucket_accumulation_index_scan_jax, - static_argnames='msm_length').lower( - jax.ShapeDtypeStruct( - (COORDINATE_NUM, WINDOW_NUM,BUCKET_NUM_PER_WINDOW, CHUNK_NUM), - dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (MSM_LENGTH, COORDINATE_NUM, CHUNK_NUM), - dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (MSM_LENGTH,WINDOW_NUM), - dtype=jnp.uint32 - ), - jax.ShapeDtypeStruct( - (MSM_LENGTH,WINDOW_NUM,BUCKET_NUM_PER_WINDOW), - dtype=jnp.uint8 - ), - MSM_LENGTH - ).compile() - -bucket_reduction_scan_jax_jit = jax.jit( - bucket_reduction_scan_jax, - static_argnames='bucket_num_in_window').lower( - jax.ShapeDtypeStruct( - (COORDINATE_NUM, WINDOW_NUM, BUCKET_NUM_PER_WINDOW, CHUNK_NUM), - dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (COORDINATE_NUM, WINDOW_NUM, CHUNK_NUM), - dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (COORDINATE_NUM, WINDOW_NUM, CHUNK_NUM), - dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (BUCKET_NUM_PER_WINDOW, WINDOW_NUM), - dtype=jnp.uint8 - ), - jax.ShapeDtypeStruct( - (BUCKET_NUM_PER_WINDOW, WINDOW_NUM), - dtype=jnp.uint8 - ), - jax.ShapeDtypeStruct( - (BUCKET_NUM_PER_WINDOW, WINDOW_NUM,), - dtype=jnp.uint8 - ), - BUCKET_NUM_PER_WINDOW - ).compile() - -window_merge_scan_jax_jit = jax.jit( - window_merge_scan_jax, - static_argnames='slice_length').lower( - jax.ShapeDtypeStruct((COORDINATE_NUM, WINDOW_NUM, CHUNK_NUM), - dtype=jnp.uint16), - SLICE_LENGTH - ).compile() - -selective_padd_with_zero_jit = jax.jit( - selective_padd_with_zero).lower( - jax.ShapeDtypeStruct( - (COORDINATE_NUM, WINDOW_NUM*BUCKET_NUM_PER_WINDOW, COORDINATE_NUM), - dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (COORDINATE_NUM, MSM_LENGTH, COORDINATE_NUM), - dtype=jnp.uint16 - ), - jax.ShapeDtypeStruct( - (WINDOW_NUM*BUCKET_NUM_PER_WINDOW,), - dtype=jnp.uint8 - ), - jax.ShapeDtypeStruct( - (WINDOW_NUM*BUCKET_NUM_PER_WINDOW,), - dtype=jnp.uint8 - ), - ).compile() -""" - -import copy -import math -from typing import List - -import jax -import jax.numpy as jnp -from jaxite.jaxite_ec import util -import jaxite.jaxite_ec.elliptic_curve as jec -import jaxite.jaxite_ec.finite_field as ff -import numpy as np - - -deepcopy = copy.deepcopy - -""" -Example Parameters: -SLICE_LENGTH = 4 -WINDOW_NUM = int(math.ceil(255 / SLICE_LENGTH)) -BUCKET_NUM_PER_WINDOW = 2**SLICE_LENGTH -MSM_LENGTH = 1024 -COORDINATE_NUM = 4 -""" - - -def selective_padd_with_zero(partial_sum, single_point, select, is_zero): - """Padd the partial sum with the single point, but only if the selection state is 1. - - Args: - partial_sum: The partial sum. - single_point: The single point. - select: The selection state. - is_zero: The zero states. - - Returns: - The new partial sum. - """ - _, batch_dim, _ = partial_sum.shape - new_partial_sum = jec.padd_rns_xyzz_pack(partial_sum, single_point) - - cond_select = jnp.equal(select, 1).reshape(1, batch_dim, 1) - sum_result = jnp.where(cond_select, new_partial_sum, partial_sum) - - cond_zero = jnp.equal(is_zero, 1).reshape(1, batch_dim, 1) - cond_select_and_zero = jnp.logical_and(cond_select, cond_zero) - result = jnp.where(cond_select_and_zero, single_point, sum_result) - return result - - -def padd_with_zero(partial_sum, single_point, ps_is_zero, sp_is_zero): - """Padd the partial sum with the single point. - - Check if the partial sum is equal to the single point first. - - Args: - partial_sum: The partial sum. - single_point: The single point. - ps_is_zero: The zero states of the partial sum. - sp_is_zero: The zero states of the single point. - - Returns: - The new partial sum. - """ - _, batch_dim, _ = partial_sum.shape - new_partial_sum = jec.padd_rns_xyzz_pack(partial_sum, single_point) - cond_sp_zero = jnp.equal(sp_is_zero, 1).reshape(1, batch_dim, 1) - cond_ps_zero = jnp.equal(ps_is_zero, 1).reshape(1, batch_dim, 1) - result_1 = jnp.where(cond_sp_zero, partial_sum, single_point) - result_2 = jnp.where( - jnp.logical_or(cond_sp_zero, cond_ps_zero), result_1, new_partial_sum - ) - return result_2 - - -def padd_with_zero_alter(partial_sum, single_point, ps_is_zero): - - _, batch_dim, _ = partial_sum.shape - new_partial_sum = jec.padd_rns_xyzz_pack(partial_sum, single_point) - cond_ps_zero = jnp.equal(ps_is_zero, 1).reshape(1, batch_dim, 1) - result_2 = jnp.where(cond_ps_zero, single_point, new_partial_sum) - return result_2 - - -def padd_with_zero_and_pdul_check( - partial_sum, single_point, ps_is_zero, sp_is_zero -): - """Padd the partial sum with the single point. - - Check if the partial sum is equal to the single point first. If they are - equal, then double the partial sum. - - Args: - partial_sum: The partial sum. - single_point: The single point. - ps_is_zero: The zero states of the partial sum. - sp_is_zero: The zero states of the single point. - - Returns: - The new partial sum. - """ - # coordinate_dim, batch_dim, precision_dim = partial_sum.shape - _, batch_dim, _ = partial_sum.shape - new_partial_sum = jec.padd_rns_xyzz_pack(partial_sum, single_point) - double_partial_sum = jec.pdul_rns_xyzz_pack(partial_sum) - cond_equal = jnp.all(partial_sum == single_point, axis=(0, 2)).reshape( - 1, batch_dim, 1 - ) - cond_sp_zero = jnp.equal(sp_is_zero, 1).reshape(1, batch_dim, 1) - cond_ps_zero = jnp.equal(ps_is_zero, 1).reshape(1, batch_dim, 1) - result_1 = jnp.where(cond_sp_zero, partial_sum, single_point) - result_2 = jnp.where( - jnp.logical_or(cond_sp_zero, cond_ps_zero), result_1, new_partial_sum - ) - reuslt_3 = jnp.where(cond_equal, double_partial_sum, result_2) - return reuslt_3 - - -def bucket_accumulation_algorithm( - all_buckets: jnp.ndarray, - all_points: jnp.ndarray, - selection_list: jnp.ndarray, - zero_states_list: jnp.ndarray, - msm_length: int, -): - """Non-scan version BA.""" - coordinate_dim, buckets_dim, precision_dim = all_buckets.shape - for i in range(msm_length): - point = jax.lax.broadcast_in_dim( - all_points[i], (coordinate_dim, buckets_dim, precision_dim), (0, 2) - ) - all_buckets = selective_padd_with_zero( - all_buckets, point, selection_list[i], zero_states_list[i] - ) - return all_buckets - - -def bucket_accumulation_scan_algorithm( - all_buckets: jnp.ndarray, - all_points: jnp.ndarray, - selection_list: jnp.ndarray, - zero_states_list: jnp.ndarray, - msm_length: int, -): - """Scan version BA.""" - coordinate_dim, buckets_dim, precision_dim = all_buckets.shape - - def scan_body(buckets, point_with_cond_pack): - point, selection, zero_states = point_with_cond_pack - point = jax.lax.broadcast_in_dim( - point, (coordinate_dim, buckets_dim, precision_dim), (0, 2) - ) - all_buckets = selective_padd_with_zero( - buckets, point, selection, zero_states - ) - return all_buckets, None - - all_buckets, _ = jax.lax.scan( - scan_body, - all_buckets, - (all_points, selection_list, zero_states_list), - length=msm_length, - ) - return all_buckets - - -def bucket_accumulation_index_algorithm( - all_buckets: jnp.ndarray, - all_points: jnp.ndarray, - selection_index_list: jnp.ndarray, - zero_states_list: jnp.ndarray, - msm_length: int, -): - """Non-scan version BA with index selection.""" - # buckets_dim is not used in the algorithm, changed it into _. - coordinate_dim, window_dim, _, precision_dim = all_buckets.shape - for i in range(msm_length): - point = jax.lax.broadcast_in_dim( - all_points[i], (coordinate_dim, window_dim, precision_dim), (0, 2) - ) - selective_buckets = all_buckets[ - :, jnp.arange(window_dim), selection_index_list[i], : - ] - selective_zero_states = zero_states_list[ - i, jnp.arange(window_dim), selection_index_list[i] - ] - selective_update = padd_with_zero_alter( - selective_buckets, point, selective_zero_states - ) - all_buckets = all_buckets.at[ - :, jnp.arange(window_dim), selection_index_list[i], : - ].set(selective_update) - return all_buckets - - -def bucket_accumulation_index_scan_algorithm( - all_buckets: jnp.ndarray, - all_points: jnp.ndarray, - selection_index_list: jnp.ndarray, - zero_states_list: jnp.ndarray, - msm_length: int, -): - """Scan version BA with index selection.""" - # buckets_dim is not used in the algorithm, changed it into _. - coordinate_dim, window_dim, _, precision_dim = all_buckets.shape - - def scan_body(buckets, point_with_cond_pack): - point, selection_index, zero_states = point_with_cond_pack - point = jax.lax.broadcast_in_dim( - point, (coordinate_dim, window_dim, precision_dim), (0, 2) - ) - selective_buckets = buckets[:, jnp.arange(window_dim), selection_index, :] - selective_zero_states = zero_states[jnp.arange(window_dim), selection_index] - selective_update = padd_with_zero_alter( - selective_buckets, point, selective_zero_states - ) - return ( - buckets.at[:, jnp.arange(window_dim), selection_index, :].set( - selective_update - ), - None, - ) - - all_buckets, _ = jax.lax.scan( - scan_body, - all_buckets, - (all_points, selection_index_list, zero_states_list), - length=msm_length, - ) - return all_buckets - - -def bucket_reduction_algorithm( - all_buckets: jnp.ndarray, - temp_sum: jnp.ndarray, - window_sum: jnp.ndarray, - bucket_zero_states_list: jnp.ndarray, - temp_zero_states_list: jnp.ndarray, - window_zero_states_list: jnp.ndarray, - bucket_num_in_window: int, -): - """Non-scan version BR.""" - all_buckets = jnp.flip(all_buckets.transpose(2, 0, 1, 3), axis=0) - for i in range(bucket_num_in_window - 1): - temp_sum = padd_with_zero( - temp_sum, - all_buckets[i], - temp_zero_states_list[i], - bucket_zero_states_list[i], - ) - window_sum = padd_with_zero_and_pdul_check( - window_sum, - temp_sum, - window_zero_states_list[i], - temp_zero_states_list[i + 1], - ) - return window_sum - - -def bucket_reduction_scan_algorithm( - all_buckets: jnp.ndarray, - temp_sum: jnp.ndarray, - window_sum: jnp.ndarray, - bucket_zero_states_list: jnp.ndarray, - temp_zero_states_list: jnp.ndarray, - window_zero_states_list: jnp.ndarray, - bucket_num_in_window: int, -): - """Scan version BR.""" - all_buckets = jnp.flip(all_buckets.transpose(2, 0, 1, 3), axis=0) - - def scan_body(temp_and_window_sum_pack, bucket_with_cond_pack): - temp_sum, window_sum = temp_and_window_sum_pack - ( - bucket, - bucket_zero_states, - temp_zero_states, - temp_zero_states1, - window_zero_states, - ) = bucket_with_cond_pack - temp_sum = padd_with_zero( - temp_sum, bucket, temp_zero_states, bucket_zero_states - ) - window_sum = padd_with_zero_and_pdul_check( - window_sum, temp_sum, window_zero_states, temp_zero_states1 - ) - return (temp_sum, window_sum), None - - (_, window_sum), _ = jax.lax.scan( - scan_body, - (temp_sum, window_sum), - ( - all_buckets[: bucket_num_in_window - 1], - bucket_zero_states_list[: bucket_num_in_window - 1], - temp_zero_states_list[: bucket_num_in_window - 1], - temp_zero_states_list[1:], - window_zero_states_list[: bucket_num_in_window - 1], - ), - length=bucket_num_in_window - 1, - ) - - return window_sum - - -def window_merge_algorithm(window_sum: jnp.ndarray, slice_length: int): - """Non-scan version WM.""" - coordinate_dim, window_dim, precision_dim = window_sum.shape - result = window_sum[:, window_dim - 1, :].reshape( - (coordinate_dim, 1, precision_dim) - ) - for w in range(window_dim - 2, -1, -1): - for _ in range(slice_length): - result = jec.pdul_rns_xyzz_pack(result) - result = jec.padd_rns_xyzz_pack( - result, - window_sum[:, w, :].reshape((coordinate_dim, 1, util.NUM_MODULI)), - ) - - result = result.reshape((coordinate_dim, precision_dim)) - return result - - -def window_merge_scan_algorithm(window_sum: jnp.ndarray, slice_length: int): - """Scan version WM.""" - coordinate_dim, window_dim, precision_dim = window_sum.shape - window_sum = window_sum.transpose(1, 0, 2) - result = window_sum[window_dim - 1, :, :].reshape( - (coordinate_dim, 1, precision_dim) - ) - - def fori_loop_body(_, result): - result = jec.pdul_rns_xyzz_pack(result) - return result - - def scan_body(result, window_sum): - result = jax.lax.fori_loop(0, slice_length, fori_loop_body, result) - result = jec.padd_rns_xyzz_pack( - result, window_sum.reshape((coordinate_dim, 1, util.NUM_MODULI)) - ) - return result, None - - result, _ = jax.lax.scan( - scan_body, - result, - window_sum[: window_dim - 1, :, :], - reverse=True, - length=window_dim - 1, - ) - result = result.reshape((coordinate_dim, precision_dim)) - return result - - -class MSMPippenger: - """Pippenger algorithm for elliptic curves. - - Attributes: - coordinate_num: The number of coordinates in the elliptic curve. - slice_length: The length of each slice in the elliptic curve. - window_num: The number of windows in the elliptic curve. - bucket_num_per_window: The number of buckets in each window. - bucket_num_in_window: The number of buckets in each window. - slice_mask: The mask for the slices in the elliptic curve. - blank_point: A JAX array of zeros, used to initialize the buckets. - all_buckets: A JAX array of all the buckets in the elliptic curve. - points: A list of JAX arrays, where each array represents an Orignal point - from the trace. - scalars: A list of integers, where each integer represents an Orignal scalar - from the trace. - all_points: A JAX array of all the points in the elliptic curve. from the - trace. - window_sum: A JAX array of the window sum. - bucket_zero_states: A JAX array of the zero states for the buckets. - temp_sum_zero_states: A JAX array of the zero states for the temp sum. - window_sum_zero_states: A JAX array of the zero states for the window sum. - zero_states_list: A JAX array of the zero states for the buckets. - selection_list: A JAX array of the selection states for the buckets. - selection_index_list: A JAX array of the selection index for the buckets. - msm_length: The length of the MSM trace. - result: The final elliptic curve. - rns_mat: The lazy matrix used for padding and doubling. - """ - - def __init__(self, slice_length): - self.coordinate_num = util.COORDINATE_NUM - - self.slice_length = slice_length - self.window_num = int(math.ceil(254 / self.slice_length)) # - self.bucket_num_per_window = 2**self.slice_length - self.slice_mask = self.bucket_num_per_window - 1 - self.blank_point = util.int_list_to_array( - [0, 0, 0, 0], util.BASE, util.NUM_MODULI - ).reshape(self.coordinate_num, 1, util.NUM_MODULI) - - self.all_buckets = jnp.broadcast_to( - self.blank_point.reshape( - 1, self.coordinate_num, 1, util.NUM_MODULI - ).transpose(1, 0, 2, 3), - ( - self.coordinate_num, - self.window_num, - self.bucket_num_per_window, - util.NUM_MODULI, - ), - ) - - self.window_sum: jnp.ndarray - - self.msm_length = 0 - self.bucket_zero_states: jnp.ndarray - self.temp_sum_zero_states: jnp.ndarray - self.window_sum_zero_states: jnp.ndarray - self.zero_states_list: jnp.ndarray - self.selection_list: jnp.ndarray - self.selection_index_list: jnp.ndarray - self.all_points: jnp.ndarray - - self.scalars: List[int] = [] # Orignal scalar from the trace - # [Points, Points, ..., Points] - self.points: List[jnp.ndarray] = [] # Orignal points from the trace - self.rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT) - - self.result = None - - def initialize(self, scalars, points): - """Initialize the Pippenger algorithm. - - Args: - scalars: A list of integers, where each integer represents an Orignal - scalar from the trace. - points: A list of JAX arrays, where each array represents an Orignal point - from the trace. - """ - # Initial internal selection from the scalar - self.scalars = scalars - self.msm_length = len(scalars) - - # Convert high-precision points into a vector of low-precision chunks - self.points = [ - util.int_list_to_array_rns(coordinates + [1, 1]) - for coordinates in points - ] # pytype: disable=container-type-mismatch - - self.all_points = jnp.array(self.points).astype(jnp.uint16) - - # For BA - zero_states_pylist, selection_pylist, selection_index_pylist = ( - self.construct_ba_zero_states_and_selection() - ) - self.selection_list = jnp.array(selection_pylist, dtype=jnp.uint8).reshape( - (-1, self.window_num * self.bucket_num_per_window) - ) - # For index selection version BA - self.zero_states_list = jnp.array(zero_states_pylist, dtype=jnp.uint8) - self.selection_index_list = jnp.array( - selection_index_pylist, dtype=jnp.uint32 - ) - - # For BR - ( - bucket_zero_states_py, - temp_sum_zero_states_py, - window_sum_zero_states_py, - ) = self.construct_br_zero_states( - zero_states_pylist[len(zero_states_pylist) - 1] - ) - self.bucket_zero_states = jnp.array(bucket_zero_states_py, dtype=jnp.uint8) - self.temp_sum_zero_states = jnp.array( - temp_sum_zero_states_py, dtype=jnp.uint8 - ) - self.window_sum_zero_states = jnp.array( - window_sum_zero_states_py, dtype=jnp.uint8 - ) - - def bucket_accumulation(self, bucket_accumulation_index_func): - """BA index selection version.""" - self.all_buckets = bucket_accumulation_index_func( - self.all_buckets, - self.all_points, - self.selection_index_list, - self.zero_states_list[: self.msm_length], - ) - - return self.all_buckets - - def bucket_reduction(self, bucket_reduction_func): - """Reduce the buckets to a single point for each window.""" - temp_sum = jnp.broadcast_to( - self.blank_point, - ( - self.coordinate_num, - self.window_num, - util.NUM_MODULI, - ), - ) - window_sum = jnp.broadcast_to( - self.blank_point, - ( - self.coordinate_num, - self.window_num, - util.NUM_MODULI, - ), - ) - self.window_sum = bucket_reduction_func( - self.all_buckets, - temp_sum, - window_sum, - self.bucket_zero_states[: self.bucket_num_per_window], - self.temp_sum_zero_states[: self.bucket_num_per_window], - self.window_sum_zero_states[: self.bucket_num_per_window], - ) - return self.window_sum - - def window_merge(self, window_merge_func): - """Merge the windows to form the final elliptic curve. - - Args: - window_merge_func: The function to merge the windows. - - Returns: - The final elliptic curve. - """ - self.result = window_merge_func(self.window_sum) - return self.result - - def construct_ba_zero_states_and_selection(self): - """Construct the zero states and selection for the bucket accumulation (BA) step. - - Returns: - A tuple of two lists: the zero states for the bucket accumulation, and the - selection for the bucket accumulation. - """ - zero_states = [ - deepcopy([1] * self.bucket_num_per_window) - for _ in range(self.window_num) - ] - zero_states_list = [] - zero_states_list.append(deepcopy(zero_states)) - selection_list = [] - selection_index_list = [] # Used for index selection - for scalar in self.scalars: - # Compute the zero states for each scalar by time dependence - selection = [ - deepcopy(([0] * self.bucket_num_per_window)) - for _ in range(self.window_num) - ] - selection_index = [] - for w in range(self.window_num): - slice_index = (scalar >> (w * self.slice_length)) & self.slice_mask - zero_states[w][slice_index] = 0 - selection[w][slice_index] = 1 - selection_index.append(slice_index) - - selection_list.append(deepcopy(selection)) - zero_states_list.append(deepcopy(zero_states)) - selection_index_list.append(deepcopy(selection_index)) - return zero_states_list, selection_list, selection_index_list - - def construct_br_zero_states(self, bucket_zero_states): - """Construct the zero states for the bucket reduction (BR) step. - - Args: - bucket_zero_states: The zero states of the buckets. - - Returns: - A tuple of three lists: the zero states for the bucket reduction, the zero - states for the temporary sum, and the zero states for the window sum. - """ - temp_sum_zero_states = np.array([1] * self.window_num) - window_sum_zero_states = np.array([1] * self.window_num) - temp_sum_zero_states_list = [] - window_sum_zero_states_list = [] - temp_sum_zero_states_list.append(temp_sum_zero_states) - window_sum_zero_states_list.append(window_sum_zero_states) - bucket_zero_states_list = np.flip( - np.array(bucket_zero_states).transpose(1, 0), axis=0 - ) - for b in range(self.bucket_num_per_window): - next_temp_sum_zero_states = ( - temp_sum_zero_states_list[b] & bucket_zero_states_list[b] - ) - next_window_sum_zero_states = ( - window_sum_zero_states_list[b] & next_temp_sum_zero_states - ) - temp_sum_zero_states_list.append(next_temp_sum_zero_states) - window_sum_zero_states_list.append(next_window_sum_zero_states) - return ( - bucket_zero_states_list, - temp_sum_zero_states_list, - window_sum_zero_states_list, - ) - - -######################### -# Functions for twisted curve -######################### - - -def padd(partial_sum, single_point): - return jec.padd_rns_twisted_pack(partial_sum, single_point) - - -def padd_with_pdul_check(partial_sum, single_point): - # coordinate_dim, batch_dim, precision_dim = partial_sum.shape - _, batch_dim, _ = partial_sum.shape - new_partial_sum = jec.padd_rns_twisted_pack(partial_sum, single_point) - double_partial_sum = jec.pdul_rns_twisted_pack(partial_sum) - cond_equal = jnp.all(partial_sum == single_point, axis=(0, 2)).reshape( - 1, batch_dim, 1 - ) - return jnp.where(cond_equal, double_partial_sum, new_partial_sum) - - -def bucket_accumulation_index_scan_parallel_algorithm_twisted( - all_buckets: jnp.ndarray, - all_points: jnp.ndarray, - selection_index_list: jnp.ndarray, - msm_length: int, -): - """Scan version BA with index selection.""" - # buckets_dim is not used in the algorithm, changed it into _. - coordinate_dim, batch_window_dim, _, precision_dim = all_buckets.shape - _, _, parallel_dim, _ = ( - all_points.shape - ) # (serial_dim, coordinate_dim, parallel_dim, precision_dim) - single_window_dim = batch_window_dim // parallel_dim - - def scan_body(buckets, point_with_cond_pack): - point, selection_index = point_with_cond_pack - point = jax.lax.broadcast_in_dim( - point, - (coordinate_dim, parallel_dim, single_window_dim, precision_dim), - (0, 1, 3), - ) - point = point.reshape((coordinate_dim, batch_window_dim, precision_dim)) - selective_buckets = buckets[ - :, jnp.arange(batch_window_dim), selection_index, : - ] - selective_update = padd(selective_buckets, point) - return ( - buckets.at[:, jnp.arange(batch_window_dim), selection_index, :].set( - selective_update - ), - None, - ) - - all_buckets, _ = jax.lax.scan( - scan_body, - all_buckets, - (all_points, selection_index_list), - length=msm_length, - ) - return all_buckets - - -def bucket_reduction_scan_algorithm_twisted( - all_buckets: jnp.ndarray, - temp_sum: jnp.ndarray, - window_sum: jnp.ndarray, - bucket_num_in_window: int, -): - """Scan version BR.""" - all_buckets = jnp.flip(all_buckets.transpose(2, 0, 1, 3), axis=0) - - def scan_body(temp_and_window_sum_pack, buckets): - temp_sum, window_sum = temp_and_window_sum_pack - temp_sum = padd(temp_sum, buckets) - window_sum = padd_with_pdul_check(window_sum, temp_sum) - return (temp_sum, window_sum), None - - (_, window_sum), _ = jax.lax.scan( - scan_body, - (temp_sum, window_sum), - all_buckets[:bucket_num_in_window], - length=bucket_num_in_window, - ) - - return window_sum - - -def batch_window_summation_algorithm_twisted( - batch_window_sum: jnp.ndarray, - all_window_sum: jnp.ndarray, - point_parallel: int, -): - """Batch window summation algorithm for twisted curve.""" - # batch_window_dim is not used in the algorithm, changed it into _. - coordinate_dim, _, precision_dim = all_window_sum.shape - all_window_sum = all_window_sum.reshape( - (coordinate_dim, point_parallel, -1, precision_dim) - ).transpose(1, 0, 2, 3) - - def scan_body(batch_window_sum, single_window_sum): - batch_window_sum = padd(batch_window_sum, single_window_sum) - return batch_window_sum, None - - batch_window_sum, _ = jax.lax.scan( - scan_body, batch_window_sum, all_window_sum, length=point_parallel - ) - return batch_window_sum - - -def window_merge_scan_algorithm_twisted( - window_sum: jnp.ndarray, slice_length: int -): - """Scan version WM.""" - coordinate_dim, window_dim, precision_dim = window_sum.shape - window_sum = window_sum.transpose(1, 0, 2) - result = window_sum[window_dim - 1, :, :].reshape( - (coordinate_dim, 1, precision_dim) - ) - - def fori_loop_body(_, result): - result = jec.pdul_rns_twisted_pack(result) - return result - - def scan_body(result, window_sum): - result = jax.lax.fori_loop(0, slice_length, fori_loop_body, result) - result = jec.padd_rns_twisted_pack( - result, window_sum.reshape((coordinate_dim, 1, util.NUM_MODULI)) - ) - return result, None - - result, _ = jax.lax.scan( - scan_body, - result, - window_sum[: window_dim - 1, :, :], - reverse=True, - length=window_dim - 1, - ) - result = result.reshape((coordinate_dim, precision_dim)) - return result - - -class MSMPippengerTwisted: - """Pippenger algorithm for elliptic curves with twisted points. - - Attributes: - coordinate_num: The number of coordinates in the elliptic curve. - slice_length: The length of each slice in the elliptic curve. - point_parallel: The number of parallel points in the elliptic curve. - window_num: The number of windows in the elliptic curve. - batch_window_num: The number of batch windows in the elliptic curve. - bucket_num_per_window: The number of buckets in each window. - slice_mask: The mask for the slices in the elliptic curve. - blank_point: A JAX array of zeros, used to initialize the buckets. - all_buckets: A JAX array of all the buckets in the elliptic curve. - points: A list of JAX arrays, where each array represents an Orignal point - from the trace. - scalars: A list of integers, where each integer represents an Orignal scalar - from the trace. - all_points: A JAX array of all the points in the elliptic curve. from the - trace. - window_sum: A JAX array of the window sum. - br_temp_sum: A JAX array of the temp sum for bucket reduction. - batch_window_sum: A JAX array of the batch window sum. - selection_index_list: A JAX array of the selection index for the buckets. - msm_length: The length of the MSM trace. - result: The final elliptic curve. - rns_mat: The lazy matrix used for padding and doubling. - """ - - def __init__(self, slice_length: int, point_parallel: int): - - self.coordinate_num = util.COORDINATE_NUM - - self.slice_length = slice_length - self.point_parallel = point_parallel - self.window_num = int(math.ceil(254 / self.slice_length)) # - self.batch_window_num = self.window_num * self.point_parallel - self.bucket_num_per_window = ( - 2**self.slice_length - 1 - ) # Note: here remove the bucket_0 - self.slice_mask = 2**self.slice_length - 1 - self.blank_point = ( - util.int_list_to_array_rns([0, 1, 1, 0]) - .reshape(self.coordinate_num, 1, util.NUM_MODULI) - .astype(jnp.uint16) - ) - - self.all_buckets = jnp.broadcast_to( - self.blank_point.reshape( - 1, self.coordinate_num, 1, util.NUM_MODULI - ).transpose(1, 0, 2, 3), - ( - self.coordinate_num, - self.window_num, - self.bucket_num_per_window, - util.NUM_MODULI, - ), - ) - - self.all_buckets = jnp.tile( - self.all_buckets, (1, self.point_parallel, 1, 1) - ) - - self.window_sum: jnp.ndarray - self.br_temp_sum: jnp.ndarray - self.batch_window_sum: jnp.ndarray - - self.msm_length = 0 - - self.selection_index_list: jnp.ndarray - self.all_points: jnp.ndarray - - self.scalars: List[int] = [] # Orignal scalar from the trace - # [Points, Points, ..., Points] - self.points: List[jnp.ndarray] = [] # Orignal points from the trace - self.rns_mat = util.construct_rns_matrix(util.MODULUS_377_INT) - - self.result = None - - def initialize(self, scalars, points): - """Initialize the Pippenger algorithm. - - Args: - scalars: A list of integers, where each integer represents an Orignal - scalar from the trace. - points: A list of JAX arrays, where each array represents an Orignal point - from the trace. - """ - # Initial internal selection from the scalar - self.scalars = scalars - self.msm_length = len(scalars) - - # Convert high-precision points into a vector of low-precision chunks - self.points = [ - util.int_list_to_array_rns(coordinates) for coordinates in points - ] # pytype: disable=container-type-mismatch - - self.all_points = jnp.array(self.points).astype(jnp.uint16) - _, coordinate_dim, precision_dim = self.all_points.shape - - # For BA - selection_index_pylist = self.construct_ba_selection() - # Note: it contains uint(-1) for the bucket_0. - # In BA, it may cause some undefined behavior when do bucket selection - # it is correct now, because when setting buckets after the computation, - # jax.numpy will ignore the index with uint(-1) out of index. - self.selection_index_list = jnp.array(selection_index_pylist).astype( - jnp.uint32 - ) - _, window_dim = self.selection_index_list.shape - - # Batch construction - self.all_points = self.all_points.reshape( - (-1, self.point_parallel, coordinate_dim, precision_dim) - ).transpose(0, 2, 1, 3) - self.selection_index_list = self.selection_index_list.reshape( - (-1, window_dim * self.point_parallel) - ) - self.br_temp_sum = jnp.broadcast_to( - self.blank_point, - ( - self.coordinate_num, - self.batch_window_num, - util.NUM_MODULI, - ), - ) - self.window_sum = jnp.broadcast_to( - self.blank_point, - ( - self.coordinate_num, - self.batch_window_num, - util.NUM_MODULI, - ), - ) - self.batch_window_sum = jnp.broadcast_to( - self.blank_point, - ( - self.coordinate_num, - self.window_num, - util.NUM_MODULI, - ), - ) - - def bucket_accumulation(self, bucket_accumulation_index_func): - """BA index selection version.""" - self.all_buckets = bucket_accumulation_index_func( - self.all_buckets, self.all_points, self.selection_index_list - ) - return self.all_buckets - - def bucket_reduction(self, bucket_reduction_func): - """Reduce the buckets to a single point for each window.""" - temp_sum = jnp.broadcast_to( - self.blank_point, - ( - self.coordinate_num, - self.batch_window_num, - util.NUM_MODULI, - ), - ) - window_sum = jnp.broadcast_to( - self.blank_point, - ( - self.coordinate_num, - self.batch_window_num, - util.NUM_MODULI, - ), - ) - self.window_sum = bucket_reduction_func( - self.all_buckets, temp_sum, window_sum - ) - return self.window_sum - - def batch_window_summation(self, batch_window_summation_func): - """Sum the batch windows to form the final window sum.""" - batch_window_sum = jnp.broadcast_to( - self.blank_point, - ( - self.coordinate_num, - self.window_num, - util.NUM_MODULI, - ), - ) - self.window_sum = batch_window_summation_func( - batch_window_sum, self.window_sum - ) - return self.window_sum - - def window_merge(self, window_merge_func): - """Merge the windows to form the final elliptic curve.""" - self.result = window_merge_func(self.window_sum) - return self.result - - def construct_ba_selection(self): - selection_index_list = [] # Used for index selection - for scalar in self.scalars: - # Compute the zero states for each scalar by time dependence - selection_index = [] - for w in range(self.window_num): - slice_index = ( - (scalar >> (w * self.slice_length)) & self.slice_mask - ) - 1 - selection_index.append(slice_index) - selection_index_list.append(deepcopy(selection_index)) - return selection_index_list diff --git a/jaxite_ec/profiler.py b/jaxite_ec/profiler.py new file mode 100644 index 0000000..85713d7 --- /dev/null +++ b/jaxite_ec/profiler.py @@ -0,0 +1,1067 @@ +import csv +import gzip +import json +import os +import statistics +from typing import Any, Callable, Dict, List, Optional, Tuple, cast +import warnings +import jax +import jax.numpy as jnp +import pandas as pd + + +class DataFrameGenerator: + """A utility class for building pandas DataFrames from column data.""" + + def __init__(self): + """Initialize an empty DataFrameGenerator.""" + self.data: Dict[str, List[Any]] = {} + + def add_data(self, column_name: str, values: List[Any]) -> None: + """Add data to a specific column. + + Args: + column_name: Name of the column to add data to + values: List of values to add to the column + """ + if not isinstance(column_name, str): + raise ValueError("column_name must be a string") + if not isinstance(values, list): + raise ValueError("values must be a list") + + if column_name not in self.data: + self.data[column_name] = [] + self.data[column_name].extend(values) + + def add_single_value(self, column_name: str, value: Any) -> None: + """Add a single value to a specific column. + + Args: + column_name: Name of the column to add data to + value: Single value to add to the column + """ + self.add_data(column_name, [value]) + + def get_column_lengths(self) -> Dict[str, int]: + """Get the length of each column. + + Returns: + Dictionary mapping column names to their lengths + """ + return {col: len(values) for col, values in self.data.items()} + + def is_balanced(self) -> bool: + """Check if all columns have the same length. + + Returns: + True if all columns have the same length, False otherwise + """ + if not self.data: + return True + lengths = set(len(col) for col in self.data.values()) + return len(lengths) == 1 + + def to_dataframe(self, auto_balance: bool = True) -> pd.DataFrame: + """Convert the stored data to a pandas DataFrame. + + Args: + auto_balance: If True, automatically trim columns to the minimum length. + If False, raise an error if columns have different lengths. + + Returns: + pandas DataFrame with the stored data + + Raises: + ValueError: If auto_balance is False and columns have different lengths + """ + if not self.data: + return pd.DataFrame() + + if not auto_balance and not self.is_balanced(): + lengths = self.get_column_lengths() + raise ValueError(f"Columns have different lengths: {lengths}") + + # Find the minimum length among all columns + min_len = min(len(col) for col in self.data.values()) + + # Trim each column to the minimum length + trimmed_data = {k: v[:min_len] for k, v in self.data.items()} + + return pd.DataFrame(trimmed_data) + + def clear(self) -> None: + """Clear all stored data.""" + self.data.clear() + + def get_column_names(self) -> List[str]: + """Get the names of all columns. + + Returns: + List of column names + """ + return list(self.data.keys()) + + def has_column(self, column_name: str) -> bool: + """Check if a column exists. + + Args: + column_name: Name of the column to check + + Returns: + True if the column exists, False otherwise + """ + return column_name in self.data + + def merge(self, other_dataframe_generator: "DataFrameGenerator"): + """Merge the stored data with another DataFrameGenerator. + + Args: + other_dataframe_generator: Another DataFrameGenerator to merge with + + Returns: + Merged DataFrameGenerator + """ + if not isinstance(other_dataframe_generator, DataFrameGenerator): + raise ValueError("other_dataframe_generator must be a DataFrameGenerator") + # Check if this DataFrameGenerator is empty + if not self.data: + self.data = other_dataframe_generator.data + return + # Check if the other DataFrameGenerator has the same column names + if not set(self.get_column_names()) == set( + other_dataframe_generator.get_column_names() + ): + print("The two DataFrameGenerators have different column names") + return + # raise ValueError("The two DataFrameGenerators have different column names") + # Merge the data + for column_name in other_dataframe_generator.get_column_names(): + self.add_data(column_name, other_dataframe_generator.data[column_name]) + + def get_header(self) -> List[str]: + """Get the header of the DataFrameGenerator. + + Returns: + List of column names + """ + return list(self.data.keys()) + + def get_row_dict(self, index: int) -> Dict[str, Any]: + """Get a row of the DataFrameGenerator. + + Returns: + Dictionary of column names and values + """ + return { + column_name: self.data[column_name][index] + for column_name in self.get_column_names() + } + + +class TraceParser: + + def __init__(self, trace_dir: str): + self.trace_dir = trace_dir + + def set_trace_dir(self, new_dir: str): + """Set a new trace directory for the parser.""" + self.trace_dir = new_dir + + def find_trace_file(self): + """Recursively search for the latest .trace.json.gz file in the trace_dir. + + Returns the full path to the file, or None if not found. + """ + trace_files = [] + for root, _, files in os.walk(self.trace_dir): + for file in files: + if file.endswith(".trace.json.gz"): + trace_files.append(os.path.join(root, file)) + + if not trace_files: + return None + + # Return the most recently modified file + return max(trace_files, key=os.path.getmtime) + + def read_trace_json(self): + """Finds, unzips, and reads the JSON content from the trace file. + + Returns the loaded JSON object, or None if not found or error. + """ + trace_file = self.find_trace_file() + if trace_file is None: + print("No trace file found.") + return None + try: + with gzip.open(trace_file, "rt", encoding="utf-8") as f: + data = json.load(f) + return data + except Exception as e: + print(f"Error reading trace file: {e}") + return None + + def parse_trace_csv(self): + """Parses the trace CSV file and returns a list of trace events.""" + csv_file = os.path.join(self.trace_dir, "trace_events.csv") + + # Read the trace JSON data + trace_data = self.read_trace_json() + if trace_data is None: + print("Failed to read trace data") + return None + + # Extract trace events + trace_events = trace_data.get("traceEvents", []) + if not trace_events: + print("No trace events found in the data") + return None + + headers = ["pid", "tid", "ts", "dur", "ph", "name", "args"] + # Write to CSV directly + with open(csv_file, "w", newline="", encoding="utf-8") as f: + writer = csv.DictWriter(f, fieldnames=headers) + writer.writeheader() + for event in trace_events: + # Convert args dictionary to string if it exists + if "args" in event: + event["args"] = json.dumps(event["args"]) + else: + event["args"] = "" + + # Write the event + writer.writerow(event) + print(f"Trace events written to: {csv_file}") + + +def calculate_statistics(data: List[Any]) -> Dict[str, Any]: + """Calculate the statistics of the data. + + Args: + data: List of data + + Returns: + Dictionary containing the statistics + """ + mean_value = statistics.mean(data) + if len(data) == 1: + std_value = 0 + else: + std_value = statistics.stdev(data) + min_value = min(data) + max_value = max(data) + median_value = statistics.median(data) + return { + "mean": mean_value, + "std": std_value, + "min": min_value, + "max": max_value, + "median": median_value, + } + + +def analyze_trace_bottlenecks( + trace_folder: str, + top_k: int = 10, + device_pid: int = 3, +) -> List[Dict[str, Any]]: + """Analyze a single-iteration TPU trace and return the top-K bottleneck ops. + + Each returned entry contains the op name, duration, source file/line, + abbreviated source stack, HLO category, estimated FLOPs, and bytes + accessed — everything needed to identify where kernel time is spent + and which Python line produced the HLO op. + + Args: + trace_folder: Path to an iteration's trace folder (contains + ``plugins/profile//*.trace.json.gz``). + top_k: Number of top ops to return, sorted by duration descending. + device_pid: The ``pid`` of the target TPU device in the trace (default 3 = + ``/device:TPU:0``). + + Returns: + A list of dicts, each with keys: + ``name``, ``dur_us``, ``pct``, ``source``, ``source_stack``, + ``hlo_category``, ``model_flops``, ``bytes_accessed``. + """ + parser = TraceParser(trace_folder) + trace_json = parser.read_trace_json() + if trace_json is None: + return [] + events = trace_json.get("traceEvents", []) + + # Keep only device-level duration events for the target PID. + dev_events = [ + e + for e in events + if e.get("pid") == device_pid + and e.get("ph") == "X" + and e.get("dur") is not None + ] + if not dev_events: + return [] + + total_us = sum(e["dur"] for e in dev_events) + dev_events.sort(key=lambda e: -e["dur"]) + + results = [] + for e in dev_events[:top_k]: + args = e.get("args") or {} + source = args.get("source", "") + stack_raw = args.get("source_stack", "") + # Trim the stack to just the user-code frames (skip jax internals). + stack_lines = [] + if stack_raw: + for line in stack_raw.strip().split("\n"): + line = line.strip() + if not line: + continue + # Keep lines from the user's project; skip jax/site-packages. + if "site-packages" not in line: + stack_lines.append(line) + results.append({ + "name": e.get("name", ""), + "dur_us": e["dur"], + "pct": 100.0 * e["dur"] / total_us if total_us else 0, + "source": source, + "source_stack": stack_lines, + "hlo_category": args.get("hlo_category", ""), + "model_flops": args.get("model_flops", 0), + "bytes_accessed": args.get("bytes_accessed", 0), + }) + return results + + +def print_trace_bottlenecks( + trace_folder: str, + top_k: int = 10, + device_pid: int = 3, +) -> None: + """Pretty-print the top-K bottleneck ops from a single-iteration trace. + + Convenience wrapper around :func:`analyze_trace_bottlenecks`. + """ + bottlenecks = analyze_trace_bottlenecks(trace_folder, top_k, device_pid) + if not bottlenecks: + print(" (no device events found)") + return + print( + f" {'Rank':<5} {'Dur (ms)':>10} {'%':>6} {'Category':<12}" + f" {'Op Name':<40} {'Source'}" + ) + print(" " + "-" * 110) + for i, b in enumerate(bottlenecks): + print( + f" {i+1:<5} {b['dur_us']/1000:>10.2f} {b['pct']:>5.1f}%" + f" {b['hlo_category']:<12} {b['name'][:40]:<40} {b['source']}" + ) + if b["source_stack"]: + for frame in b["source_stack"][:3]: + print(f" {'':>24} ↳ {frame}") + + +def list_add(list1: List[Any], list2: List[Any]) -> List[Any]: + """Sum two lists element-wise. + + Args: + list1: First list to sum + list2: Second list to sum + + Returns: + List of the sum of the two lists + """ + assert len(list1) == len(list2), "The two lists must have the same length" + return [e1 + e2 for e1, e2 in zip(list1, list2)] + + +class KernelWrapperBase: + """Unified interface consumed by ``Profiler``. + + Subclasses expose a callable (``get_compiled_function``) plus the metadata + the profiler needs to feed and account for it. The profiler never branches + on the concrete wrapper type — it only calls methods defined here. + """ + + def __init__( + self, + kernel_name: str, + input_structs: List[Tuple[Tuple[int, ...], jnp.dtype]], + mesh: Optional[jax.sharding.Mesh] = None, + input_shardings: Optional[Tuple[jax.sharding.Sharding, ...]] = None, + output_sharding: Optional[jax.sharding.Sharding] = None, + enable_sharding: bool = False, + ): + self.kernel_name = kernel_name + self.input_structs = input_structs + self.mesh = mesh + self.input_shardings = input_shardings + self.output_sharding = output_sharding + self.enable_sharding = enable_sharding + self.callable_function_name = None + self.jit_function_name = None + + # ---- Profiler-facing unified API ---- + def get_compiled_function(self) -> Callable[..., jnp.ndarray]: + raise NotImplementedError + + def get_input_structs(self) -> List[Tuple[Tuple[int, ...], jnp.dtype]]: + return self.input_structs + + def get_kernel_name(self) -> str: + return self.kernel_name + + def get_input_arrays(self) -> Optional[List[jnp.ndarray]]: + """Return concrete inputs, or None to let Profiler fabricate random ones.""" + return None + + def shard_inputs(self, input_arrays: List[jnp.ndarray]) -> List[jnp.ndarray]: + if self.enable_sharding and self.input_shardings: + return [ + jax.device_put(arr, s) + for arr, s in zip(input_arrays, self.input_shardings) + ] + return input_arrays + + +class KernelWrapper(KernelWrapperBase): + + def __init__( + self, + kernel_name: str, + function_to_wrap: Callable, + input_structs: List[Tuple[Tuple[int, ...], jnp.dtype]], + mesh: Optional[jax.sharding.Mesh] = None, + input_shardings: Optional[Tuple[jax.sharding.Sharding, ...]] = None, + output_sharding: Optional[jax.sharding.Sharding] = None, + parameters: Optional[Dict[str, Any]] = {}, + enable_sharding: bool = False, + ): + super().__init__( + kernel_name=kernel_name, + input_structs=input_structs, + mesh=mesh, + input_shardings=input_shardings, + output_sharding=output_sharding, + enable_sharding=enable_sharding, + ) + self.callable_function = function_to_wrap + self.parameters = parameters + + self.jit_lower = None + self.jit_compiled_function = None + + # Compile immediately upon initialization + self._compile() + + def _compile(self): + jax_input_structs = [] + if self.enable_sharding and self.input_shardings: + for (shape, dtype), sharding in zip( + self.input_structs, self.input_shardings + ): + jax_input_structs.append( + jax.ShapeDtypeStruct(shape, dtype, sharding=sharding) + ) + else: + for shape, dtype in self.input_structs: + jax_input_structs.append(jax.ShapeDtypeStruct(shape, dtype)) + + # NOTE: Do not change the name of the function, it is used for profiling + if self.parameters: + + def compiled_kernel_function(*jax_array_inputs): + return self.callable_function( + *jax_array_inputs, parameters=self.parameters + ) + + else: + + def compiled_kernel_function(*jax_array_inputs): + return self.callable_function(*jax_array_inputs) + + if self.enable_sharding and self.mesh: + with self.mesh: + self.jit_lower = jax.jit( + jax.named_call(compiled_kernel_function, name=self.kernel_name), + in_shardings=self.input_shardings, + out_shardings=self.output_sharding, + ).lower(*jax_input_structs) + else: + self.jit_lower = jax.jit( + jax.named_call(compiled_kernel_function, name=self.kernel_name) + ).lower(*jax_input_structs) + + self.jit_compiled_function = self.jit_lower.compile() + + def get_compiled_function(self) -> Callable[..., jnp.ndarray]: + assert self.jit_compiled_function is not None, "Kernel not compiled" + if self.enable_sharding and self.mesh: + + def compiled_with_mesh(*jax_array_inputs): + if self.mesh is not None: + with self.mesh: + return self.jit_compiled_function(*jax_array_inputs) + return self.jit_compiled_function(*jax_array_inputs) + + return compiled_with_mesh + return self.jit_compiled_function + + +class PrecompiledKernelWrapper(KernelWrapperBase): + """Wrapper around an already-compiled/externally-managed callable. + + Unlike ``KernelWrapper``, this does NOT re-trace the function under + ``jax.jit``. Use it when the callable already dispatches to pre-compiled + executables or carries internal state that cannot be safely re-jitted + (e.g. a stateful context method like ``FusionMSMContext.multiscalar_multiply`` + which internally looks up pre-compiled, pre-sharded kernels). + + The wrapper also takes concrete input arrays directly — Profiler will use + them as-is instead of fabricating random inputs from shape/dtype, which is + essential when the inputs must match a specific layout (sharding, padding) + baked into the compiled kernel. + """ + + def __init__( + self, + kernel_name: str, + callable_function: Callable, + input_arrays: List[jnp.ndarray], + enable_sharding: bool = False, + callable_function_name: Optional[str] = None, + ): + input_arrays = list(input_arrays) + super().__init__( + kernel_name=kernel_name, + input_structs=[(a.shape, a.dtype) for a in input_arrays], + enable_sharding=enable_sharding, + ) + self.callable_function = callable_function + self._input_arrays = input_arrays + self.callable_function_name = callable_function_name + self.jit_function_name = ( + f"jit_{callable_function_name}" if callable_function_name else None + ) + + def get_compiled_function(self) -> Callable[..., jnp.ndarray]: + return self.callable_function + + def get_input_arrays(self) -> List[jnp.ndarray]: + """Concrete inputs preserved as-is (sharding, layout, device placement).""" + return self._input_arrays + + def shard_inputs(self, input_arrays: List[jnp.ndarray]) -> List[jnp.ndarray]: + # Inputs were provided pre-placed; nothing to do. + return input_arrays + + +class Profiler: + + def __init__( + self, + output_trace_path: str, + profile_naming: str, + configuration: Optional[Dict[str, Any]] = None, + ): + self.trace_dir = output_trace_path + self.profiler_name = profile_naming + self.profile_dir = os.path.join(self.trace_dir, self.profiler_name) + if not os.path.exists(self.profile_dir): + os.makedirs(self.profile_dir) + + self.configuration = configuration or {} + self.random_seed = self.configuration.get("random_seed", 0) + self.iterations = self.configuration.get("iterations", 1) + self.save_to_file = self.configuration.get("save_to_file", True) + self.enable_sharding = self.configuration.get("enable_sharding", False) + + self.profiles: List[Dict[str, Any]] = [] + self.profile_name_list: List[str] = [] + + # Storage for results + self.storage_file = os.path.join( + self.profile_dir, f"{self.profiler_name}_results.csv" + ) + + def add_profile( + self, + name: str, + kernel_wrapper: "KernelWrapperBase", + kernel_setting_cols: Optional[Dict[str, Any]] = None, + ): + if kernel_setting_cols is None: + kernel_setting_cols = {} + if name in self.profile_name_list: + raise ValueError(f"Profiler name {name} already exists") + + self.profile_name_list.append(name) + + profile_folder = os.path.join(self.profile_dir, name) + if not os.path.exists(profile_folder): + os.makedirs(profile_folder) + + self.profiles.append({ + "name": name, + "wrapper": kernel_wrapper, + "settings": kernel_setting_cols, + "folder": profile_folder, + "failed": False, + "trace_events": None, + "filtered_events": None, + "stats": None, + }) + + def _get_input_arrays(self, kernel_wrapper: "KernelWrapperBase"): + # If the wrapper carries concrete inputs (e.g. pre-sharded, pre-padded + # arrays bound to an already-compiled kernel), use them verbatim. + provided = kernel_wrapper.get_input_arrays() + if provided is not None: + for arr in provided: + arr.block_until_ready() + return provided + + def get_max_value(dtype): + if dtype == jnp.uint8: + return 128 + elif dtype == jnp.uint16: + return 32768 + elif dtype == jnp.uint32: + return 4294967295 + elif dtype == jnp.uint64: + return 4294967295 + raise ValueError(f"Unsupported dtype: {dtype}") + + random_key = jax.random.key(self.random_seed) + input_arrays = [] + for shape, dtype in kernel_wrapper.get_input_structs(): + if jnp.issubdtype(dtype, jnp.floating): + input_arrays.append(jax.random.uniform(random_key, shape, dtype)) + elif jnp.issubdtype(dtype, jnp.integer): + input_arrays.append( + jax.random.randint( + random_key, shape, 0, get_max_value(dtype), dtype + ) + ) + elif jnp.issubdtype(dtype, jnp.bool_): + input_arrays.append(jax.random.bernoulli(random_key, 0.5, shape=shape)) + else: + raise ValueError(f"Unsupported dtype: {dtype}") + for input_array in input_arrays: + input_array.block_until_ready() + + if self.enable_sharding: + input_arrays = kernel_wrapper.shard_inputs(input_arrays) + + return input_arrays + + def profile_all_profilers(self): + for profile in self.profiles: + print(f"Profiling {profile['name']}") + try: + # Kernel wrapper is already compiled in its init + wrapper = cast(KernelWrapperBase, profile["wrapper"]) + compiled_function = wrapper.get_compiled_function() + input_arrays = self._get_input_arrays(wrapper) + + # Run each iteration in its own trace window so that the TPU + # trace buffer is not shared across iterations to avoid buffer + # overflow problem for kernel with long running time. + iter_folders = [] + for i in range(self.iterations): + iter_folder = os.path.join(profile["folder"], f"iter_{i}") + os.makedirs(iter_folder, exist_ok=True) + with jax.profiler.trace(iter_folder): + compiled_function(*input_arrays).block_until_ready() + iter_folders.append(iter_folder) + profile["iter_folders"] = iter_folders + except Exception as e: + print(f"Error profiling {profile['name']}:\n {e}") + profile["failed"] = True + + def _parse_json_trace(self, profile): + # With per-iteration trace windows, merge trace events from all + # iteration sub-folders into one combined list. + iter_folders = profile.get("iter_folders", []) + if iter_folders: + all_trace_events = [] + trace_file_path = None + for folder in iter_folders: + parser = TraceParser(folder) + tfp = parser.find_trace_file() + if tfp is None: + continue + if trace_file_path is None: + trace_file_path = tfp + pjson = parser.read_trace_json() + if pjson is not None: + all_trace_events.extend(pjson.get("traceEvents", [])) + trace_events = all_trace_events + else: + # Legacy path: single trace folder. + trace_parser = TraceParser(profile["folder"]) + trace_file_path = trace_parser.find_trace_file() + profile_json = trace_parser.read_trace_json() + if profile_json is None: + warnings.warn( + f"{profile['name']}: No trace events found in the data", UserWarning + ) + profile["failed"] = True + return None + trace_events = profile_json.get("traceEvents", []) + + if not trace_events: + warnings.warn( + f"{profile['name']}: No trace events found in the data", UserWarning + ) + profile["failed"] = True + return None + if self.save_to_file: + # Save into the same folder as the raw trace file + output_dir = os.path.dirname(trace_file_path) + profile["output_folder"] = output_dir + with open(os.path.join(output_dir, "trace_events.json"), "w") as f: + json.dump(trace_events, f, indent=2) + profile["trace_events"] = trace_events + return trace_events + + def _filter_trace_events(self, profile): + trace_events = profile["trace_events"] + if trace_events is None: + return None + + def merge_filtered_events_by_name(filtered_events): + grouped = {} + for event in filtered_events: + event_name = event.get("name", "unknown") + if ( + "args" in event.keys() + and "deduplicated_name" in event["args"].keys() + ): + event_name += "_" + event["args"]["deduplicated_name"] + elif ( + "custom-call" in event["name"] + and "args" in event.keys() + and "tf_op" in event["args"].keys() + ): + event_name += "_" + event["args"]["tf_op"] + if event_name not in grouped: + grouped[event_name] = [] + grouped[event_name].append(event) + + merged_filtered_events = {} + for event_name, events in grouped.items(): + merged = events[0].copy() + merged["dur"] = [e.get("dur") for e in events if "dur" in e] + merged["ts"] = [e.get("ts") for e in events if "ts" in e] + merged["repeat_count"] = len(events) + merged_filtered_events[event_name] = merged + return merged_filtered_events + + filtered_events_list = [] + # Check if NVIDIA is in device kind OR CPU is used as a fallback if + # explicit check needed + # But generally JAX trace events differ by backend. + # Assuming typical CPU/GPU separation. + device_kind = jax.devices()[0].device_kind + + if "NVIDIA" in device_kind: + for e in trace_events: + if "args" in e and "tf_op" in e["args"]: + # Loosen the check for compiled_kernel_function as it might be nested differently or named differently + if "compiled_kernel_function" in e["args"].get( + "hlo_module", "" + ) or "compiled_kernel_function" in e["args"].get("long_name", ""): + merged_event = False + # Try to merge with existing events + for f in filtered_events_list: + # Check if correlation_id exists before accessing it + if ( + "correlation_id" in f["args"] + and "correlation_id" in e["args"] + and f["args"]["correlation_id"] == e["args"]["correlation_id"] + and f["name"] == e["name"] + ): + f["dur"] = f["dur"] + e["dur"] + merged_event = True + if not merged_event: + filtered_events_list.append(e) + profile["filtered_events"] = merge_filtered_events_by_name( + filtered_events_list + ) + + elif "TPU" in device_kind: + wrapper = cast(KernelWrapperBase, profile["wrapper"]) + jit_function_name = wrapper.jit_function_name + for event in trace_events: + if ( + "pid" not in event.keys() or event["pid"] != 3 + ): # ToDo: change it into automatic PID detection based on "TPU:0". + continue + if ( + "name" in event.keys() + and "compiled_kernel_function" in event["name"] + and "args" in event.keys() + ): + filtered_events_list.append(event) + elif ( + "name" in event.keys() + and jit_function_name is not None + and jit_function_name in event["name"] + ): + # print(f"Found jit function name {jit_function_name} in event {event['name']}") + filtered_events_list.append(event) + elif "args" in event.keys() and "long_name" in event["args"].keys(): + filtered_events_list.append(event) + else: + continue + profile["filtered_events"] = merge_filtered_events_by_name( + filtered_events_list + ) + else: + # Fallback for CPU or other devices + # CPU traces might be different. Let's try to capture events related to our kernel. + for event in trace_events: + if "name" in event and "compiled_kernel_function" in event["name"]: + filtered_events_list.append(event) + profile["filtered_events"] = merge_filtered_events_by_name( + filtered_events_list + ) + + # Always save filtered events if we have any + if self.save_to_file: + # Make sure we don't crash if profile['filtered_events'] is None + events_to_dump = ( + profile["filtered_events"] + if profile["filtered_events"] is not None + else {} + ) + with open( + os.path.join(profile["output_folder"], "filtered_events.json"), "w" + ) as f: + json.dump(events_to_dump, f, indent=2) + + def _calculate_profiling_statistics(self, profile): + if profile["filtered_events"] is None: + return + + repeat_count = self.iterations + kernel_duration = [0] * repeat_count + + device_kind = jax.devices()[0].device_kind + + if "NVIDIA" in device_kind: + filtered_events = cast(Dict[str, Any], profile["filtered_events"]) + for event in filtered_events.values(): + if "compiled_kernel_function" in event["args"].get("hlo_module", ""): + durations = event["dur"] + if not isinstance(durations, list): + durations = [durations] + + if len(durations) == repeat_count: + kernel_duration = list_add(kernel_duration, durations) + elif ( + len(durations) > repeat_count + and len(durations) % repeat_count == 0 + ): + # Assume sequential execution of kernels within one iteration + chunk_size = len(durations) // repeat_count + aggregated_durations = [ + sum(durations[i * chunk_size : (i + 1) * chunk_size]) + for i in range(repeat_count) + ] + kernel_duration = list_add(kernel_duration, aggregated_durations) + else: + # Fallback: just take first N or handle mismatch. + # For now, adopting CPU strategy of taking first N but this is likely under-reporting. + # Ideally log a warning. + kernel_duration = list_add( + kernel_duration, durations[:repeat_count] + ) + elif "TPU" in device_kind: + wrapper = cast(KernelWrapperBase, profile["wrapper"]) + jit_function_name = wrapper.jit_function_name + filtered_events = cast(Dict[str, Any], profile["filtered_events"]) + for event in filtered_events.values(): + matches = "compiled_kernel_function" in event["name"] or ( + jit_function_name is not None and jit_function_name in event["name"] + ) + if not matches: + continue + durations = event["dur"] + if not isinstance(durations, list): + durations = [durations] + # At larger problem sizes XLA can coalesce per-device event replays + # unevenly, producing either a multiple of ``repeat_count`` (one + # kernel invocation expanded into several sub-events per iteration) + # or fewer samples than ``repeat_count`` (some replays elided). + # Mirror the NVIDIA branch's bucket-and-fallback logic instead of + # blindly zip-adding, which used to assert out and abort the + # post-processing before the CSV was written. + if len(durations) == repeat_count: + kernel_duration = list_add(kernel_duration, durations) + elif ( + len(durations) > repeat_count and len(durations) % repeat_count == 0 + ): + chunk_size = len(durations) // repeat_count + aggregated = [ + sum(durations[i * chunk_size : (i + 1) * chunk_size]) + for i in range(repeat_count) + ] + kernel_duration = list_add(kernel_duration, aggregated) + elif len(durations) < repeat_count: + padded = durations + [0] * (repeat_count - len(durations)) + kernel_duration = list_add(kernel_duration, padded) + else: + kernel_duration = list_add(kernel_duration, durations[:repeat_count]) + else: + # CPU logic - assuming direct name match from filtered events + filtered_events = cast(Dict[str, Any], profile["filtered_events"]) + for event in filtered_events.values(): + # On CPU, events might be simpler + if "compiled_kernel_function" in event.get("name", ""): + # DUR might be a single value or list depending on how it was merged + durations = event["dur"] + if not isinstance(durations, list): + durations = [durations] + + # If we have less durations than repeat_count, we might need to pad + # or it's a mismatch + # For now, let's just add what we have, assuming 1-to-1 or aggregated + if len(durations) == repeat_count: + kernel_duration = list_add(kernel_duration, durations) + elif len(durations) > repeat_count: + # Take first N + kernel_duration = list_add( + kernel_duration, durations[:repeat_count] + ) + else: + # Append 0s? Or just take what we have + padded = durations + [0] * (repeat_count - len(durations)) + kernel_duration = list_add(kernel_duration, padded) + + profile["stats"] = { + "kernel_all": kernel_duration, + } + + def post_process_all_profilers(self): + """Post-process trace events and calculate statistics for all profiles.""" + for profile in self.profiles: + if profile["failed"]: + continue + + events = self._parse_json_trace(profile) + if events is None: + continue + + self._filter_trace_events(profile) + self._calculate_profiling_statistics(profile) + + self.write_results() + + def get_profiling_dataframe_generator_all_profilers(self): + df_generator = DataFrameGenerator() + for profile in self.profiles: + if profile["failed"] or profile["stats"] is None: + continue + + p_df_gen = DataFrameGenerator() + p_df_gen.add_single_value( + "operation_name", profile["wrapper"].get_kernel_name() + ) + + for key, value in profile["settings"].items(): + p_df_gen.add_single_value(key, value) + + all_kernel_duration = profile["stats"]["kernel_all"] + for i, duration in enumerate(all_kernel_duration): + p_df_gen.add_single_value(f"sample_{i}", duration) + + df_generator.merge(p_df_gen) + return df_generator + + def write_results(self): + """Write profiling results to a CSV file and print to stdout.""" + storage_dataframe_generator = ( + self.get_profiling_dataframe_generator_all_profilers() + ) + # Check if file exists to determine if we need to write header + file_exists = os.path.exists(self.storage_file) + mode = "a" if file_exists else "w" + header = not file_exists + storage_dataframe_generator.to_dataframe().to_csv( + self.storage_file, mode=mode, header=header, index=False + ) + print( + storage_dataframe_generator.to_dataframe().to_csv() + ) # Need to see the content of the file in terminal as Google does not have file system + print(f"Results written to: {self.storage_file}") + + +def collect_logs(root_dir=".", output_csv_name="all_logs_collected"): + """Collects all CSV files found under directories named 'log' + + and aggregates them into a single CSV file. + Handles varying headers by taking the union of all found columns. + """ + all_files = [] + + # Fieldnames set to collect all unique columns + all_fieldnames = set() + # To preserve some order, we can use a list and add new ones as we see them + ordered_fieldnames = [] + + # First pass: identify files and collect all possible fieldnames + for dirpath, dirnames, filenames in os.walk(root_dir): + path_parts = dirpath.split(os.sep) + if "log" in path_parts: + for file in filenames: + if file.endswith(".csv"): + full_path = os.path.join(dirpath, file) + all_files.append(full_path) + try: + with open(full_path, "r", newline="") as csvfile: + reader = csv.reader(csvfile) + try: + header = next(reader) + for h in header: + if h not in all_fieldnames: + all_fieldnames.add(h) + ordered_fieldnames.append(h) + except StopIteration: + # Empty file + pass + except Exception as e: + print(f"Error reading header of {full_path}: {e}") + + if not all_files: + print("No CSV files found.") + return + + print(f"Found {len(all_files)} CSV files.") + print(f"Unified collected columns: {ordered_fieldnames}") + + output_file = os.path.join(root_dir, f"{output_csv_name}.csv") + total_rows = 0 + + try: + with open(output_file, "w", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=ordered_fieldnames) + writer.writeheader() + + for full_path in all_files: + try: + with open(full_path, "r", newline="") as infile: + reader = csv.DictReader(infile) + # The DictReader uses the file's own header mapping + # We just iterate and write to the master dict writer + for row in reader: + writer.writerow(row) + total_rows += 1 + except Exception as e: + print(f"Error processing {full_path}: {e}") + + print(f"Saved aggregated logs to {os.path.abspath(output_file)}") + print(f"Total rows collected: {total_rows}") + + except Exception as e: + print(f"Error writing output file: {e}") diff --git a/jaxite_ec/test_case/t1/zprize_msm_curve_377_bases_dim_1_seed_0.csv b/jaxite_ec/test_case/t1/zprize_msm_curve_377_bases_dim_1_seed_0.csv deleted file mode 100644 index 9ee1257..0000000 --- a/jaxite_ec/test_case/t1/zprize_msm_curve_377_bases_dim_1_seed_0.csv +++ /dev/null @@ -1 +0,0 @@ -0, x, 004cd4683316487f, b89b952325a20f91, 293cd479875ceb77, 2af0a9c29bacf6f9, f8673c91ec915534, aff8c262353486d7, zz: Fp384 "(003F62945CE627D1ECB0F654CD9E80F82DC0E62E484AE9CED02FC5DC5658CD2D6C3AB04E97AD20BD970B5C7E2178709C)", y, 00aa7ad6de759cb7, 270923c09134b2c6, 7f88b0b899bceace, 22071c8c7a09df51, 7f08b237fda96e55, b4d65d75663afe99, zz: Fp384 "(016113C0C182762A49721F7EAF286AFD6B0FA13DD9CE68CE5119C574EDF947D9AEBF6D36C26C3C208AB518E3079AC6FA)" \ No newline at end of file diff --git a/jaxite_ec/test_case/t1/zprize_msm_curve_377_res_dim_1_seed_0.csv b/jaxite_ec/test_case/t1/zprize_msm_curve_377_res_dim_1_seed_0.csv deleted file mode 100644 index 80821f2..0000000 --- a/jaxite_ec/test_case/t1/zprize_msm_curve_377_res_dim_1_seed_0.csv +++ /dev/null @@ -1 +0,0 @@ -x, 009976f628cc7c42, d55f93b8c7dfbc61, 86a374e1e5a2fafb, a3e0ba0828a88f69, 9f20ac113f371d74, 77a360edb4c28207, zz: Fp384 "(0134A9468957E7FDC88CF9D69A4E7B3AE93432AE7FFFE93F6BE4F6F287BCA1880E185D98E304467DEC0709D9F7BC7AF0)", y, 01a7e5e4f8816331, 28dcee5c74675219, e54fd183a73529c5, 776dbcf5725a252e, 39fdb795acda186b, 0eadea2601fc454d, zz: Fp384 "(004A4A15E00C14B99875584249FCAF18B0AAA567BF67863D6DCF47CDC39A227154C3369DF6D9A63EECDB3BAA19FD2B58)" \ No newline at end of file diff --git a/jaxite_ec/test_case/t1/zprize_msm_curve_377_scalars_dim_1_seed_0.csv b/jaxite_ec/test_case/t1/zprize_msm_curve_377_scalars_dim_1_seed_0.csv deleted file mode 100644 index 73ac78f..0000000 --- a/jaxite_ec/test_case/t1/zprize_msm_curve_377_scalars_dim_1_seed_0.csv +++ /dev/null @@ -1 +0,0 @@ -0, 0f923fffd2a6f534, dc5b6a6901840fc0, fb65827e6efd22a8, 063cded681f5f7b2, zz: Fp256 "(11057EEFA14A393F5EBD62FD9747A76B77CDFC90E07F0BC6D8C87440D758452D)" \ No newline at end of file diff --git a/jaxite_ec/test_case/t2/zprize_msm_curve_377_bases_dim_2_seed_0.csv b/jaxite_ec/test_case/t2/zprize_msm_curve_377_bases_dim_2_seed_0.csv deleted file mode 100644 index d61bb34..0000000 --- a/jaxite_ec/test_case/t2/zprize_msm_curve_377_bases_dim_2_seed_0.csv +++ /dev/null @@ -1,2 +0,0 @@ -0, x, 0024f0169fe2cfbb, 2dcc78c274a772fd, a0bf1eb986baf6d8, 0f9d3669482b16c0, e9d8bb1ef8cbea3e, 02f2a0491d3f7f41, zz: Fp384 "(01AC3A384FC584EFD3E7F2C5A2927E7D454875C874A051027B9E7363D08942533EDE85DAE295D8CAB2751085206BCA76)", y, 002da6d2840f4ce6, 1514adf2c130292c, 86aa9bea8cf858ba, 53b98e247e9dc552, 36abb853a15bff31, 803b023ee6aaf142, zz: Fp384 "(011DB83AEC88460820F4868A73B12309EE2E910526E62DB4ACCB303ABF50F86C3985A072ED07A4B81FFB82D8DD247283)" -1, x, 00b3b2f197a3b127, cde0e1d3d77e7d89, 49d6a6a3dac76741, fc4e1568a3b90dcc, f690d49a6208e299, fcf3195b0a1534a6, zz: Fp384 "(01546AF2ABB4E189E9BBC412FDBF2A8E5EC6E4A3B0AF132E21EE9CEC3EF5E226490FB98D662670FA3CFB3948B7E2A48C)", y, 012e6fd70da14f01, 3841e339306efa6e, c845ceff3aa42d46, 5d869b1205ec884c, 77718567b7535dd9, 1bc1fdfd9ce464a3, zz: Fp384 "(002961A558A885DF227FDB09F8BDF57AF179CB9437FF8828F13E9DF01AE55502F409AAF5058B88F2F7CCC7BC0676A5D4)" \ No newline at end of file diff --git a/jaxite_ec/test_case/t2/zprize_msm_curve_377_res_dim_2_seed_0.csv b/jaxite_ec/test_case/t2/zprize_msm_curve_377_res_dim_2_seed_0.csv deleted file mode 100644 index b28b820..0000000 --- a/jaxite_ec/test_case/t2/zprize_msm_curve_377_res_dim_2_seed_0.csv +++ /dev/null @@ -1 +0,0 @@ -x, 001ba77010560355, d4a083cd44ea3b27, 47e444e832f177a9, e5a2cc6f64868bf3, 34875bb04ff9a153, 29a12ab6ca346603, zz: Fp384 "(001E9324B117F5E1E1EA2D63635906F087DD18C9C3840E28CD24FD6960DC4BAC7659B648FA29E89DA0DCDA4B9677254F)", y, 0038529273debc84, 567b59a70a96e53f, f2421d1eeaa395eb, e3d065d79346fc9b, 20a9f718531bd65d, 29662f4a78afce5d, zz: Fp384 "(010FDD33E114E8A119E7855FC6E457A0B2F1FC0436076E9557302DD86CD11EF6148721B9E4E63A01693378E985496444)" \ No newline at end of file diff --git a/jaxite_ec/test_case/t2/zprize_msm_curve_377_scalars_dim_2_seed_0.csv b/jaxite_ec/test_case/t2/zprize_msm_curve_377_scalars_dim_2_seed_0.csv deleted file mode 100644 index f6344a1..0000000 --- a/jaxite_ec/test_case/t2/zprize_msm_curve_377_scalars_dim_2_seed_0.csv +++ /dev/null @@ -1,2 +0,0 @@ -0, 0f923fffd2a6f534, dc5b6a6901840fc0, fb65827e6efd22a8, 063cded681f5f7b2, zz: Fp256 "(11057EEFA14A393F5EBD62FD9747A76B77CDFC90E07F0BC6D8C87440D758452D)" -1, 002e9d0d87c0a600, 1c9a1f731ec9a8d0, 3ca0557886321ce6, e5716b57188ca258, zz: Fp256 "(0C736B82B849EB8E4D8F81136A30803766F90E66AB059D12A6EDD2577E605788)" \ No newline at end of file diff --git a/jaxite_ec/test_case/t4/zprize_msm_curve_377_bases_dim_4_seed_0.csv b/jaxite_ec/test_case/t4/zprize_msm_curve_377_bases_dim_4_seed_0.csv deleted file mode 100644 index 8896d35..0000000 --- a/jaxite_ec/test_case/t4/zprize_msm_curve_377_bases_dim_4_seed_0.csv +++ /dev/null @@ -1,4 +0,0 @@ -0, x, 0054dadbda56447b, 1edf94ecc29bac9d, be41bffb21bbab70, 0d13b7683a5604e7, ae2f805045068087, 3bd87a659df906e5, zz: Fp384 "(0132A41F10B5AAE35C21740C1742AA70445C0DAF97DC1F92B7B51023BCC4ED52DE46227513DCFBCE518FD1AEEC21B83B)", y, 013bed8689b68cbd, 11eb13b43aa44895, 86836736dcaf43a5, a9675efb170a9c2c, be3e2a2333c55942, 31edbd28217dbb58, zz: Fp384 "(004F55B794051CB39882288086577AD6E7D199297A9DA267666D6FEC28F3D8781C95781639307F2A9B9C2522FCBD040B)" -1, x, 002f3f48c6b8f5fa, 720aa4decad34840, a94743660088442e, 712af8cd5eb5d3d1, 8da9cd5acf6bff08, a003e41885a627d0, zz: Fp384 "(012668C2A40C279AD883A98302D2BA3542871C8B48B21AF176A8C704DED10DB7E4753B830250EDDFA47FCD7B163343DC)", y, 0125bd5ea0a1989d, 3cad55ebd9ef75b3, 6f167f9389f8360d, 87346eda19e3e3b8, 50d092aa5b40cec1, daf39285f5fe38d8, zz: Fp384 "(01A4904AB5F9B7C2FBD154025161FEC3414AF1D3B764377D50741A06FC86742B9EC862BFCD8507486CFE67995CEDCDBD)" -2, x, 0024be0cb95af7cb, 212a56e5183de381, a551cc6f82fb71e2, 1e3996681d117eb1, f9278d11382dfdec, 01a164e98bc73565, zz: Fp384 "(0145535A3BF5EF85CBA87B177AE7672B3ABD8C9A0C35108D9C92201AAFF925B3678D58F150CD1EA5395AC4D141CCDF28)", y, 012530da4d102355, 59cd4f2fb4b622cc, fad0061949e42051, 38bbcb79a57c7fd0, 67da66ec382efb95, ecf7cfddb7cfc0c9, zz: Fp384 "(007028D05152FD66FEA7662953D4C0B6F36891B2276EF5B7F48CA78EFF4892C3ACEB3FEC17D45C780506A27A18521DD6)" -3, x, 016efb97086a0478, 0ad22f31d98ac0b5, 934213bb4a45b710, bb75e577d1bfc091, 0880564021f4eb65, 58a3d749c1d1b878, zz: Fp384 "(0035B2C1EC332094CE8D30528D3E5E373021DB3BC54A97675AE885C574D1DD0084CE476EFD85EEBD4109A2D2339E8824)", y, 011f2bb8975d0cfe, 19132fe5974c13ea, eb279bb047b1abe8, e9ac98f26391a563, 0101680172684959, e9e6e94a59e90ffa, zz: Fp384 "(00CFEC2077BD515E1AA9B72FCC2AC8A79D4B7F333E9D139C2E9482905CE5C4294261C9ADF1EA8E05540D4834C2461536)" \ No newline at end of file diff --git a/jaxite_ec/test_case/t4/zprize_msm_curve_377_res_dim_4_seed_0.csv b/jaxite_ec/test_case/t4/zprize_msm_curve_377_res_dim_4_seed_0.csv deleted file mode 100644 index ce32abf..0000000 --- a/jaxite_ec/test_case/t4/zprize_msm_curve_377_res_dim_4_seed_0.csv +++ /dev/null @@ -1 +0,0 @@ -x, 014d7f8695379436, 16965d5ab3d91b3a, be23f9940b3b2857, 6924eb4496b95388, 2d22b77f18c3d89e, 2feb08e40a36bb91, zz: Fp384 "(0105E8EEFA0C3DE4D1E65A63489065FB0F8C66F1CBFF2422FE246E3CD1443DC699711B81B7790B7653547310240B2AF0)", y, 015845f50dda6c32, 87e7e0cbd525e428, ea1cd4b49a411788, eb7226604c36c531, d4d3a8e5c5d6d925, 0906ee99df45c728, zz: Fp384 "(00D7A61C6E85442A8DE1B4B14D18BD07FDC08AEFD9C71CAD928D404B6FD3F5C347057AE82ED6581F71EBF6407A1E07EE)" \ No newline at end of file diff --git a/jaxite_ec/test_case/t4/zprize_msm_curve_377_scalars_dim_4_seed_0.csv b/jaxite_ec/test_case/t4/zprize_msm_curve_377_scalars_dim_4_seed_0.csv deleted file mode 100644 index 8b24b20..0000000 --- a/jaxite_ec/test_case/t4/zprize_msm_curve_377_scalars_dim_4_seed_0.csv +++ /dev/null @@ -1,4 +0,0 @@ -0, 0f923fffd2a6f534, dc5b6a6901840fc0, fb65827e6efd22a8, 063cded681f5f7b2, zz: Fp256 "(11057EEFA14A393F5EBD62FD9747A76B77CDFC90E07F0BC6D8C87440D758452D)" -1, 002e9d0d87c0a600, 1c9a1f731ec9a8d0, 3ca0557886321ce6, e5716b57188ca258, zz: Fp256 "(0C736B82B849EB8E4D8F81136A30803766F90E66AB059D12A6EDD2577E605788)" -2, 0e09b93519940300, 0a4867513ba6edb7, 530cce8b6a69232d, 0ecdd08a922d30e3, zz: Fp256 "(0738C6A142F9D317BFD26C26E9629C54E3E28C0CD3B100A6311C83A0BA678390)" -3, 0a85273b8e6bb8ed, dc0d90c24a4f5c72, 0f57216e312a2f32, 6ef1359bfc3df47f, zz: Fp256 "(04C4AFFA315644F295F99B40E9A87D32380603D675FBA350A0BB852DFEFE7878)" \ No newline at end of file diff --git a/jaxite_ec/test_case/t8/zprize_msm_curve_377_bases_dim_8_seed_0.csv b/jaxite_ec/test_case/t8/zprize_msm_curve_377_bases_dim_8_seed_0.csv deleted file mode 100644 index a2a3b4e..0000000 --- a/jaxite_ec/test_case/t8/zprize_msm_curve_377_bases_dim_8_seed_0.csv +++ /dev/null @@ -1,8 +0,0 @@ -0, x, 018df51d28d7bb33, 12125d3dc2345560, dc6f8aa4f602ac7a, 9060043fe5dbc225, f9c8eba7ee4b72f0, 4054363299acd090, zz: Fp384 "(00FAF6788CB83075B17A9E3C865EA5DAA699646B8C42C7B0CCCF7D1306DB76C14D9AE857115C728EAA5053B4DD7E8907)", y, 00600f6808058b1d, 93e9bef7c2902d8c, 9480c5e3a177149d, 8fcd6015fc212417, 5073eadb38affa0a, abbd2a0e0b6894be, zz: Fp384 "(0167A1F5704269F768539A23A97A44611791593BD8556BE1718B8A08B54549698604AABA33DD7EBD8E935596E50380EA)" -1, x, 00daceeae593097d, e7dff89b697bff72, b4f1081a7c7da896, e16b722bb6c4d340, e91940d6dd520c64, a549a1cc3d4185d6, zz: Fp384 "(0005AA5A34376BA2D6D2982F5EADBCACEFA1FA2999B381AEAAF51F8EE541CBB439C56420ED43DC6F137A4116F42018D0)", y, 01373aad56ca01f0, 92e5fb3bd9b1a845, 7231a14c2005c972, 8cb4c2d02007c407, 7319d9315a06327c, a080605f19f5a0af, zz: Fp384 "(01936530B0244C450ADEB2A2F03119D090FE80E3E4C1878B6066E42A3BFF646C503872BDE9A0ACBDCDFA8DF81FFCF13D)" -2, x, 002ecd4e42ae033a, f64d0e37eabca682, dca865eb442d4487, 390bb206301b7dc9, e6e5addb49897da3, 885c7248fc2da32c, zz: Fp384 "(00099F98BF058B8521588A6D521FF5D95F0A4E6EC9D9F0658FD513BC654C8E2DA2475CA572C5FA1C90EC09D4A3ACC452)", y, 00497ee9a51a8c00, 2e6865603a44a22c, f36f2ee7608b7ebf, 7211ab8e2046edbc, 00c91718b1a340d2, a5086213363a7ff9, zz: Fp384 "(00E8AA0A224C09333B12378C512414EAD8439E8F15F520096F97B3BD45BDCFE06D67E2E881F5AD575D9AAF49A0C6B7A4)" -3, x, 003c2546fd254a90, 0a3e18aab430f4ad, dad00063a6701bac, e9b3c577880ecf51, c9fbb89016ce87fe, b125886dfaa66b55, zz: Fp384 "(0164DDDBF27670CE389E2992C0E7DAB7741F1B925EDBDC254D2BC0830BAF8E0B186F80F0DD4DE0F0EA6176E55934D45B)", y, 0087161d71559880, 7f6dbe70a20e787d, 996d9a6282b303bc, 66e92f1cbb9bd9f9, e3b3bb47a7de8d87, 2fbb78e7edf485df, zz: Fp384 "(01908E9D77A0F8AD89AC41441F74248704E756BC59C38920617F51BFCDB738EE5B123876D489D09C9EB904A321A336EC)" -4, x, 00373631a6b80a93, b33760134d4de6be, 56ed27d5df9baba3, 755b753d0162f9b8, 1ba8a82579296bdb, 9aafe4d1b780ac70, zz: Fp384 "(0066CB93F6DBF96A94EBD6874DBD7B0E51D67A924C5AB5F96C2A45AF5EC75EA093843107E2EB1A45D645310AC3984B24)", y, 00a5af9f753d8472, b96fcf742d8981cd, 792ff8a32fab8a6a, 41d6fc8217f9f9ef, cf392541d749bf16, ef7345fe8183a959, zz: Fp384 "(0192AF6092F295756DDEDB7A77E7DB52B9FFCCB9D00EB6B61ED0E9E6B795484287CD9DFADD3B30784390941F86BD0FC2)" -5, x, 00ea9dc132d85e66, 87fda87d39f0499a, 17b4277d88cdbd0a, 27b8b9dcfea6be53, 34c23351b99f097d, b3cf7ac7ffd6ed2f, zz: Fp384 "(0084801FF559F4882F1180089D4FC1C81A0017932824F2ADAB09956C79AEF5219AAC66D806A086C838DB2424E67B1039)", y, 011f02cfd6c5d920, b577b59e650d501d, 66326514f4c44ebf, 7fc342da9c7ae7b6, 30325f5017e975e8, cbbf2d840d9c64e5, zz: Fp384 "(009B77B4C386398F4CB40B86E74BF5B310ADDC286678166AB660FDEB4A0FD5F1DD080FDB4EE3D8BC4553F5FA529A74DD)" -6, x, 019f02ca76047ec7, dc062d90f205d06c, 7dfbdc07f954e9d7, 85c4a3e6845721a1, 30f1080faa1484d8, 493b00c9cd7be16e, zz: Fp384 "(00B0630E7F192D20443A93860275447074CE77DF559907FA1900F378D4674649BF25F85C893E2A1916B1DA57594F2E17)", y, 006b2c62488dc2ba, 833a6b9f0c8f808b, 98c825e9e1583848, fc8b9db4aeb70f49, e5840cf05def3735, 88f30a7a997ca239, zz: Fp384 "(01ACC84F362CF60A265C011F0FE4360A15F51BECF7E2C3923FE07C66D5D113104B56E8486C64204A2A9ECD75BA0C41A7)" -7, x, 0048c29459872078, eb2d1fcf4a139b51, 67da2ccdbaa8f2e3, 2a03ceed817240a7, fe43123affec9af4, c8486205c2ebdfba, zz: Fp384 "(018C09BCC580517E8C93CECCA83E200690FD66AD94DE753023401E3E337938715885C26FC631FF1833618D053E0872B0)", y, 00dbef2406460af5, 382cc49017f942af, 4a56c1a953eedf43, a1aaaa31c829cfd3, 96d012bb8bedbcc2, b31c1ced77585736, zz: Fp384 "(011080F0E32525EBB7F919B908C9F6F7B20E5B46A8CAEF059CE14E0739C9F164570264F91ECA4DE80E45520ED61B881C)" \ No newline at end of file diff --git a/jaxite_ec/test_case/t8/zprize_msm_curve_377_res_dim_8_seed_0.csv b/jaxite_ec/test_case/t8/zprize_msm_curve_377_res_dim_8_seed_0.csv deleted file mode 100644 index dc0e754..0000000 --- a/jaxite_ec/test_case/t8/zprize_msm_curve_377_res_dim_8_seed_0.csv +++ /dev/null @@ -1 +0,0 @@ -x, 000e569ef8591fd6, 6cfc4f3224e47b68, bc4e8df117d07814, 160ba23f3fcab777, b1fc73a8395568c4, d21d407e354f8f52, zz: Fp384 "(008BFE20A05721F86AFEAEB8A94FCE08847074F2F555A3281E72545045976C72B33A10C9E62EC748C87087274D63ED22)", y, 00a952354041686e, dd824e49cb9926ab, 1d6a3f89b5ddb42d, a64339e196f523d8, 1ecff81311076324, 7ec7773b01f43018, zz: Fp384 "(01803103D87B815018D87204D4AC96B3BD49DC467802B7C8B7E6FC7E1FE0B20443DA072921D521F7B8BA7823C8B76048)" \ No newline at end of file diff --git a/jaxite_ec/test_case/t8/zprize_msm_curve_377_scalars_dim_8_seed_0.csv b/jaxite_ec/test_case/t8/zprize_msm_curve_377_scalars_dim_8_seed_0.csv deleted file mode 100644 index 19f34d8..0000000 --- a/jaxite_ec/test_case/t8/zprize_msm_curve_377_scalars_dim_8_seed_0.csv +++ /dev/null @@ -1,8 +0,0 @@ -0, 0f923fffd2a6f534, dc5b6a6901840fc0, fb65827e6efd22a8, 063cded681f5f7b2, zz: Fp256 "(11057EEFA14A393F5EBD62FD9747A76B77CDFC90E07F0BC6D8C87440D758452D)" -1, 002e9d0d87c0a600, 1c9a1f731ec9a8d0, 3ca0557886321ce6, e5716b57188ca258, zz: Fp256 "(0C736B82B849EB8E4D8F81136A30803766F90E66AB059D12A6EDD2577E605788)" -2, 0e09b93519940300, 0a4867513ba6edb7, 530cce8b6a69232d, 0ecdd08a922d30e3, zz: Fp256 "(0738C6A142F9D317BFD26C26E9629C54E3E28C0CD3B100A6311C83A0BA678390)" -3, 0a85273b8e6bb8ed, dc0d90c24a4f5c72, 0f57216e312a2f32, 6ef1359bfc3df47f, zz: Fp256 "(04C4AFFA315644F295F99B40E9A87D32380603D675FBA350A0BB852DFEFE7878)" -4, 0ed8f6c7a5b1a650, 031ebc9b7a93492e, 89f282d49e7d2560, 7a5693b3d8ae2e87, zz: Fp256 "(0D19ECF1BAF67C8B31F01A16F3EA15BDF78812F5A00857F90C61C66F3445F9EC)" -5, 09014a7226acc95a, 4835d93dae8844b6, cda8ebe010d04060, bce87d13ab77d2f7, zz: Fp256 "(0136D026DE3C43F05C8035F0932506241268E9FA45095A2FA17C72C6FD1076C6)" -6, 0d40d252d29285e5, b39c381c97f2caed, 963b0d72b9280ba1, ee92ff87524aa9f5, zz: Fp256 "(05963104A4318BC9978D3B7746B2D8E3E9E36DC4CB751A4406B76D04BD4C3DB8)" -7, 060bae0e7184f261, 1e89f2311ccfb076, 9149396b939c46ae, 4a5c432085813bb4, zz: Fp256 "(0215F60E325265030BA4B084E9C465B1A22969C73E357B37F215AA734C66D6FA)" \ No newline at end of file diff --git a/jaxite_ec/util.py b/jaxite_ec/util.py deleted file mode 100644 index 17afce2..0000000 --- a/jaxite_ec/util.py +++ /dev/null @@ -1,628 +0,0 @@ -"""Utility functions for jaxite_ec. - -Note that: All functions that directly take Python int as input cannot be -jitted. -""" - -import csv -import json -import math -from typing import Any, Callable, List, Tuple - -import jax -import jax.numpy as jnp -import numpy as np - -# copybara: from google3.perftools.accelerators.xprof.api.python import xprof_analysis_client -# copybara: from google3.perftools.accelerators.xprof.api.python import xprof_session - -gcd = math.gcd - - -#################################### -# BLS12-377 Curve Configurations -#################################### - -MODULUS_377_INT = 0x01AE3A4617C510EAC63B05C06CA1493B1A22D9F300F5138F1EF3622FBA094800170B5D44300000008508C00000000001 -MU_377_INT = 0x98542343310183A5DB0F28160BBD3DCEEEB43799DDAC681ABCB52236169B40B43B5A1DE2710A9647E7F56317936BFF32 -TWIST_D_INT = 122268283598675559488486339158635529096981886914877139579534153582033676785385790730042363341236035746924960903179 - - -#################################### -# Global Configurations -#################################### - -BASE = 16 -BASE_TYPE = jnp.uint16 # this type must match the BASE, i.e. jnp.uint -U16_MASK = 0xFFFF -U32_MASK = 0xFFFFFFFF - -U8_CHUNK_NUM = 48 -U16_CHUNK_NUM = 24 -U32_CHUNK_NUM = 12 -U16_CHUNK_SHIFT_BITS = 16 -U32_CHUNK_SHIFT_BITS = 32 - -U16_EXT_CHUNK_NUM = 25 - - -BARRETT_SHIFT_U8 = 95 # BARRETT Params for k = 380 -CHUNK_PRECISION = 8 - -# Lazy Reduction Logics -MODULUS_377_S16_INT = MODULUS_377_INT << 16 - -# Pippenger Logics -COORDINATE_NUM = 4 - -# RNS Reduction Logics -# Hardware friendly moduli factors are 2**16 - v for v in the following list -RNS_MODULI_T = ( - 0, - 1, - 3, - 5, - 9, - 15, - 17, - 27, - 33, - 39, - 45, - 47, - 57, - 59, - 63, - 77, - 87, - 89, - 99, - 105, - 113, - 117, - 123, - 125, - 129, - 143, - 153, - 155, - 165, - 167, - 173, - 179, - 183, - 189, - 197, - 209, - 213, - 215, - 225, - 227, - 243, - 249, - 14, - 38, - 50, - 54, - 98, - 102, - 110, - 122, -) - -MODULI = tuple([ - 2**16 if i == 0 else 2**16 - int(i) if i % 2 == 1 else 2**15 - (int(i) // 2) - for i in RNS_MODULI_T -]) - - -RNS_PRECISION = 16 -NUM_MODULI = len(RNS_MODULI_T) -# Maximum number of consecutive additions/subtractions -ADDITION_BOUND = 4 - -# Warning: specific to target modulus and addition bound -MODULI_SUB = tuple([ - ((512 * NUM_MODULI * MODULUS_377_INT * ADDITION_BOUND) - 2**16) % m - for m in MODULI -]) -TWIST_D_RNS = tuple([TWIST_D_INT % MODULI[i] for i in range(len(MODULI))]) - - -#################################### -# Utility Functions -#################################### - - -def print_hex_values(int_list): - hex_values = " ".join((hex(value)) for value in int_list) - print(hex_values) - - -def array_to_int(jax_array: jax.Array, base) -> int: - """Converts a JAX array to a single Python integer.""" - result = 0 - - for i, elem in enumerate(jax_array): - result |= int(elem) << (i * base) - - return result - - -def int_to_array( - python_int, base=BASE, dtype=jnp.uint16, array_size=U16_CHUNK_NUM -): - """Converts a Python integer to a JAX array.""" - mask = (1 << base) - 1 - - elements = [] - while python_int > 0: - elements.append(python_int & mask) # Extract the lower bits - python_int >>= base # Shift to remove the extracted bits - - # we pad or trim the result to match the desired size - if array_size is not None: - assert array_size >= len(elements) - elements = elements[:array_size] + [0] * (array_size - len(elements)) - - return jnp.array(elements, dtype=dtype) - - -def array_to_int_list(jax_array, base): - """Converts JAX array to single integer.""" - result_list = [] - for i in range(jax_array.shape[0]): - value_vector = jax_array[i] - value_int = array_to_int(value_vector, base) - result_list.append(value_int) - return result_list - - -def int_list_to_array(int_list, base=BASE, array_size=U16_CHUNK_NUM): - """Converts a list of integers to a JAX array.""" - chunked_arrays = [] - for int_value in int_list: - chunked_arrays.append(int_to_array(int_value, base, array_size=array_size)) - return jnp.array(chunked_arrays) - - -def int_point_to_jax_point_pack( - coordinates: List[int], base=BASE, chunk_num=U16_CHUNK_NUM -): - result = [] - for i in range(len(coordinates)): - result.append(int_to_array(coordinates[i], base, array_size=chunk_num)) - return jnp.array(result) - - -def jax_point_pack_to_int_point(point: jax.Array): - coordinate_num = point.shape[0] - coordinates = [] - for i in range(coordinate_num): - c = array_to_int(point[i], BASE) - coordinates.append(c) - return coordinates - - -# RNS related data formal conversion -def int_to_array_rns(x): - return [x % m for m in MODULI] - - -def array_rns_to_int(residues): - rns_precompute_values = rns_precompute(MODULI) - return rns_reconstruct(residues, MODULI, rns_precompute_values) - - -def int_list_to_array_rns(int_list) -> jnp.ndarray: - """Converts a list of integers to a JAX array.""" - limbs = [] - for int_value in int_list: - limbs.append(int_to_array_rns(int_value)) - return jnp.array(limbs) - - -def array_rns_to_int_list(jax_array): - """Converts JAX array to single integer.""" - result_list = [] - for i in range(jax_array.shape[0]): - value_vector = jax_array[i] - value_int = array_rns_to_int(value_vector) - result_list.append(value_int) - return result_list - - -def int_point_to_jax_rns_point_pack(coordinates: List[int]): - result = [] - for i in range(len(coordinates)): - result.append(int_to_array_rns(coordinates[i])) - return jnp.array(result) - - -def jax_rns_point_pack_to_int_point(point: jax.Array): - coordinate_num = point.shape[0] - coordinates = [] - for i in range(coordinate_num): - c = array_rns_to_int(point[i]) - coordinates.append(c) - return coordinates - - -def int_point_batch_to_jax_point_pack( - points: List[List[int]], base=BASE, chunk_num=U16_CHUNK_NUM -): - result = [] - for i in range(len(points)): - result.append(int_point_to_jax_point_pack(points[i], base, chunk_num)) - return jnp.transpose(jnp.array(result), (1, 0, 2)) - - -def jax_point_pack_to_int_point_batch(point_pack: jnp.ndarray, base=BASE): - points = jnp.transpose(point_pack, (1, 0, 2)) - results = [] - for i in range(len(points)): - results.append(array_to_int_list(points[i], base)) - return results - - -def int_point_batch_to_jax_rns_point_pack(points: List[List[int]]): - result = [] - for i in range(len(points)): - result.append(int_point_to_jax_rns_point_pack(points[i])) - return jnp.transpose(jnp.array(result), (1, 0, 2)) - - -def jax_rns_point_pack_to_int_point_batch(point_pack: jnp.ndarray): - points = jnp.transpose(point_pack, (1, 0, 2)) - results = [] - for i in range(len(points)): - results.append(array_rns_to_int_list(points[i])) - return results - - -# RNS helpers -def total_modulus(moduli): - modulus = 1 # Compute the big modulus - for m in moduli: - modulus *= m - return modulus - - -def rns_precompute(moduli): - modulus = total_modulus(moduli) - precomputed = [] - for m in moduli: - rest = modulus // m # 0 mod all the other moduli - inverse = pow(rest % m, -1, m) # factor to make 1 mod this moduli - icrt_val = (rest * inverse) % modulus # combine - precomputed.append(icrt_val) - return precomputed - - -def rns_reconstruct(residues, moduli, precomputed): - assert len(residues) == len(moduli) - assert len(moduli) == len(precomputed) - output = 0 - for i, r in enumerate(residues): - output += precomputed[i] * int(r) - return output % total_modulus(moduli) - - -def to_rns(x, moduli): - assert x < total_modulus(moduli) - return [x % m for m in moduli] - - -def to_tuple(a): - """Create to convert numpy array into tuple.""" - try: - return tuple(to_tuple(i) for i in a) - except TypeError: - return a - - -# The following function achieves the same function as int_to_array, but it -# can be pre-run (Google restriction), and returns a tuple. -def int_to_precomputed_array( - python_int, base=BASE, dtype=jnp.uint16, array_size=U16_CHUNK_NUM -): - """Converts a Python integer to a JAX array.""" - mask = (1 << base) - 1 - - elements = [] - while python_int > 0: - elements.append(python_int & mask) # Extract the lower bits - python_int >>= base # Shift to remove the extracted bits - - # we pad or trim the result to match the desired size - if array_size is not None: - assert array_size >= len(elements) - elements = elements[:array_size] + [0] * (array_size - len(elements)) - - return to_tuple(np.array(elements, dtype=dtype).tolist()) - - -#################################### -# Performance Profiler Functions (Google Internal) -#################################### - - -def profile_jax_functions( - tasks: List[Tuple[Callable[..., Any], Tuple[Any, ...]]], - profile_name: str = "jax_profile", -): - """Profiles a list of JAX functions. - - Args: - tasks: A list of tuples, where each tuple contains a JAX function and its - arguments. - profile_name: The name of the profile. - - Usage: - tasks = [ - (jit_pdul_barrett_xyzz_pack, (point_a_jax,)), - ] - profile_name = "jit_pdul_barrett_xyzz_pack" - profile_jax_functions(tasks, profile_name) - """ - session_id = None - - # copybara: session = xprof_session.XprofSession() - # copybara: session.start_session() - try: - # Launch all JAX computations - results = [] - for func, args_tuple in tasks: - result = func(*args_tuple) - results.append(result) - - # Wait for all computations launched in the loop to complete - if results: - jax.block_until_ready(results) - - except Exception as e: # pylint: disable=broad-exception-caught - print(f"Error type: {type(e).__name__}") - print(f"Error details: {e}") - # Attempt to end the session even if there was an error - # copybara: session_id = session.end_session_and_get_session_id() - print("Xprof session ended due to error.") - if session_id: - print(f"{profile_name}: http://xprof/?session_id={session_id}") - finally: - if session_id is None: - # copybara: session_id = session.end_session_and_get_session_id() - print(f"{profile_name}: http://xprof/?session_id={session_id}") - # copybara: client = xprof_analysis_client.XprofAnalysisClient() - trace = ( - client.get_profile_data("trace_viewer.json", session_id) - if client - else None - ) - jtrace = json.loads(trace[1]) if trace else None - if jtrace: - for e in jtrace["traceEvents"]: - if profile_name in e["name"]: - print(f"{profile_name} latency: {e['dur']}\n") - - -#################################### -# Lazy Reduction -- Offline Precompute -#################################### - - -def construct_lazy_matrix(p, chunk_precision=8, chunk_num_u8=U8_CHUNK_NUM): - """Construct the lazy matrix. - - Args: - p: The modulus. - chunk_precision: The chunk precision. - chunk_num_u8: The number of chunks in the u8 value. - - Returns: - lazy_mat: The lazy matrix. - - Note that: this function runs on CPU of the TPU-VM, which cannot be jitted. - """ - lazy_mat_list = [] - for i in range(chunk_num_u8 + 4): - val = int(int(256) ** (chunk_num_u8 + i)) % p - lazy_mat_list.append( - int_to_precomputed_array(val, chunk_precision, array_size=chunk_num_u8) - ) - return to_tuple(lazy_mat_list) - - -MODULUS_377_LAZY_MAT = construct_lazy_matrix(MODULUS_377_INT) - - -#################################### -# RNS Reduction -- Offline Precompute -#################################### - - -def find_moduli(total_modulus, precision): - """Finds a list of moduli close to the given precision. - - Args: - total_modulus: The target modulus. - precision: The desired precision of the moduli. - - Returns: - A tuple containing two lists: - - overall_moduli: A list of moduli close to the given precision. - - overall_constant_offset: A list of constant offsets for the moduli. - """ - initial_moduli = 2**precision - overall_moduli = [] - overall_constant_offset = [] - overall_modulus = 1 - for i in range(2 ** (precision >> 1) - 1): - cur_moduli = initial_moduli - i - if math.gcd(cur_moduli, overall_modulus) == 1: - overall_moduli.append(cur_moduli) - overall_constant_offset.append(i) - overall_modulus *= cur_moduli - if overall_modulus > total_modulus: - return to_tuple(overall_moduli), to_tuple(overall_constant_offset) - - # Find 2**15 - v too - initial_moduli = 2 ** (precision - 1) - if overall_modulus < total_modulus: - for i in range(2 ** (precision >> 1) - 1): - cur_moduli = initial_moduli - i - if math.gcd(cur_moduli, overall_modulus) == 1: - overall_moduli.append(cur_moduli) - overall_constant_offset.append(i << 1) - overall_modulus *= cur_moduli - if overall_modulus > total_modulus: - return to_tuple(overall_moduli), to_tuple(overall_constant_offset) - - return to_tuple(overall_moduli), to_tuple(overall_constant_offset) - - -def rns_icrt_factors_compute(modulus, moduli): - precomputed = [] - for m in moduli: - rest = modulus // m # 0 mod all the other moduli - inverse = pow(rest % m, -1, m) # factor to make 1 mod this moduli - icrt_val = (rest * inverse) % modulus # combine - precomputed.append(icrt_val) - return precomputed - - -def rns_coefficients_precompute( - icrt_factors, - overall_moduli, - num_bytes, - moduli_precision, - overall_modulus, - q, -): - """Precompute RNS coefficients. - - Args: - icrt_factors: Precomputed inverse CRT factors. - overall_moduli: Array of moduli. - num_bytes: Number of bytes. - moduli_precision: Precision of the moduli. - overall_modulus: Overall modulus. - q: Target modulus. - - Returns: - Precomputed RNS coefficients and correction coefficients. - """ - num_residues = len(overall_moduli) - # icrt_factors_byteshifted -- (num_residues, num_bytes) - icrt_factors_byteshifted = [ - [ - (((1 << (8 * pre_id)) * factor) % overall_modulus) - for pre_id in range(num_bytes) - ] - for factor in icrt_factors - ] - # icrt_factors_byteshifted_modq -- (num_residues, num_bytes) - icrt_factors_byteshifted_modq = [ - [(chunk % q) for chunk in factors] for factors in icrt_factors_byteshifted - ] - # icrt_factors_byteshifted_modq_rns - # (num_residues, num_bytes, num_residues) [Convert each byte range into RNS] - icrt_factors_byteshifted_modq_rns = [ - [to_rns(chunk, overall_moduli) for chunk in factors] - for factors in icrt_factors_byteshifted_modq - ] - - rns_mat = np.array( - icrt_factors_byteshifted_modq_rns, dtype=np.uint16 - ).reshape(-1, num_residues) - - # calculate quotient estimation - fix_point = 1 << moduli_precision - - shifted_quotient_estimations = [] - for factors in icrt_factors_byteshifted: - for chunk in factors: - shifted_quotient_estimations.append( - [math.ceil((chunk * fix_point) / overall_modulus)] - ) - sqe_mat = np.array(shifted_quotient_estimations, dtype=np.uint16) - - cor_mat = np.array( - [to_rns(-overall_modulus % q, overall_moduli)], dtype=np.uint16 - ) - - # Convert rns_mat and sqe_mat into various bytes. - # Version 1: split precision into different chunks. - # rns_mat_u8 = rns_mat.view(np.uint8).reshape(*rns_mat.shape, num_bytes) - # seq_mat_u8 = sqe_mat.view(np.uint8).reshape(*sqe_mat.shape, num_bytes) - # rns_stack_mat_u8 = np.hstack(( - # rns_mat_u8[..., 0], - # seq_mat_u8[..., 0], - # rns_mat_u8[..., 1], - # seq_mat_u8[..., 1], - # )) - # Version 2: interleave precision -- tested to be faster. - rns_stack_mat_u8 = np.hstack( - (rns_mat.view(jnp.uint8), sqe_mat.view(jnp.uint8)) - ) - return to_tuple(rns_stack_mat_u8.tolist()), to_tuple(cor_mat.tolist()) - - -def get_parts(u16mat): - assert u16mat.dtype == np.uint16 - u16bytes = u16mat.view(np.uint8) - return [u16bytes[:, ::2], u16bytes[:, 1::2]] - - -M = MODULUS_377_INT * MODULUS_377_INT * 256 * 256 * 50 * 50 * 4 * 2 -moduli_precision = 16 -num_bytes = moduli_precision // 8 # 2 -# hardware friendly moduli is 2**precision - t -# overall_moduli is the jax.array of "2**precision - t" -# overall_constant_offset is the jax.array of "t" -overall_moduli, overall_constant_offset = find_moduli(M, moduli_precision) -M = 1 -for moduli in overall_moduli: - M *= moduli -M = int(M) -assert len(overall_moduli) == ( - (M.bit_length() + moduli_precision - 1) // moduli_precision -) - -icrt_factors = rns_icrt_factors_compute(M, overall_moduli) - -RNS_STACK_MAT_NEW, COR_MAT_NEW = rns_coefficients_precompute( - icrt_factors, - overall_moduli, - num_bytes, - moduli_precision, - M, - MODULUS_377_INT, -) - - -def construct_rns_matrix(q): - return rns_coefficients_precompute( - icrt_factors, overall_moduli, num_bytes, moduli_precision, M, q - ) - - -############################### -# Break High-precision Integer into Chunkcs -############################### - - -MODULI = overall_moduli -RNS_MODULI_T = overall_constant_offset -RNS_MAT = (RNS_STACK_MAT_NEW, COR_MAT_NEW) -MODULUS_377_INT_CHUNK = int_to_precomputed_array( - MODULUS_377_INT, base=BASE, array_size=U16_CHUNK_NUM -) -MU_377_INT_CHUNK = int_to_precomputed_array( - MU_377_INT, base=BASE, array_size=U16_CHUNK_NUM -) -TWIST_D_INT_CHUNK = int_to_precomputed_array( - TWIST_D_INT, base=BASE, array_size=U16_EXT_CHUNK_NUM -) -MODULUS_377_S16_INT_CHUNK = int_to_precomputed_array( - MODULUS_377_S16_INT, base=BASE, array_size=U16_EXT_CHUNK_NUM -) diff --git a/jaxite_ec/utils.py b/jaxite_ec/utils.py new file mode 100644 index 0000000..21cabb6 --- /dev/null +++ b/jaxite_ec/utils.py @@ -0,0 +1,804 @@ +import csv +from functools import lru_cache +import hashlib +import math +import os +import pickle +from typing import Any, Callable, List, Optional +import warnings +import jax +from jax import export +from jax import sharding as shd +import jax.numpy as jnp +import numpy as np +import toml + +# ============================================================================= +# Load configurations +# ============================================================================= +_config_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "configurations.toml" +) +with open(_config_path, "r", encoding="utf-8") as _f: + _config = toml.load(_f) +_serialized_jax_kernel_dir = _config.get( + "serialized_jax_kernel_dir", "./deployments/" +) +if "TEST_TMPDIR" in os.environ: + _serialized_jax_kernel_dir = os.path.join( + os.environ["TEST_TMPDIR"], "deployments" + ) + +if not os.path.exists(_serialized_jax_kernel_dir): + os.makedirs(_serialized_jax_kernel_dir) +_hash_length = _config.get("hash_length", 8) + + +# ============================================================================= +# Number theory transform related helper functions +# ============================================================================= +def gen_twiddle_matrix(rows, cols, q, omega): + """Precompute the twiddle matrix T of shape (rows, cols), where T[r, c] = omega^(r*c) mod q. + + Args: + rows: The number of rows in the matrix. + cols: The number of columns in the matrix. + q: The modulus. + omega: The primitive root of unity. + + Returns: + The twiddle matrix. + """ + warnings.warn( + "gen_twiddle_matrix is deprecated. Use" + " jaxite_ntt.number_theory_transform.utils.gen_twiddle_matrix instead.", + UserWarning, + stacklevel=2, + ) + twiddle_matrix = np.zeros((rows, cols), dtype=int) + for r in range(rows): + for c in range(cols): + twiddle_matrix[r, c] = pow(int(omega), int(r * c), int(q)) + return twiddle_matrix + + +def gen_twiddle_matrix_inv(rows, cols, q, omega): + """Precompute the inverse twiddle matrix T_inv of shape (rows, cols). + + T_inv[r, c] = omega^{- (r*c)} mod q. + + Args: + rows: The number of rows in the matrix. + cols: The number of columns in the matrix. + q: The modulus. + omega: The primitive root of unity. + + Returns: + The inverse twiddle matrix. + """ + warnings.warn( + "gen_twiddle_matrix_inv is deprecated. Use" + " jaxite_ntt.number_theory_transform.utils.gen_twiddle_matrix_inv" + " instead.", + UserWarning, + stacklevel=2, + ) + twiddle_matrix_inv = np.zeros((rows, cols), dtype=int) + for r in range(rows): + for c in range(cols): + twiddle_matrix_inv[r, c] = pow(int(omega), int(-r * c), int(q)) + return twiddle_matrix_inv + + +def prime_factors(n): + """Return the set of prime factors of n.""" + factors = set() + # Divide out factors of 2 + while n % 2 == 0: + factors.add(2) + n //= 2 + # Check odd factors from 3 to sqrt(n) + p = 3 + while p**2 <= n: + while n % p == 0: + factors.add(p) + n //= p + p += 2 + if n > 1: + factors.add(n) + return factors + + +def find_generator(q): + """Find a primitive root modulo q. + + Args: + q (int): The prime modulus. + + Returns: + A generator of GF(q)^*. + + Raises: + ValueError: If no generator is found, indicating q is not prime. + """ + phi = q - 1 + factors = prime_factors(phi) + + # Test candidates from 2 to q-1. + for g in range(2, q): + is_generator = all(pow(g, phi // p, q) != 1 for p in factors) + if is_generator: + return g + raise ValueError("No generator found, check that q is prime.") + + +def root_of_unity(m: int, q: int) -> int: + """Canonical primitive m-th root of unity modulo q that **works with NTT**. + + Args: + m (int): The order of the root of unity. + q (int): The prime modulus. + + Returns: + int: The canonical primitive m-th root of unity modulo q. + + Usage: + root_of_unity(16, 134219681) # This works with NTT. + computed_psi = [root_of_unity(m, q) for q in original_modulus] + """ + assert (q - 1) % m == 0, "q-1 must be divisible by m" + # Step 1: multiplicative generator of Z_q^* + g = find_generator(q) + # Step 2: raise to (q-1)/m to get an m-th root candidate + r = pow(g, (q - 1) // m, q) + # Step 3: among r^k with gcd(k,m)=1, pick the minimal value whose order is exactly m + # For m=2^t, order check is psi^(m/2) == q-1 (i.e., == -1 mod q) + candidates = [] + half = m // 2 + for k in range(1, m): + if math.gcd(k, m) != 1: + continue + psi = pow(r, k, q) + if pow(psi, half, q) == q - 1 and pow(psi, m, q) == 1: + candidates.append(psi) + assert candidates, "No primitive m-th root found" + return int(min(candidates)) + + +# ============================================================================= +# Modular arithmetic related helper functions +# ============================================================================= +def modular_inverse(a: int, m: int): + t, new_t = 0, 1 + r, new_r = m, a + + while new_r != 0: + quotient = r // new_r + t, new_t = new_t, t - quotient * new_t + r, new_r = new_r, r - quotient * new_r + + if r > 1: + raise ValueError(f"{a} is not invertible modulo {m}") + if t < 0: + t += m + + return t + + +def compute_crt_factors(moduli): + modular = math.prod(moduli) + ms = [modular // m for m in moduli] + ms_inv = [modular_inverse(ms[i], moduli[i]) for i in range(len(moduli))] + return [(ms[i] * ms_inv[i]) % modular for i in range(len(moduli))] + + +def to_rns(x, moduli): + return [x % m for m in moduli] + + +def rns_reconstruct(residues, moduli, crt_factors): + return sum( + [residues[i] * crt_factors[i] for i in range(len(residues))] + ) % math.prod(moduli) + + +def find_moduli_specified_number(total_number, precision): + """Finds a list of moduli close to the given precision. + + The moduli are all odd and coprime. + + Args: + total_number: The total number of moduli requirement. + precision: The desired precision of the moduli. + + Returns: + A tuple containing two lists: + - overall_moduli: A list of moduli close to the given precision. + """ + initial_moduli = 2**precision + overall_moduli = [] + overall_modulus = 1 + for i in range(1, 2 ** (precision >> 1) - 1): + cur_moduli = initial_moduli - i + if cur_moduli % 2 == 1 and math.gcd(cur_moduli, overall_modulus) == 1: + overall_moduli.append(cur_moduli) + overall_modulus *= cur_moduli + if len(overall_moduli) >= total_number: + return to_tuple(overall_moduli) + + # Find 2**31 - v + initial_moduli = 2 ** (precision - 1) + if len(overall_moduli) < total_number: + for i in range(1, 2 ** (precision >> 1) - 1): + cur_moduli = initial_moduli - i + if cur_moduli % 2 == 1 and math.gcd(cur_moduli, overall_modulus) == 1: + overall_moduli.append(cur_moduli) + overall_modulus *= cur_moduli + if len(overall_moduli) >= total_number: + return to_tuple(overall_moduli) + + return to_tuple(overall_moduli) + + +def find_primes_with_bits(number, bits): + """Returns a list of 'number' prime numbers, each with exactly 'bits' bits. + + Args: + number (int): The number of primes to find. + bits (int): The bit length of each prime. + + Returns: + List[int]: List of prime numbers with the specified bit length. + """ + + def is_prime(n): + if n < 2: + return False + if n == 2 or n == 3: + return True + if n % 2 == 0 or n % 3 == 0: + return False + i = 5 + w = 2 + while i * i <= n: + if n % i == 0: + return False + i += w + w = 6 - w + return True + + primes = [] + lower = 1 << (bits - 1) + upper = (1 << bits) - 1 + candidate = lower | 1 # ensure odd + + while candidate <= upper and len(primes) < number: + if is_prime(candidate): + primes.append(candidate) + candidate += 2 # only check odd numbers + + return primes + + +def modular_matrix_np_u32_to_u8_bat_4d(matrix: np.ndarray, modulus: int): + rows, cols = matrix.shape + assert modulus <= 2**31 + matrix_u64 = matrix.astype(np.uint64) + matrix_u64_byteshifted = np.array( + [matrix_u64 << (8 * byte_idx) for byte_idx in range(4)], dtype=np.uint64 + ) + # shape is (4, rows, cols) + matrix_u64_byteshifted = matrix_u64_byteshifted.transpose(1, 0, 2) + matrix_u64_byteshifted_mod_modulus = ( + matrix_u64_byteshifted % modulus + ).astype(np.uint32) + matrix_u8 = matrix_u64_byteshifted_mod_modulus.view(np.uint8).reshape( + rows, 4, cols, 4 + ) + return matrix_u8 + + +# ============================================================================= +# MSM related helper functions +# ============================================================================= +def read_external_msm_file(path, type: str): + if type == "scalars": + scalars = [] + with open( + path, "r", newline="", encoding="utf-8" + ) as csvfile: # Handle potential encoding issues + csv_reader = csv.reader(csvfile) + for row in csv_reader: + scalars.append(int(row[-1][13:-2], 16)) + return scalars + + elif type == "points": + points = [] + with open( + path, "r", newline="", encoding="utf-8" + ) as csvfile: # Handle potential encoding issues + csv_reader = csv.reader(csvfile) + for row in csv_reader: + points.append([int(row[8][13:-2], 16), int(row[-1][13:-2], 16)]) + return points + + elif type == "result_ref": + result_ref = [] + with open( + path, "r", newline="", encoding="utf-8" + ) as csvfile: # Handle potential encoding issues + csv_reader = csv.reader(csvfile) + for row in csv_reader: + result_ref.append(int(row[7][13:-2], 16)) + result_ref.append(int(row[-1][13:-2], 16)) + return result_ref + + +def split_list(lst, chunk_size): + """Splits a list into equal-sized chunks.""" + return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] + + +def slice_scalars( + scalars: List, scalar_bits: int, slice_length: int +) -> List[List]: + window_num = int(math.ceil(scalar_bits / slice_length)) + mask = (1 << slice_length) - 1 + slices_list = [[] for _ in range(window_num)] + for scalar in scalars: + for i in range(window_num): + slices_list[i].append(scalar & mask) + scalar >>= slice_length + return slices_list + + +# ============================================================================= +# shape and tuple related helper functions +# ============================================================================= +def nested_list_depth(x: Any) -> int: + """Return the nesting depth of a list. + + Examples: + nested_list_depth([1, 2, 3]) -> 1 + nested_list_depth([[1, 2], [3, 4]]) -> 2 + nested_list_depth([[[1], [2]]]) -> 3 + + Args: + x: The value to inspect. + + Returns: + 0 if x is not a list, otherwise 1 + the maximum depth of its elements. + """ + if not isinstance(x, list): + return 0 + if len(x) == 0: + return 1 + return 1 + max(nested_list_depth(item) for item in x) + + +def to_tuple(a): + """Create to convert numpy array into tuple.""" + if isinstance(a, (list, tuple, np.ndarray)): + return tuple(to_tuple(i) for i in a) + return int(a) + + +def pad_jax_array(array: jnp.ndarray, target_shape: tuple) -> jax.Array: + if array.shape == target_shape: + return array + assert len(array.shape) == len( + target_shape + ), "array and target_shape must have the same number of dimensions" + pad_width = [] + for cur, tgt in zip(array.shape, target_shape): + assert tgt >= cur, f"target size {tgt} is smaller than current size {cur}" + pad_width.append((0, tgt - cur)) + return jnp.pad(array, pad_width, mode="constant", constant_values=0) + + +# ============================================================================= +# Code structure +# ============================================================================= +class JaxParameters: + word_bits: Any = None + rns_moduli_inv_word: Any = None + word_mask: Any = None + half_word_mask: Any = None + half_word_bits: Any = None + rns_moduli_high: Any = None + rns_moduli_low: Any = None + rns_moduli: Any = None + crns_precision: Any = None + crns_vector_g: Any = None + crns_stacked_mat_E_with_f_T: Any = None + rns_moduli_negate: Any = None + rns_moduli_sub: Any = None + twist_d: Any = None + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + def set_parameter(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + +class JaxKernelContextBase: + + def __init__(self, use_compiled_kernels: bool = False): + self.compiled_kernels = {} + self.use_compiled_kernels = use_compiled_kernels + self.use_sharding = False + self.sharding_mesh = None + self.mesh_axes = () + + def set_use_compiled_kernels(self, use_compiled_kernels: bool): + self.use_compiled_kernels = use_compiled_kernels + + def serialize(self, parameters) -> Any: + pass + + def compile(self, parameters) -> Any: + pass + + def context_hash(self) -> str: + raise NotImplementedError("Subclasses must implement context_hash") + + def set_use_sharding(self, use_sharding: bool): + if use_sharding: + self.use_sharding = True + mesh, partition_spec = create_sharding() + axis_names = mesh.axis_names + partition = axis_names if len(axis_names) > 1 else axis_names[0] + self.sharding_mesh = mesh + self.sharding_partition = partition + self.sharding_partition_spec = partition_spec + self.mesh_axes = tuple(mesh.axis_names) + else: + self.use_sharding = False + self.sharding_mesh = None + self.sharding_partition = None + self.sharding_partition_spec = None + self.mesh_axes = () + + def make_named_sharding(self, spec) -> "jax.sharding.NamedSharding | None": + """Wrap a PartitionSpec in a NamedSharding bound to the current mesh. + + Returns None if sharding is disabled. + """ + mesh = getattr(self, "sharding_mesh", None) + if mesh is None: + return None + return jax.sharding.NamedSharding(mesh, spec) + + def shard_constraint(self, x, spec): + """Apply ``with_sharding_constraint(x, NamedSharding(mesh, spec))``. + + No-op when sharding is disabled. Intended for use inside jitted + kernels to pin intermediate layouts. + """ + mesh = getattr(self, "sharding_mesh", None) + if mesh is None: + return x + return jax.lax.with_sharding_constraint( + x, jax.sharding.NamedSharding(mesh, spec) + ) + + def create_named_sharding( + self, shape: tuple, axes: list[int] + ) -> tuple[jax.sharding.NamedSharding, tuple]: + """Create an efficient NamedSharding for the given shape and shard axes. + + (Generated by Claude) + + For 1 axis: shards across all mesh devices on that axis. Skips sharding + (returns replicated) if the axis is too small for the mesh size. Pads + the axis so that shape[axis] % total_devices == 0. + + For 2 axes: tries both mesh-to-data-axis mappings and picks the one + with less padding waste. Pads each sharded axis as needed. + + Args: + shape: Array shape. + axes: 1 or 2 axis indices to shard along. >=3 is not allowed. + + Returns: + (NamedSharding, padded_shape) — the sharding and the shape after + padding (same as input shape if no padding was needed). + """ + assert ( + hasattr(self, "sharding_mesh") and self.sharding_mesh is not None + ), "Sharding must be enabled first via enable_sharding()" + + if len(axes) == 0: + spec = [None] * len(shape) + partition_spec = jax.sharding.PartitionSpec(*spec) + return ( + jax.sharding.NamedSharding(self.sharding_mesh, partition_spec), + shape, + ) + + assert len(axes) >= 1, "Must specify at least 1 axis" + assert len(axes) <= 2, "Sharding along >= 3 axes is not supported" + + mesh = self.sharding_mesh + mesh_axis_names = mesh.axis_names # e.g. ('x', 'y') + mesh_axis_sizes = {name: mesh.shape[name] for name in mesh_axis_names} + padded_shape = list(shape) + + # Minimum elements per device below which sharding is not worthwhile + _MIN_ELEMS_PER_DEVICE = 1 + + if len(axes) == 1: + axis = axes[0] + total_devices = math.prod(mesh_axis_sizes.values()) + + # Small axis: replicate instead of sharding + if shape[axis] < total_devices * _MIN_ELEMS_PER_DEVICE: + spec = [None] * len(shape) + partition_spec = jax.sharding.PartitionSpec(*spec) + warnings.warn( + f"Shape {shape} is too small for sharding, replicating the axis" + f" {axis}", + stacklevel=2, + ) + return jax.sharding.NamedSharding(mesh, partition_spec), tuple(shape) + + # Pad so shape[axis] is divisible by total_devices + if shape[axis] % total_devices != 0: + warnings.warn( + f"Shape {shape} is not divisible by total_devices {total_devices}," + f" padding the axis {axis} to" + f" {math.ceil(shape[axis] / total_devices) * total_devices}", + stacklevel=2, + ) + padded_shape[axis] = ( + math.ceil(shape[axis] / total_devices) * total_devices + ) + + # Shard this single data axis across all mesh axes + spec = [None] * len(shape) + spec[axis] = tuple(mesh_axis_names) + partition_spec = jax.sharding.PartitionSpec(*spec) + + else: # len(axes) == 2 + axis0, axis1 = axes[0], axes[1] + total_devices = math.prod(mesh_axis_sizes.values()) + + def _compute_candidate(spec_list, shard_plan): + """Compute (waste, spec, padded) for a shard_plan: list of (mesh_label, data_axis). + + mesh_label is either a single mesh axis name or a tuple of all mesh axis + names. + """ + p = list(shape) + waste = 0 + for mesh_label, d_axis in shard_plan: + if isinstance(mesh_label, tuple): + ms = math.prod(mesh_axis_sizes[n] for n in mesh_label) + else: + ms = mesh_axis_sizes[mesh_label] + s = shape[d_axis] + if s % ms != 0: + waste += (math.ceil(s / ms) * ms) - s + p[d_axis] = math.ceil(s / ms) * ms + return waste, spec_list, tuple(p) + + # Candidate A: mesh_x -> axis0, mesh_y -> axis1 + spec_a = [None] * len(shape) + spec_a[axis0] = mesh_axis_names[0] + spec_a[axis1] = mesh_axis_names[1] + cand_a = _compute_candidate( + spec_a, [(mesh_axis_names[0], axis0), (mesh_axis_names[1], axis1)] + ) + + # Candidate B: mesh_x -> axis1, mesh_y -> axis0 + spec_b = [None] * len(shape) + spec_b[axis1] = mesh_axis_names[0] + spec_b[axis0] = mesh_axis_names[1] + cand_b = _compute_candidate( + spec_b, [(mesh_axis_names[0], axis1), (mesh_axis_names[1], axis0)] + ) + + # Candidate C: all mesh axes -> axis0 only + spec_c = [None] * len(shape) + spec_c[axis0] = tuple(mesh_axis_names) + cand_c = _compute_candidate(spec_c, [(tuple(mesh_axis_names), axis0)]) + + # Candidate D: all mesh axes -> axis1 only + spec_d = [None] * len(shape) + spec_d[axis1] = tuple(mesh_axis_names) + cand_d = _compute_candidate(spec_d, [(tuple(mesh_axis_names), axis1)]) + + candidates = [cand_a, cand_b, cand_c, cand_d] + + # Filter out candidates where a sharded axis is too small + def _is_viable(spec_list): + for i, s in enumerate(spec_list): + if s is None: + continue + if isinstance(s, tuple): + ms = math.prod(mesh_axis_sizes[n] for n in s) + else: + ms = mesh_axis_sizes[s] + if shape[i] < ms * _MIN_ELEMS_PER_DEVICE: + return False + return True + + viable = [(w, sp, pp) for w, sp, pp in candidates if _is_viable(sp)] + if not viable: + # All too small — replicate + spec = [None] * len(shape) + partition_spec = jax.sharding.PartitionSpec(*spec) + return jax.sharding.NamedSharding(mesh, partition_spec), tuple(shape) + + best_waste, best_spec, best_padded = min(viable, key=lambda x: x[0]) + spec = best_spec + padded_shape = list(best_padded) + + if best_waste > 0: + warnings.warn( + f"Shape {shape} requires padding for sharding, padded to" + f" {tuple(padded_shape)}" + ) + + partition_spec = jax.sharding.PartitionSpec(*spec) + + return jax.sharding.NamedSharding(mesh, partition_spec), tuple(padded_shape) + + +def jax_jit_lower_compile(func: Callable, *args, **kwargs) -> Callable: + # in_shardings = tuple( + # arg.sharding if isinstance(arg, jax.ShapeDtypeStruct) and getattr(arg, "sharding", None) is not None + # else None + # for arg in args + # ) + # if any(s is not None for s in in_shardings): + # assert + # return jax.jit(func, in_shardings=in_shardings).lower(*args).compile() + return jax.jit(func).lower(*args).compile() + + +def store_jax_kernel(func: Callable, *args, **kwargs) -> None: + name = kwargs.get("name") + path = os.path.join(_serialized_jax_kernel_dir, f"{name}.jax") + + serialized_jax_kernel = export.export(jax.jit(func))(*args).serialize() + with open(path, "wb") as f: + f.write(serialized_jax_kernel) + print(f"stored jax kernel to {path}") + + +def load_jax_kernel( + name: str, alternative_callable: Optional[Callable] = None +) -> Optional[Callable]: + path = os.path.join(_serialized_jax_kernel_dir, f"{name}.jax") + if not os.path.exists(path): + return alternative_callable + with open(path, "rb") as f: + serialized_jax_kernel = export.deserialize(bytearray(f.read())) + print(f"loaded jax kernel from {path}") + return serialized_jax_kernel.call + + +def store_jax_executable(func: Callable, *args, **kwargs) -> None: + return store_jax_kernel(func, *args, **kwargs) + + +def load_jax_executable(name: str) -> Optional[Callable]: + return load_jax_kernel(name) + + +# ============================================================================= +# Hashing utilities +# ============================================================================= + +# 62-character alphabet: digits + lowercase + uppercase +_HASH_CHARS = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" +_HASH_BASE = len(_HASH_CHARS) # 62 + + +@lru_cache(maxsize=1) +def _get_hash_length() -> int: + """Read hash_length from configurations.toml, defaulting to 16.""" + import toml + + config_path = os.path.join( + os.path.dirname(os.path.abspath(__file__)), "configurations.toml" + ) + with open(config_path, "r", encoding="utf-8") as f: + config = toml.load(f) + return int(config.get("hash_length", 16)) + + +def _serialize(arg: Any) -> str: + """Produce a deterministic, type-tagged string for any common Python value. + + Handles arbitrarily nested lists and tuples recursively. + """ + if isinstance(arg, bool): + return f"bool:{arg}" + if isinstance(arg, int): + return f"int:{arg}" + if isinstance(arg, float): + return f"float:{repr(arg)}" + if isinstance(arg, str): + # embed length so "ab","c" != "a","bc" + return f"str{len(arg)}:{arg}" + if isinstance(arg, bytes): + return f"bytes:{arg.hex()}" + if isinstance(arg, (list, tuple)): + tag = "list" if isinstance(arg, list) else "tuple" + inner = ",".join(_serialize(x) for x in arg) + return f"{tag}[{len(arg)}:{inner}]" + if isinstance(arg, dict): + items = ";".join( + f"{_serialize(k)}->{_serialize(v)}" + for k, v in sorted(arg.items(), key=lambda kv: repr(kv[0])) + ) + return f"dict{{{len(arg)}:{items}}}" + # Fallback for other types (e.g. numpy scalars, custom objects) + return f"{type(arg).__name__}:{repr(arg)}" + + +def hash_args(*args: Any) -> str: + """Hash any number of Python values into a fixed-length alphanumeric string. + + The output length is controlled by ``hash_length`` in ``configurations.toml``. + Characters are drawn from ``[0-9a-zA-Z]`` (base-62), giving 62^length possible + values — e.g. 16 characters yield ~4.7 × 10²⁸ distinct hashes. The result is + safe to embed directly in a file name. + + Supported argument types: ``int``, ``float``, ``bool``, ``str``, ``bytes``, + ``list``, ``tuple`` (arbitrarily nested), ``dict``, and any type with a + stable ``repr``. + + Args: + *args: Values to hash. + + Returns: + A fixed-length alphanumeric hash string. + + Example: + >>> hash_args(42, "hello", [1, 2, 3]) + 'aB3x9Kp2mNqR7tYz' + """ + length = _hash_length + payload = "|".join(_serialize(a) for a in args) + digest = hashlib.blake2b(payload.encode("utf-8")).digest() # 64 bytes + + # Base-62 encode the big-endian integer + num = int.from_bytes(digest, "big") + chars: list[str] = [] + while num: + chars.append(_HASH_CHARS[num % _HASH_BASE]) + num //= _HASH_BASE + + # blake2b gives ~86 base-62 digits from 64 bytes — always enough for any + # reasonable hash_length without padding + return "".join(reversed(chars))[:length] + + +# ============================================================================= +# Sharding utilities +# ============================================================================= +def create_sharding(): + """Create default batch and replicated shardings for the current device mesh.""" + available_devices = jax.devices() + if not available_devices: + raise RuntimeError("No devices available for sharding test.") + if len(available_devices) == 8: + mesh_shape = (2, 4) + elif len(available_devices) == 4: + mesh_shape = (2, 2) + elif len(available_devices) == 2: + mesh_shape = (2, 1) + else: + mesh_shape = (1, 1) + + mesh = jax.make_mesh(mesh_shape, ("x", "y")) + shd.set_mesh(mesh) + + partition_spec = jax.sharding.PartitionSpec + return mesh, partition_spec