diff --git a/README.md b/README.md index 41ac58e..996228b 100644 --- a/README.md +++ b/README.md @@ -54,6 +54,15 @@ alpha = b"input data" ad = b"additional data" ``` +Deterministic key generation from a seed (matching ark-vrf): + +```python +from dot_ring import Bandersnatch, secret_from_seed + +seed = (0).to_bytes(32, "little") +public_key, secret_scalar = secret_from_seed(seed, Bandersnatch) +``` + ### IETF VRF ```python diff --git a/dot_ring/__init__.py b/dot_ring/__init__.py index ba40c43..74e07b5 100644 --- a/dot_ring/__init__.py +++ b/dot_ring/__init__.py @@ -48,6 +48,7 @@ from dot_ring.curve.specs.p384 import P384_NU, P384_RO from dot_ring.curve.specs.p521 import P521_NU, P521_RO from dot_ring.curve.specs.secp256k1 import Secp256k1_NU, Secp256k1_RO +from dot_ring.keygen import secret_from_seed from dot_ring.vrf.ietf.ietf import IETF_VRF from dot_ring.vrf.pedersen.pedersen import PedersenVRF from dot_ring.vrf.ring.ring_vrf import RingVRF @@ -114,4 +115,5 @@ "BLS12_381_G2", "BLS12_381_G2_RO", "BLS12_381_G2_NU", + "secret_from_seed", ] diff --git a/dot_ring/keygen.py b/dot_ring/keygen.py new file mode 100644 index 0000000..3f96c58 --- /dev/null +++ b/dot_ring/keygen.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +from typing import Literal, cast + +from dot_ring.curve.curve import CurveVariant +from dot_ring.curve.specs.bandersnatch import Bandersnatch +from dot_ring.ring_proof.helpers import Helpers +from dot_ring.vrf.ietf.ietf import IETF_VRF + + +def _hash_seed(curve: CurveVariant, seed: bytes, counter: int) -> bytes: + hasher = curve.curve.H_A() + hasher.update(seed) + if counter: + hasher.update(bytes([counter])) + if curve.curve._uses_xof(): + length = curve.curve._default_xof_len() + return cast(bytes, hasher.digest(length)) + return cast(bytes, hasher.digest()) + + +def secret_from_seed(seed: bytes, curve: CurveVariant = Bandersnatch) -> tuple[bytes, bytes]: + """ + Deterministically derive a secret scalar and public key from a seed. + + Mirrors ark-vrf's Secret::from_seed: + - Hash seed with curve's hash function + - Interpret hash output as little-endian integer, reduce modulo curve order + - If zero, append a counter byte and rehash + + Returns: + (public_key_bytes, secret_scalar_bytes) + """ + if not isinstance(seed, (bytes, bytearray)): + raise TypeError("seed must be bytes") + if not isinstance(curve, CurveVariant): + raise TypeError("curve must be a CurveVariant") + + seed_bytes = bytes(seed) + order = curve.curve.ORDER + scalar_len = (order.bit_length() + 7) // 8 + + counter = 0 + while True: + digest = _hash_seed(curve, seed_bytes, counter) + sk_int = int.from_bytes(digest, "little") % order + if sk_int != 0: + break + counter = (counter + 1) & 0xFF + if counter == 0: + raise RuntimeError("failed to derive non-zero secret scalar") + + sk_bytes = Helpers.int_to_str( + sk_int, + cast(Literal["little", "big"], curve.curve.ENDIAN), + scalar_len, + ) + pk_bytes = IETF_VRF[curve].get_public_key(sk_bytes) + return pk_bytes, sk_bytes diff --git a/tests/test_keygen.py b/tests/test_keygen.py new file mode 100644 index 0000000..8ff47b7 --- /dev/null +++ b/tests/test_keygen.py @@ -0,0 +1,38 @@ +import pytest + +from dot_ring import Bandersnatch, secret_from_seed +from dot_ring.vrf.ietf.ietf import IETF_VRF + + +@pytest.mark.parametrize( + "seed, expected_pk, expected_sk", + [ + ( + 0, + "5e465beb01dbafe160ce8216047f2155dd0569f058afd52dcea601025a8d161d", + "51c1537c18eea5c5969cb2ae45c1224cc245de5c5b8e6e25f48fb99f2786ee05", + ), + ( + 100, + "caf7eb70d84e27511179c83ac352f8d3e9b9661371520c54c9ad56781f374a32", + "ad20931d3f8cee57206bc1c3e5dad50677afb9fb712217c6a980867d3a56451c", + ), + ], +) +def test_secret_from_seed_vectors(seed: int, expected_pk: str, expected_sk: str) -> None: + pk, sk = secret_from_seed(seed.to_bytes(32, "little"), Bandersnatch) + assert pk.hex() == expected_pk + assert sk.hex() == expected_sk + + +def test_secret_from_seed_public_key_roundtrip() -> None: + seed = (2**32 - 1).to_bytes(32, "little") + pk, sk = secret_from_seed(seed, Bandersnatch) + assert pk == IETF_VRF[Bandersnatch].get_public_key(sk) + + +def test_secret_from_seed_type_errors() -> None: + with pytest.raises(TypeError): + secret_from_seed("not-bytes") # type: ignore[arg-type] + with pytest.raises(TypeError): + secret_from_seed(b"\x00" * 32, "not-a-curve") # type: ignore[arg-type]