From cf6a7a9b36f15b1868e6c481fc18db8b4f11eda9 Mon Sep 17 00:00:00 2001 From: The Jaxite Team Date: Wed, 9 Oct 2024 13:19:36 -0700 Subject: [PATCH] Update SR_onboarding tutorial PiperOrigin-RevId: 684148769 --- jaxite/jaxite_lib/matrix_utils.py | 463 ++++++++++++++++++++++++ jaxite/jaxite_lib/matrix_utils_test.py | 30 ++ jaxite/jaxite_lib/zkp/elliptic_curve.py | 372 +++++++++++++++++++ jaxite/jaxite_lib/zkp/finite_field.py | 286 +++++++++++++++ jaxite/jaxite_lib/zkp/hp_int.py | 271 ++++++++++++++ jaxite/jaxite_lib/zkp/pippenger.py | 144 ++++++++ 6 files changed, 1566 insertions(+) create mode 100644 jaxite/jaxite_lib/zkp/elliptic_curve.py create mode 100644 jaxite/jaxite_lib/zkp/finite_field.py create mode 100644 jaxite/jaxite_lib/zkp/hp_int.py create mode 100644 jaxite/jaxite_lib/zkp/pippenger.py diff --git a/jaxite/jaxite_lib/matrix_utils.py b/jaxite/jaxite_lib/matrix_utils.py index 543d3de..29e59db 100644 --- a/jaxite/jaxite_lib/matrix_utils.py +++ b/jaxite/jaxite_lib/matrix_utils.py @@ -98,6 +98,469 @@ def i32_as_u8_matmul(lhs: jnp.ndarray, rhs: jnp.ndarray) -> jnp.ndarray: return jnp.sum(i8_products << shift_factors, axis=(1, 2)) +def hpmatmul_conv_adapt_outer_product(x: jax.Array, y: jax.Array) -> jax.Array: + """Interleaved u8 matmul with fused einsum kernels. + + Args: + x: The left matrix. + y: The right matrix. + + Returns: + The result matrix. + """ + assert x.dtype == jnp.uint32 + assert y.dtype == jnp.uint32 + + lhs: jax.Array = int32_to_int8_arr(x) + rhs: jax.Array = int32_to_int8_arr(y) + + i8_products = jnp.einsum( + "mnp,nkq->mkpq", + lhs, + rhs, + preferred_element_type=jnp.int32, + ) + shift_factors = jnp.array( + [ + [0, 8, 16, 24], + [8, 16, 24, 32], + [16, 24, 32, 40], + [24, 32, 40, 48], + ], + dtype=jnp.uint32, + ) + return jnp.sum(i8_products.astype(jnp.uint64) << shift_factors, axis=(2, 3)) + + +@jax.jit +def hpmatmul_conv_adapt_conv(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: + """Interleaved u8 matmul with padded 1D convolution. + + (reformulated as 2D convolution) + + How do we map workload into Conv? + + Left Mat Right Mat -> + <- in channel (C)-> <-Output Channel(O)-> -> + - xxxxxxxxxxxxxxxxxx - xxxxxxxxxxxxxxxxxx -> - + ^ xxxxxxxxxxxxxxxxxx ^ xxxxxxxxxxxxxxxxxx -> ^ + | xxxxxxxxxxxxxxxxxx | xxxxxxxxxxxxxxxxxx -> | + batch xxxxxxxxxxxxxxxxxx In xxxxxxxxxxxxxxxxxx -> batch + (N) xxxxxxxxxxxxxxxxxx channel xxxxxxxxxxxxxxxxxx -> (N) + | xxxxxxxxxxxxxxxxxx (I) xxxxxxxxxxxxxxxxxx -> | + v xxxxxxxxxxxxxxxxxx v xxxxxxxxxxxxxxxxxx -> v + - xxxxxxxxxxxxxxxxxx - xxxxxxxxxxxxxxxxxx -> - + + Result Mat + <-Output channel(C)-> + xxxxxxxxxxxxxxxxxx + xxxxxxxxxxxxxxxxxx + xxxxxxxxxxxxxxxxxx + xxxxxxxxxxxxxxxxxx + xxxxxxxxxxxxxxxxxx + xxxxxxxxxxxxxxxxxx + xxxxxxxxxxxxxxxxxx + xxxxxxxxxxxxxxxxxx + + Each x in the above example is a 1DConv + + <---W---> <---W---> <---W---> + xxxxxxxxx @ xxxxxxxxx = xxxxxxxxx + + Args: + x: The left matrix. + y: The right matrix. + + Returns: + The result matrix. + """ + + assert x.dtype == jnp.uint32 + assert y.dtype == jnp.uint32 + + lhs: jax.Array = jax.lax.bitcast_convert_type(x, new_dtype=jnp.uint8) # bnmp + rhs: jax.Array = jax.lax.bitcast_convert_type(y, new_dtype=jnp.uint8) # nk1q + # https://github.com/google/jax/issues/11483 + rhs = jax.lax.rev(rhs, [2]) + # rhs = jlax.rev(rhs, dimensions=[3]) + + # basically an einsum of "mnp,nkq->mk(p+q)" but jax einsum doesn't support + # convolution yet + u8_products = jax.lax.conv_general_dilated( + lhs, + rhs, + window_strides=(1,), + padding=((3, 3),), + dimension_numbers=("NCW", "IOW", "NCW"), + preferred_element_type=jnp.uint32, + ) + + shift_factors = jnp.array([0, 8, 16, 24, 32, 40, 48], dtype=jnp.uint32) + return jnp.sum(u8_products.astype(jnp.uint64) << shift_factors, axis=(2,)) + + +def chunk_decomposition(x, chunkwidth=8): + """Precision-level data conversion. + + Args: + x: The input data. + chunkwidth: The chunkwidth. + + Returns: + The decomposed data. + """ + dtype = jnp.uint8 + if chunkwidth == 16: + dtype = jnp.uint16 + elif chunkwidth == 32: + dtype = jnp.uint32 + + elements = [] + mask = (1 << chunkwidth) - 1 + # Mask to extract the lower bits (e.g., 32 bits -> 0xFFFFFFFF) + + # Extract each element from the integer + while x > 0: + elements.append(x & mask) # Extract the lower bits + x >>= chunkwidth # Shift to remove the extracted bits + + # Convert the list to a JAX array + return jnp.array(elements, dtype=dtype) + + +def rechunkify_after_chunkwise_add(arr_a, chunkwidth): + """Recalculate chunks after chunkwise addition to handle carry. + + Context: + We divide a single high-precision data into multiple low-precision + chunks. In JAX, each chunk is represented as a single element of + built-in data type, such as jnp.uint16. + + # Assuming the original input data type is jnp.uint16. # + During the construction, value of each chunk comes from a specific 16 + bits of original input data. Such decomposition has a nice property + that a direct concatenation of all chunks would restore the original + data. + However, chunk-wise calculation (such as chunk-wise addition and + multiplication) might increase the bitwidth of each chunk. E.g. addition + of two 16 bits could potential generates a single 17-bit result, and + multiplication of two 16 bits could lead to a 32-bit result, leading + to bitwidth overflow. When bitwidth overflows happens, an expensive + shift_add is required to merge all chunk-wise results back to the + original high-precision data. Therefore, this API performs carry + propogation to rescale each chunk back to the original bitwidth + (e.g. 16 bits.) + + Input: each chunk is longer than chunkwidth but smaller than + 2*chunkwidth (assuming the bitwidth of each chunk is increased by 1 + after chunk-wise addition.) + Assumption: The bitwidth of each element in arr_a is <= 2 * chunkwidth. + + Args: + arr_a: The input array. + chunkwidth: The chunkwidth. + + Returns: + The recalculated chunks. + """ + dtype_double_length = jnp.uint16 + if chunkwidth == 16: + dtype_double_length = jnp.uint32 + elif chunkwidth == 32: + dtype_double_length = jnp.uint64 + + # assert isinstance(arr_a, jnp.array) + # assume the precision of partial sum is <= 2 * precision of input value. + bitmask = (1 << chunkwidth) - 1 + + # # Data Type Illustration + # We need to accumulate these data + # - Could directly perform bitwidth concatenation to generate the final + # result if there is no overlap across each partial sum + # LSB MSB + # |-----------------> bit + # | a0 + # | ==-- + # | a1 + # | ==-- + # | a2 + # | ==-- + # | a3 + # v ==-- + + # whole a0 a1 a2 a3 + # precision ==-- ==-- ==-- ==-- + + # lower a0 a1 a2 a3 + # half == == == == + + # upper a0 a1 a2 a3 + # half -- -- -- -- + + # # Chunk Splitting -> upper and lower half + # padding to align + # lower a0 a1 a2 a3 0 + # half == == == == == + + # upper 0 a0 a1 a2 a3 + # half -- -- -- -- -- + + # # Vectorized Accumulation + # lower a0 a1 a2 a3 0 + # half == == == == == + # + + + + + + # upper 0 a0 a1 a2 a3 + # half -- -- -- -- -- + + # -> result b0 b1 b2 b3 b4 + # -- 1/0-- 1/0-- 1/0-- -- + # (b1 and b4 does not have carry for sure.) + + # Each result chunk might have one more bit for carry. + # Perform one more chunk decomposition and accumulation. + + # # One more Chunk Splitting for partial sum "b" to take care of carry bit. + # carry b0 b1 b2 b3 b4 + # 0 1/0 1/0 1/0 0 + + # carry b4 b0 b1 b2 b3 + # right 0 0 1/0 1/0 1/0 + # shift + # (wrap around rotation, b4 is always zero so will be correct) + # + + + + + + # lower b0 b1 b2 b3 b4 + # half -- -- -- -- -- + # = = = = = + # c0 c1 c2 c3 c4 + # -> -- -- -- -- 1/0-- + # (! c4 might overflow, need one more chunk decomposition) + + # c0 c1 c2 c3 c4 c5 + # -> -- -- -- -- -- 1/0 + + # Chunk Splitting -> upper and lower half + arr_a_lower_half = jnp.bitwise_and(arr_a, bitmask) + arr_a_upper_half = jnp.right_shift(arr_a, chunkwidth) + + # Padding to align + arr_a_lower_half_pad = jnp.pad(arr_a_lower_half, (0, 1)) + arr_a_upper_half_pad = jnp.pad(arr_a_upper_half, (1, 0)) + + # Vectorized Accumulation + arr_b = jnp.add( + arr_a_lower_half_pad.astype(dtype_double_length), + arr_a_upper_half_pad.astype(dtype_double_length), + ) + + arr_b_lower_half = jnp.bitwise_and(arr_b, bitmask) + arr_b_carry = jnp.right_shift(arr_b, chunkwidth) + arr_b_carry = jnp.roll(arr_b_carry, 1) + + # Vectorized Accumulation + arr_c = jnp.add(arr_b_lower_half, arr_b_carry) + + # break top chunk into upper and lower to avoid overflow. + arr_c = jnp.pad(arr_c, (0, 1)) + arr_c = arr_c.at[-1].set(jnp.right_shift(arr_c[-2], chunkwidth)) + arr_c = arr_c.at[-2].set(jnp.bitwise_and(arr_c[-2], bitmask)) + + return arr_c + + +def smul_as_dense_gemv_bag( + x, total_in_precision=32, chunkwidth=8, q=4294967291 +): + """This is the implementation of BAG; Major improvement to achieve dense matrix. + + Args: + x: The input matrix. + total_in_precision: The total precision of the input matrix. + chunkwidth: The chunkwidth. + q: The modulus. + + Returns: + The dense matrix. + + Steps: + 1. break x into [x0, x1, x2, x3] + 2. reform [x0, x1, x2, x3] into the output + [ + x0 r00 r00 r00 # 2^0 + x1 x0+r01 r01 r01 # 2^8 + x2 x1+r02 x0+r02 r02 # 2^16 + x3 x2+r03 x1+r03 x0+r03 # 2^24 + ] + """ + dtype = jnp.uint8 + dtype_double_length = jnp.uint16 + chunk_upper_bound = (1 << 8) - 1 + if chunkwidth == 16: + dtype = jnp.uint16 + dtype_double_length = jnp.uint32 + chunk_upper_bound = (1 << 16) - 1 + elif chunkwidth == 32: + dtype = jnp.uint32 + dtype_double_length = jnp.uint64 + chunk_upper_bound = (1 << 32) - 1 + + total_chunk_num = int(jnp.ceil(total_in_precision / chunkwidth)) + + # the number of row in left matrix + height = total_chunk_num + total_chunk_num - 1 + x_dtype = chunk_decomposition(int(x), chunkwidth) + x_dense = jnp.zeros( + (total_chunk_num + total_chunk_num - 1, total_chunk_num), + dtype=dtype_double_length, + ) + for j in range(total_chunk_num): + upper_idx = min(total_chunk_num, x_dtype.shape[0] + j) + x_dense = x_dense.at[j:upper_idx, j].set(x_dtype[0 : upper_idx - j]) + + # [ + # x0 # 2^0 + # x1 x0 # 2^8 + # x2 x1 x0 # 2^16 + # x3 x2 x1 x0 # 2^24 + # ----------- + # x3 x2 x1 # 2^32 iterate all elements in the bottom block + # x3 x2 # 2^40 + # x3 # 2^48 + # ] + + # Perform BAG to the following block of the matrix + # j 2 1 0 + # x3 x2 x1 # 2^32 i=0 + # x3 x2 # 2^40 i=1 + # x3 # 2^48 i=2 + for i in range(x_dtype.shape[0] - 1): + for j in range(x_dtype.shape[0] - 1 - i): + basis = (total_chunk_num + i) * chunkwidth + projected_data = (int(x_dtype[i + j + 1]) << basis) % q + r = chunk_decomposition(projected_data, chunkwidth).astype( + dtype_double_length + ) + + x_dense = x_dense.at[: len(r), total_chunk_num - 1 - j].set( + jnp.add(r, x_dense[: len(r), total_chunk_num - 1 - j]) + ) + + for j in range(x_dtype.shape[0] - 1): + # Iterate over different columns + if not jnp.all(x_dense[:, total_chunk_num - 1 - j] <= chunk_upper_bound): + arr_new_chunkified = rechunkify_after_chunkwise_add( + x_dense[:, total_chunk_num - 1 - j], chunkwidth + ) + x_dense = x_dense.at[:, total_chunk_num - 1 - j].set( + arr_new_chunkified[:height] + ) + + while not jnp.all(x_dense <= chunk_upper_bound): + for j in range(total_chunk_num - 1): + # Iterate over different columns + if not jnp.all(x_dense[:, total_chunk_num - 1 - j] <= chunk_upper_bound): + arr_new_chunkified = rechunkify_after_chunkwise_add( + x_dense[:, total_chunk_num - 1 - j], chunkwidth + ) + x_dense = x_dense.at[:, total_chunk_num - 1 - j].set( + arr_new_chunkified[:height] + ) + + # j 2 1 0 + # x3 x2 x1 # 2^32 i=0 + # x3 x2 # 2^40 i=1 + # x3 # 2^48 i=2 + + for i in range(total_chunk_num - 1): + if x_dense[total_chunk_num + i, total_chunk_num - 1 - j] > 0: + basis = (total_chunk_num + i) * chunkwidth + projected_data = ( + int(x_dense[total_chunk_num + i, total_chunk_num - 1 - j]) + << basis + ) % q + r = chunk_decomposition(projected_data, chunkwidth).astype( + dtype_double_length + ) + x_dense = x_dense.at[: len(r), total_chunk_num - 1 - j].set( + jnp.add(r, x_dense[: len(r), total_chunk_num - 1 - j]) + ) + + x_dense = x_dense.at[ + total_chunk_num + i, total_chunk_num - 1 - j + ].set(0) + + return x_dense[:total_chunk_num, :].astype(dtype) + + +def hpmatmul_offline_compile_bag(mat_a, q): + """Convert the input (m,n) matrix into (m,n,p,q), i.e. + + replace each element in the original matrix by a p*q matrix (p==q). + + Args: + mat_a: The input matrix. + q: The modulus. + + Returns: + The converted matrix. + """ + assert mat_a.dtype == jnp.uint32 # This version is defined for 32-bit input. + if isinstance(mat_a, list): + m, n = len(mat_a), len(mat_a[0]) + else: + m, n = mat_a.shape[0], mat_a.shape[1] + total_in_precision = 32 + chunkwidth = 8 + # Convert left-side matrix + total_chunk_num = int(jnp.ceil(total_in_precision / chunkwidth)) + + left_mat = jnp.zeros( + (m, n, total_chunk_num, total_chunk_num), dtype=jnp.uint16 + ) + + if isinstance(mat_a, list): + for i in range(m): + for k in range(n): + left_mat = left_mat.at[i, k, :, :].set( + smul_as_dense_gemv_bag( + mat_a[i][k], + total_in_precision=total_in_precision, + chunkwidth=chunkwidth, + q=q, + ) + ) + else: + for i in range(m): + for k in range(n): + left_mat = left_mat.at[i, k, :, :].set( + smul_as_dense_gemv_bag( + mat_a[i, k], + total_in_precision=total_in_precision, + chunkwidth=chunkwidth, + q=q, + ) + ) + + return left_mat + + +@jax.jit +def hpmatmul_bag_adapt(lhs: jax.Array, y: jax.Array): + """Input (m, n) Left Matrix -> (m, n, p, q) Left Matrix, where each element in the original (m, n) matrix is replaced by a (p, q) matrix.""" + # assert lhs.dtype == jnp.uint8 + # assert y.dtype == jnp.uint32 + + rhs: jax.Array = jax.lax.bitcast_convert_type(y, new_dtype=jnp.uint8) + + i8_products = jnp.einsum( + "mnpq,nkq->mkp", + lhs, + rhs, + preferred_element_type=jnp.int32, + ) + shift_factors = jnp.array([0, 8, 16, 24], dtype=jnp.uint32) + return jnp.sum(i8_products.astype(jnp.uint64) << shift_factors, axis=(2,)) + + # https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.tril.html # For n=3, generates the following # [[ 1 -1 -1] diff --git a/jaxite/jaxite_lib/matrix_utils_test.py b/jaxite/jaxite_lib/matrix_utils_test.py index c897829..039b771 100644 --- a/jaxite/jaxite_lib/matrix_utils_test.py +++ b/jaxite/jaxite_lib/matrix_utils_test.py @@ -3,6 +3,7 @@ import hypothesis from hypothesis import strategies +import jax import jax.numpy as jnp from jaxite.jaxite_lib import jax_helpers from jaxite.jaxite_lib import matrix_utils @@ -269,6 +270,35 @@ def test_scale_by_x_power_n_minus_1(self, power, poly): ) np.testing.assert_array_equal(expected, actual) + def test_hpmatmul_Conv_Adapt_Conv(self): + """Test the correctness of the Conv-Adapt-Conv algorithm.""" + key = jax.random.key(0) + mat_a_shape = (16, 16) + mat_b_shape = (mat_a_shape[1], 16) + upper_value = (1 << 31) - 1 + modulus_32 = 4294967291 + mat_a = jax.random.randint( + key, mat_a_shape, 0, upper_value, dtype=jnp.uint32 + ) + mat_b = jax.random.randint( + key, mat_b_shape, 0, upper_value, dtype=jnp.uint32 + ) + mat_result_outerproduct = matrix_utils.hpmatmul_conv_adapt_outer_product( + mat_a, mat_b + ) + compiled_mat_a = matrix_utils.hpmatmul_offline_compile_bag( + mat_a, modulus_32 + ) + mat_result_bag = matrix_utils.hpmatmul_bag_adapt(compiled_mat_a, mat_b) + mat_result_conv = matrix_utils.hpmatmul_conv_adapt_conv(mat_a, mat_b) + if np.testing.assert_array_equal(mat_result_outerproduct, mat_result_conv): + if np.testing.assert_array_equal(mat_result_bag, mat_result_outerproduct): + print('pass') + else: + print('mat_result_bag and mat_result_outerproduct do not match') + else: + print('mat_result_outerproduct and mat_result_conv do not match') + if __name__ == '__main__': absltest.main() diff --git a/jaxite/jaxite_lib/zkp/elliptic_curve.py b/jaxite/jaxite_lib/zkp/elliptic_curve.py new file mode 100644 index 0000000..e6db9d5 --- /dev/null +++ b/jaxite/jaxite_lib/zkp/elliptic_curve.py @@ -0,0 +1,372 @@ +"""elliptic_curve class implementation.""" + +from abc import ABC, abstractmethod +from collections.abc import Sequence +import copy +from enum import Enum, auto +from typing import List + +from absl import app +from traitlets import Bool + +USE_GMP = True +USE_BARRETT = True +if USE_GMP: + from hp_int import GMPHPint as HPint +else: + from hp_int import HPint as HPint + +if USE_BARRETT: + from finite_field import FiniteFieldElementBarrett as FieldEle +else: + from finite_field import FiniteFieldElement as FieldEle + + +class CoordinateSystemType(Enum): + """Enum to represent different types of coordinate systems for elliptic curves.""" + + NONE = auto() + WEIERSTRASS_AFFINE = auto() + WEIERSTRASS_PROJECTIVE = auto() + + +class ECCPoint: + + def __init__( + self, + coordinates: List, + coordinate_system: 'EllipticCurveCoordinateSystem', + zero: Bool = False, + ) -> None: + self.coordinate_system = coordinate_system + self.zero = zero + if self.zero: + self.coordinates = None + else: + self.coordinates = self.coordinate_system.generate_formal_coordinates( + coordinates + ) + self.type = coordinate_system.get_type() + + def __getitem__(self, index: int) -> FieldEle: + """Allows access to elements in self.coordinates using index.""" + return self.coordinates[index] + + def __setitem__(self, index: int, value: FieldEle) -> None: + """Allows modification of elements in self.coordinates using index.""" + self.coordinates[index] = value + + def __eq__(self, other: 'ECCPoint') -> bool: + if not isinstance(other, ECCPoint): + return NotImplemented + + # return (self.coordinates == other.coordinates and + # self.coordinate_system == other.coordinate_system) + return self.coordinates == other.coordinates + + def __add__(self, other: 'ECCPoint') -> 'ECCPoint': + return self.coordinate_system.point_add(self, other) + + def __lshift__(self, shift) -> 'ECCPoint': + return self.coordinate_system.point_lshift(self, shift) + + def is_zero(self): + return self.zero + + def get_type(self) -> CoordinateSystemType: + return self.type + + def set_type(self, cs_type: CoordinateSystemType): + self.type = cs_type + + def set_coordinate_system( + self, coordinate_system: 'EllipticCurveCoordinateSystem' + ): + self.coordinate_system = coordinate_system + self.type = coordinate_system.get_type() + + def append(self, coordinate: FieldEle): + self.coordinates.append(coordinate) + + def copy(self): + obj = copy.copy(self) + if not self.is_zero(): + obj.coordinates = self.coordinate_system.generate_formal_coordinates( + self.coordinates + ) + return obj + + def convert_to_affine(self): + self = self.coordinate_system.convert_to_affine(self) + + def __str__(self) -> str: + if self.is_zero(): + return 'Point, O' + ret = 'Point, ' + for i in range(len(self.coordinates)): + ret += self.coordinates[i].hex_value_str() + ', ' + return ret + + +class EllipticCurveCoordinateSystem(ABC): + + def __init__(self, config: dict) -> None: + super().__init__() + self.config = config + self.ff0: FieldEle = FieldEle(0, config['prime']) + self.ff1: FieldEle = FieldEle(1, config['prime']) + self.type = CoordinateSystemType.NONE + + @abstractmethod + def generate_formal_coordinates(self, coordinates: List) -> List[FieldEle]: + pass + + def get_type(self): + return self.type + + def generate_point(self, coordinates: List, zero: bool = False) -> ECCPoint: + return ECCPoint(coordinates, self, zero) + + @abstractmethod + def point_add(self, pointA: ECCPoint, pointB: ECCPoint): + pass + + @abstractmethod + def point_lshift(self, pointA: ECCPoint, index): + pass + + @abstractmethod + def convert_to_affine(self, pointA: ECCPoint): + pass + + +class ECCWeierstrassSystem(EllipticCurveCoordinateSystem): + + def __init__(self, config: dict) -> None: + super().__init__(config) + self.prime = BigInt(config['prime']) + self.order = BigInt(config['order']) + + self.generator: List[FieldEle] = [] + for coordinate in config['generator']: + element = self.ff0.copy(coordinate) + self.generator.append(element) + + self.a = self.ff0.copy(config['a']) + self.b = self.ff0.copy(config['b']) + + def generate_formal_coordinates(self, coordinates: List) -> List[FieldEle]: + formal_coordinate: List[FieldEle] = [] + for coordinate in coordinates: + if isinstance(coordinate, FieldEle): + formal_coordinate.append(coordinate) + else: + formal_coordinate.append(self.ff0.copy(coordinate)) + return formal_coordinate + + +class ECCWeierstrassAffine(ECCWeierstrassSystem): + + def __init__(self, config: dict) -> None: + super().__init__(config) + self.type = CoordinateSystemType.WEIERSTRASS_AFFINE + + def add_general(self, pointA: ECCPoint, pointB: ECCPoint): + slope = (pointB[1] - pointA[1]) / (pointB[0] - pointA[0]) + cx = (slope * slope) - pointA[0] - pointB[0] + cy = slope * (pointA[0] - cx) - pointA[1] + return ECCPoint([cx, cy], self) + + def double_general(self, pointA: ECCPoint): + x1 = pointA[0] + y1 = pointA[1] + slope = (x1 * x1 * 3 + self.a) / (y1 * 2) + cx = slope * slope - x1 - x1 + cy = slope * (x1 - cx) - y1 + return ECCPoint([cx, cy], self) + + def point_add(self, pointA: ECCPoint, pointB: ECCPoint): + if pointB.is_zero(): + return pointA.copy() + elif pointA.is_zero(): + return pointB.copy() + + if pointA == pointB: + result = self.double_general(pointA) + else: + result = self.add_general(pointA, pointB) + + return result + + def point_lshift(self, pointA: ECCPoint, shift: int): + if pointA.is_zero(): + return pointA.copy() + + for i in range(shift): + pointA = self.double_general(pointA) + return pointA + + def convert_to_affine(self, pointA: ECCPoint) -> ECCPoint: + return pointA + + +class ECCWeierstrassProjective(ECCWeierstrassSystem): + + def __init__(self, config: dict) -> None: + super().__init__(config) + self.type = CoordinateSystemType.WEIERSTRASS_PROJECTIVE + + def generate_formal_coordinates(self, coordinates: List) -> List[FieldEle]: + coordinate_length = len(coordinates) + assert coordinate_length == 2 or coordinate_length == 3 + formal_coordinates: List[FieldEle] = [] + for coordinate in coordinates: + if isinstance(coordinate, FieldEle): + formal_coordinates.append(coordinate) + else: + formal_coordinates.append(self.ff0.copy(coordinate)) + if coordinate_length == 2: + formal_coordinates.append(self.ff1.copy()) + assert len(formal_coordinates) == 3 + return formal_coordinates + + def add_z2_eq_1(self, pointA: ECCPoint, pointB: ECCPoint): + if pointB[2] == self.ff1: + X1, Y1, Z1 = pointA + X2, Y2, Z2 = pointB + elif pointA[2] == self.ff1: + X1, Y1, Z1 = pointB + X2, Y2, Z2 = pointA + else: + raise NotImplementedError + + u = Y2 * Z1 - Y1 + uu = u * u + v = X2 * Z1 - X1 + vv = v * v + vvv = v * vv + R = vv * X1 + A = uu * Z1 - vvv - (R + R) + X3 = v * A + Y3 = u * (R - A) - vvv * Y1 + Z3 = vvv * Z1 + return ECCPoint([X3, Y3, Z3], self) + + def add_general(self, pointA: ECCPoint, pointB: ECCPoint): + X1, Y1, Z1 = pointA + X2, Y2, Z2 = pointB + + b3 = self.b * self.ff0.copy(3) + a = self.a + + # Perform the operations + t0 = X1 * X2 + t1 = Y1 * Y2 + t2 = Z1 * Z2 + t3 = (X1 + Y1) * (X2 + Y2) + t4 = t0 + t1 + t3 = t3 - t4 + t4 = (X1 + Z1) * (X2 + Z2) + t5 = t0 + t2 + t4 = t4 - t5 + t5 = (Y1 + Z1) * (Y2 + Z2) + X3 = t1 + t2 + t5 = t5 - X3 + Z3 = a * t4 + X3 = b3 * t2 + Z3 = X3 + Z3 + X3 = t1 - Z3 + Z3 = t1 + Z3 + Y3 = X3 * Z3 + t1 = t0 + t0 + t1 = t1 + t0 + t2 = a * t2 + t4 = b3 * t4 + t1 = t1 + t2 + t2 = t0 - t2 + t2 = a * t2 + t4 = t4 + t2 + t0 = t1 * t4 + Y3 = Y3 + t0 + t0 = t5 * t4 + X3 = t3 * X3 + X3 = X3 - t0 + t0 = t3 * t1 + Z3 = t5 * Z3 + Z3 = Z3 + t0 + + return ECCPoint([X3, Y3, Z3], self) + + def double_general(self, pointA: ECCPoint): + X1, Y1, Z1 = pointA + a = self.a + ff2 = self.ff0.copy(2) + + # Perform the operations based on the pseudocode + XX = X1 * X1 + ZZ = Z1 * Z1 + w = a * ZZ + XX * self.ff0.copy(3) + s = Y1 * Z1 * ff2 + ss = s * s + sss = s * ss + R = Y1 * s + RR = R * R + B = (X1 + R) * (X1 + R) - XX - RR + h = w * w - B * ff2 + X3 = h * s + Y3 = w * (B - h) - RR * ff2 + Z3 = sss + + return ECCPoint([X3, Y3, Z3], self) + + def point_add(self, pointA: ECCPoint, pointB: ECCPoint) -> ECCPoint: + if pointB.is_zero(): + return pointA.copy() + elif pointA.is_zero(): + return pointB.copy() + + if pointA == pointB: + result = self.double_general(pointA) + elif pointA[2] == self.ff1 or pointB[2] == self.ff1: + result = self.add_z2_eq_1(pointA, pointB) + else: + result = self.add_general(pointA, pointB) + return result + + def point_lshift(self, pointA: ECCPoint, shift: int): + if pointA.is_zero(): + return pointA.copy() + + for i in range(shift): + pointA = self.double_general(pointA) + return pointA + + def convert_from_affine(self, pointA: ECCPoint) -> ECCPoint: + assert ( + pointA.coordinate_system.get_type() + == CoordinateSystemType.WEIERSTRASS_AFFINE + ) + new_point = pointA.copy() + new_point.append(self.ff1) + new_point.set_coordinate_system(self) + return new_point + + def convert_to_affine(self, pointA: ECCPoint) -> ECCPoint: + assert pointA.get_type() == self.type + new_point = pointA.copy() + new_point.coordinates.clear() + Z_invert = self.ff1 / pointA[2] + new_point.append(pointA[0] * Z_invert) + new_point.append(pointA[1] * Z_invert) + new_point.append(self.ff1) + new_point.set_type(CoordinateSystemType.WEIERSTRASS_AFFINE) + return new_point + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError('Too many command-line arguments.') + + +if __name__ == '__main__': + app.run(main) diff --git a/jaxite/jaxite_lib/zkp/finite_field.py b/jaxite/jaxite_lib/zkp/finite_field.py new file mode 100644 index 0000000..8a749a3 --- /dev/null +++ b/jaxite/jaxite_lib/zkp/finite_field.py @@ -0,0 +1,286 @@ +"""Modular Reduction with finite_field Operations for ZKP.""" + +from collections.abc import Sequence +import copy +import math + +from absl import app + +# from config_file import USE_GMP +USE_GMP = True +if USE_GMP: + from hp_int import GMPHPint as HPint +else: + from hp_int import HPint as HPint + + +class FiniteFieldElement: + """Finite Field Element Operation Library.""" + + def __init__(self, value, prime, k=None): + if isinstance(value, HPint): + self.value = value + else: + self.value = HPint(value) + + if isinstance(prime, HPint): + self.prime = prime + else: + self.prime = HPint(prime) + + if value < 0 or value >= prime: + raise ValueError(f"Value {value} not in range 0 to {prime - 1}") + + def set_value(self, value): + """Set the value of the finite field element, with validation.""" + if isinstance(value, HPint): + new_value = value + else: + new_value = HPint(value) + + if new_value < HPint(0) or new_value >= self.prime: + raise ValueError(f"Value {new_value} not in range 0 to {self.prime - 1}") + + self.value = new_value + + def get_prime(self): + return copy.deepcopy(self.prime) + + def copy(self, value=None, transform=False): + """Create a deep copy of the current finite field element.""" + obj = copy.copy(self) + if value == None: + obj.value = HPint(self.value) + else: + obj.value = HPint(value) + return obj + + def __add__(self, other): + if self.prime != other.prime: + raise ValueError("Cannot add two numbers in different Fields") + result = (self.value + other.value) % self.prime + return FiniteFieldElement(result.value, self.prime.value) + + def __sub__(self, other): + if self.prime != other.prime: + raise ValueError("Cannot subtract two numbers in different Fields") + result = (self.value - other.value) % self.prime + return FiniteFieldElement(result.value, self.prime.value) + + def __mul__(self, other): + if self.prime != other.prime: + raise ValueError("Cannot multiply two numbers in different Fields") + result = (self.value * other.value) % self.prime + return FiniteFieldElement(result.value, self.prime.value) + + def __truediv__(self, other): + if self.prime != other.prime: + raise ValueError("Cannot divide two numbers in different Fields") + # Use Fermat's Little Theorem to find the inverse: a^(p-1) ≡ 1 (mod p) + inverse = other.value.__pow__(self.prime.value - 2, self.prime.value) + result = (self.value * inverse) % self.prime + return FiniteFieldElement(result.value, self.prime.value) + + def __pow__(self, exponent): + result = self.value.__pow__(exponent, self.prime.value) + return FiniteFieldElement(result.value, self.prime.value) + + def __eq__(self, other): + return self.value == other.value and self.prime == other.prime + + def __str__(self): + return f"FieldElement_{self.prime.value}({self.value.value})" + + def __repr__(self): + return ( + f"FiniteFieldElement(value={self.value.value}," + f" prime={self.prime.value})" + ) + + def __hex__(self): + return hex(int(self.value.value)) + + def hex_value_str(self) -> str: + return self.value.hex_value_str() + + +class FiniteFieldElementBarrett(FiniteFieldElement): + + def __init__(self, value, prime, k=None): + super().__init__(value, prime) + if k == None: + if isinstance(prime, HPint): + self.two_k = prime.ceil_log2() * 2 + else: + self.two_k = HPint(math.ceil(math.log2(prime))) * 2 + else: + self.two_k = HPint(2 * k) + self.mu = HPint(2) ** self.two_k / prime + + def barrett_reduction(self, x): + # q = (x * mu) >> 2k + q = (x * self.mu) >> self.two_k + # r = x - q * prime + r = x - q * self.prime + if r >= self.prime: + r -= self.prime + return r + + def __add__(self, other): + if self.prime != other.prime: + raise ValueError("Cannot add two numbers in different Fields") + result = self.value + other.value + if result > self.prime: + result -= self.prime + new_instance = self.copy() + new_instance.value = result + return new_instance + + def __sub__(self, other): + if self.prime != other.prime: + raise ValueError("Cannot subtract two numbers in different Fields") + if self.value < other.value: + result = self.value + self.prime - other.value + else: + result = self.value - other.value + new_instance = self.copy() + new_instance.value = result + return new_instance + + def __mul__(self, other): + if self.prime != other.prime: + raise ValueError("Cannot multiply two numbers in different Fields") + result = self.value * other.value + reduced_result = self.barrett_reduction(result) + new_instance = self.copy() + new_instance.value = reduced_result + return new_instance + + def __truediv__(self, other): + if self.prime != other.prime: + raise ValueError("Cannot divide two numbers in different Fields") + inverse = other.value.__pow__(self.prime.value - 2, self.prime.value) + result = self.value * inverse + reduced_result = self.barrett_reduction(result) + new_instance = self.copy() + new_instance.value = reduced_result + return new_instance + + +class FiniteFieldElementMontgomery(FiniteFieldElement): + + def __init__(self, value, prime, k=None): + super().__init__(value, prime, k) + if k == None: + if isinstance(prime, HPint): + self.k = prime.ceil_log2() + else: + self.k = HPint(math.ceil(math.log2(prime))) + else: + self.k = HPint(k) + + self.r = HPint(2) ** self.k + # self.r_inverse = (self.r ** (self.prime - 2)) % self.prime + self.r_inverse = self.r.__pow__(self.prime - 2, self.prime) + self.n_prime = (self.r * self.r_inverse - 1) / self.prime + self.r_mask = self.r - 1 + self.value = self.montgomeryize(self.value) + self.montgomeryized = True + self.one_bar = self.montgomeryize(HPint(1)) + + def montgomery_reduction(self, x): + m = ((x & self.r_mask) * self.n_prime) & self.r_mask + u = (x + m * self.prime) >> self.k + if u >= self.prime: + u -= self.prime + return u + + def montgomeryize(self, x): + x_bar = (x * self.r) % self.prime + return x_bar + + def de_montgomeryize(self, x_bar): + x = self.montgomery_reduction(x_bar) + return x + + def change_montgomery_form(self): + if self.montgomeryized: + self.value = self.de_montgomeryize(self.value) + self.montgomeryized = False + else: + self.value = self.montgomeryize(self.value) + self.montgomeryized = True + return self + + def __add__(self, other): + if self.prime != other.prime: + raise ValueError("Cannot add two numbers in different Fields") + assert other.montgomeryized + result = self.value + other.value + if result > self.prime: + result -= self.prime + new_instance = self.copy() + new_instance.value = result + return new_instance + + def __sub__(self, other): + if self.prime != other.prime: + raise ValueError("Cannot subtract two numbers in different Fields") + assert other.montgomeryized + if self.value < other.value: + result = self.value + self.prime - other.value + else: + result = self.value - other.value + new_instance = self.copy() + new_instance.value = result + return new_instance + + def __mul__(self, other): + if self.prime != other.prime: + raise ValueError("Cannot multiply two numbers in different Fields") + assert other.montgomeryized + result = self.value * other.value + reduced_result = self.montgomery_reduction(result) + new_instance = self.copy() + new_instance.value = reduced_result + return new_instance + + def __truediv__(self, other): + if self.prime != other.prime: + raise ValueError("Cannot divide two numbers in different Fields") + + if other.montgomeryized: + other_value = self.de_montgomeryize(other.value) + else: + other_value = other.value + + inverse = other_value.__pow__(self.prime.value - 2, self.prime.value) + + inverse_bar = self.montgomeryize(inverse) + + result = self.value * inverse_bar + reduced_result = self.montgomery_reduction(result) + new_instance = self.copy() + new_instance.value = reduced_result + return new_instance + + def copy(self, value=None, transform=False): + """Create a deep copy of the current finite field element.""" + obj = copy.copy(self) + if value == None: + obj.value = HPint(self.value) + else: + if transform: + obj.value = self.montgomeryize(HPint(value)) + else: + obj.value = HPint(value) + return obj + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + +if __name__ == "__main__": + app.run(main) diff --git a/jaxite/jaxite_lib/zkp/hp_int.py b/jaxite/jaxite_lib/zkp/hp_int.py new file mode 100644 index 0000000..6c42acc --- /dev/null +++ b/jaxite/jaxite_lib/zkp/hp_int.py @@ -0,0 +1,271 @@ +"""High-precision integer class for ZKP.""" + +import math + +import gmpy2 + + +class HPint: + """An integer that supports arbitrary precision arithmetic.""" + + def __init__(self, value) -> None: + if isinstance(value, int): + self.value = value + elif isinstance(value, HPint): + self.value = value.value + else: + raise TypeError("Unsupported type for HPint initialization") + + def __add__(self, other): + if isinstance(other, (HPint, int)): + return HPint( + self.value + (other.value if isinstance(other, HPint) else other) + ) + return NotImplemented + + def __sub__(self, other): + if isinstance(other, (HPint, int)): + return HPint( + self.value - (other.value if isinstance(other, HPint) else other) + ) + return NotImplemented + + def __mul__(self, other): + if isinstance(other, (HPint, int)): + return HPint( + self.value * (other.value if isinstance(other, HPint) else other) + ) + return NotImplemented + + def __truediv__(self, other): + if isinstance(other, (HPint, int)): + if (other.value if isinstance(other, HPint) else other) == 0: + raise ZeroDivisionError("division by zero") + return HPint( + self.value // (other.value if isinstance(other, HPint) else other) + ) + return NotImplemented + + def __mod__(self, other): + if isinstance(other, (HPint, int)): + return HPint( + self.value % (other.value if isinstance(other, HPint) else other) + ) + return NotImplemented + + def __eq__(self, other): + if isinstance(other, (HPint, int)): + return self.value == (other.value if isinstance(other, HPint) else other) + return NotImplemented + + def __ne__(self, other): + if isinstance(other, (HPint, int)): + return self.value != (other.value if isinstance(other, HPint) else other) + return NotImplemented + + def __lt__(self, other): + if isinstance(other, (HPint, int)): + return self.value < (other.value if isinstance(other, HPint) else other) + return NotImplemented + + def __le__(self, other): + if isinstance(other, (HPint, int)): + return self.value <= (other.value if isinstance(other, HPint) else other) + return NotImplemented + + def __gt__(self, other): + if isinstance(other, (HPint, int)): + return self.value > (other.value if isinstance(other, HPint) else other) + return NotImplemented + + def __ge__(self, other): + if isinstance(other, (HPint, int)): + return self.value >= (other.value if isinstance(other, HPint) else other) + return NotImplemented + + def __pow__(self, exponent, modulus=None): + if isinstance(exponent, (HPint, int)): + if isinstance(exponent, HPint): + exponent = exponent.value + if modulus is None: + return HPint(pow(self.value, exponent)) + else: + return HPint(pow(self.value, exponent, modulus)) + else: + raise TypeError("Exponent must be an integer") + + def __lshift__(self, shift): + """Left shift operator (<<)""" + return HPint(self.value << shift) + + def __rshift__(self, shift): + """Right shift operator (>>)""" + if isinstance(shift, HPint): + shift = shift.value + return HPint(self.value >> shift) + + def __and__(self, other): + if isinstance(other, (HPint, int)): + return HPint( + self.value & (other.value if isinstance(other, HPint) else other) + ) + return NotImplemented + + def ceil_log2(self) -> float: + """Calculate the base-2 logarithm of the HPint.""" + if self.value <= 0: + raise ValueError("log2 is only defined for positive integers") + return HPint(math.ceil(math.log2(self.value))) + + def __int__(self): + return int(self.value) + + def __str__(self): + return str(self.value) + + def __repr__(self): + return f"HPint({self.value})" + + def hex_value_str(self) -> str: + return hex(self.value) + + +class GMPHPint(HPint): + + def __init__(self, value) -> None: + if isinstance(value, (int, gmpy2.mpz)): + self.value = gmpy2.mpz(value) + elif isinstance(value, HPint): + self.value = gmpy2.mpz(value.value) + else: + raise TypeError("Unsupported type for GMPHPint initialization") + + def __add__(self, other): + if isinstance(other, (GMPHPint, int)): + return GMPHPint( + self.value + + gmpy2.mpz(other.value if isinstance(other, GMPHPint) else other) + ) + return NotImplemented + + def __sub__(self, other): + if isinstance(other, (GMPHPint, int)): + return GMPHPint( + self.value + - gmpy2.mpz(other.value if isinstance(other, GMPHPint) else other) + ) + return NotImplemented + + def __mul__(self, other): + if isinstance(other, (GMPHPint, int)): + return GMPHPint( + self.value + * gmpy2.mpz(other.value if isinstance(other, GMPHPint) else other) + ) + return NotImplemented + + def __truediv__(self, other): + if isinstance(other, (GMPHPint, int)): + if gmpy2.mpz(other.value if isinstance(other, GMPHPint) else other) == 0: + raise ZeroDivisionError("division by zero") + return GMPHPint( + self.value + // gmpy2.mpz(other.value if isinstance(other, GMPHPint) else other) + ) + return NotImplemented + + def __mod__(self, other): + if isinstance(other, (GMPHPint, int)): + return GMPHPint( + self.value + % gmpy2.mpz(other.value if isinstance(other, GMPHPint) else other) + ) + return NotImplemented + + def __eq__(self, other): + if isinstance(other, (GMPHPint, int)): + return self.value == gmpy2.mpz( + other.value if isinstance(other, GMPHPint) else other + ) + return NotImplemented + + def __ne__(self, other): + if isinstance(other, (GMPHPint, int)): + return self.value != gmpy2.mpz( + other.value if isinstance(other, GMPHPint) else other + ) + return NotImplemented + + def __lt__(self, other): + if isinstance(other, (GMPHPint, int)): + return self.value < gmpy2.mpz( + other.value if isinstance(other, GMPHPint) else other + ) + return NotImplemented + + def __le__(self, other): + if isinstance(other, (GMPHPint, int)): + return self.value <= gmpy2.mpz( + other.value if isinstance(other, GMPHPint) else other + ) + return NotImplemented + + def __gt__(self, other): + if isinstance(other, (GMPHPint, int)): + return self.value > gmpy2.mpz( + other.value if isinstance(other, GMPHPint) else other + ) + return NotImplemented + + def __ge__(self, other): + if isinstance(other, (GMPHPint, int)): + return self.value >= gmpy2.mpz( + other.value if isinstance(other, GMPHPint) else other + ) + return NotImplemented + + def __pow__(self, exponent, modulus=None): + if isinstance(exponent, (GMPHPint, int, gmpy2.mpz)): + if isinstance(exponent, GMPHPint): + exponent = gmpy2.mpz(exponent.value) + if isinstance(modulus, GMPHPint): + modulus = gmpy2.mpz(modulus.value) + if modulus is None: + return GMPHPint(self.value**exponent) + else: + return GMPHPint(gmpy2.powmod(self.value, exponent, modulus)) + else: + print(type(exponent)) + raise TypeError("Exponent must be an integer") + + def __lshift__(self, shift): + """Left shift operator (<<)""" + if isinstance(shift, HPint): + shift = shift.value + return GMPHPint(self.value << shift) + + def __rshift__(self, shift): + """Right shift operator (>>)""" + if isinstance(shift, HPint): + shift = shift.value + return GMPHPint(self.value >> shift) + + def __and__(self, other): + if isinstance(other, (GMPHPint, int)): + return GMPHPint( + self.value + & gmpy2.mpz(other.value if isinstance(other, GMPHPint) else other) + ) + return NotImplemented + + def ceil_log2(self): + """Calculate the base-2 logarithm of the GMPHPint.""" + if self.value <= 0: + raise ValueError("log2 is only defined for positive integers") + return GMPHPint(gmpy2.ceil(gmpy2.log2(self.value))) + + def __str__(self): + return str(self.value) + + def __repr__(self): + return f"GMPHPint({self.value})" diff --git a/jaxite/jaxite_lib/zkp/pippenger.py b/jaxite/jaxite_lib/zkp/pippenger.py new file mode 100644 index 0000000..e812cfb --- /dev/null +++ b/jaxite/jaxite_lib/zkp/pippenger.py @@ -0,0 +1,144 @@ +"""pippenger Implementation for Multi-scalar-multiplication (MSM) -- ZKP""" + +from collections.abc import Sequence +from typing import List + +from absl import app +from config_file import USE_BARRETT, USE_GMP +from elliptic_curve import * + +if USE_GMP: + from big_integer import GMPBigInteger as BigInt +else: + from big_integer import BigInteger as BigInt +from msm_reader import MSMReader +import logging + + +class PippengerBucket: + + def __init__( + self, coordinate_system: EllipticCurveCoordinateSystem, slice_id=0 + ): + self.slice_id = slice_id + # self.empty = True + self.point = coordinate_system.generate_point(None, True) + + def add_point(self, point: ECCPoint): + self.point += point + + def get_point(self): + return self.point + + +class PippengerWindow: + + def __init__( + self, + coordinate_system: EllipticCurveCoordinateSystem, + slice_length, + window_id=0, + ) -> None: + self.coordinate_system = coordinate_system + self.slice_length = slice_length + self.bucket_num = 2**slice_length + self.window_id = window_id + self.buckets: List[PippengerBucket] = [ + PippengerBucket(coordinate_system, i) for i in range(self.bucket_num) + ] + self.point = coordinate_system.generate_point(None, True) + + def bucket_reduction(self) -> ECCPoint: + window_sum: ECCPoint = self.coordinate_system.generate_point(None, True) + temp_sum: ECCPoint = self.coordinate_system.generate_point(None, True) + for i in range(self.bucket_num - 1, 0, -1): + temp_sum += self.buckets[i].get_point() + window_sum += temp_sum + self.point = window_sum + return window_sum + + def get_point(self): + return self.point + + def __getitem__(self, index): + return self.buckets[index] + + def __setitem__(self, index, value): + self.buckets[index] = value + + +class PippengerMSM: + + def __init__( + self, + coordinate_system: EllipticCurveCoordinateSystem, + slice_length, + window_num, + ) -> None: + self.coordinate_system: EllipticCurveCoordinateSystem = coordinate_system + self.slice_length = slice_length + self.window_num = window_num + self.bucket_num_in_window = 2**self.slice_length + self.slice_mask = self.bucket_num_in_window - 1 + self.bucket_num_all = self.bucket_num_in_window * self.window_num + self.windows: List[PippengerWindow] = [ + PippengerWindow(coordinate_system, self.slice_length, i) + for i in range(self.window_num) + ] + self.result = None + + def msm_run(self, msm_reader: MSMReader): + scalar = msm_reader.get_next_scalar() + coordinates = msm_reader.get_next_base() + + while scalar != None and coordinates != None: + logging.debug(f"scalar: {hex(scalar)}") + point = self.coordinate_system.generate_point(coordinates) + self.msm_horizental_stream(scalar, point) + scalar = msm_reader.get_next_scalar() + coordinates = msm_reader.get_next_base() + self.bucket_reduction() + self.result = self.window_merge() + self.result = self.result.coordinate_system.convert_to_affine() + return self.result + + def msm_horizental_stream(self, scalar: int, point: ECCPoint): + window_id = 0 + current_scalar = scalar + while current_scalar != 0: + logging.debug(f"window_id: {window_id}") + bucket_id = current_scalar & self.slice_mask + self.bucket_accumulation(window_id, bucket_id, point) + logging.debug( + f"bucket: {hex(bucket_id)}," + f" {self.windows[window_id][bucket_id].get_point()}" + ) + current_scalar = current_scalar >> self.slice_length + window_id += 1 + + def bucket_accumulation(self, window_id, bucket_id, point: ECCPoint): + self.windows[window_id][bucket_id].add_point(point) + + def bucket_reduction(self): + for i in range(0, self.window_num): + logging.debug(f"window {i} BR") + self.windows[i].bucket_reduction() + logging.debug(f"{self.windows[i].get_point()}") + + def window_merge(self): + merged = self.coordinate_system.generate_point(None, True) + for i in range(self.window_num - 1, -1, -1): + logging.debug(f"window {i} WM") + point = self.windows[i].get_point() + merged = (merged << self.slice_length) + point + + return merged + + +def main(argv: Sequence[str]) -> None: + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + + +if __name__ == "__main__": + app.run(main)